-
Notifications
You must be signed in to change notification settings - Fork 251
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a Causal LM model for Mistral (#1429)
* Add Mistral Causal LM Preprocessor * Add the Causal LM for Mistral * Remove sliding window attention from Mistral's attention layer JAX complains about dynamic slicing when compiled with XLA. This is unavoidable since, at runtime, the slice of the current key/value array to use for that iteration is determined by `cache_update_index` which is itself a JAX `TracedArray`. Any workaround would lead to using dynamic shapes at some point. Hence, I had to remove this and instead use vanilla caching for now. For some reason, TensorFlow doesn't complain with XLA. I think this might be because TensorFlow is as stringent about statis shapes as JAX. In any case, adding sliding window attention that is XLA compatible is a story for the future. * Enable JIT compile in the Mistral LM model * Fix Mistral transformer decoder * Port the causal LM to the new infra * Fix a minor bug in sliding window attention caching * Fix a small bug in mistral transformer decoder * Remove the RoPE shenanigan in mistral attention layer * Address review comments and add mistral to the public API
- Loading branch information
1 parent
22c1e30
commit 1951b5c
Showing
7 changed files
with
640 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from keras_nlp.api_export import keras_nlp_export | ||
from keras_nlp.backend import keras | ||
from keras_nlp.backend import ops | ||
from keras_nlp.models.generative_task import GenerativeTask | ||
from keras_nlp.models.mistral.mistral_backbone import MistralBackbone | ||
from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( | ||
MistralCausalLMPreprocessor, | ||
) | ||
from keras_nlp.utils.python_utils import classproperty | ||
|
||
|
||
@keras_nlp_export("keras_nlp.models.MistralCausalLM") | ||
class MistralCausalLM(GenerativeTask): | ||
"""An end-to-end Mistral model for causal language modeling. | ||
A causal language model (LM) predicts the next token based on previous | ||
tokens. This task setup can be used to train the model unsupervised on | ||
plain text input, or to autoregressively generate plain text similar to | ||
the data used for training. This task can be used for pre-training or | ||
fine-tuning a GPT-NeoX model, simply by calling `fit()`. | ||
This model has a `generate()` method, which generates text based on a | ||
prompt. The generation strategy used is controlled by an additional | ||
`sampler` argument on `compile()`. You can recompile the model with | ||
different `keras_nlp.samplers` objects to control the generation. By | ||
default, `"top_k"` sampling will be used. | ||
Args: | ||
backbone: A `keras_nlp.models.MistralBackbone` instance. | ||
preprocessor: A `keras_nlp.models.MistralCausalLMPreprocessor` or `None`. | ||
If `None`, this model will not apply preprocessing, and inputs | ||
should be preprocessed before calling the model. | ||
""" | ||
|
||
def __init__(self, backbone, preprocessor=None, **kwargs): | ||
# === Layers === | ||
self.backbone = backbone | ||
self.preprocessor = preprocessor | ||
|
||
# === Functional Model === | ||
inputs = backbone.inputs | ||
hidden_states = backbone(inputs) | ||
outputs = backbone.token_embedding(hidden_states, reverse=True) | ||
super().__init__( | ||
inputs=inputs, | ||
outputs=outputs, | ||
**kwargs, | ||
) | ||
|
||
# === Default compilation === | ||
self.compile( | ||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
optimizer=keras.optimizers.Adam(2e-5), | ||
metrics=[keras.metrics.SparseCategoricalAccuracy()], | ||
jit_compile=True, | ||
) | ||
|
||
@classproperty | ||
def backbone_cls(cls): | ||
return MistralBackbone | ||
|
||
@classproperty | ||
def preprocessor_cls(cls): | ||
return MistralCausalLMPreprocessor | ||
|
||
def call_with_cache( | ||
self, | ||
token_ids, | ||
cache, | ||
cache_update_index, | ||
): | ||
"""Forward pass of `MistralCausalLM` with cache. | ||
`call_with_cache` adds an additional forward pass for the model for | ||
autoregressive inference. Unlike calling the model directly, this method | ||
allows caching previous key/value Tensors in multi-head attention layer, | ||
and avoids recomputing the outputs of seen tokens. | ||
Args: | ||
token_ids: a dense int Tensor with shape `(batch_size, max_length)`. | ||
cache: a dense float Tensor, the cache of key and value. | ||
cache_update_index: int, or int Tensor. The index of current inputs | ||
in the whole sequence. | ||
Returns: | ||
A (logits, hidden_states, cache) tuple. Where `logits` is the | ||
language model logits for the input token_ids, `hidden_states` is | ||
the final hidden representation of the input tokens, and `cache` is | ||
the decoding cache. | ||
""" | ||
x = self.backbone.token_embedding(token_ids) | ||
# Each decoder layer has a cache; we update them separately. | ||
updated_cache = [] | ||
for i in range(self.backbone.num_layers): | ||
current_cache = cache[:, i, ...] | ||
x, next_cache = self.backbone.transformer_layers[i]( | ||
x, | ||
self_attention_cache=current_cache, | ||
self_attention_cache_update_index=cache_update_index, | ||
) | ||
updated_cache.append(next_cache) | ||
cache = ops.stack(updated_cache, axis=1) | ||
hidden_states = x = self.backbone.layer_norm(x) | ||
logits = self.backbone.token_embedding(x, reverse=True) | ||
return logits, hidden_states, cache | ||
|
||
def _build_cache(self, token_ids): | ||
"""Build an empty cache for use with `call_with_cache()`.""" | ||
batch_size = ops.shape(token_ids)[0] | ||
max_length = ops.shape(token_ids)[1] | ||
num_layers = self.backbone.num_layers | ||
num_key_value_heads = self.backbone.num_key_value_heads | ||
head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads | ||
shape = [ | ||
batch_size, | ||
num_layers, | ||
2, | ||
max_length, | ||
num_key_value_heads, | ||
head_dim, | ||
] | ||
cache = ops.zeros(shape, dtype=self.compute_dtype) | ||
# Seed the cache. | ||
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) | ||
return hidden_states, cache | ||
|
||
def generate_step( | ||
self, | ||
inputs, | ||
end_token_id=None, | ||
): | ||
"""A compilable generation function for a single batch of inputs. | ||
This function represents the inner, XLA-compilable, generation function | ||
for a single batch of inputs. Inputs should have the same structure as | ||
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. | ||
Args: | ||
inputs: A dictionary with two keys `"token_ids"` and | ||
`"padding_mask"` and batched tensor values. | ||
end_token_id: The id of the end token to stop on. If all | ||
sequences have produced a new `end_token_id`, generation | ||
will stop. | ||
""" | ||
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] | ||
# Create and seed cache with a single forward pass. | ||
hidden_states, cache = self._build_cache(token_ids) | ||
# Compute the lengths of all user inputted tokens ids. | ||
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) | ||
# Start at the first index that has no user inputted id. | ||
index = ops.min(row_lengths) | ||
|
||
def next(prompt, cache, index): | ||
# The cache index is the index of our previous token. | ||
cache_update_index = index - 1 | ||
batch_size = ops.shape(prompt)[0] | ||
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) | ||
logits, hidden_states, cache = self.call_with_cache( | ||
prompt, | ||
cache, | ||
cache_update_index, | ||
) | ||
return ( | ||
ops.squeeze(logits, axis=1), | ||
ops.squeeze(hidden_states, axis=1), | ||
cache, | ||
) | ||
|
||
token_ids = self._sampler( | ||
next=next, | ||
prompt=token_ids, | ||
cache=cache, | ||
index=index, | ||
mask=padding_mask, | ||
end_token_id=end_token_id, | ||
hidden_states=hidden_states, | ||
) | ||
|
||
# Compute an output padding mask with the token ids we updated. | ||
if end_token_id is not None: | ||
# Build a mask of `end_token_id` locations not in the original | ||
# prompt (not in locations where `padding_mask` is True). | ||
end_locations = ops.logical_and( | ||
ops.equal(token_ids, end_token_id), | ||
ops.logical_not(padding_mask), | ||
) | ||
end_locations = ops.cast(end_locations, "int32") | ||
# Use cumsum to get ones in all locations after end_locations. | ||
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") | ||
overflow = cumsum - end_locations | ||
# Our padding mask is the inverse of these overflow locations. | ||
padding_mask = ops.logical_not(ops.cast(overflow, "bool")) | ||
else: | ||
# Without early stopping, all locations will have been updated. | ||
padding_mask = ops.ones_like(token_ids, dtype="bool") | ||
return { | ||
"token_ids": token_ids, | ||
"padding_mask": padding_mask, | ||
} |
Oops, something went wrong.