-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a Causal LM model for Mistral #1429
Conversation
JAX complains about dynamic slicing when compiled with XLA. This is unavoidable since, at runtime, the slice of the current key/value array to use for that iteration is determined by `cache_update_index` which is itself a JAX `TracedArray`. Any workaround would lead to using dynamic shapes at some point. Hence, I had to remove this and instead use vanilla caching for now. For some reason, TensorFlow doesn't complain with XLA. I think this might be because TensorFlow is as stringent about statis shapes as JAX. In any case, adding sliding window attention that is XLA compatible is a story for the future.
@mattdangerw Tested it with the 7B preset. The outputs of both the backbone and the generator match up. This is ready from my side! I can share the preset with you once this is merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Just a couple minor comments.
**kwargs, | ||
) | ||
|
||
# Default compilation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nit, was styling this as it's own heading. https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gpt2/gpt2_causal_lm.py#L172
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
hidden_states = backbone(inputs) | ||
outputs = backbone.token_embedding(hidden_states, reverse=True) | ||
|
||
# Instantiate the Functional API Model constructor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete this comment, and newline above, the header gives enough clues what this is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
padding_mask = padding_mask.astype("bool") | ||
# Strip any special tokens during detokenization (e.g. the start and | ||
# end markers). In the future we could make this configurable. | ||
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we want to also remove the start_token_id (as it is a different token). Just a line like this below with start_token_id
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
prompt. The generation strategy used is controlled by an additional | ||
`sampler` argument on `compile()`. You can recompile the model with | ||
different `keras_nlp.samplers` objects to control the generation. By | ||
default, `"top_k"` sampling will be used. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a good default? For these newer larger models, we might just want to default to greedy if performance is good.
Maybe quick check, does it tend to get stuck in loops with greedy sampling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the output with "greedy"
sampler:
>>> output = generator.generate("What is Keras?", max_length=100)
2024-02-13 06:42:36.336579: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 20952865944 exceeds 10% of free system memory.
>>> print(output)
What is Keras?
Keras is a high-level neural network API, written in Python and capable of running on top of TensorFlow, CNTK or Theano. It was designed with a focus on usability, modularity and extensibility.
Keras is a high-level neural network API, written in Python and capable of running on top of TensorFlow, CNTK or Theano. It was designed with a focus on usability, mod
Noticed the same output with HF. I guess, for most prompts, the model would get stuck in a loop eventually.
HF Output:
>>> print(tokenizer.batch_decode(generated_ids)[0])
<s> What is Keras?
Keras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, or Theano. It was developed with a focus on enabling fast experimentation. Being able to go from idea to result with the least possible delay is key to doing good research.
Keras is meant for quick prototyping and easy and fast training. It should not be used in production.
Keras is a high-level API, which means that it is designed to be used by developers who are not experts in machine learning. It is designed to be easy to use, and to make it easy to experiment with different ideas.
Keras is a high-level API, which means that it is designed to be used by developers who are not experts in machine learning. It is designed to be easy to use, and to make it easy to experiment with different ideas.
Keras is a high-
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for checking! Let's stick with top-k then.
@mattdangerw I forgot to put the tokenizer and the LM preprocessor in the public API, will address that along with your comments |
This PR adds a Causal LM for Mistral called
MistralCausalLM
and a preprocessor for it calledMistralCausalLMPreprocessor
. Presets are not added yet but can done in a follow-up PR.Note that I removed sliding window attention cache from the attention layer for Mistral. This is because JAX was complaining about dynamic slicing which is required to make the caching work. More explaination in this commit: 19b0b89
I am in the process of testing if this model matches the outputs of the original model after the weights transfer. Once that's done, I can open the PR up for reviews.