-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added top k search util #232
added top k search util #232
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just a few comments
keras_nlp/utils/text_generation.py
Outdated
token_probability_fn, | ||
prompt, | ||
max_length, | ||
k=10, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 is fine, but how did we come up with this number?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this number should be related to the vocab size, so maybe let's make it a required arg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean vocab_size should be another argument? Or k should be a required arg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Talked on chat. Let's make k
required (no default). No need to take in vocab_size, that can continue to be inferred.
keras_nlp/utils/text_generation.py
Outdated
pad_token_id=0, | ||
): | ||
""" | ||
Text generation utility based on top k sampling. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
top-k
here and elsewhere in docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited
input_is_1d = prompt.shape.rank == 1 | ||
if input_is_1d: | ||
prompt = prompt[tf.newaxis, :] | ||
i = prompt.shape[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are doing both shape[1]
and shape[-1]
to read the last axis? This reads confusing, let's choose one and be consistent. Unless there is a case where these are different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited to both use shape[1]. The output should strictly be [batch_size, vocab_size] in the pred and [batch_size, length] in prompt
keras_nlp/utils/text_generation.py
Outdated
|
||
Args: | ||
token_probability_fn: a callable, which takes in input_sequence | ||
and output the probability distribution of the next token. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here and elsewhere in this file, we probably should mention that this function should return the unnormalized logits and not softmax probabilities right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually one confusing part where we should make a decision - currently there is not such enforcement on the return type, shall we add it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type should be probabilities here for this to work, if not I would need to add a softmax over it
keras_nlp/utils/text_generation.py
Outdated
# If k is greater than the vocabulary size, use the entire vocabulary. | ||
k = min(k, pred.shape[-1]) | ||
# Filter out top k tokens. | ||
sorted_pred, sorted_indices = tf.math.top_k(pred, k=k, sorted=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need sorted here? tf.random.categorical doesn't need a sort order. We just need to make sure we gather than correct indices from the top_k call which you are already doing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep doesn't need sorted, edited
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mainly looks good!
keras_nlp/utils/text_generation.py
Outdated
token_probability_fn, | ||
prompt, | ||
max_length, | ||
k=10, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this number should be related to the vocab size, so maybe let's make it a required arg?
keras_nlp/utils/text_generation.py
Outdated
|
||
Args: | ||
token_probability_fn: a callable, which takes in input_sequence | ||
and output the probability distribution of the next token. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually one confusing part where we should make a decision - currently there is not such enforcement on the return type, shall we add it?
outputs = top_k_search( | ||
token_probability_fn, inputs, k=2, max_length=max_length, seed=42 | ||
) | ||
# Random sampling result with seed 42 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
top-k search result
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edited
rtol=0.2, | ||
) | ||
|
||
def test_assert_top_k_generation_is_correct(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test is to assert only top-k tokens can appear, but the name does not suggest so. Let's rename to something like test_only_choose_from_top_k_tokens
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple comments!
keras_nlp/utils/text_generation.py
Outdated
|
||
i = prompt.shape[1] | ||
while i < max_length: | ||
# If the prompt has reached our desired length, exit while loop. | ||
pred = token_probability_fn(prompt) | ||
pred = token_fn(prompt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to wrap in another function? This whole change is a little hard to read. I would find it simpler to just add a block here.
pred = token_probability_fn(prompt)
if from_logits:
pred = tf.keras.activations.softmax(pred)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, edited
keras_nlp/utils/text_generation.py
Outdated
def token_probability_fn(inputs): | ||
return model(inputs)[:, -1, :] | ||
|
||
prompt = tf.random.uniform(shape=[5, 5], maxval=VOCAB_SIZE, dtype=tf.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is kind of a weird prompt to show. Who would want to generate sequences after 5 totally random tokens.
Maybe we should do something like
BATCH_SIZE = 8
VOCAB_SIZE = 10
FEATURE_SIZE = 16
START_ID=1
END_ID=2
...
prompt = tf.fill((BATCH_SIZE, 1), START_ID)
keras_nlp.utils.top_k_search(
token_probability_fn,
prompt,
k=10,
max_length=10,
end_token_id=END_ID)
We may want to update other examples if they have the same problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also to be clear the ellipsis is me being lazy, not suggesting we put that in the docstring :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Last few nits!
keras_nlp/utils/text_generation.py
Outdated
|
||
# Print the generated sequence (token ids). | ||
keras_nlp.utils.greedy_search( | ||
token_probability_fn, | ||
prompt, | ||
max_length=10, | ||
end_token_id=0,) | ||
end_token_id=END_ID |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: trailing comma
keras_nlp/utils/text_generation.py
Outdated
token_probability_fn, | ||
prompt, | ||
max_length=10, | ||
end_token_id=0,) | ||
end_token_id=END_ID |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trailing comma
keras_nlp/utils/text_generation.py
Outdated
prompt, | ||
max_length=10, | ||
k=4, | ||
end_token_id=END_ID |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trailing comma
keras_nlp/utils/text_generation.py
Outdated
prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to | ||
append generated tokens. | ||
max_length: int. The max length of generated text. | ||
from_logits: bool. Indicates whether `token_probability_fn` outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You document this but didn't actually add it? Should we just leave it off as for greedy search it is a no-op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I added it at first and realised greedy search isn't affected, will remove it from docstring
Follow up from random search