Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Causal LM model for Mistral #1429

Merged
merged 11 commits into from
Feb 13, 2024
6 changes: 6 additions & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@
from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
from keras_nlp.models.llama.llama_backbone import LlamaBackbone
from keras_nlp.models.mistral.mistral_backbone import MistralBackbone
from keras_nlp.models.mistral.mistral_causal_lm import MistralCausalLM
from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import (
MistralCausalLMPreprocessor,
)
from keras_nlp.models.mistral.mistral_preprocessor import MistralPreprocessor
from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer
from keras_nlp.models.opt.opt_backbone import OPTBackbone
from keras_nlp.models.opt.opt_causal_lm import OPTCausalLM
from keras_nlp.models.opt.opt_causal_lm_preprocessor import (
Expand Down
96 changes: 20 additions & 76 deletions keras_nlp/models/mistral/mistral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def call(
cache_update_index=None,
training=None,
):
seq_len = ops.shape(hidden_states)[1]
start_index = (
cache_update_index if cache_update_index is not None else 0
)
Expand All @@ -148,89 +147,34 @@ def call(

query = self._query_dense(hidden_states)

# Note that the original PyTorch implementation uses
# view_as_complex/view_as_real while we use split/concatenate to
# convert to/from complex numbers. The transformations below make
# the rope computation numerically equivalent to the original
# implementation.
def _mistral_rope(x):
x = ops.concatenate([x[..., ::2], x[..., 1::2]], axis=-1)
x = self.rotary_embedding_layer(x, start_index=start_index)
x = ops.reshape(
ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x)
)
return x

# Compute RoPE for queries
query = _mistral_rope(query)
query = self.rotary_embedding_layer(query, start_index=start_index)

def _compute_key_value(x):
key, value = self._key_dense(x), self._value_dense(x)
key = _mistral_rope(key)
# Compute RoPE for keys
key = self.rotary_embedding_layer(key, start_index=start_index)
return key, value

if cache is not None:
cache_k = cache[:, 0, ...]
cache_v = cache[:, 1, ...]

key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update, value_update = _compute_key_value(hidden_states)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
# Compute the new keys and values
key, value = _compute_key_value(hidden_states)

# Cache is a rotating buffer, we want to warp around if
# the sequence length exceeds the sliding window.
update_end_index = (
cache_update_index + seq_len - 1
) % self._sliding_window + 1
update_end_index = ops.cast(update_end_index, "int32")
cache_update_index = cache_update_index % self._sliding_window
update_start_index = ops.cond(
update_end_index > cache_update_index,
lambda: ops.cast(cache_update_index, "int32"),
lambda: ops.cast(0, "int32"),
)
# Also note that the update step below assumes that the
# sequence length is always one when `cache_update_index != 0`.
# This is necessary to support XLA compilation. Ideally, we
# would want to use
# `key[:, -(update_end_index - update_start_index):, ...]`
# as the update but updating using a dynamic slice gives an
# XLA compilation error in TensorFlow.
# Passing a sequence of length > 1 with cache update might give
# incorrect results (since there is no way to determine how
# many most recent tokens are to be saved if the tokens exceed
# the sliding window length).
cache_k = ops.slice_update(
cache_k,
[0, update_start_index, 0, 0],
# We slice the keys and values since if the user has passed
# a sequence of length > `self._sliding_window`. We want to
# prefill the cache using just the most recent values in the
# sliding window.
ops.cast(
key[:, -self._sliding_window :, ...], cache_k.dtype
),
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)
cache_v = ops.slice_update(
cache_v,
[0, update_start_index, 0, 0],
ops.cast(
value[:, -self._sliding_window :, ...], cache_v.dtype
),
)
cache = ops.stack([cache_k, cache_v], axis=1)

# Get the required keys and values from the cache.
# Since we expect the user to pass a fixed-size cache, we just
# pick the first few slices up-to and including the newly computed
# keys and values.
cache_k = cache_k[:, :update_end_index, ...]
cache_v = cache_v[:, :update_end_index, ...]

key = ops.cast(cache_k, dtype=self.compute_dtype)
value = ops.cast(cache_v, dtype=self.compute_dtype)
else:
# Compute keys and values
key, value = _compute_key_value(hidden_states)

# [batch_shape, seq_len, num_key_value_heads, head_dim]
Expand Down Expand Up @@ -260,7 +204,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
return self._softmax(attention_scores)

def _compute_attention(self, query, key, value, attention_mask=None):
attention_scores = ops.einsum(self._dot_product_equation, key, query)
attention_scores = ops.einsum(self._dot_product_equation, query, key)
tirthasheshpatel marked this conversation as resolved.
Show resolved Hide resolved

norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype))

Expand Down
213 changes: 213 additions & 0 deletions keras_nlp/models/mistral/mistral_causal_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.models.generative_task import GenerativeTask
from keras_nlp.models.mistral.mistral_backbone import MistralBackbone
from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import (
MistralCausalLMPreprocessor,
)
from keras_nlp.utils.python_utils import classproperty


@keras_nlp_export("keras_nlp.models.MistralCausalLM")
class MistralCausalLM(GenerativeTask):
"""An end-to-end Mistral model for causal language modeling.
A causal language model (LM) predicts the next token based on previous
tokens. This task setup can be used to train the model unsupervised on
plain text input, or to autoregressively generate plain text similar to
the data used for training. This task can be used for pre-training or
fine-tuning a GPT-NeoX model, simply by calling `fit()`.
This model has a `generate()` method, which generates text based on a
prompt. The generation strategy used is controlled by an additional
`sampler` argument on `compile()`. You can recompile the model with
different `keras_nlp.samplers` objects to control the generation. By
default, `"top_k"` sampling will be used.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a good default? For these newer larger models, we might just want to default to greedy if performance is good.

Maybe quick check, does it tend to get stuck in loops with greedy sampling?

Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the output with "greedy" sampler:

>>> output = generator.generate("What is Keras?", max_length=100)
2024-02-13 06:42:36.336579: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 20952865944 exceeds 10% of free system memory.
>>> print(output)
What is Keras?

Keras is a high-level neural network API, written in Python and capable of running on top of TensorFlow, CNTK or Theano. It was designed with a focus on usability, modularity and extensibility.

Keras is a high-level neural network API, written in Python and capable of running on top of TensorFlow, CNTK or Theano. It was designed with a focus on usability, mod

Noticed the same output with HF. I guess, for most prompts, the model would get stuck in a loop eventually.

HF Output:

>>> print(tokenizer.batch_decode(generated_ids)[0])
<s> What is Keras?

Keras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, or Theano. It was developed with a focus on enabling fast experimentation. Being able to go from idea to result with the least possible delay is key to doing good research.

Keras is meant for quick prototyping and easy and fast training. It should not be used in production.

Keras is a high-level API, which means that it is designed to be used by developers who are not experts in machine learning. It is designed to be easy to use, and to make it easy to experiment with different ideas.

Keras is a high-level API, which means that it is designed to be used by developers who are not experts in machine learning. It is designed to be easy to use, and to make it easy to experiment with different ideas.

Keras is a high-

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking! Let's stick with top-k then.

Args:
backbone: A `keras_nlp.models.MistralBackbone` instance.
preprocessor: A `keras_nlp.models.MistralCausalLMPreprocessor` or `None`.
If `None`, this model will not apply preprocessing, and inputs
should be preprocessed before calling the model.
"""

def __init__(self, backbone, preprocessor=None, **kwargs):
# === Layers ===
self.backbone = backbone
self.preprocessor = preprocessor

# === Functional Model ===
inputs = backbone.inputs
hidden_states = backbone(inputs)
outputs = backbone.token_embedding(hidden_states, reverse=True)
super().__init__(
inputs=inputs,
outputs=outputs,
**kwargs,
)

# === Default compilation ===
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(2e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
jit_compile=True,
)

@classproperty
def backbone_cls(cls):
return MistralBackbone

@classproperty
def preprocessor_cls(cls):
return MistralCausalLMPreprocessor

def call_with_cache(
self,
token_ids,
cache,
cache_update_index,
):
"""Forward pass of `MistralCausalLM` with cache.
`call_with_cache` adds an additional forward pass for the model for
autoregressive inference. Unlike calling the model directly, this method
allows caching previous key/value Tensors in multi-head attention layer,
and avoids recomputing the outputs of seen tokens.
Args:
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
cache: a dense float Tensor, the cache of key and value.
cache_update_index: int, or int Tensor. The index of current inputs
in the whole sequence.
Returns:
A (logits, hidden_states, cache) tuple. Where `logits` is the
language model logits for the input token_ids, `hidden_states` is
the final hidden representation of the input tokens, and `cache` is
the decoding cache.
"""
x = self.backbone.token_embedding(token_ids)
# Each decoder layer has a cache; we update them separately.
updated_cache = []
for i in range(self.backbone.num_layers):
current_cache = cache[:, i, ...]
x, next_cache = self.backbone.transformer_layers[i](
x,
self_attention_cache=current_cache,
self_attention_cache_update_index=cache_update_index,
)
updated_cache.append(next_cache)
cache = ops.stack(updated_cache, axis=1)
hidden_states = x = self.backbone.layer_norm(x)
logits = self.backbone.token_embedding(x, reverse=True)
return logits, hidden_states, cache

def _build_cache(self, token_ids):
"""Build an empty cache for use with `call_with_cache()`."""
batch_size = ops.shape(token_ids)[0]
max_length = ops.shape(token_ids)[1]
num_layers = self.backbone.num_layers
num_key_value_heads = self.backbone.num_key_value_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads
shape = [
batch_size,
num_layers,
2,
max_length,
num_key_value_heads,
head_dim,
]
cache = ops.zeros(shape, dtype=self.compute_dtype)
# Seed the cache.
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
return hidden_states, cache

def generate_step(
self,
inputs,
end_token_id=None,
):
"""A compilable generation function for a single batch of inputs.
This function represents the inner, XLA-compilable, generation function
for a single batch of inputs. Inputs should have the same structure as
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
Args:
inputs: A dictionary with two keys `"token_ids"` and
`"padding_mask"` and batched tensor values.
end_token_id: The id of the end token to stop on. If all
sequences have produced a new `end_token_id`, generation
will stop.
"""
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(token_ids)
# Compute the lengths of all user inputted tokens ids.
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
# Start at the first index that has no user inputted id.
index = ops.min(row_lengths)

def next(prompt, cache, index):
# The cache index is the index of our previous token.
cache_update_index = index - 1
batch_size = ops.shape(prompt)[0]
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
logits, hidden_states, cache = self.call_with_cache(
prompt,
cache,
cache_update_index,
)
return (
ops.squeeze(logits, axis=1),
ops.squeeze(hidden_states, axis=1),
cache,
)

token_ids = self._sampler(
next=next,
prompt=token_ids,
cache=cache,
index=index,
mask=padding_mask,
end_token_id=end_token_id,
hidden_states=hidden_states,
)

# Compute an output padding mask with the token ids we updated.
if end_token_id is not None:
# Build a mask of `end_token_id` locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = ops.logical_and(
ops.equal(token_ids, end_token_id),
ops.logical_not(padding_mask),
)
end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after end_locations.
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
overflow = cumsum - end_locations
# Our padding mask is the inverse of these overflow locations.
padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
else:
# Without early stopping, all locations will have been updated.
padding_mask = ops.ones_like(token_ids, dtype="bool")
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}
Loading
Loading