Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: memory efficient deepseek mla fused page-attention kernel #804

Merged
merged 16 commits into from
Feb 12, 2025

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Feb 10, 2025

Description:

This PR implements a memory efficient fused Deepseek MLA PageAttention kernel for decode, prefill, and chunked prefill operations.
MLA computation after matrix absorption can be described as a (multi-query attention) MQA kernel, with same K/V cache and special head dimensions: head_dim_qk=576, head_dim_vo=512.

Background:
For Deepseek v2/3, the large head_dim (512) makes it challenging to store the output tensor in registers when using tensor cores. A previous approach (#551) split head_dim across gridDim.y, which led to two main issues:

  • Re-computation Overhead: Multiple blocks had to redundantly compute the Q*K operation.
  • Memory Access Latency: Using multiple blocks prevented shared memory usage, forcing reliance on slower L2 for KV-Cache accesses.

New Design:
We use head-group fusion to increase the operational intensity of the kernel (appendix A in the paper ). To address the large head-dimension issue, we redesign the kernel (diagram below) to use two warp groups (WG1 and WG2, each with 4 warps) per CTA:
image

  • QK Computation: Each warp group processes half of the CTA_TILE_KV dimension, eliminating redundant computations.
  • Shared Memory Broadcast: Local QK results are written to shared memory (reusing the kpe buffer) to efficiently broadcast data for PV computation.
  • PV Computation: The head dimension is split between the warp groups, with each computing half.

The maximum CTA_TILE_Q (bounded by register files) is 64, and in this case the maximum CTA_TILE_KV (bounded by shared memory limit) is 64 (for H100) and 32 (for A100) when number of pipeline stages is set to 2.

For large num_local_heads such as 128 (if no TP and no MTP), we create a cluster of size 2, and the upper half (64) and lower half (64) are dispatched to two SMs in a cluster, and we can use software-managed multicasting (for large page size), we leave it to later PR.

Benchmark results

Decoding Memory bandwidth on H100 SXM5 (Peak bandwidth = 3352 GB/s):

Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 2002.11 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 2035.59 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2082.20 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 2064.97 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 2080.99 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2082.78 GB/s

@yzh119 yzh119 changed the title perf: memory efficient deepseek mla fused kernel perf: memory efficient deepseek mla fused page-attention kernel Feb 10, 2025
@yzh119 yzh119 merged commit 106e6fc into main Feb 12, 2025
yzh119 added a commit that referenced this pull request Feb 12, 2025
The previous PR #804 only tests page_size 1, this PR fixes the issue
with other page sizes and add corresponding unittests.
@seanxcwang
Copy link

test on H20(peak bandwidth 4TB/s)

Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 668.77 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 685.71 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 718.93 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 669.80 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 678.41 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 695.46 GB/s

Is there any performance improvement plan for H20, any suggestion will be helpful

@yzh119
Copy link
Collaborator Author

yzh119 commented Feb 12, 2025

@seanxcwang I'm working on the FA3 version (this PR still uses FA2 template, not overlapping softmax and computation, and doesn't use wgmma) of this kernel, which should also benefit H20, please stay tuned :)

@zhyncs zhyncs deleted the mla-prefill-fa2 branch February 12, 2025 19:26
yzh119 added a commit that referenced this pull request Feb 12, 2025
This PR changes the MLA attention template to support sm89 GPUs, which
has small shared memory size (99kb per sm), so we have to further reduce
shared memory usage: the `NUM_STAGES` can only be set to 1, and
`CTA_TILE_KV` could only be set to atmost 16.

We add an option `QK_SHARD` in the KernelTraits (our previous template
only supports `QK_SHARD=true`):
1. If true, we use the schedule mentioned in #804, and shards the QK
computation on KV dimension, each warpgroup compute half of it, and we
need to perform a round of allgather on shared memory for getting the
full P in PV computation.
2. If false, we duplicate QK computation on two warpgroups (which is not
necessary) but we save the allgather step for P.

We set `QK_SHARD=true` for A100/H100 (shared memory limit is 164kb and
228kb, correspondingly), and `QK_SHARD=false` for sm89.

## Reference

The effect of `QK_SHARD` on H100 SXM5 (3352 GB/s):
```
QK_SHARD=true (Allgather with shared memory)
=======================
Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 2010.78 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 2036.13 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2085.52 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 2068.62 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 2085.84 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2080.85 GB/s

QK_SHARD=false (Duplicate P)
=======================
Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 1610.81 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 1638.73 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 1690.86 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 1636.08 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 1651.57 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 1653.31 GB/s
```

The effect of `QK_SHARD` on A100 SXM 40GB (1555 GB/s):
```
QK_SHARD=true (Allgather with shared memory)
=======================
Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 891.30 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 929.65 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 954.24 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 923.07 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 933.77 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 943.48 GB/s

QK_SHARD=false (Duplicate P)
=======================
Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 753.89 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 780.96 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 804.61 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 785.70 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 796.87 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 808.83 GB/s
```
yzh119 added a commit that referenced this pull request Feb 17, 2025
#804 didn't implement split-k, which might result in performance
degradation if concurrency is not large enough. This PR fixes issue.

We implemented the v2 scheduler and write-through optimization mentioned
in [our paper](https://arxiv.org/pdf/2501.01005) (section 3.3 and
appendix in D.2) for load-balancing.

In an early PR (#72), we
turned off `cudaLaunchCooperativeKernels` and `grid.sync()` because we
are not sure whether it's compatible with CUDAGraph. This PR adds them
back again for grid synchronization, to save some kernel launch
overhead.

## Benchmark

On H100 SXM5 80GB (3352 GB/s), this PR:
```
Config: batch_size=1, seq_len=1024, num_heads=16
Memory bandwidth: 22.33 GB/s
Config: batch_size=16, seq_len=1024, num_heads=16
Memory bandwidth: 330.72 GB/s
Config: batch_size=32, seq_len=1024, num_heads=16
Memory bandwidth: 638.73 GB/s
Config: batch_size=64, seq_len=1024, num_heads=16
Memory bandwidth: 1188.90 GB/s
Config: batch_size=1, seq_len=2048, num_heads=16
Memory bandwidth: 40.74 GB/s
Config: batch_size=16, seq_len=2048, num_heads=16
Memory bandwidth: 592.77 GB/s
Config: batch_size=32, seq_len=2048, num_heads=16
Memory bandwidth: 1112.83 GB/s
Config: batch_size=64, seq_len=2048, num_heads=16
Memory bandwidth: 1506.01 GB/s
Config: batch_size=1, seq_len=4096, num_heads=16
Memory bandwidth: 72.53 GB/s
Config: batch_size=16, seq_len=4096, num_heads=16
Memory bandwidth: 1007.80 GB/s
Config: batch_size=32, seq_len=4096, num_heads=16
Memory bandwidth: 1438.99 GB/s
Config: batch_size=64, seq_len=4096, num_heads=16
Memory bandwidth: 1730.62 GB/s
Config: batch_size=1, seq_len=8192, num_heads=16
Memory bandwidth: 120.74 GB/s
Config: batch_size=16, seq_len=8192, num_heads=16
Memory bandwidth: 1340.86 GB/s
Config: batch_size=32, seq_len=8192, num_heads=16
Memory bandwidth: 1689.36 GB/s
Config: batch_size=64, seq_len=8192, num_heads=16
Memory bandwidth: 1901.26 GB/s
Config: batch_size=1, seq_len=16384, num_heads=16
Memory bandwidth: 177.94 GB/s
Config: batch_size=16, seq_len=16384, num_heads=16
Memory bandwidth: 1619.51 GB/s
Config: batch_size=32, seq_len=16384, num_heads=16
Memory bandwidth: 1876.50 GB/s
Config: batch_size=64, seq_len=16384, num_heads=16
Memory bandwidth: 2010.58 GB/s
Config: batch_size=1, seq_len=32768, num_heads=16
Memory bandwidth: 231.70 GB/s
Config: batch_size=16, seq_len=32768, num_heads=16
Memory bandwidth: 1835.16 GB/s
Config: batch_size=32, seq_len=32768, num_heads=16
Memory bandwidth: 1997.24 GB/s
Config: batch_size=64, seq_len=32768, num_heads=16
Memory bandwidth: 2067.99 GB/s
```

Before this PR:
```
Config: batch_size=1, seq_len=1024, num_heads=16
Memory bandwidth: 15.46 GB/s
Config: batch_size=16, seq_len=1024, num_heads=16
Memory bandwidth: 238.49 GB/s
Config: batch_size=32, seq_len=1024, num_heads=16
Memory bandwidth: 472.44 GB/s
Config: batch_size=64, seq_len=1024, num_heads=16
Memory bandwidth: 929.12 GB/s
Config: batch_size=1, seq_len=2048, num_heads=16
Memory bandwidth: 15.47 GB/s
Config: batch_size=16, seq_len=2048, num_heads=16
Memory bandwidth: 250.71 GB/s
Config: batch_size=32, seq_len=2048, num_heads=16
Memory bandwidth: 500.21 GB/s
Config: batch_size=64, seq_len=2048, num_heads=16
Memory bandwidth: 996.37 GB/s
Config: batch_size=1, seq_len=4096, num_heads=16
Memory bandwidth: 16.36 GB/s
Config: batch_size=16, seq_len=4096, num_heads=16
Memory bandwidth: 257.59 GB/s
Config: batch_size=32, seq_len=4096, num_heads=16
Memory bandwidth: 515.88 GB/s
Config: batch_size=64, seq_len=4096, num_heads=16
Memory bandwidth: 1035.55 GB/s
Config: batch_size=1, seq_len=8192, num_heads=16
Memory bandwidth: 16.37 GB/s
Config: batch_size=16, seq_len=8192, num_heads=16
Memory bandwidth: 261.47 GB/s
Config: batch_size=32, seq_len=8192, num_heads=16
Memory bandwidth: 524.76 GB/s
Config: batch_size=64, seq_len=8192, num_heads=16
Memory bandwidth: 1054.54 GB/s
Config: batch_size=1, seq_len=16384, num_heads=16
Memory bandwidth: 16.50 GB/s
Config: batch_size=16, seq_len=16384, num_heads=16
Memory bandwidth: 263.69 GB/s
Config: batch_size=32, seq_len=16384, num_heads=16
Memory bandwidth: 528.89 GB/s
Config: batch_size=64, seq_len=16384, num_heads=16
Memory bandwidth: 1064.87 GB/s
Config: batch_size=1, seq_len=32768, num_heads=16
Memory bandwidth: 16.45 GB/s
Config: batch_size=16, seq_len=32768, num_heads=16
Memory bandwidth: 264.66 GB/s
Config: batch_size=32, seq_len=32768, num_heads=16
Memory bandwidth: 530.87 GB/s
Config: batch_size=64, seq_len=32768, num_heads=16
Memory bandwidth: 1070.93 GB/s
```
yzh119 added a commit that referenced this pull request Feb 23, 2025
This PR is the followup of #804 , we implemented a FlashAttention-3
version of warp specialization pattern (splitting on head-dimension) in
#804 for faster attention on Hopper GPUs. Compared to the previous
version (in FA2 style), this PR did the following changes:

1. use one warpgroup for producer, two warpgroup for consumer.
2. use async wgmma instead of mma.
3. use the software pipeline algorithm in FlashAttention-3, to overlap
CUDA-Cores and Tensor-Cores operations.
4. Compared to original attention, MLA uses the same set of K and V (the
ckv matrix), if we reuse the `CTA_TILE_KV=64` and `PIPE_STAGES=2`, the
software pipeline algorithm would block the memory copy for next KV-Tile
(both the pipe slots were be occupied), original attention do not have
this issue because it has both `pipeline_k` and `pipeline_v`, doubling
the stages. This PR changes `CTA_TILE_KV=32` and `PIPE_STAGES=4` to
ensure we can compute the current KV-tile while loading the next
KV-Tile, when using software pipeline.
5. Unlike original attention, we can't reuse V shared memory space for
O. This PR designed a circular buffer for `o_smem` that reuses the KV
slots, one KV-slot is not large enough for `o_smem` so we use two KV
shared memory slot for one `o_smem`, a barrier is required to guarantee
the memory order.

## Pipeline

This figures explains our pipeline design:

![pipeline-design-mla](https://github.com/user-attachments/assets/178e465e-e671-459f-a4ea-02e2eaf17343)

## Results

Benchmark result on H100 SXM3 (80GB).

This PR (fa3 template), `page_size=1`:
```
Config: batch_size=64, seq_len=1024, num_heads=64
Memory bandwidth: 1305.40 GB/s
Config: batch_size=128, seq_len=1024, num_heads=64
Memory bandwidth: 2228.56 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2759.33 GB/s
Config: batch_size=64, seq_len=2048, num_heads=64
Memory bandwidth: 1766.33 GB/s
Config: batch_size=128, seq_len=2048, num_heads=64
Memory bandwidth: 2498.08 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2768.37 GB/s
```

#804 + #863 (fa2 template), `page_size=1`:
```
Config: batch_size=64, seq_len=1024, num_heads=64
Memory bandwidth: 1067.74 GB/s
Config: batch_size=128, seq_len=1024, num_heads=64
Memory bandwidth: 1761.25 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2065.78 GB/s
Config: batch_size=64, seq_len=2048, num_heads=64
Memory bandwidth: 1384.35 GB/s
Config: batch_size=128, seq_len=2048, num_heads=64
Memory bandwidth: 1892.64 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2075.97 GB/s
```

Using TMA and multicast could further improve performance for
`page_size` larger than 1, we leave them for future work.
@yzh119
Copy link
Collaborator Author

yzh119 commented Feb 23, 2025

Hi @seanxcwang , #887 should improve performance on H20.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants