Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Causal LM model for Mistral #1429

Merged
merged 11 commits into from
Feb 13, 2024

Conversation

tirthasheshpatel
Copy link
Contributor

This PR adds a Causal LM for Mistral called MistralCausalLM and a preprocessor for it called MistralCausalLMPreprocessor. Presets are not added yet but can done in a follow-up PR.

Note that I removed sliding window attention cache from the attention layer for Mistral. This is because JAX was complaining about dynamic slicing which is required to make the caching work. More explaination in this commit: 19b0b89

I am in the process of testing if this model matches the outputs of the original model after the weights transfer. Once that's done, I can open the PR up for reviews.

JAX complains about dynamic slicing when compiled with XLA. This is unavoidable
since, at runtime, the slice of the current key/value array to use for that iteration
is determined by `cache_update_index` which is itself a JAX `TracedArray`. Any workaround
would lead to using dynamic shapes at some point. Hence, I had to remove this and instead
use vanilla caching for now.

For some reason, TensorFlow doesn't complain with XLA. I think this might be because
TensorFlow is as stringent about statis shapes as JAX.

In any case, adding sliding window attention that is XLA compatible is a story for the
future.
@tirthasheshpatel tirthasheshpatel self-assigned this Feb 8, 2024
@tirthasheshpatel tirthasheshpatel added the type:feature New feature or request label Feb 8, 2024
@tirthasheshpatel tirthasheshpatel marked this pull request as ready for review February 13, 2024 01:57
@tirthasheshpatel
Copy link
Contributor Author

I am in the process of testing if this model matches the outputs of the original model after the weights transfer. Once that's done, I can open the PR up for reviews.

@mattdangerw Tested it with the 7B preset. The outputs of both the backbone and the generator match up. This is ready from my side! I can share the preset with you once this is merged.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Just a couple minor comments.

**kwargs,
)

# Default compilation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

hidden_states = backbone(inputs)
outputs = backbone.token_embedding(hidden_states, reverse=True)

# Instantiate the Functional API Model constructor.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete this comment, and newline above, the header gives enough clues what this is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

padding_mask = padding_mask.astype("bool")
# Strip any special tokens during detokenization (e.g. the start and
# end markers). In the future we could make this configurable.
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we want to also remove the start_token_id (as it is a different token). Just a line like this below with start_token_id instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

prompt. The generation strategy used is controlled by an additional
`sampler` argument on `compile()`. You can recompile the model with
different `keras_nlp.samplers` objects to control the generation. By
default, `"top_k"` sampling will be used.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a good default? For these newer larger models, we might just want to default to greedy if performance is good.

Maybe quick check, does it tend to get stuck in loops with greedy sampling?

Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the output with "greedy" sampler:

>>> output = generator.generate("What is Keras?", max_length=100)
2024-02-13 06:42:36.336579: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 20952865944 exceeds 10% of free system memory.
>>> print(output)
What is Keras?

Keras is a high-level neural network API, written in Python and capable of running on top of TensorFlow, CNTK or Theano. It was designed with a focus on usability, modularity and extensibility.

Keras is a high-level neural network API, written in Python and capable of running on top of TensorFlow, CNTK or Theano. It was designed with a focus on usability, mod

Noticed the same output with HF. I guess, for most prompts, the model would get stuck in a loop eventually.

HF Output:

>>> print(tokenizer.batch_decode(generated_ids)[0])
<s> What is Keras?

Keras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, or Theano. It was developed with a focus on enabling fast experimentation. Being able to go from idea to result with the least possible delay is key to doing good research.

Keras is meant for quick prototyping and easy and fast training. It should not be used in production.

Keras is a high-level API, which means that it is designed to be used by developers who are not experts in machine learning. It is designed to be easy to use, and to make it easy to experiment with different ideas.

Keras is a high-level API, which means that it is designed to be used by developers who are not experts in machine learning. It is designed to be easy to use, and to make it easy to experiment with different ideas.

Keras is a high-

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking! Let's stick with top-k then.

@tirthasheshpatel
Copy link
Contributor Author

@mattdangerw I forgot to put the tokenizer and the LM preprocessor in the public API, will address that along with your comments

@mattdangerw mattdangerw merged commit 1951b5c into keras-team:master Feb 13, 2024
10 checks passed
@tirthasheshpatel tirthasheshpatel deleted the mistral-lm branch February 14, 2024 09:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:feature New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants