Skip to content

Commit

Permalink
Remove sliding window attention from Mistral's attention layer
Browse files Browse the repository at this point in the history
JAX complains about dynamic slicing when compiled with XLA. This is unavoidable
since, at runtime, the slice of the current key/value array to use for that iteration
is determined by `cache_update_index` which is itself a JAX `TracedArray`. Any workaround
would lead to using dynamic shapes at some point. Hence, I had to remove this and instead
use vanilla caching for now.

For some reason, TensorFlow doesn't complain with XLA. I think this might be because
TensorFlow is as stringent about statis shapes as JAX.

In any case, adding sliding window attention that is XLA compatible is a story for the
future.
  • Loading branch information
tirthasheshpatel committed Feb 8, 2024
1 parent 2e2e2e5 commit 19b0b89
Showing 1 changed file with 17 additions and 61 deletions.
78 changes: 17 additions & 61 deletions keras_nlp/models/mistral/mistral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def call(
cache_update_index=None,
training=None,
):
seq_len = ops.shape(hidden_states)[1]
start_index = (
cache_update_index if cache_update_index is not None else 0
)
Expand Down Expand Up @@ -170,67 +169,24 @@ def _compute_key_value(x):
return key, value

if cache is not None:
cache_k = cache[:, 0, ...]
cache_v = cache[:, 1, ...]

key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update, value_update = _compute_key_value(hidden_states)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
# Compute the new keys and values
key, value = _compute_key_value(hidden_states)

# Cache is a rotating buffer, we want to warp around if
# the sequence length exceeds the sliding window.
update_end_index = (
cache_update_index + seq_len - 1
) % self._sliding_window + 1
update_end_index = ops.cast(update_end_index, "int32")
cache_update_index = cache_update_index % self._sliding_window
update_start_index = ops.cond(
update_end_index > cache_update_index,
lambda: ops.cast(cache_update_index, "int32"),
lambda: ops.cast(0, "int32"),
)
# Also note that the update step below assumes that the
# sequence length is always one when `cache_update_index != 0`.
# This is necessary to support XLA compilation. Ideally, we
# would want to use
# `key[:, -(update_end_index - update_start_index):, ...]`
# as the update but updating using a dynamic slice gives an
# XLA compilation error in TensorFlow.
# Passing a sequence of length > 1 with cache update might give
# incorrect results (since there is no way to determine how
# many most recent tokens are to be saved if the tokens exceed
# the sliding window length).
cache_k = ops.slice_update(
cache_k,
[0, update_start_index, 0, 0],
# We slice the keys and values since if the user has passed
# a sequence of length > `self._sliding_window`. We want to
# prefill the cache using just the most recent values in the
# sliding window.
ops.cast(
key[:, -self._sliding_window :, ...], cache_k.dtype
),
)
cache_v = ops.slice_update(
cache_v,
[0, update_start_index, 0, 0],
ops.cast(
value[:, -self._sliding_window :, ...], cache_v.dtype
),
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)
cache = ops.stack([cache_k, cache_v], axis=1)

# Get the required keys and values from the cache.
# Since we expect the user to pass a fixed-size cache, we just
# pick the first few slices up-to and including the newly computed
# keys and values.
cache_k = cache_k[:, :update_end_index, ...]
cache_v = cache_v[:, :update_end_index, ...]

key = ops.cast(cache_k, dtype=self.compute_dtype)
value = ops.cast(cache_v, dtype=self.compute_dtype)
else:
# Compute keys and values
key, value = _compute_key_value(hidden_states)

# [batch_shape, seq_len, num_key_value_heads, head_dim]
Expand Down Expand Up @@ -260,7 +216,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
return self._softmax(attention_scores)

def _compute_attention(self, query, key, value, attention_mask=None):
attention_scores = ops.einsum(self._dot_product_equation, key, query)
attention_scores = ops.einsum(self._dot_product_equation, query, key)

norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype))

Expand Down

0 comments on commit 19b0b89

Please sign in to comment.