Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve future compatibility of CLIPMultiHeadAttention #1975

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Improve future compatibility of CLIPMultiHeadAttention
  • Loading branch information
james77777778 committed Nov 7, 2024
commit 5787fcd45c43ab84ba3953fac8c814e0789f03c0
23 changes: 3 additions & 20 deletions keras_hub/src/models/clip/clip_encoder_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,11 @@ def quick_gelu(x):
# TODO: Deprecate this in favor of `keras.layers.MultiHeadAttention` once the
# dtype compatibility issue is resolved.
class CLIPMultiHeadAttention(layers.MultiHeadAttention):
def _compute_attention(
self, query, key, value, attention_mask=None, training=None
):
query = ops.multiply(
query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
)
attention_scores = ops.einsum(self._dot_product_equation, key, query)
attention_scores = self._masked_softmax(
def _masked_softmax(self, attention_scores, attention_mask=None):
attention_scores = super()._masked_softmax(
attention_scores, attention_mask
)
# Fix the dtype compatibility.
attention_scores = ops.cast(attention_scores, value.dtype)
if self.dropout:
final_attn_scores = self._dropout_layer(
attention_scores, training=training
)
else:
final_attn_scores = attention_scores
attention_output = ops.einsum(
self._combine_equation, final_attn_scores, value
)
return attention_output, attention_scores
return ops.cast(attention_scores, self._value_dense.compute_dtype)


class CLIPEncoderBlock(layers.Layer):
Expand Down
Loading