Skip to content

Commit

Permalink
Fix for sliding window issues
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Aug 6, 2024
1 parent a1d3e66 commit a2ccea3
Showing 1 changed file with 33 additions and 24 deletions.
57 changes: 33 additions & 24 deletions keras_nlp/src/models/gemma/gemma_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def _compute_attention(
v,
attention_mask,
training=False,
cache_update_index=0,
):
if self.query_head_dim_normalize:
query_normalization = 1 / np.sqrt(self.head_dim)
Expand Down Expand Up @@ -152,29 +153,10 @@ def _compute_attention(
)

if self.use_sliding_window_attention:
all_ones = ops.ones_like(attention_mask)
if keras.config.backend() == "tensorflow":
import tensorflow as tf

sliding_window_size = ops.minimum(
self.sliding_window_size - 1, q_len
)
sliding_window_size = ops.cast(
sliding_window_size, dtype="int32"
)
sliding_mask = tf.linalg.band_part(
all_ones, sliding_window_size - 1, sliding_window_size - 1
)
sliding_mask = ops.cast(sliding_mask, dtype="bool")
bool_attention_mask = ops.cast(attention_mask, dtype="bool")
attention_mask = tf.math.logical_and(
sliding_mask, bool_attention_mask
)
else:
sliding_mask = ops.triu(
all_ones, -1 * self.sliding_window_size + 1
) * ops.tril(all_ones, self.sliding_window_size - 1)
attention_mask = sliding_mask * attention_mask
attention_mask = self._mask_sliding_window(
attention_mask,
cache_update_index=cache_update_index,
)

attention_mask = attention_mask[:, None, None, :, :]
orig_dtype = attention_logits.dtype
Expand All @@ -189,6 +171,28 @@ def _compute_attention(
results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
return ops.reshape(results, (b, q_len, self.num_query_heads, h))

def _mask_sliding_window(
self,
attention_mask,
cache_update_index=0,
):
batch_size, query_len, key_len = ops.shape(attention_mask)
all_ones = ops.ones((key_len, key_len), "bool")
if keras.config.backend() == "tensorflow":
import tensorflow as tf

band_size = ops.minimum(key_len, self.sliding_window_size - 1)
band_size = ops.cast(band_size, "int32")
sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
else:
sliding_mask = ops.triu(
all_ones, -1 * self.sliding_window_size + 1
) * ops.tril(all_ones, self.sliding_window_size - 1)
start = (cache_update_index, 0)
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
sliding_mask = ops.expand_dims(sliding_mask, 0)
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))

def call(
self,
x,
Expand Down Expand Up @@ -216,7 +220,12 @@ def call(
value = self.value_dense(x)

attention_vec = self._compute_attention(
query, key, value, attention_mask, training=training
query,
key,
value,
attention_mask,
training=training,
cache_update_index=cache_update_index,
)

# Wipe attn vec if there are no attended tokens.
Expand Down

0 comments on commit a2ccea3

Please sign in to comment.