Skip to content

Commit

Permalink
feat: support deepseek prefill attention shape (#765)
Browse files Browse the repository at this point in the history
Deepseek requires `head_dim_qk` of 192 and `head_dim_vo` of 128, this PR
implements this feature for prefill attention on ragged tensors. (we can
also also support paged kv-cache but it's not emergent, because we only
use `head_dim_qk=192, head_dim_vo=128` for ragged tensor in DeepSeek MLA
w/o matrix absorption, and we need another MQA kernel with
`head_dim_qk=576, head_dim_vo=512` for Deepseek MLA w/ matrix
absorption, I'll upstream that kernel in the next PR)

## Checklist
- [x] Make FA3 template compatible with deepseek model shape
- [x] Make FA2 template compatible with deepseek model shape
- [x] Fix AOT compilation scripts
- [x] Fix C++ tests/benchmarks

## Changes to the programming interface

We added an optional field `num_heads_vo` in the `plan` function
allowing user to specify different `num_heads_qk` and `num_heads_vo`:

```python
wrapper.plan(
    ...
    num_heads_qk,
    num_heads_vo=num_heads_vo
    ...
)
```
  • Loading branch information
yzh119 authored Feb 1, 2025
1 parent 44ee479 commit eb660de
Show file tree
Hide file tree
Showing 66 changed files with 1,428 additions and 1,041 deletions.
51 changes: 37 additions & 14 deletions aot_build_utils/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
generate_batch_paged_decode_inst,
generate_batch_paged_prefill_inst,
generate_batch_ragged_prefill_inst,
generate_dispatch_inc,
generate_single_decode_inst,
generate_single_prefill_inst,
)
Expand All @@ -47,6 +48,19 @@ def write_if_different(path: Path, content: str) -> None:

path.mkdir(parents=True, exist_ok=True)

write_if_different(
path / "dispatch.inc",
generate_dispatch_inc.get_dispatch_inc_str(
argparse.Namespace(
head_dims=head_dims,
head_dims_sm90=head_dims,
pos_encoding_modes=[0],
use_fp16_qk_reductions=[0],
mask_modes=mask_modes,
)
),
)

write_if_different(
path / "aot_default_additional_params.h",
generate_aot_default_additional_params_header.get_aot_default_additional_params_header_str(),
Expand Down Expand Up @@ -79,9 +93,10 @@ def write_if_different(path: Path, content: str) -> None:
product(fp16_dtypes, fp8_dtypes)
):
dtype_out = dtype_q
fname = f"single_decode_head_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu"
fname = f"single_decode_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu"
content = generate_single_decode_inst.get_cu_file_str(
head_dim,
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
dtype_q,
dtype_kv,
Expand All @@ -93,7 +108,8 @@ def write_if_different(path: Path, content: str) -> None:
f"single_decode_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_out}_"
f"head_dim_{head_dim}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}"
Expand All @@ -114,9 +130,10 @@ def write_if_different(path: Path, content: str) -> None:
product(fp16_dtypes, fp8_dtypes)
):
dtype_out = dtype_q
fname = f"batch_paged_decode_head_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu"
fname = f"batch_paged_decode_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu"
content = generate_batch_paged_decode_inst.get_cu_file_str(
head_dim,
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
dtype_q,
dtype_kv,
Expand All @@ -130,7 +147,8 @@ def write_if_different(path: Path, content: str) -> None:
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_out}_"
f"dtype_idx_{idtype}_"
f"head_dim_{head_dim}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}"
Expand All @@ -153,9 +171,10 @@ def write_if_different(path: Path, content: str) -> None:
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list(
product(prefill_dtypes, fp8_dtypes)
):
fname = f"single_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu"
fname = f"single_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu"
content = generate_single_prefill_inst.get_cu_file_str(
head_dim,
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -172,7 +191,8 @@ def write_if_different(path: Path, content: str) -> None:
f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"head_dim_{head_dim}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}_"
Expand All @@ -198,9 +218,10 @@ def write_if_different(path: Path, content: str) -> None:
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list(
product(prefill_dtypes, fp8_dtypes)
):
fname = f"batch_paged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
fname = f"batch_paged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
content = generate_batch_paged_prefill_inst.get_cu_file_str(
head_dim,
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -211,9 +232,10 @@ def write_if_different(path: Path, content: str) -> None:
)
write_if_different(path / fname, content)

fname = f"batch_ragged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
fname = f"batch_ragged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
content = generate_batch_ragged_prefill_inst.get_cu_file_str(
head_dim,
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -234,7 +256,8 @@ def write_if_different(path: Path, content: str) -> None:
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"dtype_idx_{idtype}_"
f"head_dim_{head_dim}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{sliding_window}_"
f"use_logits_cap_{logits_soft_cap}_"
Expand Down
19 changes: 11 additions & 8 deletions aot_build_utils/generate_batch_paged_decode_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@


def get_cu_file_str(
head_dim,
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
dtype_q,
dtype_kv,
Expand All @@ -35,25 +36,25 @@ def get_cu_file_str(
using Params = BatchDecodeParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>;
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>(
Params params,
{dtype_out}* tmp_v, float* tmp_s,
cudaStream_t stream);
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>(
Params params,
{dtype_out}* tmp_v, float* tmp_s,
cudaStream_t stream);
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>(
Params params,
{dtype_out}* tmp_v, float* tmp_s,
cudaStream_t stream);
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>(
Params params,
{dtype_out}* tmp_v, float* tmp_s,
Expand All @@ -69,20 +70,22 @@ def get_cu_file_str(
}}
""".format(
head_dim=head_dim,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
dtype_q=dtype_literal[dtype_q],
dtype_kv=dtype_literal[dtype_kv],
dtype_out=dtype_literal[dtype_out],
idtype=idtype_literal[idtype],
head_dim_kpe=head_dim // 8,
head_dim=head_dim_vo, # NOTE(Zihao): for MLA instantiation, we should move them to a standalone file
head_dim_kpe=head_dim_vo // 8,
)
return content


if __name__ == "__main__":
pattern = (
r"batch_paged_decode_head_([0-9]+)_posenc_([0-9]+)_"
r"batch_paged_decode_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_"
r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu"
)

Expand Down
10 changes: 6 additions & 4 deletions aot_build_utils/generate_batch_paged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@


def get_cu_file_str(
head_dim,
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -41,13 +42,14 @@ def get_cu_file_str(
def get_insts(attention_variant, dtype_out):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{cta_tile_q}, {head_dim}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, {attention_variant}, Params>(
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{cta_tile_q}, {head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, {attention_variant}, Params>(
Params params,
{dtype_out}* tmp_v,
float* tmp_s, cudaStream_t stream);
""".format(
cta_tile_q=cta_tile_q,
head_dim=head_dim,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
use_fp16_qk_reduction=use_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
Expand Down Expand Up @@ -92,7 +94,7 @@ def get_insts(attention_variant, dtype_out):

if __name__ == "__main__":
pattern = (
r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"batch_paged_prefill_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu"
)
compiled_pattern = re.compile(pattern)
Expand Down
20 changes: 13 additions & 7 deletions aot_build_utils/generate_batch_paged_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@


def get_cu_file_str(
head_dim,
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -37,7 +38,8 @@ def get_cu_file_str(
def get_insts(attention_variant):
return """
template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
<{head_dim_qk},
{head_dim_vo},
{mask_mode},
/*USE_SLIDING_WINDOW=*/true,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
Expand All @@ -46,7 +48,8 @@ def get_insts(attention_variant):
(Params& params, cudaStream_t stream);
template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
<{head_dim_qk},
{head_dim_vo},
{mask_mode},
/*USE_SLIDING_WINDOW=*/true,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
Expand All @@ -55,7 +58,8 @@ def get_insts(attention_variant):
(Params& params, cudaStream_t stream);
template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
<{head_dim_qk},
{head_dim_vo},
{mask_mode},
/*USE_SLIDING_WINDOW=*/false,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
Expand All @@ -64,15 +68,17 @@ def get_insts(attention_variant):
(Params& params, cudaStream_t stream);
template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
<{head_dim_qk},
{head_dim_vo},
{mask_mode},
/*USE_SLIDING_WINDOW=*/false,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
{attention_variant},
Params>
(Params& params, cudaStream_t stream);
""".format(
head_dim=head_dim,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
)
Expand Down Expand Up @@ -107,7 +113,7 @@ def get_insts(attention_variant):

if __name__ == "__main__":
pattern = (
r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"batch_paged_prefill_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_"
r"dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
)
Expand Down
10 changes: 6 additions & 4 deletions aot_build_utils/generate_batch_ragged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@


def get_cu_file_str(
head_dim,
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -41,13 +42,14 @@ def get_cu_file_str(
def get_insts(attention_variant, dtype_out):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{cta_tile_q}, {head_dim}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, {attention_variant}, Params>(
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{cta_tile_q}, {head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, {attention_variant}, Params>(
Params params,
{dtype_out}* tmp_v,
float* tmp_s, cudaStream_t stream);
""".format(
cta_tile_q=cta_tile_q,
head_dim=head_dim,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
use_fp16_qk_reduction=use_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
Expand Down Expand Up @@ -94,7 +96,7 @@ def get_insts(attention_variant, dtype_out):

if __name__ == "__main__":
pattern = (
r"batch_ragged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"batch_ragged_prefill_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu"
)
compiled_pattern = re.compile(pattern)
Expand Down
Loading

0 comments on commit eb660de

Please sign in to comment.