Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add beam search decoding util #237

Merged
merged 16 commits into from
Jun 30, 2022

Conversation

jessechancy
Copy link
Contributor

No description provided.

@jessechancy jessechancy changed the title beam search Add beam search decoding util Jun 28, 2022
prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
append generated tokens. The initial beam for beam search.
max_length: int. The max length of generated text.
beam_width: int. The number of beams that should be kept at each
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why width? what about just num_beams?

If inputs are batched, inputs should be `tf.RaggedTensor`s with shape
`[batch_size, None]` and will be packed and converted to a dense tensor with
shape `[batch_size, sequence_length]`.
If inputs are batched, inputs should either be `tf.RaggedTensor`s with shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't part of this PR I think!

respective beams, before beginning the next iteration.

Args:
token_probability_fn: a callable, which takes in input_sequence
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should document the input shape and output shape expected by this function, as it is a non-standard batch size.

"""
Text generation utility based on beam search algorithm.

At each time-step, beam search keeps the top `beam_width` beams (sequences),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's include the information that the top num_beams beams means sequences of highest num_beams probability.

i = length
while i < max_length:
beam_size = beams.shape[1]
reshaped_beam = tf.reshape(beams, [batch_size * beam_size, i])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per our offline discussion, let's retain the loop over beams.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I was saying we could either document this expectation or retain the loop.

But either way I think we need to fix something here. We currently document the fn input as shape [batch_size, length] but it's currently [batch_size * beam_size, length] right?

Fine with either brining back the loop or just documenting the current behavior correclty.

from keras_nlp.utils.text_generation import greedy_search
from keras_nlp.utils.text_generation import random_search
from keras_nlp.utils.text_generation import top_k_search
from keras_nlp.utils.text_generation import top_p_search
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change already got in. make sure to rebase

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! Thank you!

@mattdangerw mattdangerw merged commit f9abc8f into keras-team:master Jun 30, 2022
@kevinerazoBSD
Copy link

Hi, I ran the "English-to-Spanish translation with KerasNLP" Colab notebook by Abheesht Sharma and changed the decoding util from "top_p" to "beam" and got all the dimension mismatch errors alluded to above. Is there a tidy example of how the token_probability_fn and beam_search decoder should be set up in this case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants