Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable offloading multi-query attention by Flash Attention #990

Merged
merged 6 commits into from
Oct 4, 2023

Conversation

masahi
Copy link
Contributor

@masahi masahi commented Sep 28, 2023

Following apache/tvm#15831, this is the mlc-side change to enable MQA offload.

To use the new option, use_flash_attn_mqa, ${TVM_HOME} must point to a TVM build with that commit. In particular, since the TVM submodule pulled by mlc is a custom one that doesn't support flash attention at all, it cannot use this new feature. The option is set to False by default to avoid potential troubles.

Copy link
Contributor

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, LGTM. Thanks for the PR!

Two high-level comments:

  • Would be nice to have at least one testcase.
  • Is this about mlc-ai/relax not having tvm/unity yet?

In particular, since the TVM submodule pulled by mlc is a custom one that doesn't support flash attention at all, it cannot use this new feature.

mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=True)
mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=True)

mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For args.use_flash_attn_mqa==True, do we need to run rewrite_attention twice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is for a case where there are both MQA and regular attention in the same model. I don't think it would come up in practice, but I added for completeness.

@masahi
Copy link
Contributor Author

masahi commented Oct 3, 2023

Is this about mlc-ai/relax not having tvm/unity yet?

No, when I checked the commit history of that fork at that time, they explicitly reverted all flash attention related PRs to apache/unity (probably to make compilation faster). That doesn't seem to be the case anymore.

@sunggg
Copy link
Contributor

sunggg commented Oct 4, 2023

Is this about mlc-ai/relax not having tvm/unity yet?

No, when I checked the commit history of that fork at that time, they explicitly reverted all flash attention related PRs to apache/unity (probably to make compilation faster). That doesn't seem to be the case anymore.

If this is the case, would it make sense to enable this MQA offload by default?

@masahi
Copy link
Contributor Author

masahi commented Oct 4, 2023

I don't want to require Flash Attention for mlc, since it is only needed for MQA and flash attention can be problematic for packaging purposes etc due to its insane compilation time. Moreover, flash attention doesn't seem to be faster than cutlass fMHA for LLM decoding workload, so unless the context length is very large, this optimization doesn't give good speed up over the default explicit repeat + cutlass fmha path.

So for now this feature is experimental.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants