-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support deepseek prefill attention shape #765
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Hi,I am looking forward to use this MQA kernel in deepseek V3 model. So do you have any date or plan for next pr? |
yzh119
added a commit
that referenced
this pull request
Feb 1, 2025
Followup of #765 , fix the JIT warmup utilities functions.
This was referenced Feb 3, 2025
4 tasks
yzh119
pushed a commit
that referenced
this pull request
Feb 7, 2025
#765 introduced changes to the API of `plan`, including renaming `head_dim` to `head_dim_qk` and adding `head_dim_vo`. However, some calling sites were not updated to reflect these changes, resulting in failing unit tests. This PR addresses the issue by updating the relevant calls, which should resolve the following unit test failures after merging: - `tests/test_block_sparse.py::test_block_sparse_attention` - `tests/test_non_contiguous_prefill.py::test_batch_paged_prefill_packed_input` --------- Signed-off-by: abmfy <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Deepseek requires
head_dim_qk
of 192 andhead_dim_vo
of 128, this PR implements this feature for prefill attention on ragged tensors. (we can also also support paged kv-cache but it's not emergent, because we only usehead_dim_qk=192, head_dim_vo=128
for ragged tensor in DeepSeek MLA w/o matrix absorption, and we need another MQA kernel withhead_dim_qk=576, head_dim_vo=512
for Deepseek MLA w/ matrix absorption, I'll upstream that kernel in the next PR)Checklist
Changes to the programming interface
We added an optional field
num_heads_vo
in theplan
function allowing user to specify differentnum_heads_qk
andnum_heads_vo
: