From 96947a34c7ec5cbfcaff598554fdd69d19722b93 Mon Sep 17 00:00:00 2001 From: bob Date: Mon, 17 Feb 2025 10:52:24 +0000 Subject: [PATCH 1/2] Test of kv_len not evenly divided by page_size. --- tests/test_deepseek_mla.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/tests/test_deepseek_mla.py b/tests/test_deepseek_mla.py index eafd2d718..205258ac7 100644 --- a/tests/test_deepseek_mla.py +++ b/tests/test_deepseek_mla.py @@ -151,6 +151,18 @@ def test_batch_prefill_with_ragged_kv_cache( torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) torch.testing.assert_close(lse, lse_buffer, rtol=1e-3, atol=1e-3) +def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads): + bs_page_num, page_size, ckv_dim = ckv.shape + page_num = bs_page_num // batch_size + _,_,kpe_dim = kpe.shape + ckv = ckv.view(batch_size, page_num * page_size, ckv_dim) + kpe = kpe.view(batch_size, page_num * page_size, kpe_dim) + ckv = ckv[:, :kv_len, :] + kpe = kpe[:, :kv_len, :] + k = (torch.cat([ckv, kpe], dim = -1).view(-1, 1, ckv_dim + kpe_dim).repeat_interleave(num_heads, dim = 1)) + v = ckv.repeat_interleave(num_heads, dim = 1) + + return k, v @pytest.mark.parametrize("batch_size", [1, 17, 37]) @pytest.mark.parametrize("kv_len", [17, 33, 96, 97, 114, 514, 1024]) @@ -171,8 +183,6 @@ def test_batch_mla_page_attention( if causal and qo_len > kv_len: pytest.skip("qo_len > kv_len not supported for causal attention") torch.manual_seed(42) - if kv_len % page_size != 0: - pytest.skip("kv_len not divisible by page_size") head_dim_ckv = 512 head_dim_kpe = 64 q_nope = torch.randn( @@ -180,16 +190,17 @@ def test_batch_mla_page_attention( ) q_pe = torch.randn( batch_size * qo_len, num_heads, head_dim_kpe, dtype=torch.half, device="cuda" - ) + ) + pages_num = math.ceil(kv_len // page_size) ckv = torch.randn( - batch_size * kv_len // page_size, + batch_size * pages_num, page_size, head_dim_ckv, dtype=torch.half, device="cuda", ) kpe = torch.randn( - batch_size * kv_len // page_size, + batch_size * pages_num, page_size, head_dim_kpe, dtype=torch.half, @@ -201,8 +212,8 @@ def test_batch_mla_page_attention( workspace_buffer, backend=backend ) q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len - kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len // page_size - kv_indices = torch.arange(0, batch_size * kv_len // page_size).to(0).int() + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * pages_num + kv_indices = torch.arange(0, batch_size * pages_num).to(0).int() kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0) wrapper.plan( q_indptr, @@ -220,12 +231,7 @@ def test_batch_mla_page_attention( ) o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) - k = ( - torch.cat([ckv, kpe], dim=-1) - .view(-1, 1, head_dim_ckv + head_dim_kpe) - .repeat_interleave(num_heads, dim=1) - ) - v = ckv.repeat_interleave(num_heads, dim=1) + k, v = generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads) q = torch.cat([q_nope, q_pe], dim=-1) o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale) From 03c79165164078db2cdbef733bcd5d9b258c138c Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 17 Feb 2025 19:35:22 +0000 Subject: [PATCH 2/2] bugfix --- tests/test_deepseek_mla.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_deepseek_mla.py b/tests/test_deepseek_mla.py index 205258ac7..ba1667baf 100644 --- a/tests/test_deepseek_mla.py +++ b/tests/test_deepseek_mla.py @@ -151,19 +151,25 @@ def test_batch_prefill_with_ragged_kv_cache( torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) torch.testing.assert_close(lse, lse_buffer, rtol=1e-3, atol=1e-3) + def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads): bs_page_num, page_size, ckv_dim = ckv.shape page_num = bs_page_num // batch_size - _,_,kpe_dim = kpe.shape + _, _, kpe_dim = kpe.shape ckv = ckv.view(batch_size, page_num * page_size, ckv_dim) kpe = kpe.view(batch_size, page_num * page_size, kpe_dim) ckv = ckv[:, :kv_len, :] kpe = kpe[:, :kv_len, :] - k = (torch.cat([ckv, kpe], dim = -1).view(-1, 1, ckv_dim + kpe_dim).repeat_interleave(num_heads, dim = 1)) - v = ckv.repeat_interleave(num_heads, dim = 1) + k = ( + torch.cat([ckv, kpe], dim=-1) + .view(-1, 1, ckv_dim + kpe_dim) + .repeat_interleave(num_heads, dim=1) + ) + v = ckv.repeat_interleave(num_heads, dim=1) return k, v + @pytest.mark.parametrize("batch_size", [1, 17, 37]) @pytest.mark.parametrize("kv_len", [17, 33, 96, 97, 114, 514, 1024]) @pytest.mark.parametrize("qo_len", [1, 17, 37, 77]) @@ -190,8 +196,8 @@ def test_batch_mla_page_attention( ) q_pe = torch.randn( batch_size * qo_len, num_heads, head_dim_kpe, dtype=torch.half, device="cuda" - ) - pages_num = math.ceil(kv_len // page_size) + ) + pages_num = math.ceil(kv_len / page_size) ckv = torch.randn( batch_size * pages_num, page_size,