-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable offloading multi-query attention by Flash Attention #990
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, LGTM. Thanks for the PR!
Two high-level comments:
- Would be nice to have at least one testcase.
- Is this about
mlc-ai/relax
not havingtvm/unity
yet?
In particular, since the TVM submodule pulled by mlc is a custom one that doesn't support flash attention at all, it cannot use this new feature.
mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=True) | ||
mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=True) | ||
|
||
mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For args.use_flash_attn_mqa==True
, do we need to run rewrite_attention
twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is for a case where there are both MQA and regular attention in the same model. I don't think it would come up in practice, but I added for completeness.
No, when I checked the commit history of that fork at that time, they explicitly reverted all flash attention related PRs to |
If this is the case, would it make sense to enable this MQA offload by default? |
I don't want to require Flash Attention for mlc, since it is only needed for MQA and flash attention can be problematic for packaging purposes etc due to its insane compilation time. Moreover, flash attention doesn't seem to be faster than cutlass fMHA for LLM decoding workload, so unless the context length is very large, this optimization doesn't give good speed up over the default explicit repeat + cutlass fmha path. So for now this feature is experimental. |
Following apache/tvm#15831, this is the mlc-side change to enable MQA offload.
To use the new option,
use_flash_attn_mqa
,${TVM_HOME}
must point to a TVM build with that commit. In particular, since the TVM submodule pulled by mlc is a custom one that doesn't support flash attention at all, it cannot use this new feature. The option is set to False by default to avoid potential troubles.