Skip to content

Commit

Permalink
test: add prefill/decode test for non-contiguous kv cache
Browse files Browse the repository at this point in the history
Signed-off-by: LinHeLurking <[email protected]>
  • Loading branch information
LinHeLurking committed Sep 27, 2024
1 parent f159a70 commit 0ca5ec6
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 73 deletions.
136 changes: 98 additions & 38 deletions python/tests/test_batch_decode_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
@pytest.mark.parametrize(
"kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2]
)
@pytest.mark.parametrize("contiguous_kv", [True, False])
def test_batch_decode_with_paged_kv_cache(
batch_size,
kv_len,
Expand All @@ -46,15 +47,33 @@ def test_batch_decode_with_paged_kv_cache(
return_lse,
q_dtype,
kv_dtype,
contiguous_kv,
):
q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(q_dtype)
num_pages_per_seq = (kv_len + page_size - 1) // page_size
total_num_pages = num_pages_per_seq * batch_size
kv_data = (
torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0)
if kv_layout == "HND"
else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim).to(0)
)
if kv_layout == "HND":
kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim]
else:
kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim]
if not contiguous_kv:
tmp = [kv_shape[0]]
for v in kv_shape[1:]:
tmp.append(2)
tmp.append(v)
kv_shape = tmp
kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32).to(0)
kv_data = kv_data_fp32.to(kv_dtype)
kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :]
kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :]
# actual data is stored in non-contiguous memory
assert (
kv_data.stride(-4)
!= kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1]
)
else:
kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32).to(0)
kv_data = kv_data_fp32.to(kv_dtype)
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq
kv_indices = torch.arange(0, total_num_pages).to(0).int()
kv_last_page_len = torch.full(
Expand All @@ -77,23 +96,23 @@ def test_batch_decode_with_paged_kv_cache(
q_data_type=q_dtype,
)
if return_lse:
o, _ = wrapper.run_return_lse(q, kv_data.to(kv_dtype))
o, _ = wrapper.run_return_lse(q, kv_data)
else:
o = wrapper.run(q, kv_data.to(kv_dtype))
o = wrapper.run(q, kv_data)

for i in range(batch_size):
perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3]
perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2]
qi = q[i]
ki = torch.cat(
[
kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 0]
kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 0]
.permute(*perm_dims)
.reshape(-1, num_kv_heads, head_dim),
(
kv_data[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]]
kv_data_fp32[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]]
if kv_layout == "HND"
else kv_data[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :]
else kv_data_fp32[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :]
)
.permute(*perm_dims_last)
.reshape(-1, num_kv_heads, head_dim),
Expand All @@ -102,13 +121,13 @@ def test_batch_decode_with_paged_kv_cache(
).to(kv_dtype)
vi = torch.cat(
[
kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 1]
kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 1]
.permute(*perm_dims)
.reshape(-1, num_kv_heads, head_dim),
(
kv_data[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]]
kv_data_fp32[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]]
if kv_layout == "HND"
else kv_data[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :]
else kv_data_fp32[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :]
)
.permute(*perm_dims_last)
.reshape(-1, num_kv_heads, head_dim),
Expand Down Expand Up @@ -141,6 +160,7 @@ def test_batch_decode_with_paged_kv_cache(
@pytest.mark.parametrize(
"kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2]
)
@pytest.mark.parametrize("contiguous_kv", [True, False])
def test_batch_decode_with_tuple_paged_kv_cache(
batch_size,
kv_len,
Expand All @@ -154,18 +174,39 @@ def test_batch_decode_with_tuple_paged_kv_cache(
return_lse,
q_dtype,
kv_dtype,
contiguous_kv,
):
q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(q_dtype)
num_pages_per_seq = (kv_len + page_size - 1) // page_size
total_num_pages = num_pages_per_seq * batch_size
kv_data = tuple(
(
torch.randn(total_num_pages, num_kv_heads, page_size, head_dim).to(0)
if kv_layout == "HND"
else torch.randn(total_num_pages, page_size, num_kv_heads, head_dim).to(0)
)
for _ in range(2)
)
if kv_layout == "HND":
kv_shape = [total_num_pages, num_kv_heads, page_size, head_dim]
else:
kv_shape = [total_num_pages, page_size, num_kv_heads, head_dim]
if not contiguous_kv:
tmp = [kv_shape[0]]
for v in kv_shape[1:]:
tmp.append(2)
tmp.append(v)
kv_shape = tmp
kv_data_fp32 = [
torch.randn(*kv_shape, dtype=torch.float32).to(0) for _ in range(2)
]
kv_data = [kv_data_fp32[i].to(kv_dtype) for i in range(2)]
for i in range(2):
kv_data_fp32[i] = kv_data_fp32[i][:, 1, :, 1, :, 1, :]
kv_data[i] = kv_data[i][:, 1, :, 1, :, 1, :]
# actual data is stored in non-contiguous memory
assert (
kv_data[i].stride(-4)
!= kv_data[i].shape[-3] * kv_data[i].shape[-2] * kv_data[i].shape[-1]
)
else:
kv_data_fp32 = [
torch.randn(*kv_shape, dtype=torch.float32).to(0) for _ in range(2)
]
kv_data = [kv_data_fp32[i].to(kv_dtype) for i in range(2)]
kv_data = tuple(kv_data)
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq
kv_indices = torch.arange(0, total_num_pages).to(0).int()
kv_last_page_len = torch.full(
Expand All @@ -188,11 +229,11 @@ def test_batch_decode_with_tuple_paged_kv_cache(
q_data_type=q_dtype,
)
if return_lse:
o, _ = wrapper.run_return_lse(q, tuple(map(lambda _: _.to(kv_dtype), kv_data)))
o, _ = wrapper.run_return_lse(q, kv_data)
else:
o = wrapper.run(q, tuple(map(lambda _: _.to(kv_dtype), kv_data)))
o = wrapper.run(q, kv_data)

k_cache, v_cache = kv_data
k_cache, v_cache = kv_data_fp32
for i in range(batch_size):
perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3]
perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2]
Expand All @@ -215,6 +256,7 @@ def test_batch_decode_with_tuple_paged_kv_cache(
vi = torch.cat(
[
v_cache[kv_indptr[i] : kv_indptr[i + 1] - 1]
.to(torch.float32) # torch.cat does not support some fp8 types
.permute(*perm_dims)
.reshape(-1, num_kv_heads, head_dim),
(
Expand Down Expand Up @@ -251,6 +293,7 @@ def test_batch_decode_with_tuple_paged_kv_cache(
@pytest.mark.parametrize(
"kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2]
)
@pytest.mark.parametrize("contiguous_kv", [True, False])
def test_cuda_graph_batch_decode_with_paged_kv_cache(
batch_size,
kv_len,
Expand All @@ -262,16 +305,33 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache(
pos_encoding_mode,
q_dtype,
kv_dtype,
contiguous_kv,
):
q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(q_dtype)
num_pages_per_seq = (kv_len + page_size - 1) // page_size
total_num_pages = num_pages_per_seq * batch_size
kv_data = (
torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0)
if kv_layout == "HND"
else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim).to(0)
)
kv_data_dtype = kv_data.to(kv_dtype)
if kv_layout == "HND":
kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim]
else:
kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim]
if not contiguous_kv:
tmp = [kv_shape[0]]
for v in kv_shape[1:]:
tmp.append(2)
tmp.append(v)
kv_shape = tmp
kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32).to(0)
kv_data = kv_data_fp32.to(kv_dtype)
kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :]
kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :]
# actual data is stored in non-contiguous memory
assert (
kv_data.stride(-4)
!= kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1]
)
else:
kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32).to(0)
kv_data = kv_data_fp32.to(kv_dtype)
kv_indptr_host_warmup = torch.arange(0, batch_size + 1).int()
kv_indices_host_warmup = torch.arange(0, batch_size).int()
kv_last_page_len_host_warmup = torch.full(
Expand Down Expand Up @@ -308,13 +368,13 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache(
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
o = wrapper.run(q, kv_data_dtype)
o = wrapper.run(q, kv_data)
torch.cuda.current_stream().wait_stream(s)

# capture
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
o = wrapper.run(q, kv_data_dtype)
o = wrapper.run(q, kv_data)

# replay multiple times
for i in range(1, min(4, num_pages_per_seq)):
Expand Down Expand Up @@ -367,13 +427,13 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache(
qi = q[i]
ki = torch.cat(
[
kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 0]
kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 0]
.permute(*perm_dims)
.reshape(-1, num_kv_heads, head_dim),
(
kv_data[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]]
kv_data_fp32[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]]
if kv_layout == "HND"
else kv_data[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :]
else kv_data_fp32[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :]
)
.permute(*perm_dims_last)
.reshape(-1, num_kv_heads, head_dim),
Expand All @@ -382,13 +442,13 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache(
).to(kv_dtype)
vi = torch.cat(
[
kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 1]
kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 1]
.permute(*perm_dims)
.reshape(-1, num_kv_heads, head_dim),
(
kv_data[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]]
kv_data_fp32[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]]
if kv_layout == "HND"
else kv_data[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :]
else kv_data_fp32[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :]
)
.permute(*perm_dims_last)
.reshape(-1, num_kv_heads, head_dim),
Expand Down
Loading

0 comments on commit 0ca5ec6

Please sign in to comment.