I'm trying to run a RNN beam search on a tf.keras.Model in a vectorized way to have it work completely on GPU. However, despite having everything as tf.function, as vectorized as I can make it, it runs exactly the same speed with or without a GPU. Attached is a minimal example with a fake model. In reality, for n=32, k=32, steps=128 which is what I would want to work with, this takes 20s (per n=32 samples) to decode, both on CPU and on GPU!
I must be missing something. When I train the model, on GPU a training iteration (128 steps) with batch size 512 takes 100ms, and on CPU a training iteration with batch size 32 takes 1 sec. The GPU isn't saturated at batch size 512. I get that I have overhead from doing the steps individually and doing a blocking operation per step, but in terms of computation my overhead is negligible compared to the rest of the model.
I also get that using a tf.keras.Model in this way is probably not ideal, but is there another way to wire output tensors via a function back to the input tensors, and particularly also rewire the states?
Full working example: https://gist.github.com/meowcat/e3eaa4b8543a7c8444f4a74a9074b9ae
@tf.function def decode_beam(states_init, scores_init, y_init, steps, k, n): states = states_init scores = scores_init xstep = embed_y_to_x(y_init) # Keep the results in TensorArrays y_chain = tf.TensorArray(dtype="int32", size=steps) sequences_chain = tf.TensorArray(dtype="int32", size=steps) scores_chain = tf.TensorArray(dtype="float32", size=steps) for i in range(steps): # model_decode is the trained model with 3.5 million trainable params. # Run a single step of the RNN model. y, states = model_decode([xstep, states]) # Add scores of step n to previous scores # (I left out the sequence end killer for this demo) scores_y = tf.expand_dims(tf.reshape(scores, y.shape[:-1]), 2) + tm.log(y) # Reshape into (n,k,tokens) and find the best k sequences to continue for each of n candidates scores_y = tf.reshape(scores_y, [n, -1]) top_k = tm.top_k(scores_y, k, sorted=False) # Transform the indices. I was using tf.unravel_index but # `tf.debugging.set_log_device_placement(True)` indicated that this would be placed on the CPU # thus I rewrote it top_k_index = tf.reshape( top_k[1] + tf.reshape(tf.range(n), (-1, 1)) * scores_y.shape[1], [-1]) ysequence = top_k_index // y.shape[2] ymax = top_k_index % y.shape[2] # this gives us two (n*k,) tensors with parent sequence (ysequence) # and chosen character (ymax) per sequence. # For continuation, pick the states, and "return" the scores states = tf.gather(states, ysequence) scores = tf.reshape(top_k[0], [-1]) # Write the results into the TensorArrays, # and embed for the next step xstep = embed_y_to_x(ymax) y_chain = y_chain.write(i, ymax) sequences_chain = sequences_chain.write(i, ysequence) scores_chain = scores_chain.write(i, scores) # Done: Stack up the results and return them sequences_final = sequences_chain.stack() y_final = y_chain.stack() scores_final = scores_chain.stack() return sequences_final, y_final, scores_final