Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements the attention kernel with vertical and slash sparse pattern described in Appendix C.4.2 of https://arxiv.org/abs/2407.02490 (as sparse_attn_func) #33

Merged
merged 12 commits into from
Jan 15, 2025

Conversation

minminsun
Copy link

@minminsun minminsun commented Dec 19, 2024

Implements the attention kernel with vertical and slash sparse pattern described in Appendix C.4.2 of https://arxiv.org/abs/2407.02490 (as sparse_attn_func).

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm-flash-attn has unfortunately diverged (in conflicting ways) from the upstream but we are trying to simplify diffs with upstream and in that spirit I think it would be really helpful if we could push most of the additions in
csrc/flash_attn/src/flash_fwd_kernel.h and
csrc/flash_attn/src/flash_fwd_launch_template.h
into their own file, i.e. move them to files like:
csrc/flash_attn/src/vllm_extensions/flash_fwd_sparse_kernel.h and csrc/flash_attn/src/vllm_extensions/flash_fwd_sparse_launch_template.h

(Im looking into ways to reduce diffs in csrc/flash_attn/flash_api.cpp but this is trickier so I think what's in this PR currently is fine, we can address it in a future PR)

@minminsun
Copy link
Author

vllm-flash-attn has unfortunately diverged (in conflicting ways) from the upstream but we are trying to simplify diffs with upstream and in that spirit I think it would be really helpful if we could push most of the additions in csrc/flash_attn/src/flash_fwd_kernel.h and csrc/flash_attn/src/flash_fwd_launch_template.h into their own file, i.e. move them to files like: csrc/flash_attn/src/vllm_extensions/flash_fwd_sparse_kernel.h and csrc/flash_attn/src/vllm_extensions/flash_fwd_sparse_launch_template.h

(Im looking into ways to reduce diffs in csrc/flash_attn/flash_api.cpp but this is trickier so I think what's in this PR currently is fine, we can address it in a future PR)

Thanks for your suggestion! I've moved code to new files.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved code to new files.
Thank you!

Thanks for addressing my previous comments and thanks for the contribution! I did another pass and left some more comments.

Overall the kernel seems quite cool but there seems to be alot of commented out code (thats appears to not be of the "un-comment for useful debug prints" style) that could use cleaning up. I think I caught most of it but another clean-up pass could be useful here.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the changes! Overall looks pretty good to me, left a couple more (optional) nits.

My final concern is binary size (we are a bit sensitive to this in vLLM), do you know which head dims are actually being used (since theres only a limited set of models using this currently)? Ideally we'd only build and ship those for now

Can you get the DCO check to pass be signing off on the commits https://github.com/apps/dco. After that I think everything is good from my side, @WoosukKwon not sure if you want to take a look?

// flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgVBlock, tVsV, tKVcKV, tKVpKV);
#pragma unroll
for (int m = 0; m < size<1>(tVgVToken); ++m) {
if (true) { // Is_even_MN
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: bump

Signed-off-by: Minmin Sun <[email protected]>
Signed-off-by: Minmin Sun <[email protected]>
Signed-off-by: Minmin Sun <[email protected]>
Signed-off-by: Minmin Sun <[email protected]>
@minminsun
Copy link
Author

Thanks for all the changes! Overall looks pretty good to me, left a couple more (optional) nits.

My final concern is binary size (we are a bit sensitive to this in vLLM), do you know which head dims are actually being used (since theres only a limited set of models using this currently)? Ideally we'd only build and ship those for now

Can you get the DCO check to pass be signing off on the commits https://github.com/apps/dco. After that I think everything is good from my side, @WoosukKwon not sure if you want to take a look?

Keeping headdim 128 only is enough for us for now.

The DCO check is now passed. Thank you again for your thorough review and valuable suggestions!

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now, thanks!

@minminsun
Copy link
Author

LGTM now, thanks!

Thank you! Could you please help to get the PR merged? The vLLM PR vllm-project/vllm#11844 depends on this.

@LucasWilkinson
Copy link
Collaborator

@minminsun could you please expand the PR description to something like:

"Implements the kernel described in Appendix C.4.2 https://arxiv.org/abs/2407.02490 (as sparse_attn_func)"

Just so there will be a more useful commit message

@minminsun
Copy link
Author

@minminsun could you please expand the PR description to something like:

"Implements the kernel described in Appendix C.4.2 https://arxiv.org/abs/2407.02490 (as sparse_attn_func)"

Just so there will be a more useful commit message

Sure!

@minminsun minminsun changed the title Add sparse attention with vertical and slash Implements the attention kernel with vertical and slash sparse pattern described in Appendix C.4.2 of https://arxiv.org/abs/2407.02490 (as sparse_attn_func) Jan 13, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the varlen api! it appears the minf.py is not used? am I missing something or can this be removed?

Signed-off-by: Minmin Sun <[email protected]>
@minminsun
Copy link
Author

Thanks for adding the varlen api! it appears the minf.py is not used? am I missing something or can this be removed?

No, it's not used. Removed.

@LucasWilkinson LucasWilkinson merged commit 6e1f8b6 into vllm-project:main Jan 15, 2025
1 check passed
@yzh119
Copy link

yzh119 commented Jan 18, 2025

Hi @minminsun @LucasWilkinson @WoosukKwon , I just saw this PR, you might be interested in flashinfer's sparse attention implementation (which supports fine-grained block size, and both fa2/3 templates), you can try this feature using the sparse attention API in flashinfer: https://docs.flashinfer.ai/api/sparse.html. and it was used in projects such as quest.

@LucasWilkinson
Copy link
Collaborator

Cool thanks for the letting us know, ill check it out!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants