From c04755e21f4d6fb7813c703f2b00a7ef012be9b8 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 1 Feb 2025 16:19:36 -0500 Subject: [PATCH] bugfix: fix the JIT warmup arguments in unittests (#775) Followup of #765 , fix the JIT warmup utilities functions. --- tests/jit_utils.py | 6 ++++-- tests/test_jit_warmup.py | 20 ++++++++++++++------ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/jit_utils.py b/tests/jit_utils.py index 6cc8787bf..d6648bee6 100644 --- a/tests/jit_utils.py +++ b/tests/jit_utils.py @@ -53,7 +53,8 @@ def jit_decode_attention_func_args( q_dtype, kv_dtype, q_dtype, - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo pos_encoding_mode, use_sliding_window, use_logits_soft_cap, @@ -68,7 +69,8 @@ def jit_decode_attention_func_args( kv_dtype, q_dtype, torch.int32, - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo pos_encoding_mode, use_sliding_window, use_logits_soft_cap, diff --git a/tests/test_jit_warmup.py b/tests/test_jit_warmup.py index e89f6e7af..930a70aeb 100644 --- a/tests/test_jit_warmup.py +++ b/tests/test_jit_warmup.py @@ -36,7 +36,8 @@ def test_warmpup_llama(): torch.float16, torch.float16, torch.int32, - 128, + 128, # head_dim_qk + 128, # head_dim_vo PosEncodingMode.NONE.value, False, # use_sliding_window False, # use_logits_soft_cap @@ -45,11 +46,13 @@ def test_warmpup_llama(): ( flashinfer.prefill.gen_batch_prefill_module, [ + "fa2", # backend torch.float16, torch.float16, torch.float16, torch.int32, - 128, + 128, # head_dim_qk + 128, # head_dim_vo PosEncodingMode.NONE.value, False, # use_sliding_window False, # use_logits_soft_cap @@ -75,7 +78,8 @@ def test_warmpup_llama_sm90(): torch.float16, torch.float16, torch.int32, - 128, + 128, # head_dim_qk + 128, # head_dim_vo PosEncodingMode.NONE.value, False, # use_sliding_window False, # use_logits_soft_cap @@ -84,11 +88,13 @@ def test_warmpup_llama_sm90(): ( flashinfer.prefill.gen_batch_prefill_module, [ + "fa2", # backend torch.float16, torch.float16, torch.float16, torch.int32, - 128, + 128, # head_dim_qk + 128, # head_dim_vo PosEncodingMode.NONE.value, False, # use_sliding_window False, # use_logits_soft_cap @@ -96,13 +102,15 @@ def test_warmpup_llama_sm90(): ], ), ( - flashinfer.prefill.gen_batch_prefill_sm90_module, + flashinfer.prefill.gen_batch_prefill_module, [ + "fa3", # backend torch.float16, torch.float16, torch.float16, torch.int32, - 128, + 128, # head_dim_qk + 128, # head_dim_vo PosEncodingMode.NONE.value, False, # use_sliding_window False, # use_logits_soft_cap