-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Greedy text generation util #154
Greedy text generation util #154
Conversation
3347fd1
to
86a6dd3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Left some initial comments mainly on design front.
e0b878a
to
f2b503a
Compare
keras_nlp/utils/text_generation.py
Outdated
next_token, end_token_received, end_token_id | ||
) | ||
# Append the next token to current sequence. | ||
input_ids = tf.concat([input_ids, next_token[:, tf.newaxis]], axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about testing with XLA GPU at least if you cannot test on TPU :)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point.
- I am not sure how to add github GPU/TPU test. It also reminds that we might want to test distributed training (not for this utility but other modules) in the future. We will check more on it.
- IIUC XLA on GPU should be turned on manually? It requires
@tf.function(jit_compile=True)
. For this specific utility, we probably don't need XLA testing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if you realize that the tf.concat will yield dynamic shape and if you know xla requirements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting - so tf.concat is something that cannot work on TPU and GPU/CPU (XLA)?
Our current plan is not to wrap this utility with tf.function()
, but users can choose to wrap next_token_fn
by tf.function()
, as next_token_fn
takes most of the computation time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the performance will be bad or you only target this as a demo util function? Is it the reason you use python while loop but not tf.while_loop for the decoding loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, yes, this util is mainly useful for a demo like colab guide, so performance is not the focus now.
I am actually curious, how does model garden handle the token concatenation? Are you using a fixed size tensor, and change the value at each iteration? I am not sure how much performance diff it would have, since the bottleneck is mostly at model calling when I did benchmark on colab.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to allocate buffer to the max sequence length to decode and use in-place update.
The padded_decode path is dedicated to XLA: https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/beam_search.py#L109
BTW, I read fairseq code before and I guess they use something similar for GPU performance optimization as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Hongkun! I will open an issue to track the refactoring,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this isn't just not XLA compilable, it is not at all tf.function compilable right now.
We definitely should look at making it so, both for usability (use this inside a keras model) and performance for any sort of bulk inference like job.
For now, this has been more to get the API signature how we want it. I wouldn't say this should always be only a demo util function, that's just where we are at today.
from keras_nlp.utils.text_generation import generate_text_greedy | ||
|
||
|
||
class TextGenerationTest(tf.test.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting, do you need tf.test.main()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems no - our test is based on pytest, which automatically captures the test cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some more comments. But the big one, I couldn't figure out a way to actually do text generation with a plain text seed text and our tokenizers. Take a look:
I think if we want to support batches of input, we need to support batches of raggeds, otherwise I'm at a loss as to how this could be used for batched text generation where the seed input is plain text.
keras_nlp/utils/text_generation.py
Outdated
``` | ||
|
||
""" | ||
if 0 in input_ids.shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks at this more, and I really think we should not keep this check. Seems fully possible to support a empty input tensor input_ids=[]
or a batch shape with input_ids=tf.zeros([bs, 0])
.
The former would be really useful in guides when doing something really simple. There will not always be start tokens as a convention, see the main tf text generation guide as an example https://www.tensorflow.org/text/tutorials/text_generation.
We should also call tf.convert_to_tensor
on non tensor input, so we can do things like input_ids=[]
or input_ids=[start_id]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing I feel confused is if there is no prompt at all, how shall we generate the next token? Maybe add an extra argument start_token
?
keras_nlp/utils/text_generation.py
Outdated
Args: | ||
token_probability_fn: a callable, which takes in input_sequence | ||
and output the probability distribution of the next token. | ||
input_ids: a list, the initial tokens to append generated tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prompt
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg!
sequence. If None, every sequence is generated up to `max_length`. | ||
|
||
Returns: | ||
A 1D int Tensor, or 2D int RaggedTensor representing the generated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most likely it should operate on a single sequence, never a batch. The user could map
it to a batch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Processing at single sequence level would make the code simpler, but the execution would be slowed down: there are lots of model calling inside this utility, so without parallelism it could take much longer.
keras_nlp/utils/text_generation.py
Outdated
return filtered_next_token, end_token_received | ||
|
||
|
||
def generate_text_greedy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This operates in integer token space, not text space. It should not be named "generate text" (also "generate text greedy" is not a sentence, unlike "greedy text generation" or "generate text greedily").
Maybe this should just be called greedy_search
(vs beam_search
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg!
and output the probability distribution of the next token. | ||
input_ids: a list, the initial tokens to append generated tokens. | ||
max_length: int. The max length of generated text. | ||
end_token_id: int, defaults to None. The token marking the end of the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
end_token
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We haven't really formalized this (I should add it to our API design guide). But I think if this operates strictly on ints end_token_id
is the correct name. So people don't pass "<eos>"
by mistake.
If we switched this to strings end_token
would be the correct name. But it sounds like we will keep this in int space for now, and just rename to beam_search
and greedy_search
, which I think is a good call.
29dc810
to
9cc5a60
Compare
9cc5a60
to
85de395
Compare
Colab example for text generation: https://colab.research.google.com/gist/chenmoneygithub/002633be87e440248870b43089f47530/kerasnlp-text-generation-model-util.ipynb dataset: mini Shakespeare |
Colab example for machine translation: dataset: http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip The accuracy is bad, but the main goal of this colab is to show how to use the text generation util. |
Thanks @chenmoneygithub!
I'll play around with compiling this into functions and models. Curious what our current standing is there. |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few more comments. This is looking closer. We can simplify the end sequence token logic, and we need to do something user friendly in the compiled function case.
keras_nlp/utils/text_generation.py
Outdated
prompt = tf.concat([prompt, next_token[:, tf.newaxis]], axis=-1) | ||
return get_subsequent_tokens(prompt, end_token_id_received) | ||
|
||
generated_sequence = get_subsequent_tokens(prompt, end_token_id_received) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the end token logic could be much more readable (and probably more efficient) as a post process. tf.sequence_mask
could help.
Something like
if end_token_id is not None:
# Find index of first end_token_id.
end_indices = tf.math.argmax(outputs == end_token_id, -1)
# Use max_length if none found.
end_indices = tf.where(end_indices == 0, max_length, end_indices)
# Build a mask including end_token and replace overflow with pad_token_id.
valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length)
outputs = tf.where(valid_indices, outputs, pad_token_id)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point! done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just minor edits now.
keras_nlp/utils/text_generation.py
Outdated
prompt = tf.random.uniform(shape=[5, 5], maxval=VOCAB_SIZE, dtype=tf.int64) | ||
|
||
# Print the generated sequence (token ids). | ||
keras_nlp.greedy_search( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keras_nlp.utils.greedy_search right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch!
Can prompt be a prefix list of tokens? context: prefix-LM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving last few comments.
end_token_id=2, | ||
pad_token_id=0, | ||
) | ||
self.assertAllEqual(outputs[0, 2:], tf.repeat(3, max_length - 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just test the whole output here, that will be much more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg!
e43f4da
to
e61b8ac
Compare
#108