-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support MLA decode, implemented by CuTe targeted to SM80 #766
Conversation
att[i] = -flashinfer::math::inf; | ||
} | ||
|
||
float row_o_scale = expf(row_max_prev - row_max); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have sense about the input value range of exp?
For fp8, one solution to improve numerical stability is to multiply
p = exp2(log2(448) + x - row_max)
o += p * v
o = o / 448
where 448 is the maximum value of e4m3.
@@ -42,7 +42,7 @@ struct state_t { | |||
|
|||
__device__ __forceinline__ state_t() { init(); } | |||
|
|||
__device__ __forceinline__ float get_lse() const { return m + math::ptx_log2(d); } | |||
__device__ __forceinline__ float get_lse() const { return m + logf(d); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that we enabled -use_fast_math
so expf
will be cmopiled to (similarly for logf
):
exp2(log2e * x)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the numerical accuracy might not be coming from approx exp2, instead, we should multiply sm_scale
by log2e
like other kernel implementations.
5527c3f
to
26200c8
Compare
Hi @yzh119 you are right, the inaccuracy is not from approx exp2. |
0625a04
to
1e0192b
Compare
… versions of MLA decode
1e0192b
to
bab7b50
Compare
I checked this implementation, and it seems to be using fp16 (please correct me if I'm wrong) as the mma output data type (and data type for o register), which is not enough to guarantee correctness, especially for long kv-cache. If you use f32 for o data type, then this implementation (which uses 4 warps) cannot fit Since #804 implements both decode and incremental prefill, and also supports versatile number heads, I think we should turn to maintain this codebase together instead: https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/attention/mla_fa2.cuh Let me know if you have other concerns, I really appreciate your work in this and I believe we can learn from your findings, especially in Thank you again for making this contribution! |
No problem, the limitation of register file size is a practical issue, the scheduling design of this PR is what I can do best to accommodate 64 q-heads for one block, indeed no way to use f32 o register. |
Hi @yzh119 , this is a follow up of #766, an interesting idea came to my mind today, can't help to change few lines to verify this idea. We can use asymmetric warp config to solve the register file size limit issue, the solution is simply to use 8 warps for the output mma stage, and keep other parts unchanged, because the limitation is on the reg num per cuda block not the whole SM, there is 64K 32b registers per SM which is enough for the f32 output of 64 heads. So we now have 4 warps for the att mma stage, 2 warps for the softmax stage, 8 warps for output mma stage, and 4 warps for data load stage, the diagram is updated below:  After the change, output mma stage needs more computation, the benchmark drops a little as expected, but still looks good:  It seems the performance of this CuTe implementation is slightly better than the current FA2 implementation according to #814  So I think this CuTe implementation still has its value, consider such interesting scheduling design and better performance, maybe we can regard it as an ad hoc implementation for (decode only /128 q-heads / SM80) case, and JIT logic can accommodate this kernel.
The scheduling design is shown in below diagram (q_pe and kpe parts are omitted), you can tell where each tensor resides by the prefix of their name, such as

smem
,reg
andgmem
.Some design key points:
k_kv_tile_len
be the minimum value 8, to save the smem space forq_nope
q_nope
withq_pe
andckv
andkpe
in smem, becauseckv
can be transposed in the output matmul stage more easilysmem_att
is only 8 rather smallo_scale_broadcast_mat
shared the same smem data withsmem_o_scale
, the column stride ofo_scale_broadcast_mat
is set to 0, ando_scale_broadcast_mat
is then partitioned bythr_mma_output
and we further
in this way we can easily implement row-wise broadcasting, each element of
smem_o_scale
multiplies one row of output tensor and we don't need to care about the complex register value layout of mma.Benchmark data:
I benchmark the kernel on 4090 which can accommodate 64 q-heads at most, so for 128 q-heads we need to load all kv-cache data twice, we can see for 64 q-heads the BWUtil is up to 85%, for 128 q-heads BWUtil is up to 44% * 2 (88% is higher than 85%, I think it's because of the L2 cache benefit), you can also make a comparison with the benchmark screenshot of the cuda-core implementation:
#551
Limitation: