Skip to content

Commit

Permalink
bugfix: fix the JIT warmup arguments in unittests (#775)
Browse files Browse the repository at this point in the history
Followup of #765 , fix the JIT warmup utilities functions.
  • Loading branch information
yzh119 authored Feb 1, 2025
1 parent a0443d5 commit c04755e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
6 changes: 4 additions & 2 deletions tests/jit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def jit_decode_attention_func_args(
q_dtype,
kv_dtype,
q_dtype,
head_dim,
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_sliding_window,
use_logits_soft_cap,
Expand All @@ -68,7 +69,8 @@ def jit_decode_attention_func_args(
kv_dtype,
q_dtype,
torch.int32,
head_dim,
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_sliding_window,
use_logits_soft_cap,
Expand Down
20 changes: 14 additions & 6 deletions tests/test_jit_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def test_warmpup_llama():
torch.float16,
torch.float16,
torch.int32,
128,
128, # head_dim_qk
128, # head_dim_vo
PosEncodingMode.NONE.value,
False, # use_sliding_window
False, # use_logits_soft_cap
Expand All @@ -45,11 +46,13 @@ def test_warmpup_llama():
(
flashinfer.prefill.gen_batch_prefill_module,
[
"fa2", # backend
torch.float16,
torch.float16,
torch.float16,
torch.int32,
128,
128, # head_dim_qk
128, # head_dim_vo
PosEncodingMode.NONE.value,
False, # use_sliding_window
False, # use_logits_soft_cap
Expand All @@ -75,7 +78,8 @@ def test_warmpup_llama_sm90():
torch.float16,
torch.float16,
torch.int32,
128,
128, # head_dim_qk
128, # head_dim_vo
PosEncodingMode.NONE.value,
False, # use_sliding_window
False, # use_logits_soft_cap
Expand All @@ -84,25 +88,29 @@ def test_warmpup_llama_sm90():
(
flashinfer.prefill.gen_batch_prefill_module,
[
"fa2", # backend
torch.float16,
torch.float16,
torch.float16,
torch.int32,
128,
128, # head_dim_qk
128, # head_dim_vo
PosEncodingMode.NONE.value,
False, # use_sliding_window
False, # use_logits_soft_cap
False, # use_fp16_qk_reduction
],
),
(
flashinfer.prefill.gen_batch_prefill_sm90_module,
flashinfer.prefill.gen_batch_prefill_module,
[
"fa3", # backend
torch.float16,
torch.float16,
torch.float16,
torch.int32,
128,
128, # head_dim_qk
128, # head_dim_vo
PosEncodingMode.NONE.value,
False, # use_sliding_window
False, # use_logits_soft_cap
Expand Down

0 comments on commit c04755e

Please sign in to comment.