-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add beam search decoding util #237
Add beam search decoding util #237
Conversation
keras_nlp/utils/text_generation.py
Outdated
prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to | ||
append generated tokens. The initial beam for beam search. | ||
max_length: int. The max length of generated text. | ||
beam_width: int. The number of beams that should be kept at each |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why width? what about just num_beams
?
If inputs are batched, inputs should be `tf.RaggedTensor`s with shape | ||
`[batch_size, None]` and will be packed and converted to a dense tensor with | ||
shape `[batch_size, sequence_length]`. | ||
If inputs are batched, inputs should either be `tf.RaggedTensor`s with shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't part of this PR I think!
respective beams, before beginning the next iteration. | ||
|
||
Args: | ||
token_probability_fn: a callable, which takes in input_sequence |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should document the input shape and output shape expected by this function, as it is a non-standard batch size.
…ras-nlp into jesse-beam-search
keras_nlp/utils/text_generation.py
Outdated
""" | ||
Text generation utility based on beam search algorithm. | ||
|
||
At each time-step, beam search keeps the top `beam_width` beams (sequences), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's include the information that the top num_beams
beams means sequences of highest num_beams
probability.
keras_nlp/utils/text_generation.py
Outdated
i = length | ||
while i < max_length: | ||
beam_size = beams.shape[1] | ||
reshaped_beam = tf.reshape(beams, [batch_size * beam_size, i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per our offline discussion, let's retain the loop over beams.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I was saying we could either document this expectation or retain the loop.
But either way I think we need to fix something here. We currently document the fn input as shape [batch_size, length]
but it's currently [batch_size * beam_size, length]
right?
Fine with either brining back the loop or just documenting the current behavior correclty.
keras_nlp/utils/__init__.py
Outdated
from keras_nlp.utils.text_generation import greedy_search | ||
from keras_nlp.utils.text_generation import random_search | ||
from keras_nlp.utils.text_generation import top_k_search | ||
from keras_nlp.utils.text_generation import top_p_search |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this change already got in. make sure to rebase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me! Thank you!
Hi, I ran the "English-to-Spanish translation with KerasNLP" Colab notebook by Abheesht Sharma and changed the decoding util from "top_p" to "beam" and got all the dimension mismatch errors alluded to above. Is there a tidy example of how the token_probability_fn and beam_search decoder should be set up in this case? |
No description provided.