Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support deepseek prefill attention shape #765

Merged
merged 21 commits into from
Feb 1, 2025
Merged

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Jan 30, 2025

Deepseek requires head_dim_qk of 192 and head_dim_vo of 128, this PR implements this feature for prefill attention on ragged tensors. (we can also also support paged kv-cache but it's not emergent, because we only use head_dim_qk=192, head_dim_vo=128 for ragged tensor in DeepSeek MLA w/o matrix absorption, and we need another MQA kernel with head_dim_qk=576, head_dim_vo=512 for Deepseek MLA w/ matrix absorption, I'll upstream that kernel in the next PR)

Checklist

  • Make FA3 template compatible with deepseek model shape
  • Make FA2 template compatible with deepseek model shape
  • Fix AOT compilation scripts
  • Fix C++ tests/benchmarks

Changes to the programming interface

We added an optional field num_heads_vo in the plan function allowing user to specify different num_heads_qk and num_heads_vo:

wrapper.plan(
    ...
    num_heads_qk,
    num_heads_vo=num_heads_vo
    ...
)

@yzh119 yzh119 merged commit eb660de into main Feb 1, 2025
@lw921014
Copy link

lw921014 commented Feb 1, 2025

Hi,I am looking forward to use this MQA kernel in deepseek V3 model. So do you have any date or plan for next pr?

@zhyncs zhyncs deleted the deepseek-prefill branch February 1, 2025 11:32
yzh119 added a commit that referenced this pull request Feb 1, 2025
Followup of #765 , fix the JIT warmup utilities functions.
yzh119 pushed a commit that referenced this pull request Feb 7, 2025
#765 introduced changes to the API of `plan`, including renaming
`head_dim` to `head_dim_qk` and adding `head_dim_vo`. However, some
calling sites were not updated to reflect these changes, resulting in
failing unit tests.

This PR addresses the issue by updating the relevant calls, which should
resolve the following unit test failures after merging:

- `tests/test_block_sparse.py::test_block_sparse_attention`
-
`tests/test_non_contiguous_prefill.py::test_batch_paged_prefill_packed_input`

---------

Signed-off-by: abmfy <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants