diff --git a/tests/test_deepseek_mla.py b/tests/test_deepseek_mla.py index eafd2d718..ba1667baf 100644 --- a/tests/test_deepseek_mla.py +++ b/tests/test_deepseek_mla.py @@ -152,6 +152,24 @@ def test_batch_prefill_with_ragged_kv_cache( torch.testing.assert_close(lse, lse_buffer, rtol=1e-3, atol=1e-3) +def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads): + bs_page_num, page_size, ckv_dim = ckv.shape + page_num = bs_page_num // batch_size + _, _, kpe_dim = kpe.shape + ckv = ckv.view(batch_size, page_num * page_size, ckv_dim) + kpe = kpe.view(batch_size, page_num * page_size, kpe_dim) + ckv = ckv[:, :kv_len, :] + kpe = kpe[:, :kv_len, :] + k = ( + torch.cat([ckv, kpe], dim=-1) + .view(-1, 1, ckv_dim + kpe_dim) + .repeat_interleave(num_heads, dim=1) + ) + v = ckv.repeat_interleave(num_heads, dim=1) + + return k, v + + @pytest.mark.parametrize("batch_size", [1, 17, 37]) @pytest.mark.parametrize("kv_len", [17, 33, 96, 97, 114, 514, 1024]) @pytest.mark.parametrize("qo_len", [1, 17, 37, 77]) @@ -171,8 +189,6 @@ def test_batch_mla_page_attention( if causal and qo_len > kv_len: pytest.skip("qo_len > kv_len not supported for causal attention") torch.manual_seed(42) - if kv_len % page_size != 0: - pytest.skip("kv_len not divisible by page_size") head_dim_ckv = 512 head_dim_kpe = 64 q_nope = torch.randn( @@ -181,15 +197,16 @@ def test_batch_mla_page_attention( q_pe = torch.randn( batch_size * qo_len, num_heads, head_dim_kpe, dtype=torch.half, device="cuda" ) + pages_num = math.ceil(kv_len / page_size) ckv = torch.randn( - batch_size * kv_len // page_size, + batch_size * pages_num, page_size, head_dim_ckv, dtype=torch.half, device="cuda", ) kpe = torch.randn( - batch_size * kv_len // page_size, + batch_size * pages_num, page_size, head_dim_kpe, dtype=torch.half, @@ -201,8 +218,8 @@ def test_batch_mla_page_attention( workspace_buffer, backend=backend ) q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len - kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len // page_size - kv_indices = torch.arange(0, batch_size * kv_len // page_size).to(0).int() + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * pages_num + kv_indices = torch.arange(0, batch_size * pages_num).to(0).int() kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0) wrapper.plan( q_indptr, @@ -220,12 +237,7 @@ def test_batch_mla_page_attention( ) o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) - k = ( - torch.cat([ckv, kpe], dim=-1) - .view(-1, 1, head_dim_ckv + head_dim_kpe) - .repeat_interleave(num_heads, dim=1) - ) - v = ckv.repeat_interleave(num_heads, dim=1) + k, v = generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads) q = torch.cat([q_nope, q_pe], dim=-1) o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale)