diff --git a/keras_nlp/src/models/gemma/gemma_attention.py b/keras_nlp/src/models/gemma/gemma_attention.py index 9e5d3adbe4..bbe2d6531d 100644 --- a/keras_nlp/src/models/gemma/gemma_attention.py +++ b/keras_nlp/src/models/gemma/gemma_attention.py @@ -122,6 +122,7 @@ def _compute_attention( v, attention_mask, training=False, + cache_update_index=0, ): if self.query_head_dim_normalize: query_normalization = 1 / np.sqrt(self.head_dim) @@ -152,29 +153,10 @@ def _compute_attention( ) if self.use_sliding_window_attention: - all_ones = ops.ones_like(attention_mask) - if keras.config.backend() == "tensorflow": - import tensorflow as tf - - sliding_window_size = ops.minimum( - self.sliding_window_size - 1, q_len - ) - sliding_window_size = ops.cast( - sliding_window_size, dtype="int32" - ) - sliding_mask = tf.linalg.band_part( - all_ones, sliding_window_size - 1, sliding_window_size - 1 - ) - sliding_mask = ops.cast(sliding_mask, dtype="bool") - bool_attention_mask = ops.cast(attention_mask, dtype="bool") - attention_mask = tf.math.logical_and( - sliding_mask, bool_attention_mask - ) - else: - sliding_mask = ops.triu( - all_ones, -1 * self.sliding_window_size + 1 - ) * ops.tril(all_ones, self.sliding_window_size - 1) - attention_mask = sliding_mask * attention_mask + attention_mask = self._mask_sliding_window( + attention_mask, + cache_update_index=cache_update_index, + ) attention_mask = attention_mask[:, None, None, :, :] orig_dtype = attention_logits.dtype @@ -189,6 +171,28 @@ def _compute_attention( results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v) return ops.reshape(results, (b, q_len, self.num_query_heads, h)) + def _mask_sliding_window( + self, + attention_mask, + cache_update_index=0, + ): + batch_size, query_len, key_len = ops.shape(attention_mask) + all_ones = ops.ones((key_len, key_len), "bool") + if keras.config.backend() == "tensorflow": + import tensorflow as tf + + band_size = ops.minimum(key_len, self.sliding_window_size - 1) + band_size = ops.cast(band_size, "int32") + sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size) + else: + sliding_mask = ops.triu( + all_ones, -1 * self.sliding_window_size + 1 + ) * ops.tril(all_ones, self.sliding_window_size - 1) + start = (cache_update_index, 0) + sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len)) + sliding_mask = ops.expand_dims(sliding_mask, 0) + return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool")) + def call( self, x, @@ -216,7 +220,12 @@ def call( value = self.value_dense(x) attention_vec = self._compute_attention( - query, key, value, attention_mask, training=training + query, + key, + value, + attention_mask, + training=training, + cache_update_index=cache_update_index, ) # Wipe attn vec if there are no attended tokens.