-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implements the attention kernel with vertical and slash sparse pattern described in Appendix C.4.2 of https://arxiv.org/abs/2407.02490 (as sparse_attn_func) #33
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vllm-flash-attn has unfortunately diverged (in conflicting ways) from the upstream but we are trying to simplify diffs with upstream and in that spirit I think it would be really helpful if we could push most of the additions in
csrc/flash_attn/src/flash_fwd_kernel.h
and
csrc/flash_attn/src/flash_fwd_launch_template.h
into their own file, i.e. move them to files like:
csrc/flash_attn/src/vllm_extensions/flash_fwd_sparse_kernel.h
and csrc/flash_attn/src/vllm_extensions/flash_fwd_sparse_launch_template.h
(Im looking into ways to reduce diffs in csrc/flash_attn/flash_api.cpp
but this is trickier so I think what's in this PR currently is fine, we can address it in a future PR)
Thanks for your suggestion! I've moved code to new files. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've moved code to new files.
Thank you!
Thanks for addressing my previous comments and thanks for the contribution! I did another pass and left some more comments.
Overall the kernel seems quite cool but there seems to be alot of commented out code (thats appears to not be of the "un-comment for useful debug prints" style) that could use cleaning up. I think I caught most of it but another clean-up pass could be useful here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the changes! Overall looks pretty good to me, left a couple more (optional) nits.
My final concern is binary size (we are a bit sensitive to this in vLLM), do you know which head dims are actually being used (since theres only a limited set of models using this currently)? Ideally we'd only build and ship those for now
Can you get the DCO check to pass be signing off on the commits https://github.com/apps/dco. After that I think everything is good from my side, @WoosukKwon not sure if you want to take a look?
// flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgVBlock, tVsV, tKVcKV, tKVpKV); | ||
#pragma unroll | ||
for (int m = 0; m < size<1>(tVgVToken); ++m) { | ||
if (true) { // Is_even_MN |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: bump
Signed-off-by: Minmin Sun <[email protected]>
Signed-off-by: Minmin Sun <[email protected]>
Signed-off-by: Minmin Sun <[email protected]>
Signed-off-by: Minmin Sun <[email protected]>
Signed-off-by: Minmin Sun <[email protected]>
Signed-off-by: Minmin Sun <[email protected]>
Signed-off-by: Minmin Sun <[email protected]>
Signed-off-by: Minmin Sun <[email protected]>
Keeping headdim 128 only is enough for us for now. The DCO check is now passed. Thank you again for your thorough review and valuable suggestions! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now, thanks!
Thank you! Could you please help to get the PR merged? The vLLM PR vllm-project/vllm#11844 depends on this. |
@minminsun could you please expand the PR description to something like: "Implements the kernel described in Appendix C.4.2 https://arxiv.org/abs/2407.02490 (as Just so there will be a more useful commit message |
Sure! |
Signed-off-by: Minmin Sun <[email protected]>
Signed-off-by: Minmin Sun <[email protected]>
Signed-off-by: Minmin Sun <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the varlen api! it appears the minf.py
is not used? am I missing something or can this be removed?
Signed-off-by: Minmin Sun <[email protected]>
No, it's not used. Removed. |
Hi @minminsun @LucasWilkinson @WoosukKwon , I just saw this PR, you might be interested in flashinfer's sparse attention implementation (which supports fine-grained block size, and both fa2/3 templates), you can try this feature using the sparse attention API in flashinfer: https://docs.flashinfer.ai/api/sparse.html. and it was used in projects such as quest. |
Cool thanks for the letting us know, ill check it out! |
Implements the attention kernel with vertical and slash sparse pattern described in Appendix C.4.2 of https://arxiv.org/abs/2407.02490 (as sparse_attn_func).