Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop on multiple end tokens #1518

Merged
merged 19 commits into from
Mar 27, 2024
Merged

Conversation

grasskin
Copy link
Member

No description provided.

@grasskin grasskin marked this pull request as draft March 21, 2024 16:39
@mattdangerw
Copy link
Member

mattdangerw commented Mar 22, 2024

I think we should think first of the overall API experience we want here. What about something like this?

# Default. Stop at gemma_lm.preprocessor.tokenizer.end_token_id, or error if
# self.preprocessor is none.
gemma_lm.generate(
    prompt,
    max_length=64,
    stop_token_ids="auto",
)
# Don't stop till max_length!
gemma_lm.generate(
    prompt,
    max_length=64,
    stop_token_ids=None,
)
# Custom. Provide multiple stop tokens, in this case we also stop on the literal word stop.
gemma_lm.generate(
    prompt,
    max_length=64,
    stop_token_ids=[tokenizer.end_token_id, tokenizer.token_to_id("stop")],
)

I don't really like setting this on the tokenizer. Tokenizer special token ids are not generally set by a user. Every tokenizer.xx_token_id is just a single integer right now. Preprocessing can also be detached from the task, in which case, the CausalLM does not even have a tokenizer to query.

@mattdangerw
Copy link
Member

If we go we with above proposal, we should update the sampler API to also take in stop_tokens_ids, but it does not need the "auto" value.

We can do this with Gemma at first, but we should eventually update all models to have a consistent API surface.

We also might want to refactor a helper into tensor_utils.py. Would help readability:

def any_equal(inputs, values):
    """Return a mask that is True anywhere `inputs` has a value in `values`."""
    output = ops.equal(inputs, values[0])
    for value in values[1:]:
        output = ops.logical_or(outputs, value)
    return output

@grasskin
Copy link
Member Author

We're currently defaulting to a mix of if preprocessor is specified use "auto" otherwise go with None. Should we error out if no preprocessor is specified or just switch to None?

@grasskin
Copy link
Member Author

Discussed offline - we're going to do a full refactor and go with the more sane choice of erroring if "auto" is specified with no preprocessor. API will be more consistent for multitoken requirements.

@grasskin grasskin requested a review from mattdangerw March 25, 2024 18:15
@grasskin
Copy link
Member Author

@mattdangerw this works for Gemma, if overall method lgty we can replicate in other models. Given that we're switching to stop_token_ids we very explicitly require iterables instead of single int, fixed sampling tests already.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good generally! Dropped a few comments. Can probably start replicating to other models.

keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py Outdated Show resolved Hide resolved
keras_nlp/models/generative_task.py Outdated Show resolved Hide resolved
keras_nlp/utils/tensor_utils.py Show resolved Hide resolved
keras_nlp/utils/tensor_utils.py Outdated Show resolved Hide resolved
@@ -81,7 +81,7 @@ def wrapped_generate_function(
import jax

@jax.jit
def compiled_generate_function(inputs, end_token_id, state):
def compiled_generate_function(inputs, stop_token_ids, state):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fairly technical jax question, but what happens if we pass lists of differing lengths here? Does jax automatically recompile? We could possibly pass static_argnames here, but not sure the right approach.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every element in the list will be treated as an individual Jax variable and will trigger recompilation. I think static_argnames is the right call here just checking that they work well with lists

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have to force this to be a tuple to guarantee that stop tokens are immutable

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Few last comments.

keras_nlp/models/generative_task.py Show resolved Hide resolved
keras_nlp/models/generative_task.py Outdated Show resolved Hide resolved
keras_nlp/models/generative_task.py Show resolved Hide resolved
keras_nlp/models/generative_task.py Outdated Show resolved Hide resolved
keras_nlp/models/gemma/gemma_causal_lm.py Show resolved Hide resolved
@mattdangerw mattdangerw marked this pull request as ready for review March 27, 2024 16:44
@mattdangerw mattdangerw merged commit 4d1c883 into keras-team:master Mar 27, 2024
9 of 10 checks passed
abuelnasr0 pushed a commit to abuelnasr0/keras-nlp that referenced this pull request Apr 2, 2024
* Add multitoken stopping

* Update gemma_causal_lm.py

* Add further multitoken support

* Formatting

* Revert tokenizer changes

* Move multi token stop to generative task

* None check

* None check

* Error message

* Add stop_token_ids

* Util testing

* Fix sampler tests

* All multitoken stop to all models

* Sampler multi token

* Formatting

* Tuple required

* Tuple docstring

* Pytorch GPU fix

* Numpy fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants