Skip to content

Commit

Permalink
Enable Flash Attention for SD3 MMDiT (#2014)
Browse files Browse the repository at this point in the history
* Enable flash attention for SD3 MMDiT

* Remove tf condition
  • Loading branch information
james77777778 authored Dec 12, 2024
1 parent 15564ca commit 821c014
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion keras_hub/src/models/stable_diffusion_3/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,28 @@ def build(self, inputs_shape, context_shape, timestep_embedding_shape):
self.context_block.build(context_shape, timestep_embedding_shape)

def _compute_attention(self, query, key, value):
batch_size = ops.shape(query)[0]

# Use the fast path when `ops.dot_product_attention` and flash attention
# are available.
if hasattr(ops, "dot_product_attention") and hasattr(
keras.config, "is_flash_attention_enabled"
):
# `ops.dot_product_attention` is slower than the vanilla
# implementation in the tensorflow backend.
encoded = ops.dot_product_attention(
query,
key,
value,
scale=self._inverse_sqrt_key_dim,
flash_attention=keras.config.is_flash_attention_enabled(),
)
return ops.reshape(
encoded, (batch_size, -1, self.num_heads * self.head_dim)
)

# Ref: jax.nn.dot_product_attention
# https://github.com/jax-ml/jax/blob/db89c245ac66911c98f265a05956fdfa4bc79d83/jax/_src/nn/functions.py#L846
batch_size = ops.shape(query)[0]
logits = ops.einsum("BTNH,BSNH->BNTS", query, key)
logits = ops.multiply(logits, self._inverse_sqrt_key_dim)
probs = self.softmax(logits)
Expand Down

0 comments on commit 821c014

Please sign in to comment.