From 821c01459d2b24e5cbd8d3a29b521fd0a61031c3 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Fri, 13 Dec 2024 06:48:13 +0800 Subject: [PATCH] Enable Flash Attention for SD3 MMDiT (#2014) * Enable flash attention for SD3 MMDiT * Remove tf condition --- .../src/models/stable_diffusion_3/mmdit.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/stable_diffusion_3/mmdit.py b/keras_hub/src/models/stable_diffusion_3/mmdit.py index 083e4a359a..2069d1595f 100644 --- a/keras_hub/src/models/stable_diffusion_3/mmdit.py +++ b/keras_hub/src/models/stable_diffusion_3/mmdit.py @@ -595,9 +595,28 @@ def build(self, inputs_shape, context_shape, timestep_embedding_shape): self.context_block.build(context_shape, timestep_embedding_shape) def _compute_attention(self, query, key, value): + batch_size = ops.shape(query)[0] + + # Use the fast path when `ops.dot_product_attention` and flash attention + # are available. + if hasattr(ops, "dot_product_attention") and hasattr( + keras.config, "is_flash_attention_enabled" + ): + # `ops.dot_product_attention` is slower than the vanilla + # implementation in the tensorflow backend. + encoded = ops.dot_product_attention( + query, + key, + value, + scale=self._inverse_sqrt_key_dim, + flash_attention=keras.config.is_flash_attention_enabled(), + ) + return ops.reshape( + encoded, (batch_size, -1, self.num_heads * self.head_dim) + ) + # Ref: jax.nn.dot_product_attention # https://github.com/jax-ml/jax/blob/db89c245ac66911c98f265a05956fdfa4bc79d83/jax/_src/nn/functions.py#L846 - batch_size = ops.shape(query)[0] logits = ops.einsum("BTNH,BSNH->BNTS", query, key) logits = ops.multiply(logits, self._inverse_sqrt_key_dim) probs = self.softmax(logits)