-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stop on multiple end tokens #1518
Conversation
I think we should think first of the overall API experience we want here. What about something like this? # Default. Stop at gemma_lm.preprocessor.tokenizer.end_token_id, or error if
# self.preprocessor is none.
gemma_lm.generate(
prompt,
max_length=64,
stop_token_ids="auto",
)
# Don't stop till max_length!
gemma_lm.generate(
prompt,
max_length=64,
stop_token_ids=None,
)
# Custom. Provide multiple stop tokens, in this case we also stop on the literal word stop.
gemma_lm.generate(
prompt,
max_length=64,
stop_token_ids=[tokenizer.end_token_id, tokenizer.token_to_id("stop")],
) I don't really like setting this on the tokenizer. Tokenizer special token ids are not generally set by a user. Every |
If we go we with above proposal, we should update the sampler API to also take in We can do this with Gemma at first, but we should eventually update all models to have a consistent API surface. We also might want to refactor a helper into def any_equal(inputs, values):
"""Return a mask that is True anywhere `inputs` has a value in `values`."""
output = ops.equal(inputs, values[0])
for value in values[1:]:
output = ops.logical_or(outputs, value)
return output |
We're currently defaulting to a mix of if preprocessor is specified use "auto" otherwise go with None. Should we error out if no preprocessor is specified or just switch to None? |
Discussed offline - we're going to do a full refactor and go with the more sane choice of erroring if "auto" is specified with no preprocessor. API will be more consistent for multitoken requirements. |
@mattdangerw this works for Gemma, if overall method lgty we can replicate in other models. Given that we're switching to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good generally! Dropped a few comments. Can probably start replicating to other models.
@@ -81,7 +81,7 @@ def wrapped_generate_function( | |||
import jax | |||
|
|||
@jax.jit | |||
def compiled_generate_function(inputs, end_token_id, state): | |||
def compiled_generate_function(inputs, stop_token_ids, state): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fairly technical jax question, but what happens if we pass lists of differing lengths here? Does jax automatically recompile? We could possibly pass static_argnames
here, but not sure the right approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every element in the list will be treated as an individual Jax variable and will trigger recompilation. I think static_argnames
is the right call here just checking that they work well with lists
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will have to force this to be a tuple to guarantee that stop tokens are immutable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Few last comments.
* Add multitoken stopping * Update gemma_causal_lm.py * Add further multitoken support * Formatting * Revert tokenizer changes * Move multi token stop to generative task * None check * None check * Error message * Add stop_token_ids * Util testing * Fix sampler tests * All multitoken stop to all models * Sampler multi token * Formatting * Tuple required * Tuple docstring * Pytorch GPU fix * Numpy fix
No description provided.