Skip to content

Commit

Permalink
Torch's dot_product_attention doesn't support bias.
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jan 28, 2025
1 parent 0f74562 commit d769bc9
Showing 1 changed file with 14 additions and 34 deletions.
48 changes: 14 additions & 34 deletions keras_hub/src/models/falcon/falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import keras
from keras import ops

from keras_hub.src.utils.keras_utils import has_flash_attention_support


class FalconAttention(keras.layers.Layer):
def __init__(
Expand Down Expand Up @@ -110,38 +108,20 @@ def call(
f"cache_update_index={cache_update_index}"
)

if has_flash_attention_support() and self.attention_dropout_rate == 0:
# Use `dot_product_attention` with Flash Attention support if
# available.
if attention_mask is not None:
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.cast(attention_mask, dtype="bool")
attention_output = ops.dot_product_attention(
query,
key,
value,
bias=ops.multiply(
alibi,
ops.cast(self.inv_norm_factor, self.compute_dtype),
),
mask=attention_mask,
scale=self.inv_norm_factor,
)
else:
attention_scores = ops.einsum("bqnh,bknh->bnqk", query, key)
attention_scores = ops.add(attention_scores, alibi)
# [batch_size, num_heads, query_length, kv_length]
attention_scores = ops.multiply(
attention_scores,
ops.cast(self.inv_norm_factor, self.compute_dtype),
)
attention_scores = self.softmax(
attention_scores, ops.expand_dims(attention_mask, 1)
)
attention_scores = self.attention_dropout(attention_scores)
attention_output = ops.einsum(
"bnqk,bknh->bqnh", attention_scores, value
)
attention_scores = ops.einsum("bqnh,bknh->bnqk", query, key)
attention_scores = ops.add(attention_scores, alibi)
# [batch_size, num_heads, query_length, kv_length]
attention_scores = ops.multiply(
attention_scores,
ops.cast(self.inv_norm_factor, self.compute_dtype),
)
attention_scores = self.softmax(
attention_scores, ops.expand_dims(attention_mask, 1)
)
attention_scores = self.attention_dropout(attention_scores)
attention_output = ops.einsum(
"bnqk,bknh->bqnh", attention_scores, value
)

attention_output = ops.reshape(
attention_output,
Expand Down

0 comments on commit d769bc9

Please sign in to comment.