Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Flash Attention for SD3 MMDiT #2014

Merged
merged 2 commits into from
Dec 12, 2024

Conversation

james77777778
Copy link
Collaborator

@james77777778 james77777778 commented Dec 9, 2024

This PR utilizes ops.dot_product_attention to accelerate inference in SD3

  • 800x800
  • SD3 medium
  • float16
Backend Flash Attention Cost Time Improvement
jax 10.61s
jax 5.45s -48.7%
torch 24.10s
torch 18.43s -23.6%

I noticed that ops.dot_product_attention performed slower than the vanilla impl in the tensorflow backend. Therefore, this optimization path is skipped for it.
(vanilla: 10.55s vs. ops.dot_product_attention: 14.33s)

EDITED:
jax now runs faster than diffusers in an out-of-box manner:
diffusers.StableDiffusion3Pipeline: 6.15s

The benchmark script (KerasHub):

import time

from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import (
    StableDiffusion3TextToImage,
)

height, width = 800, 800
preset = "stable_diffusion_3_medium"
num_steps = 28
guidance_scale = 7.0
dtype = "float16"
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
prompt = [prompt]
text_to_image = StableDiffusion3TextToImage.from_preset(
    preset,
    image_shape=(height, width, 3),
    dtype=dtype,
)

for _ in range(1):
    _ = text_to_image.generate(
        prompt, num_steps=num_steps, guidance_scale=guidance_scale
    )
print("Finish warmup.")

st = time.time()
images = text_to_image.generate(
    prompt, num_steps=num_steps, guidance_scale=guidance_scale
)
ed = time.time()
print(f"Cost time: {ed-st:.2f}s")

The benchmark script (diffusers):

import time

import torch
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    text_encoder_3=None,
    tokenizer_3=None,
    torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
height, width = 800, 800
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

image = pipe(
    prompt,
    negative_prompt="",
    num_inference_steps=28,
    guidance_scale=7.0,
    height=height,
    width=width,
).images[0]
print("Finish warmup.")

st = time.time()
image = pipe(
    prompt,
    negative_prompt="",
    num_inference_steps=28,
    guidance_scale=7.0,
    height=height,
    width=width,
).images[0]
print(time.time() - st)

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test this somehow?

if (
hasattr(ops, "dot_product_attention")
and hasattr(keras.config, "is_flash_attention_enabled")
and keras.backend.backend() != "tensorflow"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's drop the tf part? And just do the same on all backends?

I don't think we want to be in the business of trying to outsmart core Keras. And layers.MultiHeadAttention isn't doing anything like this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

I have submitted a PR to core Keras:
keras-team/keras#20615
With that change, the cost time of the tensorflow will be comparable to jax (w/o flash attention)

  • tensorflow: 10.57s
  • jax: 10.61s

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test this somehow?

I’m not sure how to test this since ops.dot_product_attention is intended to be a drop-in replacement.
Should I compare the numeric w/ and w/o ops.dot_product_attention?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! keras-team/keras#20615, yeah I think that's the way to go.

Hmm, yeah as long as the code is being exercised, probably fine to leave as is. Let's go with this!

@mattdangerw mattdangerw enabled auto-merge (squash) December 10, 2024 19:16
@mattdangerw mattdangerw merged commit 821c014 into keras-team:master Dec 12, 2024
7 checks passed
@james77777778 james77777778 deleted the flash-attn-sd3 branch December 25, 2024 03:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants