Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: FlashAttention-3 style MLA PageAttention #887

Merged
merged 14 commits into from
Feb 23, 2025
Merged

perf: FlashAttention-3 style MLA PageAttention #887

merged 14 commits into from
Feb 23, 2025

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Feb 23, 2025

This PR is the followup of #804 , we implemented a FlashAttention-3 version of warp specialization pattern (splitting on head-dimension) in #804 for faster attention on Hopper GPUs. Compared to the previous version (in FA2 style), this PR did the following changes:

  1. use one warpgroup for producer, two warpgroup for consumer.
  2. use async wgmma instead of mma.
  3. use the software pipeline algorithm in FlashAttention-3, to overlap CUDA-Cores and Tensor-Cores operations.
  4. Compared to original attention, MLA uses the same set of K and V (the ckv matrix), if we reuse the CTA_TILE_KV=64 and PIPE_STAGES=2, the software pipeline algorithm would block the memory copy for next KV-Tile (both the pipe slots were be occupied), original attention do not have this issue because it has both pipeline_k and pipeline_v, doubling the stages. This PR changes CTA_TILE_KV=32 and PIPE_STAGES=4 to ensure we can compute the current KV-tile while loading the next KV-Tile, when using software pipeline.
  5. Unlike original attention, we can't reuse V shared memory space for O. This PR designed a circular buffer for o_smem that reuses the KV slots, one KV-slot is not large enough for o_smem so we use two KV shared memory slot for one o_smem, a barrier is required to guarantee the memory order.

Pipeline

This figures explains our pipeline design:
pipeline-design-mla

Results

Benchmark result on H100 SXM3 (80GB, 3352GB/s).

This PR (fa3 template), page_size=1:

Config: batch_size=64, seq_len=1024, num_heads=64
Memory bandwidth: 1305.40 GB/s
Config: batch_size=128, seq_len=1024, num_heads=64
Memory bandwidth: 2228.56 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2759.33 GB/s
Config: batch_size=64, seq_len=2048, num_heads=64
Memory bandwidth: 1766.33 GB/s
Config: batch_size=128, seq_len=2048, num_heads=64
Memory bandwidth: 2498.08 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2768.37 GB/s

#804 + #863 (fa2 template), page_size=1:

Config: batch_size=64, seq_len=1024, num_heads=64
Memory bandwidth: 1067.74 GB/s
Config: batch_size=128, seq_len=1024, num_heads=64
Memory bandwidth: 1761.25 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2065.78 GB/s
Config: batch_size=64, seq_len=2048, num_heads=64
Memory bandwidth: 1384.35 GB/s
Config: batch_size=128, seq_len=2048, num_heads=64
Memory bandwidth: 1892.64 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2075.97 GB/s

The template is designed to use ampere style LDGSTS, which is prioritized for page_size=1 (but also works for larger page_size). Using TMA and multicast could further improve performance for page_size larger than 1, we leave them for future work.

@yzh119 yzh119 merged commit 2b24293 into main Feb 23, 2025
MasterJH5574 added a commit that referenced this pull request Feb 23, 2025
This PR fixes the header include, following changes in #887.
yzh119 pushed a commit that referenced this pull request Feb 23, 2025
This PR fixes the header include, following changes in #887.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant