Skip to content

Commit

Permalink
Add a Causal LM model for Mistral (#1429)
Browse files Browse the repository at this point in the history
* Add Mistral Causal LM Preprocessor

* Add the Causal LM for Mistral

* Remove sliding window attention from Mistral's attention layer

JAX complains about dynamic slicing when compiled with XLA. This is unavoidable
since, at runtime, the slice of the current key/value array to use for that iteration
is determined by `cache_update_index` which is itself a JAX `TracedArray`. Any workaround
would lead to using dynamic shapes at some point. Hence, I had to remove this and instead
use vanilla caching for now.

For some reason, TensorFlow doesn't complain with XLA. I think this might be because
TensorFlow is as stringent about statis shapes as JAX.

In any case, adding sliding window attention that is XLA compatible is a story for the
future.

* Enable JIT compile in the Mistral LM model

* Fix Mistral transformer decoder

* Port the causal LM to the new infra

* Fix a minor bug in sliding window attention caching

* Fix a small bug in mistral transformer decoder

* Remove the RoPE shenanigan in mistral attention layer

* Address review comments and add mistral to the public API
  • Loading branch information
tirthasheshpatel authored Feb 13, 2024
1 parent 22c1e30 commit 1951b5c
Show file tree
Hide file tree
Showing 7 changed files with 640 additions and 80 deletions.
6 changes: 6 additions & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@
from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
from keras_nlp.models.llama.llama_backbone import LlamaBackbone
from keras_nlp.models.mistral.mistral_backbone import MistralBackbone
from keras_nlp.models.mistral.mistral_causal_lm import MistralCausalLM
from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import (
MistralCausalLMPreprocessor,
)
from keras_nlp.models.mistral.mistral_preprocessor import MistralPreprocessor
from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer
from keras_nlp.models.opt.opt_backbone import OPTBackbone
from keras_nlp.models.opt.opt_causal_lm import OPTCausalLM
from keras_nlp.models.opt.opt_causal_lm_preprocessor import (
Expand Down
96 changes: 20 additions & 76 deletions keras_nlp/models/mistral/mistral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def call(
cache_update_index=None,
training=None,
):
seq_len = ops.shape(hidden_states)[1]
start_index = (
cache_update_index if cache_update_index is not None else 0
)
Expand All @@ -153,89 +152,34 @@ def call(

query = self._query_dense(hidden_states)

# Note that the original PyTorch implementation uses
# view_as_complex/view_as_real while we use split/concatenate to
# convert to/from complex numbers. The transformations below make
# the rope computation numerically equivalent to the original
# implementation.
def _mistral_rope(x):
x = ops.concatenate([x[..., ::2], x[..., 1::2]], axis=-1)
x = self.rotary_embedding_layer(x, start_index=start_index)
x = ops.reshape(
ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x)
)
return x

# Compute RoPE for queries
query = _mistral_rope(query)
query = self.rotary_embedding_layer(query, start_index=start_index)

def _compute_key_value(x):
key, value = self._key_dense(x), self._value_dense(x)
key = _mistral_rope(key)
# Compute RoPE for keys
key = self.rotary_embedding_layer(key, start_index=start_index)
return key, value

if cache is not None:
cache_k = cache[:, 0, ...]
cache_v = cache[:, 1, ...]

key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update, value_update = _compute_key_value(hidden_states)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
# Compute the new keys and values
key, value = _compute_key_value(hidden_states)

# Cache is a rotating buffer, we want to warp around if
# the sequence length exceeds the sliding window.
update_end_index = (
cache_update_index + seq_len - 1
) % self._sliding_window + 1
update_end_index = ops.cast(update_end_index, "int32")
cache_update_index = cache_update_index % self._sliding_window
update_start_index = ops.cond(
update_end_index > cache_update_index,
lambda: ops.cast(cache_update_index, "int32"),
lambda: ops.cast(0, "int32"),
)
# Also note that the update step below assumes that the
# sequence length is always one when `cache_update_index != 0`.
# This is necessary to support XLA compilation. Ideally, we
# would want to use
# `key[:, -(update_end_index - update_start_index):, ...]`
# as the update but updating using a dynamic slice gives an
# XLA compilation error in TensorFlow.
# Passing a sequence of length > 1 with cache update might give
# incorrect results (since there is no way to determine how
# many most recent tokens are to be saved if the tokens exceed
# the sliding window length).
cache_k = ops.slice_update(
cache_k,
[0, update_start_index, 0, 0],
# We slice the keys and values since if the user has passed
# a sequence of length > `self._sliding_window`. We want to
# prefill the cache using just the most recent values in the
# sliding window.
ops.cast(
key[:, -self._sliding_window :, ...], cache_k.dtype
),
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)
cache_v = ops.slice_update(
cache_v,
[0, update_start_index, 0, 0],
ops.cast(
value[:, -self._sliding_window :, ...], cache_v.dtype
),
)
cache = ops.stack([cache_k, cache_v], axis=1)

# Get the required keys and values from the cache.
# Since we expect the user to pass a fixed-size cache, we just
# pick the first few slices up-to and including the newly computed
# keys and values.
cache_k = cache_k[:, :update_end_index, ...]
cache_v = cache_v[:, :update_end_index, ...]

key = ops.cast(cache_k, dtype=self.compute_dtype)
value = ops.cast(cache_v, dtype=self.compute_dtype)
else:
# Compute keys and values
key, value = _compute_key_value(hidden_states)

# [batch_shape, seq_len, num_key_value_heads, head_dim]
Expand Down Expand Up @@ -265,7 +209,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
return self._softmax(attention_scores)

def _compute_attention(self, query, key, value, attention_mask=None):
attention_scores = ops.einsum(self._dot_product_equation, key, query)
attention_scores = ops.einsum(self._dot_product_equation, query, key)

norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype))

Expand Down
213 changes: 213 additions & 0 deletions keras_nlp/models/mistral/mistral_causal_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.models.generative_task import GenerativeTask
from keras_nlp.models.mistral.mistral_backbone import MistralBackbone
from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import (
MistralCausalLMPreprocessor,
)
from keras_nlp.utils.python_utils import classproperty


@keras_nlp_export("keras_nlp.models.MistralCausalLM")
class MistralCausalLM(GenerativeTask):
"""An end-to-end Mistral model for causal language modeling.
A causal language model (LM) predicts the next token based on previous
tokens. This task setup can be used to train the model unsupervised on
plain text input, or to autoregressively generate plain text similar to
the data used for training. This task can be used for pre-training or
fine-tuning a GPT-NeoX model, simply by calling `fit()`.
This model has a `generate()` method, which generates text based on a
prompt. The generation strategy used is controlled by an additional
`sampler` argument on `compile()`. You can recompile the model with
different `keras_nlp.samplers` objects to control the generation. By
default, `"top_k"` sampling will be used.
Args:
backbone: A `keras_nlp.models.MistralBackbone` instance.
preprocessor: A `keras_nlp.models.MistralCausalLMPreprocessor` or `None`.
If `None`, this model will not apply preprocessing, and inputs
should be preprocessed before calling the model.
"""

def __init__(self, backbone, preprocessor=None, **kwargs):
# === Layers ===
self.backbone = backbone
self.preprocessor = preprocessor

# === Functional Model ===
inputs = backbone.inputs
hidden_states = backbone(inputs)
outputs = backbone.token_embedding(hidden_states, reverse=True)
super().__init__(
inputs=inputs,
outputs=outputs,
**kwargs,
)

# === Default compilation ===
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(2e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
jit_compile=True,
)

@classproperty
def backbone_cls(cls):
return MistralBackbone

@classproperty
def preprocessor_cls(cls):
return MistralCausalLMPreprocessor

def call_with_cache(
self,
token_ids,
cache,
cache_update_index,
):
"""Forward pass of `MistralCausalLM` with cache.
`call_with_cache` adds an additional forward pass for the model for
autoregressive inference. Unlike calling the model directly, this method
allows caching previous key/value Tensors in multi-head attention layer,
and avoids recomputing the outputs of seen tokens.
Args:
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
cache: a dense float Tensor, the cache of key and value.
cache_update_index: int, or int Tensor. The index of current inputs
in the whole sequence.
Returns:
A (logits, hidden_states, cache) tuple. Where `logits` is the
language model logits for the input token_ids, `hidden_states` is
the final hidden representation of the input tokens, and `cache` is
the decoding cache.
"""
x = self.backbone.token_embedding(token_ids)
# Each decoder layer has a cache; we update them separately.
updated_cache = []
for i in range(self.backbone.num_layers):
current_cache = cache[:, i, ...]
x, next_cache = self.backbone.transformer_layers[i](
x,
self_attention_cache=current_cache,
self_attention_cache_update_index=cache_update_index,
)
updated_cache.append(next_cache)
cache = ops.stack(updated_cache, axis=1)
hidden_states = x = self.backbone.layer_norm(x)
logits = self.backbone.token_embedding(x, reverse=True)
return logits, hidden_states, cache

def _build_cache(self, token_ids):
"""Build an empty cache for use with `call_with_cache()`."""
batch_size = ops.shape(token_ids)[0]
max_length = ops.shape(token_ids)[1]
num_layers = self.backbone.num_layers
num_key_value_heads = self.backbone.num_key_value_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads
shape = [
batch_size,
num_layers,
2,
max_length,
num_key_value_heads,
head_dim,
]
cache = ops.zeros(shape, dtype=self.compute_dtype)
# Seed the cache.
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
return hidden_states, cache

def generate_step(
self,
inputs,
end_token_id=None,
):
"""A compilable generation function for a single batch of inputs.
This function represents the inner, XLA-compilable, generation function
for a single batch of inputs. Inputs should have the same structure as
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
Args:
inputs: A dictionary with two keys `"token_ids"` and
`"padding_mask"` and batched tensor values.
end_token_id: The id of the end token to stop on. If all
sequences have produced a new `end_token_id`, generation
will stop.
"""
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(token_ids)
# Compute the lengths of all user inputted tokens ids.
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
# Start at the first index that has no user inputted id.
index = ops.min(row_lengths)

def next(prompt, cache, index):
# The cache index is the index of our previous token.
cache_update_index = index - 1
batch_size = ops.shape(prompt)[0]
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
logits, hidden_states, cache = self.call_with_cache(
prompt,
cache,
cache_update_index,
)
return (
ops.squeeze(logits, axis=1),
ops.squeeze(hidden_states, axis=1),
cache,
)

token_ids = self._sampler(
next=next,
prompt=token_ids,
cache=cache,
index=index,
mask=padding_mask,
end_token_id=end_token_id,
hidden_states=hidden_states,
)

# Compute an output padding mask with the token ids we updated.
if end_token_id is not None:
# Build a mask of `end_token_id` locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = ops.logical_and(
ops.equal(token_ids, end_token_id),
ops.logical_not(padding_mask),
)
end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after end_locations.
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
overflow = cumsum - end_locations
# Our padding mask is the inverse of these overflow locations.
padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
else:
# Without early stopping, all locations will have been updated.
padding_mask = ops.ones_like(token_ids, dtype="bool")
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}
Loading

0 comments on commit 1951b5c

Please sign in to comment.