-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Flash Attention for SD3 MMDiT #2014
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we test this somehow?
if ( | ||
hasattr(ops, "dot_product_attention") | ||
and hasattr(keras.config, "is_flash_attention_enabled") | ||
and keras.backend.backend() != "tensorflow" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe let's drop the tf part? And just do the same on all backends?
I don't think we want to be in the business of trying to outsmart core Keras. And layers.MultiHeadAttention
isn't doing anything like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure!
I have submitted a PR to core Keras:
keras-team/keras#20615
With that change, the cost time of the tensorflow will be comparable to jax (w/o flash attention)
- tensorflow: 10.57s
- jax: 10.61s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we test this somehow?
I’m not sure how to test this since ops.dot_product_attention
is intended to be a drop-in replacement.
Should I compare the numeric w/ and w/o ops.dot_product_attention
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! keras-team/keras#20615, yeah I think that's the way to go.
Hmm, yeah as long as the code is being exercised, probably fine to leave as is. Let's go with this!
This PR utilizes
ops.dot_product_attention
to accelerate inference in SD3I noticed that
ops.dot_product_attention
performed slower than the vanilla impl in the tensorflow backend. Therefore, this optimization path is skipped for it.(vanilla: 10.55s vs.
ops.dot_product_attention
: 14.33s)EDITED:
jax now runs faster than
diffusers
in an out-of-box manner:diffusers.StableDiffusion3Pipeline
: 6.15sThe benchmark script (KerasHub):
The benchmark script (
diffusers
):