-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: memory efficient deepseek mla fused page-attention kernel #804
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
zhyncs
reviewed
Feb 10, 2025
abcdabcd987
reviewed
Feb 10, 2025
abcdabcd987
reviewed
Feb 10, 2025
abcdabcd987
reviewed
Feb 10, 2025
Co-authored-by: Lequn Chen <[email protected]>
Co-authored-by: Lequn Chen <[email protected]>
… into mla-prefill-fa2
This was referenced Feb 12, 2025
Closed
yzh119
added a commit
that referenced
this pull request
Feb 12, 2025
The previous PR #804 only tests page_size 1, this PR fixes the issue with other page sizes and add corresponding unittests.
test on H20(peak bandwidth 4TB/s)
Is there any performance improvement plan for H20, any suggestion will be helpful |
@seanxcwang I'm working on the FA3 version (this PR still uses FA2 template, not overlapping softmax and computation, and doesn't use wgmma) of this kernel, which should also benefit H20, please stay tuned :) |
yzh119
added a commit
that referenced
this pull request
Feb 12, 2025
This PR changes the MLA attention template to support sm89 GPUs, which has small shared memory size (99kb per sm), so we have to further reduce shared memory usage: the `NUM_STAGES` can only be set to 1, and `CTA_TILE_KV` could only be set to atmost 16. We add an option `QK_SHARD` in the KernelTraits (our previous template only supports `QK_SHARD=true`): 1. If true, we use the schedule mentioned in #804, and shards the QK computation on KV dimension, each warpgroup compute half of it, and we need to perform a round of allgather on shared memory for getting the full P in PV computation. 2. If false, we duplicate QK computation on two warpgroups (which is not necessary) but we save the allgather step for P. We set `QK_SHARD=true` for A100/H100 (shared memory limit is 164kb and 228kb, correspondingly), and `QK_SHARD=false` for sm89. ## Reference The effect of `QK_SHARD` on H100 SXM5 (3352 GB/s): ``` QK_SHARD=true (Allgather with shared memory) ======================= Config: batch_size=768, seq_len=1024, num_heads=16 Memory bandwidth: 2010.78 GB/s Config: batch_size=768, seq_len=1024, num_heads=32 Memory bandwidth: 2036.13 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2085.52 GB/s Config: batch_size=768, seq_len=2048, num_heads=16 Memory bandwidth: 2068.62 GB/s Config: batch_size=768, seq_len=2048, num_heads=32 Memory bandwidth: 2085.84 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2080.85 GB/s QK_SHARD=false (Duplicate P) ======================= Config: batch_size=768, seq_len=1024, num_heads=16 Memory bandwidth: 1610.81 GB/s Config: batch_size=768, seq_len=1024, num_heads=32 Memory bandwidth: 1638.73 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 1690.86 GB/s Config: batch_size=768, seq_len=2048, num_heads=16 Memory bandwidth: 1636.08 GB/s Config: batch_size=768, seq_len=2048, num_heads=32 Memory bandwidth: 1651.57 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 1653.31 GB/s ``` The effect of `QK_SHARD` on A100 SXM 40GB (1555 GB/s): ``` QK_SHARD=true (Allgather with shared memory) ======================= Config: batch_size=768, seq_len=1024, num_heads=16 Memory bandwidth: 891.30 GB/s Config: batch_size=768, seq_len=1024, num_heads=32 Memory bandwidth: 929.65 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 954.24 GB/s Config: batch_size=768, seq_len=2048, num_heads=16 Memory bandwidth: 923.07 GB/s Config: batch_size=768, seq_len=2048, num_heads=32 Memory bandwidth: 933.77 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 943.48 GB/s QK_SHARD=false (Duplicate P) ======================= Config: batch_size=768, seq_len=1024, num_heads=16 Memory bandwidth: 753.89 GB/s Config: batch_size=768, seq_len=1024, num_heads=32 Memory bandwidth: 780.96 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 804.61 GB/s Config: batch_size=768, seq_len=2048, num_heads=16 Memory bandwidth: 785.70 GB/s Config: batch_size=768, seq_len=2048, num_heads=32 Memory bandwidth: 796.87 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 808.83 GB/s ```
This was referenced Feb 13, 2025
yzh119
added a commit
that referenced
this pull request
Feb 17, 2025
#804 didn't implement split-k, which might result in performance degradation if concurrency is not large enough. This PR fixes issue. We implemented the v2 scheduler and write-through optimization mentioned in [our paper](https://arxiv.org/pdf/2501.01005) (section 3.3 and appendix in D.2) for load-balancing. In an early PR (#72), we turned off `cudaLaunchCooperativeKernels` and `grid.sync()` because we are not sure whether it's compatible with CUDAGraph. This PR adds them back again for grid synchronization, to save some kernel launch overhead. ## Benchmark On H100 SXM5 80GB (3352 GB/s), this PR: ``` Config: batch_size=1, seq_len=1024, num_heads=16 Memory bandwidth: 22.33 GB/s Config: batch_size=16, seq_len=1024, num_heads=16 Memory bandwidth: 330.72 GB/s Config: batch_size=32, seq_len=1024, num_heads=16 Memory bandwidth: 638.73 GB/s Config: batch_size=64, seq_len=1024, num_heads=16 Memory bandwidth: 1188.90 GB/s Config: batch_size=1, seq_len=2048, num_heads=16 Memory bandwidth: 40.74 GB/s Config: batch_size=16, seq_len=2048, num_heads=16 Memory bandwidth: 592.77 GB/s Config: batch_size=32, seq_len=2048, num_heads=16 Memory bandwidth: 1112.83 GB/s Config: batch_size=64, seq_len=2048, num_heads=16 Memory bandwidth: 1506.01 GB/s Config: batch_size=1, seq_len=4096, num_heads=16 Memory bandwidth: 72.53 GB/s Config: batch_size=16, seq_len=4096, num_heads=16 Memory bandwidth: 1007.80 GB/s Config: batch_size=32, seq_len=4096, num_heads=16 Memory bandwidth: 1438.99 GB/s Config: batch_size=64, seq_len=4096, num_heads=16 Memory bandwidth: 1730.62 GB/s Config: batch_size=1, seq_len=8192, num_heads=16 Memory bandwidth: 120.74 GB/s Config: batch_size=16, seq_len=8192, num_heads=16 Memory bandwidth: 1340.86 GB/s Config: batch_size=32, seq_len=8192, num_heads=16 Memory bandwidth: 1689.36 GB/s Config: batch_size=64, seq_len=8192, num_heads=16 Memory bandwidth: 1901.26 GB/s Config: batch_size=1, seq_len=16384, num_heads=16 Memory bandwidth: 177.94 GB/s Config: batch_size=16, seq_len=16384, num_heads=16 Memory bandwidth: 1619.51 GB/s Config: batch_size=32, seq_len=16384, num_heads=16 Memory bandwidth: 1876.50 GB/s Config: batch_size=64, seq_len=16384, num_heads=16 Memory bandwidth: 2010.58 GB/s Config: batch_size=1, seq_len=32768, num_heads=16 Memory bandwidth: 231.70 GB/s Config: batch_size=16, seq_len=32768, num_heads=16 Memory bandwidth: 1835.16 GB/s Config: batch_size=32, seq_len=32768, num_heads=16 Memory bandwidth: 1997.24 GB/s Config: batch_size=64, seq_len=32768, num_heads=16 Memory bandwidth: 2067.99 GB/s ``` Before this PR: ``` Config: batch_size=1, seq_len=1024, num_heads=16 Memory bandwidth: 15.46 GB/s Config: batch_size=16, seq_len=1024, num_heads=16 Memory bandwidth: 238.49 GB/s Config: batch_size=32, seq_len=1024, num_heads=16 Memory bandwidth: 472.44 GB/s Config: batch_size=64, seq_len=1024, num_heads=16 Memory bandwidth: 929.12 GB/s Config: batch_size=1, seq_len=2048, num_heads=16 Memory bandwidth: 15.47 GB/s Config: batch_size=16, seq_len=2048, num_heads=16 Memory bandwidth: 250.71 GB/s Config: batch_size=32, seq_len=2048, num_heads=16 Memory bandwidth: 500.21 GB/s Config: batch_size=64, seq_len=2048, num_heads=16 Memory bandwidth: 996.37 GB/s Config: batch_size=1, seq_len=4096, num_heads=16 Memory bandwidth: 16.36 GB/s Config: batch_size=16, seq_len=4096, num_heads=16 Memory bandwidth: 257.59 GB/s Config: batch_size=32, seq_len=4096, num_heads=16 Memory bandwidth: 515.88 GB/s Config: batch_size=64, seq_len=4096, num_heads=16 Memory bandwidth: 1035.55 GB/s Config: batch_size=1, seq_len=8192, num_heads=16 Memory bandwidth: 16.37 GB/s Config: batch_size=16, seq_len=8192, num_heads=16 Memory bandwidth: 261.47 GB/s Config: batch_size=32, seq_len=8192, num_heads=16 Memory bandwidth: 524.76 GB/s Config: batch_size=64, seq_len=8192, num_heads=16 Memory bandwidth: 1054.54 GB/s Config: batch_size=1, seq_len=16384, num_heads=16 Memory bandwidth: 16.50 GB/s Config: batch_size=16, seq_len=16384, num_heads=16 Memory bandwidth: 263.69 GB/s Config: batch_size=32, seq_len=16384, num_heads=16 Memory bandwidth: 528.89 GB/s Config: batch_size=64, seq_len=16384, num_heads=16 Memory bandwidth: 1064.87 GB/s Config: batch_size=1, seq_len=32768, num_heads=16 Memory bandwidth: 16.45 GB/s Config: batch_size=16, seq_len=32768, num_heads=16 Memory bandwidth: 264.66 GB/s Config: batch_size=32, seq_len=32768, num_heads=16 Memory bandwidth: 530.87 GB/s Config: batch_size=64, seq_len=32768, num_heads=16 Memory bandwidth: 1070.93 GB/s ```
yzh119
added a commit
that referenced
this pull request
Feb 23, 2025
This PR is the followup of #804 , we implemented a FlashAttention-3 version of warp specialization pattern (splitting on head-dimension) in #804 for faster attention on Hopper GPUs. Compared to the previous version (in FA2 style), this PR did the following changes: 1. use one warpgroup for producer, two warpgroup for consumer. 2. use async wgmma instead of mma. 3. use the software pipeline algorithm in FlashAttention-3, to overlap CUDA-Cores and Tensor-Cores operations. 4. Compared to original attention, MLA uses the same set of K and V (the ckv matrix), if we reuse the `CTA_TILE_KV=64` and `PIPE_STAGES=2`, the software pipeline algorithm would block the memory copy for next KV-Tile (both the pipe slots were be occupied), original attention do not have this issue because it has both `pipeline_k` and `pipeline_v`, doubling the stages. This PR changes `CTA_TILE_KV=32` and `PIPE_STAGES=4` to ensure we can compute the current KV-tile while loading the next KV-Tile, when using software pipeline. 5. Unlike original attention, we can't reuse V shared memory space for O. This PR designed a circular buffer for `o_smem` that reuses the KV slots, one KV-slot is not large enough for `o_smem` so we use two KV shared memory slot for one `o_smem`, a barrier is required to guarantee the memory order. ## Pipeline This figures explains our pipeline design:  ## Results Benchmark result on H100 SXM3 (80GB). This PR (fa3 template), `page_size=1`: ``` Config: batch_size=64, seq_len=1024, num_heads=64 Memory bandwidth: 1305.40 GB/s Config: batch_size=128, seq_len=1024, num_heads=64 Memory bandwidth: 2228.56 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2759.33 GB/s Config: batch_size=64, seq_len=2048, num_heads=64 Memory bandwidth: 1766.33 GB/s Config: batch_size=128, seq_len=2048, num_heads=64 Memory bandwidth: 2498.08 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2768.37 GB/s ``` #804 + #863 (fa2 template), `page_size=1`: ``` Config: batch_size=64, seq_len=1024, num_heads=64 Memory bandwidth: 1067.74 GB/s Config: batch_size=128, seq_len=1024, num_heads=64 Memory bandwidth: 1761.25 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2065.78 GB/s Config: batch_size=64, seq_len=2048, num_heads=64 Memory bandwidth: 1384.35 GB/s Config: batch_size=128, seq_len=2048, num_heads=64 Memory bandwidth: 1892.64 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2075.97 GB/s ``` Using TMA and multicast could further improve performance for `page_size` larger than 1, we leave them for future work.
Hi @seanxcwang , #887 should improve performance on H20. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description:
This PR implements a memory efficient fused Deepseek MLA PageAttention kernel for decode, prefill, and chunked prefill operations.
MLA computation after matrix absorption can be described as a (multi-query attention) MQA kernel, with same K/V cache and special head dimensions:
head_dim_qk=576, head_dim_vo=512
.Background:
For Deepseek v2/3, the large
head_dim
(512) makes it challenging to store the output tensor in registers when using tensor cores. A previous approach (#551) splithead_dim
acrossgridDim.y
, which led to two main issues:New Design:

We use head-group fusion to increase the operational intensity of the kernel (appendix A in the paper ). To address the large head-dimension issue, we redesign the kernel (diagram below) to use two warp groups (WG1 and WG2, each with 4 warps) per CTA:
kpe
buffer) to efficiently broadcast data for PV computation.The maximum
CTA_TILE_Q
(bounded by register files) is 64, and in this case the maximumCTA_TILE_KV
(bounded by shared memory limit) is 64 (for H100) and 32 (for A100) when number of pipeline stages is set to 2.For large
num_local_heads
such as 128 (if no TP and no MTP), we create a cluster of size 2, and the upper half (64) and lower half (64) are dispatched to two SMs in a cluster, and we can use software-managed multicasting (for large page size), we leave it to later PR.Benchmark results
Decoding Memory bandwidth on H100 SXM5 (Peak bandwidth = 3352 GB/s):