Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support MLA decode, implemented by CuTe targeted to SM80 #766

Closed
wants to merge 3 commits into from

Conversation

tsu-bin
Copy link
Contributor

@tsu-bin tsu-bin commented Jan 31, 2025

The scheduling design is shown in below diagram (q_pe and kpe parts are omitted), you can tell where each tensor resides by the prefix of their name, such as smem, reg and gmem.
image

Some design key points:

  • we want to load as many q-heads to smem as possible, so we let k_kv_tile_len be the minimum value 8, to save the smem space for q_nope
  • I didn't concat q_nope with q_pe and ckv and kpe in smem, because ckv can be transposed in the output matmul stage more easily
  • the softmax calculation stage only utilize 2 warps, each thread process one row, and I didn't use warp shuffle intrins, just for the simplicity of implementation and also because the reduction axis of smem_att is only 8 rather small
  • There is one very interesting trick usage of CUTLASS CuTe Layout:
  using LayoutOScaleMat = Layout< Shape< Int<QO_TILE_LEN>, Int<HEAD_DIM_CKV>>, Stride<_1, _0>>;
  Tensor o_scale_broadcast_mat = make_tensor((ptr_o_scale), LayoutOScaleMat{});

o_scale_broadcast_mat shared the same smem data with smem_o_scale , the column stride of o_scale_broadcast_mat is set to 0, and o_scale_broadcast_mat is then partitioned by thr_mma_output

Tensor o_scale_mat_part = thr_mma_output.partition_C(o_scale_broadcast_mat);

and we further

for (int i=0; i<cute::size(reg_output_part); ++i)
    reg_output_part(i) = reg_output_part(i) * o_scale_mat_part(i);

in this way we can easily implement row-wise broadcasting, each element of smem_o_scale multiplies one row of output tensor and we don't need to care about the complex register value layout of mma.

Benchmark data:

image
I benchmark the kernel on 4090 which can accommodate 64 q-heads at most, so for 128 q-heads we need to load all kv-cache data twice, we can see for 64 q-heads the BWUtil is up to 85%, for 128 q-heads BWUtil is up to 44% * 2 (88% is higher than 85%, I think it's because of the L2 cache benefit), you can also make a comparison with the benchmark screenshot of the cuda-core implementation:
#551

Limitation:

  • The current implementation only supports q-heads num to be multiple of 64, which is ok for normal version of Deepseek models, but not compatible with lite version models which have 16 q-heads num. We need a different tiled-mma layout design for 16 q-heads.
  • Not support for SM75, since the mma wrapper in CuTe for SM75 is rather limited.

@tsu-bin tsu-bin marked this pull request as draft January 31, 2025 16:49
att[i] = -flashinfer::math::inf;
}

float row_o_scale = expf(row_max_prev - row_max);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have sense about the input value range of exp?
For fp8, one solution to improve numerical stability is to multiply $p$ by a constant and divide the constant in the end:

p = exp2(log2(448) + x - row_max)
o += p * v
o = o / 448

where 448 is the maximum value of e4m3.

@@ -42,7 +42,7 @@ struct state_t {

__device__ __forceinline__ state_t() { init(); }

__device__ __forceinline__ float get_lse() const { return m + math::ptx_log2(d); }
__device__ __forceinline__ float get_lse() const { return m + logf(d); }
Copy link
Collaborator

@yzh119 yzh119 Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we enabled -use_fast_math so expf will be cmopiled to (similarly for logf):

exp2(log2e * x)

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the numerical accuracy might not be coming from approx exp2, instead, we should multiply sm_scale by log2e like other kernel implementations.

@tsu-bin tsu-bin force-pushed the mla_decode_cute_dev branch from 5527c3f to 26200c8 Compare February 4, 2025 16:22
@tsu-bin
Copy link
Contributor Author

tsu-bin commented Feb 4, 2025

I think the numerical accuracy might not be coming from approx exp2, instead, we should multiply sm_scale by log2e like other kernel implementations.

Hi @yzh119 you are right, the inaccuracy is not from approx exp2.
I realized that the inaccuracy issue is because I didn't multiply sm_scale by math::log2e and still use exp2, which was equivalent that I have changed the exponent base value of the softmax function. And also the reason why multiply sm_scale by math::log2e is that we want to change softmax exponent base value from e to 2 and still get the same result value.

@tsu-bin tsu-bin force-pushed the mla_decode_cute_dev branch from 0625a04 to 1e0192b Compare February 10, 2025 17:24
@tsu-bin tsu-bin marked this pull request as ready for review February 10, 2025 17:25
@tsu-bin tsu-bin force-pushed the mla_decode_cute_dev branch from 1e0192b to bab7b50 Compare February 10, 2025 17:33
@yzh119 yzh119 mentioned this pull request Feb 12, 2025
@yzh119
Copy link
Collaborator

yzh119 commented Feb 13, 2025

I checked this implementation, and it seems to be using fp16 (please correct me if I'm wrong) as the mma output data type (and data type for o register), which is not enough to guarantee correctness, especially for long kv-cache.

If you use f32 for o data type, then this implementation (which uses 4 warps) cannot fit reg_output_part into registers (64 * 512 / 128 = 256, while the max number of registers per thread is 255). So, we need to use 8 warps and let each of the warp group to handle part of the output registers, like in #804 .

Since #804 implements both decode and incremental prefill, and also supports versatile number heads, I think we should turn to maintain this codebase together instead: https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/attention/mla_fa2.cuh

Let me know if you have other concerns, I really appreciate your work in this and I believe we can learn from your findings, especially in o_scale_broadcast_mat, which I believe could benefit our MLA template for FA3.

Thank you again for making this contribution!

@yzh119 yzh119 closed this Feb 13, 2025
@tsu-bin
Copy link
Contributor Author

tsu-bin commented Feb 13, 2025

No problem, the limitation of register file size is a practical issue, the scheduling design of this PR is what I can do best to accommodate 64 q-heads for one block, indeed no way to use f32 o register.
And the advantage of Cutlass CuTe is its development efficiency, such as easy usage of MMA and easy to achieve bank conflict-free smem access. Even I change the scheduling design to fit f32 o register, the performance should be at the same level as the one you currently implemented.
Looking forward to working with you next time.

yzh119 pushed a commit that referenced this pull request Feb 14, 2025
Hi @yzh119 , this is a follow up of #766, an interesting idea came to my
mind today, can't help to change few lines to verify this idea.
We can use asymmetric warp config to solve the register file size limit
issue, the solution is simply to use 8 warps for the output mma stage,
and keep other parts unchanged, because the limitation is on the reg num
per cuda block not the whole SM, there is 64K 32b registers per SM which
is enough for the f32 output of 64 heads.
So we now have 4 warps for the att mma stage, 2 warps for the softmax
stage, 8 warps for output mma stage, and 4 warps for data load stage,
the diagram is updated below:

![image](https://github.com/user-attachments/assets/2af8c5d9-d5a5-47e6-bd63-7e6b4305a529)

After the change, output mma stage needs more computation, the benchmark
drops a little as expected, but still looks good:

![image](https://github.com/user-attachments/assets/470ec576-ba91-4e71-9604-fcd6f0a9d691)

It seems the performance of this CuTe implementation is slightly better
than the current FA2 implementation according to #814

![image](https://github.com/user-attachments/assets/9f61e2ff-4bb6-4581-a199-bb6176173192)


So I think this CuTe implementation still has its value, consider such
interesting scheduling design and better performance, maybe we can
regard it as an ad hoc implementation for (decode only /128 q-heads /
SM80) case, and JIT logic can accommodate this kernel.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants