From 3edfc58669646ffc8fe4914fe3dc443e95a046a9 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 30 Apr 2024 13:52:07 +0800 Subject: [PATCH 01/28] [Bugfix][Kernel] allow non-power-of-two head sizes in prefix prefill with alibi --- vllm/attention/ops/prefix_prefill.py | 45 ++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 4896cf3909c6e..900f88c9b5759 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -437,7 +437,8 @@ def _fwd_kernel_alibi( stride_v_cache_bl, num_queries_per_kv: int, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, ): # attn_bias[] @@ -458,21 +459,34 @@ def _fwd_kernel_alibi( # initialize offsets offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + # offs_d = tl.arange(0, BLOCK_DMODEL) + # NOTE: + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) - + + # NOTE: + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + q = tl.load( Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), other=0.0) + + # q = tl.load( + # Q + off_q, + # mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + # other=0.0) # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) alibi_slope = tl.load(Alibi_slopes + cur_head) alibi_start_q = tl.arange( @@ -531,8 +545,12 @@ def _fwd_kernel_alibi( # acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc + # v = tl.load(V_cache + off_v, + # mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + # other=0.0) v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), other=0.0) p = p.to(v.dtype) @@ -600,10 +618,15 @@ def _fwd_kernel_alibi( # acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc + # v = tl.load(v_ptrs + + # (cur_batch_in_all_start_index + start_n) * stride_vbs, + # mask=(start_n + offs_n[:, None]) < + # cur_batch_seq_len - cur_batch_ctx_len, + # other=0.0) v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len), other=0.0) p = p.to(v.dtype) @@ -619,9 +642,13 @@ def _fwd_kernel_alibi( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od) out_ptrs = Out + off_o + # tl.store(out_ptrs, + # acc, + # mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) tl.store(out_ptrs, acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + mask=dim_mask[None, :] & (offs_m[:, None] < + cur_batch_seq_len - cur_batch_ctx_len)) return @torch.inference_mode() From 2a8f222b31fa30e0b18d79320b7b1f3029f45989 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 30 Apr 2024 13:59:35 +0800 Subject: [PATCH 02/28] remove un-need comments --- vllm/attention/ops/prefix_prefill.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 900f88c9b5759..b0d678243963e 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -459,15 +459,12 @@ def _fwd_kernel_alibi( # initialize offsets offs_n = tl.arange(0, BLOCK_N) - # offs_d = tl.arange(0, BLOCK_DMODEL) - # NOTE: offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) - # NOTE: dim_mask = tl.where( tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) @@ -477,15 +474,9 @@ def _fwd_kernel_alibi( (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), other=0.0) - # q = tl.load( - # Q + off_q, - # mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - # other=0.0) - # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - # acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) alibi_slope = tl.load(Alibi_slopes + cur_head) @@ -545,9 +536,6 @@ def _fwd_kernel_alibi( # acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - # v = tl.load(V_cache + off_v, - # mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - # other=0.0) v = tl.load(V_cache + off_v, mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < cur_batch_ctx_len), @@ -618,11 +606,6 @@ def _fwd_kernel_alibi( # acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - # v = tl.load(v_ptrs + - # (cur_batch_in_all_start_index + start_n) * stride_vbs, - # mask=(start_n + offs_n[:, None]) < - # cur_batch_seq_len - cur_batch_ctx_len, - # other=0.0) v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < @@ -642,9 +625,6 @@ def _fwd_kernel_alibi( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od) out_ptrs = Out + off_o - # tl.store(out_ptrs, - # acc, - # mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) tl.store(out_ptrs, acc, mask=dim_mask[None, :] & (offs_m[:, None] < From 7c5ea9b7f8977a97e6cb86e0d52a4477acb6b7e8 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 30 Apr 2024 14:03:48 +0800 Subject: [PATCH 03/28] Update context_attention_fwd --- vllm/attention/ops/prefix_prefill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index b0d678243963e..61205c31d9038 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -661,7 +661,6 @@ def context_attention_fwd(q, num_warps = 8 if Lk <= 64 else 8 if alibi_slopes is not None: - assert Lk == Lk_padded _fwd_kernel_alibi[grid]( q, k, @@ -706,6 +705,7 @@ def context_attention_fwd(q, num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, From 9a6b936c73b15a870f20d5ed5f76d0b701398ff3 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Wed, 1 May 2024 11:15:42 +0800 Subject: [PATCH 04/28] remove blanks --- vllm/attention/ops/prefix_prefill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 61205c31d9038..190f4c0285c08 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -473,7 +473,7 @@ def _fwd_kernel_alibi( mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), other=0.0) - + # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) From 2a536a4715d09b63b8eaeb963ce06cfac1f1f9ec Mon Sep 17 00:00:00 2001 From: DefTruth Date: Fri, 3 May 2024 11:12:10 +0800 Subject: [PATCH 05/28] format code --- vllm/attention/ops/prefix_prefill.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 26e0f4b1028ba..a61a11fce4d01 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -499,15 +499,14 @@ def _fwd_kernel_alibi( off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) - + dim_mask = tl.where( tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) - - q = tl.load( - Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") @@ -643,8 +642,9 @@ def _fwd_kernel_alibi( # update acc v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len), + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len), other=0.0) p = p.to(v.dtype) @@ -662,8 +662,8 @@ def _fwd_kernel_alibi( out_ptrs = Out + off_o tl.store(out_ptrs, acc, - mask=dim_mask[None, :] & (offs_m[:, None] < - cur_batch_seq_len - cur_batch_ctx_len)) + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) return @torch.inference_mode() From c62409c1521b458acd85d6087cf94b53b2690ca2 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Sat, 4 May 2024 20:14:08 +0800 Subject: [PATCH 06/28] add prefix prefill alibi test --- tests/kernels/test_prefix_prefill_alibi.py | 247 +++++++++++++++++++++ 1 file changed, 247 insertions(+) create mode 100644 tests/kernels/test_prefix_prefill_alibi.py diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py new file mode 100644 index 0000000000000..a7a0a25685055 --- /dev/null +++ b/tests/kernels/test_prefix_prefill_alibi.py @@ -0,0 +1,247 @@ +import random +import time +import math +import pytest +import torch +from xformers import ops as xops + +from vllm.attention.ops.prefix_prefill import context_attention_fwd + +NUM_HEADS = [32, 64] +NUM_QUERIES_PER_KV = [1] +HEAD_SIZES = [128, 96, 80] +DTYPES = [torch.float16] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_contexted_kv_attention_alibi( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + dtype: torch.dtype, + device: str, +) -> None: + random.seed(0) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + torch.set_default_device(device) + + # Need this, otherwise when we capture the graph the process + # for GPU 1 would run on both GPU0 and GPU1 and things would hang + # + # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 + torch.cuda.set_device(device) + alibi_slopes = _get_alibi_slopes(num_heads).to(device) + + MAX_SEQ_LEN = 1024 + MAX_CTX_LEN = 1024 + BS = 10 + cache_size = 640 + block_size = 32 + max_block_per_request = 64 + query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] + num_kv_heads = num_heads // num_queries_per_kv + + num_tokens = sum(query_lens) + query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + query.uniform_(-1e-3, 1e-3) + output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + + kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) + kv.uniform_(-1e-3, 1e-3) + key, value = kv.unbind(dim=1) + + k_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + v_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + values = torch.arange(0, cache_size, dtype=torch.long) + values = values[torch.randperm(cache_size)] + block_table = values[:BS * max_block_per_request].view( + BS, max_block_per_request) + b_seq_len = torch.tensor(seq_lens, dtype=torch.long) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + dtype=torch.long), + dim=0) + max_input_len = MAX_SEQ_LEN + # copy kv to cache + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], + dtype=torch.long), + dim=0) + for i in range(BS): + for j in range(query_lens[i]): + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + + b_ctx_len[i] + j]) + cur_ctx = 0 + block_id = 0 + while cur_ctx < b_ctx_len[i]: + start_loc = b_seq_start_loc[i] + cur_ctx + if cur_ctx + block_size > b_ctx_len[i]: + end_loc = b_seq_start_loc[i] + b_ctx_len[i] + else: + end_loc = start_loc + block_size + start_slot = block_table[i, block_id] * block_size + end_slot = start_slot + end_loc - start_loc + k_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) + cur_ctx += block_size + block_id += 1 + # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] + # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] + k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, + 8).permute(0, 2, 3, 1, 4).contiguous() + # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] + # to V_cache[num_blocks, num_kv_heads, head_size, block_size] + v_cache = v_cache.view(-1, block_size, num_kv_heads, + head_size).permute(0, 2, 3, 1).contiguous() + + # Warm up the Triton kernel by calling it once before actually measuring + # generation time + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + start_time = time.time() + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + end_time = time.time() + print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + + scale = float(1.0 / (head_size**0.5)) + + if num_kv_heads != num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # + # see also: vllm/model_executor/layers/attention.py + query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, + query.shape[-1]) + key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, + num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], num_kv_heads, + num_queries_per_kv, value.shape[-1]) + + # NOTE: We have to pad query tensor in order to reuse + # _make_alibi_bias function. + if query.shape[0] != key.shape[0]: + query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype) + query_pad.uniform_(-1e-3, 1e-3) + seq_start = 0 + query_start = 0 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + query_pad[seq_start:seq_end, ...] = torch.cat( + [torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype), + query[query_start:query_end, ...]], dim=0 + ) + seq_start += seq_len + query_start += query_len + query = query_pad + + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + from vllm.attention.backends.xformers import _make_alibi_bias + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, + dtype, seq_lens) + output_ref = torch.empty_like(output) + seq_start = 0 + query_start = 0 + start_time = time.time() + # Attention with alibi slopes. + # FIXME(DefTruth): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + # reference: vllm/attention/backends/xformers.py#343 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + out = xops.memory_efficient_attention_forward( + query[:, seq_start:seq_end], + key[:, seq_start:seq_end], + value[:, seq_start:seq_end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) + out = out.view_as(query[:, seq_start:seq_end]) + output_ref[query_start:query_end, ...].copy_( + out[:, seq_len-query_len:, ...].squeeze(0)) + seq_start += seq_len + query_start += query_len + end_time = time.time() + print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) \ No newline at end of file From e6bdcaf22f52be4f8e8dab7ded2af83eaaf3b203 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Sat, 4 May 2024 20:15:04 +0800 Subject: [PATCH 07/28] add prefix prefill alibi test --- tests/kernels/test_prefix_prefill_alibi.py | 54 +++++++++++++--------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py index a7a0a25685055..3a43e1ed94adc 100644 --- a/tests/kernels/test_prefix_prefill_alibi.py +++ b/tests/kernels/test_prefix_prefill_alibi.py @@ -1,6 +1,7 @@ +import math import random import time -import math + import pytest import torch from xformers import ops as xops @@ -40,6 +41,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: [slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -81,10 +83,10 @@ def test_contexted_kv_attention_alibi( query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) query.uniform_(-1e-3, 1e-3) output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - + kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) kv.uniform_(-1e-3, 1e-3) - key, value = kv.unbind(dim=1) + key, value = kv.unbind(dim=1) k_cache = torch.zeros(cache_size, block_size, @@ -101,7 +103,7 @@ def test_contexted_kv_attention_alibi( values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] block_table = values[:BS * max_block_per_request].view( - BS, max_block_per_request) + BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], @@ -193,20 +195,25 @@ def test_contexted_kv_attention_alibi( None, :].expand(value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]) - # NOTE: We have to pad query tensor in order to reuse + # NOTE: We have to pad query tensor in order to reuse # _make_alibi_bias function. if query.shape[0] != key.shape[0]: - query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype) + query_pad = torch.empty(sum(seq_lens), + num_heads, + head_size, + dtype=dtype) query_pad.uniform_(-1e-3, 1e-3) seq_start = 0 query_start = 0 for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): seq_end = seq_start + seq_len query_end = query_start + query_len - query_pad[seq_start:seq_end, ...] = torch.cat( - [torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype), - query[query_start:query_end, ...]], dim=0 - ) + query_pad[seq_start:seq_end, ...] = torch.cat([ + torch.zeros( + seq_len - query_len, num_heads, head_size, dtype=dtype), + query[query_start:query_end, ...] + ], + dim=0) seq_start += seq_len query_start += query_len query = query_pad @@ -214,10 +221,9 @@ def test_contexted_kv_attention_alibi( query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) - + from vllm.attention.backends.xformers import _make_alibi_bias - attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, - dtype, seq_lens) + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) output_ref = torch.empty_like(output) seq_start = 0 query_start = 0 @@ -230,18 +236,20 @@ def test_contexted_kv_attention_alibi( for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): seq_end = seq_start + seq_len query_end = query_start + query_len - out = xops.memory_efficient_attention_forward( - query[:, seq_start:seq_end], - key[:, seq_start:seq_end], - value[:, seq_start:seq_end], - attn_bias=attn_bias[i], - p=0.0, - scale=scale) + out = xops.memory_efficient_attention_forward(query[:, + seq_start:seq_end], + key[:, + seq_start:seq_end], + value[:, + seq_start:seq_end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) out = out.view_as(query[:, seq_start:seq_end]) - output_ref[query_start:query_end, ...].copy_( - out[:, seq_len-query_len:, ...].squeeze(0)) + output_ref[query_start:query_end, + ...].copy_(out[:, seq_len - query_len:, ...].squeeze(0)) seq_start += seq_len query_start += query_len end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") - assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) \ No newline at end of file + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) From 6bb88236a4f5051fab2dc9255ccddf282b542363 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Sun, 5 May 2024 14:17:43 +0800 Subject: [PATCH 08/28] add comments --- tests/kernels/test_prefix_prefill_alibi.py | 26 +++++++++------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py index 3a43e1ed94adc..b28ae4e0834ff 100644 --- a/tests/kernels/test_prefix_prefill_alibi.py +++ b/tests/kernels/test_prefix_prefill_alibi.py @@ -18,6 +18,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) base = torch.tensor( 2**(-(2**-(math.log2(closest_power_of_2) - 3))), @@ -208,12 +209,7 @@ def test_contexted_kv_attention_alibi( for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): seq_end = seq_start + seq_len query_end = query_start + query_len - query_pad[seq_start:seq_end, ...] = torch.cat([ - torch.zeros( - seq_len - query_len, num_heads, head_size, dtype=dtype), - query[query_start:query_end, ...] - ], - dim=0) + query_pad[seq_start:seq_end, ...] = torch.cat([torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype), query[query_start:query_end, ...]], dim=0) seq_start += seq_len query_start += query_len query = query_pad @@ -232,19 +228,17 @@ def test_contexted_kv_attention_alibi( # FIXME(DefTruth): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. - # reference: vllm/attention/backends/xformers.py#343 + # modified from: vllm/attention/backends/xformers.py#L343 for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): seq_end = seq_start + seq_len query_end = query_start + query_len - out = xops.memory_efficient_attention_forward(query[:, - seq_start:seq_end], - key[:, - seq_start:seq_end], - value[:, - seq_start:seq_end], - attn_bias=attn_bias[i], - p=0.0, - scale=scale) + out = xops.memory_efficient_attention_forward( + query[:, seq_start:seq_end], + key[:, seq_start:seq_end], + value[:, seq_start:seq_end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) out = out.view_as(query[:, seq_start:seq_end]) output_ref[query_start:query_end, ...].copy_(out[:, seq_len - query_len:, ...].squeeze(0)) From cf78a94801a60b3896462bf0446257344f4ac9f0 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Sun, 5 May 2024 14:18:22 +0800 Subject: [PATCH 09/28] add comments --- tests/kernels/test_prefix_prefill_alibi.py | 23 ++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py index b28ae4e0834ff..2e11789b48fb1 100644 --- a/tests/kernels/test_prefix_prefill_alibi.py +++ b/tests/kernels/test_prefix_prefill_alibi.py @@ -209,7 +209,12 @@ def test_contexted_kv_attention_alibi( for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): seq_end = seq_start + seq_len query_end = query_start + query_len - query_pad[seq_start:seq_end, ...] = torch.cat([torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype), query[query_start:query_end, ...]], dim=0) + query_pad[seq_start:seq_end, ...] = torch.cat([ + torch.zeros( + seq_len - query_len, num_heads, head_size, dtype=dtype), + query[query_start:query_end, ...] + ], + dim=0) seq_start += seq_len query_start += query_len query = query_pad @@ -232,13 +237,15 @@ def test_contexted_kv_attention_alibi( for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): seq_end = seq_start + seq_len query_end = query_start + query_len - out = xops.memory_efficient_attention_forward( - query[:, seq_start:seq_end], - key[:, seq_start:seq_end], - value[:, seq_start:seq_end], - attn_bias=attn_bias[i], - p=0.0, - scale=scale) + out = xops.memory_efficient_attention_forward(query[:, + seq_start:seq_end], + key[:, + seq_start:seq_end], + value[:, + seq_start:seq_end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) out = out.view_as(query[:, seq_start:seq_end]) output_ref[query_start:query_end, ...].copy_(out[:, seq_len - query_len:, ...].squeeze(0)) From 6bcfb755900839d2f6bb8004adc16c0fb42c1289 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Sun, 5 May 2024 15:22:20 +0800 Subject: [PATCH 10/28] update prefix prefill alibi tests --- tests/kernels/test_prefix_prefill_alibi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py index 2e11789b48fb1..1d52798a305aa 100644 --- a/tests/kernels/test_prefix_prefill_alibi.py +++ b/tests/kernels/test_prefix_prefill_alibi.py @@ -8,9 +8,9 @@ from vllm.attention.ops.prefix_prefill import context_attention_fwd -NUM_HEADS = [32, 64] +NUM_HEADS = [32] NUM_QUERIES_PER_KV = [1] -HEAD_SIZES = [128, 96, 80] +HEAD_SIZES = [128, 96] DTYPES = [torch.float16] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) From f00be264d56248542584c077e1bc82c2e6c20b07 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Sun, 5 May 2024 17:36:53 +0800 Subject: [PATCH 11/28] move alibi test into prefix prefill tests --- tests/kernels/test_prefix_prefill.py | 244 ++++++++++++++++++++ tests/kernels/test_prefix_prefill_alibi.py | 256 --------------------- 2 files changed, 244 insertions(+), 256 deletions(-) delete mode 100644 tests/kernels/test_prefix_prefill_alibi.py diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 5a5987e2242fa..da94866702ccc 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -207,3 +207,247 @@ def test_contexted_kv_attention( print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") output_ref = output_ref.reshape(output.shape) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) + + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_contexted_kv_attention_alibi( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + dtype: torch.dtype, + device: str, +) -> None: + # NOTE(DefTruth): head size=96 with num_queries_per_kv=8 will encounter: + # Triton Error [CUDA]: an illegal memory access was encountered. When + # I figure out what's going on. I'll turn it on again. + if head_size == 96 and num_queries_per_kv == 8: + pytest.skip() + + random.seed(0) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + torch.set_default_device(device) + + # Need this, otherwise when we capture the graph the process + # for GPU 1 would run on both GPU0 and GPU1 and things would hang + # + # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 + torch.cuda.set_device(device) + def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + import math + # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + alibi_slopes = _get_alibi_slopes(num_heads).to(device) + + MAX_SEQ_LEN = 1024 + MAX_CTX_LEN = 1024 + BS = 10 + cache_size = 640 + block_size = 32 + max_block_per_request = 64 + query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] + num_kv_heads = num_heads // num_queries_per_kv + + num_tokens = sum(query_lens) + query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + query.uniform_(-1e-3, 1e-3) + output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + + kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) + kv.uniform_(-1e-3, 1e-3) + key, value = kv.unbind(dim=1) + + k_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + v_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + values = torch.arange(0, cache_size, dtype=torch.long) + values = values[torch.randperm(cache_size)] + block_table = values[:BS * max_block_per_request].view( + BS, max_block_per_request) + b_seq_len = torch.tensor(seq_lens, dtype=torch.long) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + dtype=torch.long), + dim=0) + max_input_len = MAX_SEQ_LEN + # copy kv to cache + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], + dtype=torch.long), + dim=0) + for i in range(BS): + for j in range(query_lens[i]): + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + + b_ctx_len[i] + j]) + cur_ctx = 0 + block_id = 0 + while cur_ctx < b_ctx_len[i]: + start_loc = b_seq_start_loc[i] + cur_ctx + if cur_ctx + block_size > b_ctx_len[i]: + end_loc = b_seq_start_loc[i] + b_ctx_len[i] + else: + end_loc = start_loc + block_size + start_slot = block_table[i, block_id] * block_size + end_slot = start_slot + end_loc - start_loc + k_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) + cur_ctx += block_size + block_id += 1 + # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] + # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] + k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, + 8).permute(0, 2, 3, 1, 4).contiguous() + # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] + # to V_cache[num_blocks, num_kv_heads, head_size, block_size] + v_cache = v_cache.view(-1, block_size, num_kv_heads, + head_size).permute(0, 2, 3, 1).contiguous() + + # Warm up the Triton kernel by calling it once before actually measuring + # generation time + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + start_time = time.time() + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + end_time = time.time() + print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + scale = float(1.0 / (head_size**0.5)) + + # NOTE: In order to reuse _make_alibi_bias function, + # We have to pad query tensor before MQA/GQA expanding, + if query.shape[0] != key.shape[0]: + query_pad = torch.empty(sum(seq_lens), + num_heads, + head_size, + dtype=dtype) + query_pad.uniform_(-1e-3, 1e-3) + seq_start = 0 + query_start = 0 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + query_pad[seq_start:seq_end, ...] = torch.cat([ + torch.zeros( + seq_len - query_len, num_heads, head_size, dtype=dtype), + query[query_start:query_end, ...] + ], + dim=0) + seq_start += seq_len + query_start += query_len + query = query_pad + + if num_kv_heads != num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # + # see also: vllm/model_executor/layers/attention.py + query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, + query.shape[-1]) + key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, + num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], num_kv_heads, + num_queries_per_kv, value.shape[-1]) + + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + from vllm.attention.backends.xformers import _make_alibi_bias + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) + output_ref = torch.empty_like(output) + seq_start = 0 + query_start = 0 + start_time = time.time() + # Attention with alibi slopes. + # FIXME(DefTruth): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + # modified from: vllm/attention/backends/xformers.py#L343 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + out = xops.memory_efficient_attention_forward(query[:, + seq_start:seq_end], + key[:, + seq_start:seq_end], + value[:, + seq_start:seq_end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) + out = out.view_as(query[:, seq_start:seq_end]).view( + seq_len, num_heads, head_size) + output_ref[query_start:query_end, ...].copy_( + out[seq_len - query_len:, ...]) + seq_start += seq_len + query_start += query_len + end_time = time.time() + print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) \ No newline at end of file diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py deleted file mode 100644 index 1d52798a305aa..0000000000000 --- a/tests/kernels/test_prefix_prefill_alibi.py +++ /dev/null @@ -1,256 +0,0 @@ -import math -import random -import time - -import pytest -import torch -from xformers import ops as xops - -from vllm.attention.ops.prefix_prefill import context_attention_fwd - -NUM_HEADS = [32] -NUM_QUERIES_PER_KV = [1] -HEAD_SIZES = [128, 96] -DTYPES = [torch.float16] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - - -def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: - # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 - closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) - base = torch.tensor( - 2**(-(2**-(math.log2(closest_power_of_2) - 3))), - dtype=torch.float32, - ) - powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) - slopes = torch.pow(base, powers) - - if closest_power_of_2 != total_num_heads: - extra_base = torch.tensor( - 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), - dtype=torch.float32, - ) - num_remaining_heads = min(closest_power_of_2, - total_num_heads - closest_power_of_2) - extra_powers = torch.arange(start=1, - end=1 + 2 * num_remaining_heads, - step=2, - dtype=torch.int32) - slopes = torch.cat( - [slopes, torch.pow(extra_base, extra_powers)], dim=0) - return slopes - - -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_contexted_kv_attention_alibi( - num_heads: int, - num_queries_per_kv: int, - head_size: int, - dtype: torch.dtype, - device: str, -) -> None: - random.seed(0) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed(0) - torch.set_default_device(device) - - # Need this, otherwise when we capture the graph the process - # for GPU 1 would run on both GPU0 and GPU1 and things would hang - # - # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 - torch.cuda.set_device(device) - alibi_slopes = _get_alibi_slopes(num_heads).to(device) - - MAX_SEQ_LEN = 1024 - MAX_CTX_LEN = 1024 - BS = 10 - cache_size = 640 - block_size = 32 - max_block_per_request = 64 - query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] - ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] - seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] - num_kv_heads = num_heads // num_queries_per_kv - - num_tokens = sum(query_lens) - query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - query.uniform_(-1e-3, 1e-3) - output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - - kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) - kv.uniform_(-1e-3, 1e-3) - key, value = kv.unbind(dim=1) - - k_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - v_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - values = torch.arange(0, cache_size, dtype=torch.long) - values = values[torch.randperm(cache_size)] - block_table = values[:BS * max_block_per_request].view( - BS, max_block_per_request) - b_seq_len = torch.tensor(seq_lens, dtype=torch.long) - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], - dtype=torch.long), - dim=0) - max_input_len = MAX_SEQ_LEN - # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], - dtype=torch.long), - dim=0) - for i in range(BS): - for j in range(query_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + - j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + - b_ctx_len[i] + j]) - cur_ctx = 0 - block_id = 0 - while cur_ctx < b_ctx_len[i]: - start_loc = b_seq_start_loc[i] + cur_ctx - if cur_ctx + block_size > b_ctx_len[i]: - end_loc = b_seq_start_loc[i] + b_ctx_len[i] - else: - end_loc = start_loc + block_size - start_slot = block_table[i, block_id] * block_size - end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) - cur_ctx += block_size - block_id += 1 - # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] - # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] - k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, - 8).permute(0, 2, 3, 1, 4).contiguous() - # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] - # to V_cache[num_blocks, num_kv_heads, head_size, block_size] - v_cache = v_cache.view(-1, block_size, num_kv_heads, - head_size).permute(0, 2, 3, 1).contiguous() - - # Warm up the Triton kernel by calling it once before actually measuring - # generation time - context_attention_fwd(query, - k, - v, - output, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - b_ctx_len, - max_input_len, - alibi_slopes=alibi_slopes) - torch.cuda.synchronize() - start_time = time.time() - context_attention_fwd(query, - k, - v, - output, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - b_ctx_len, - max_input_len, - alibi_slopes=alibi_slopes) - torch.cuda.synchronize() - end_time = time.time() - print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") - - scale = float(1.0 / (head_size**0.5)) - - if num_kv_heads != num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # - # see also: vllm/model_executor/layers/attention.py - query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, - query.shape[-1]) - key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, - num_queries_per_kv, key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], num_kv_heads, - num_queries_per_kv, value.shape[-1]) - - # NOTE: We have to pad query tensor in order to reuse - # _make_alibi_bias function. - if query.shape[0] != key.shape[0]: - query_pad = torch.empty(sum(seq_lens), - num_heads, - head_size, - dtype=dtype) - query_pad.uniform_(-1e-3, 1e-3) - seq_start = 0 - query_start = 0 - for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): - seq_end = seq_start + seq_len - query_end = query_start + query_len - query_pad[seq_start:seq_end, ...] = torch.cat([ - torch.zeros( - seq_len - query_len, num_heads, head_size, dtype=dtype), - query[query_start:query_end, ...] - ], - dim=0) - seq_start += seq_len - query_start += query_len - query = query_pad - - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - - from vllm.attention.backends.xformers import _make_alibi_bias - attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) - output_ref = torch.empty_like(output) - seq_start = 0 - query_start = 0 - start_time = time.time() - # Attention with alibi slopes. - # FIXME(DefTruth): Because xformers does not support dynamic sequence - # lengths with custom attention bias, we process each prompt one by - # one. This is inefficient, especially when we have many short prompts. - # modified from: vllm/attention/backends/xformers.py#L343 - for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): - seq_end = seq_start + seq_len - query_end = query_start + query_len - out = xops.memory_efficient_attention_forward(query[:, - seq_start:seq_end], - key[:, - seq_start:seq_end], - value[:, - seq_start:seq_end], - attn_bias=attn_bias[i], - p=0.0, - scale=scale) - out = out.view_as(query[:, seq_start:seq_end]) - output_ref[query_start:query_end, - ...].copy_(out[:, seq_len - query_len:, ...].squeeze(0)) - seq_start += seq_len - query_start += query_len - end_time = time.time() - print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") - assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) From 590a6c732b7d523e66ed5c70da9f350872b51abc Mon Sep 17 00:00:00 2001 From: DefTruth Date: Sun, 5 May 2024 17:37:44 +0800 Subject: [PATCH 12/28] move alibi test into prefix prefill tests --- tests/kernels/test_prefix_prefill.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index da94866702ccc..f7f71f4b5eee3 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -222,8 +222,8 @@ def test_contexted_kv_attention_alibi( dtype: torch.dtype, device: str, ) -> None: - # NOTE(DefTruth): head size=96 with num_queries_per_kv=8 will encounter: - # Triton Error [CUDA]: an illegal memory access was encountered. When + # NOTE(DefTruth): head size=96 with num_queries_per_kv=8 will encounter: + # Triton Error [CUDA]: an illegal memory access was encountered. When # I figure out what's going on. I'll turn it on again. if head_size == 96 and num_queries_per_kv == 8: pytest.skip() @@ -239,8 +239,10 @@ def test_contexted_kv_attention_alibi( # # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 torch.cuda.set_device(device) + def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: import math + # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) base = torch.tensor( @@ -256,7 +258,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: dtype=torch.float32, ) num_remaining_heads = min(closest_power_of_2, - total_num_heads - closest_power_of_2) + total_num_heads - closest_power_of_2) extra_powers = torch.arange(start=1, end=1 + 2 * num_remaining_heads, step=2, @@ -264,6 +266,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: slopes = torch.cat( [slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes + alibi_slopes = _get_alibi_slopes(num_heads).to(device) MAX_SEQ_LEN = 1024 @@ -377,9 +380,9 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: end_time = time.time() print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") scale = float(1.0 / (head_size**0.5)) - - # NOTE: In order to reuse _make_alibi_bias function, - # We have to pad query tensor before MQA/GQA expanding, + + # NOTE: In order to reuse _make_alibi_bias function, + # We have to pad query tensor before MQA/GQA expanding, if query.shape[0] != key.shape[0]: query_pad = torch.empty(sum(seq_lens), num_heads, @@ -444,10 +447,10 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: scale=scale) out = out.view_as(query[:, seq_start:seq_end]).view( seq_len, num_heads, head_size) - output_ref[query_start:query_end, ...].copy_( - out[seq_len - query_len:, ...]) + output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, + ...]) seq_start += seq_len query_start += query_len end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") - assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) \ No newline at end of file + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) From 1eccf6e95f07977aaf73287719ea6b01a6652845 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Sun, 5 May 2024 19:24:57 +0800 Subject: [PATCH 13/28] update prefix prefill tests --- tests/kernels/test_prefix_prefill.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index f7f71f4b5eee3..c76c0ef89d5ff 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -211,7 +211,7 @@ def test_contexted_kv_attention( @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) -@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("head_size", [128, 80]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() @@ -225,7 +225,7 @@ def test_contexted_kv_attention_alibi( # NOTE(DefTruth): head size=96 with num_queries_per_kv=8 will encounter: # Triton Error [CUDA]: an illegal memory access was encountered. When # I figure out what's going on. I'll turn it on again. - if head_size == 96 and num_queries_per_kv == 8: + if head_size == 96: pytest.skip() random.seed(0) From 15f3bf0caa60b0b1f20b32936eb7824a26a9fc0f Mon Sep 17 00:00:00 2001 From: DefTruth Date: Sun, 5 May 2024 20:06:17 +0800 Subject: [PATCH 14/28] update prefix prefill tests --- tests/kernels/test_prefix_prefill.py | 249 +------------------- tests/kernels/test_prefix_prefill_alibi.py | 256 +++++++++++++++++++++ 2 files changed, 257 insertions(+), 248 deletions(-) create mode 100644 tests/kernels/test_prefix_prefill_alibi.py diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index c76c0ef89d5ff..def06e5c7606d 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -206,251 +206,4 @@ def test_contexted_kv_attention( end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") output_ref = output_ref.reshape(output.shape) - assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) - - -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) -@pytest.mark.parametrize("head_size", [128, 80]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_contexted_kv_attention_alibi( - num_heads: int, - num_queries_per_kv: int, - head_size: int, - dtype: torch.dtype, - device: str, -) -> None: - # NOTE(DefTruth): head size=96 with num_queries_per_kv=8 will encounter: - # Triton Error [CUDA]: an illegal memory access was encountered. When - # I figure out what's going on. I'll turn it on again. - if head_size == 96: - pytest.skip() - - random.seed(0) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed(0) - torch.set_default_device(device) - - # Need this, otherwise when we capture the graph the process - # for GPU 1 would run on both GPU0 and GPU1 and things would hang - # - # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 - torch.cuda.set_device(device) - - def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: - import math - - # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 - closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) - base = torch.tensor( - 2**(-(2**-(math.log2(closest_power_of_2) - 3))), - dtype=torch.float32, - ) - powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) - slopes = torch.pow(base, powers) - - if closest_power_of_2 != total_num_heads: - extra_base = torch.tensor( - 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), - dtype=torch.float32, - ) - num_remaining_heads = min(closest_power_of_2, - total_num_heads - closest_power_of_2) - extra_powers = torch.arange(start=1, - end=1 + 2 * num_remaining_heads, - step=2, - dtype=torch.int32) - slopes = torch.cat( - [slopes, torch.pow(extra_base, extra_powers)], dim=0) - return slopes - - alibi_slopes = _get_alibi_slopes(num_heads).to(device) - - MAX_SEQ_LEN = 1024 - MAX_CTX_LEN = 1024 - BS = 10 - cache_size = 640 - block_size = 32 - max_block_per_request = 64 - query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] - ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] - seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] - num_kv_heads = num_heads // num_queries_per_kv - - num_tokens = sum(query_lens) - query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - query.uniform_(-1e-3, 1e-3) - output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - - kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) - kv.uniform_(-1e-3, 1e-3) - key, value = kv.unbind(dim=1) - - k_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - v_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - values = torch.arange(0, cache_size, dtype=torch.long) - values = values[torch.randperm(cache_size)] - block_table = values[:BS * max_block_per_request].view( - BS, max_block_per_request) - b_seq_len = torch.tensor(seq_lens, dtype=torch.long) - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], - dtype=torch.long), - dim=0) - max_input_len = MAX_SEQ_LEN - # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], - dtype=torch.long), - dim=0) - for i in range(BS): - for j in range(query_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + - j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + - b_ctx_len[i] + j]) - cur_ctx = 0 - block_id = 0 - while cur_ctx < b_ctx_len[i]: - start_loc = b_seq_start_loc[i] + cur_ctx - if cur_ctx + block_size > b_ctx_len[i]: - end_loc = b_seq_start_loc[i] + b_ctx_len[i] - else: - end_loc = start_loc + block_size - start_slot = block_table[i, block_id] * block_size - end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) - cur_ctx += block_size - block_id += 1 - # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] - # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] - k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, - 8).permute(0, 2, 3, 1, 4).contiguous() - # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] - # to V_cache[num_blocks, num_kv_heads, head_size, block_size] - v_cache = v_cache.view(-1, block_size, num_kv_heads, - head_size).permute(0, 2, 3, 1).contiguous() - - # Warm up the Triton kernel by calling it once before actually measuring - # generation time - context_attention_fwd(query, - k, - v, - output, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - b_ctx_len, - max_input_len, - alibi_slopes=alibi_slopes) - torch.cuda.synchronize() - start_time = time.time() - context_attention_fwd(query, - k, - v, - output, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - b_ctx_len, - max_input_len, - alibi_slopes=alibi_slopes) - torch.cuda.synchronize() - end_time = time.time() - print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") - scale = float(1.0 / (head_size**0.5)) - - # NOTE: In order to reuse _make_alibi_bias function, - # We have to pad query tensor before MQA/GQA expanding, - if query.shape[0] != key.shape[0]: - query_pad = torch.empty(sum(seq_lens), - num_heads, - head_size, - dtype=dtype) - query_pad.uniform_(-1e-3, 1e-3) - seq_start = 0 - query_start = 0 - for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): - seq_end = seq_start + seq_len - query_end = query_start + query_len - query_pad[seq_start:seq_end, ...] = torch.cat([ - torch.zeros( - seq_len - query_len, num_heads, head_size, dtype=dtype), - query[query_start:query_end, ...] - ], - dim=0) - seq_start += seq_len - query_start += query_len - query = query_pad - - if num_kv_heads != num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # - # see also: vllm/model_executor/layers/attention.py - query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, - query.shape[-1]) - key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, - num_queries_per_kv, key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], num_kv_heads, - num_queries_per_kv, value.shape[-1]) - - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - - from vllm.attention.backends.xformers import _make_alibi_bias - attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) - output_ref = torch.empty_like(output) - seq_start = 0 - query_start = 0 - start_time = time.time() - # Attention with alibi slopes. - # FIXME(DefTruth): Because xformers does not support dynamic sequence - # lengths with custom attention bias, we process each prompt one by - # one. This is inefficient, especially when we have many short prompts. - # modified from: vllm/attention/backends/xformers.py#L343 - for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): - seq_end = seq_start + seq_len - query_end = query_start + query_len - out = xops.memory_efficient_attention_forward(query[:, - seq_start:seq_end], - key[:, - seq_start:seq_end], - value[:, - seq_start:seq_end], - attn_bias=attn_bias[i], - p=0.0, - scale=scale) - out = out.view_as(query[:, seq_start:seq_end]).view( - seq_len, num_heads, head_size) - output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, - ...]) - seq_start += seq_len - query_start += query_len - end_time = time.time() - print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") - assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) \ No newline at end of file diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py new file mode 100644 index 0000000000000..485586c827783 --- /dev/null +++ b/tests/kernels/test_prefix_prefill_alibi.py @@ -0,0 +1,256 @@ +import random +import time + +import pytest +import torch +from xformers import ops as xops + +from vllm.attention.ops.prefix_prefill import context_attention_fwd + +NUM_HEADS = [64] +NUM_QUERIES_PER_KV = [1, 8, 64] +HEAD_SIZES = [128, 80] +DTYPES = [torch.float16] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_contexted_kv_attention_alibi( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + dtype: torch.dtype, + device: str, +) -> None: + random.seed(0) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + torch.set_default_device(device) + + # Need this, otherwise when we capture the graph the process + # for GPU 1 would run on both GPU0 and GPU1 and things would hang + # + # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 + torch.cuda.set_device(device) + + def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + import math + + # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + alibi_slopes = _get_alibi_slopes(num_heads).to(device) + + MAX_SEQ_LEN = 1024 + MAX_CTX_LEN = 1024 + BS = 10 + cache_size = 640 + block_size = 32 + max_block_per_request = 64 + query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] + num_kv_heads = num_heads // num_queries_per_kv + + num_tokens = sum(query_lens) + query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + query.uniform_(-1e-3, 1e-3) + output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + + kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) + kv.uniform_(-1e-3, 1e-3) + key, value = kv.unbind(dim=1) + + k_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + v_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + values = torch.arange(0, cache_size, dtype=torch.long) + values = values[torch.randperm(cache_size)] + block_table = values[:BS * max_block_per_request].view( + BS, max_block_per_request) + b_seq_len = torch.tensor(seq_lens, dtype=torch.long) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + dtype=torch.long), + dim=0) + max_input_len = MAX_SEQ_LEN + # copy kv to cache + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], + dtype=torch.long), + dim=0) + for i in range(BS): + for j in range(query_lens[i]): + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + + b_ctx_len[i] + j]) + cur_ctx = 0 + block_id = 0 + while cur_ctx < b_ctx_len[i]: + start_loc = b_seq_start_loc[i] + cur_ctx + if cur_ctx + block_size > b_ctx_len[i]: + end_loc = b_seq_start_loc[i] + b_ctx_len[i] + else: + end_loc = start_loc + block_size + start_slot = block_table[i, block_id] * block_size + end_slot = start_slot + end_loc - start_loc + k_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) + cur_ctx += block_size + block_id += 1 + # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] + # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] + k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, + 8).permute(0, 2, 3, 1, 4).contiguous() + # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] + # to V_cache[num_blocks, num_kv_heads, head_size, block_size] + v_cache = v_cache.view(-1, block_size, num_kv_heads, + head_size).permute(0, 2, 3, 1).contiguous() + + # Warm up the Triton kernel by calling it once before actually measuring + # generation time + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + start_time = time.time() + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + end_time = time.time() + print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + scale = float(1.0 / (head_size**0.5)) + + # NOTE: In order to reuse _make_alibi_bias function, + # We have to pad query tensor before MQA/GQA expanding, + if query.shape[0] != key.shape[0]: + query_pad = torch.empty(sum(seq_lens), + num_heads, + head_size, + dtype=dtype) + query_pad.uniform_(-1e-3, 1e-3) + seq_start = 0 + query_start = 0 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + query_pad[seq_start:seq_end, ...] = torch.cat([ + torch.zeros( + seq_len - query_len, num_heads, head_size, dtype=dtype), + query[query_start:query_end, ...] + ], + dim=0) + seq_start += seq_len + query_start += query_len + query = query_pad + + if num_kv_heads != num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # + # see also: vllm/model_executor/layers/attention.py + query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, + query.shape[-1]) + key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, + num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], num_kv_heads, + num_queries_per_kv, value.shape[-1]) + + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + from vllm.attention.backends.xformers import _make_alibi_bias + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) + output_ref = torch.empty_like(output) + seq_start = 0 + query_start = 0 + start_time = time.time() + # Attention with alibi slopes. + # FIXME(DefTruth): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + # modified from: vllm/attention/backends/xformers.py#L343 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + out = xops.memory_efficient_attention_forward(query[:, + seq_start:seq_end], + key[:, + seq_start:seq_end], + value[:, + seq_start:seq_end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) + out = out.view_as(query[:, seq_start:seq_end]).view( + seq_len, num_heads, head_size) + output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, + ...]) + seq_start += seq_len + query_start += query_len + end_time = time.time() + print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) From f91e3b0bab46f5e8626b5945552b90f7219392f0 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Sun, 5 May 2024 20:06:50 +0800 Subject: [PATCH 15/28] update prefix prefill tests --- tests/kernels/test_prefix_prefill_alibi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py index 485586c827783..2a634f93d588a 100644 --- a/tests/kernels/test_prefix_prefill_alibi.py +++ b/tests/kernels/test_prefix_prefill_alibi.py @@ -15,6 +15,7 @@ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) From 2e899995affbecab3ac9f75fc760495761b66b81 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 6 May 2024 14:14:23 +0800 Subject: [PATCH 16/28] change random seed and fix CI --- tests/kernels/test_prefix_prefill_alibi.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py index 2a634f93d588a..d57f795e8b5bd 100644 --- a/tests/kernels/test_prefix_prefill_alibi.py +++ b/tests/kernels/test_prefix_prefill_alibi.py @@ -9,7 +9,7 @@ NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] -HEAD_SIZES = [128, 80] +HEAD_SIZES = [128, 80, 96] DTYPES = [torch.float16] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) @@ -29,10 +29,13 @@ def test_contexted_kv_attention_alibi( dtype: torch.dtype, device: str, ) -> None: - random.seed(0) - torch.manual_seed(0) + # NOTE(DefTruth): The random seed here can not been set as the + # same one in test_prefix_prefill.py script to avoid illegal + # memory access error in CI. + random.seed(1) + torch.manual_seed(1) if torch.cuda.is_available(): - torch.cuda.manual_seed(0) + torch.cuda.manual_seed(1) torch.set_default_device(device) # Need this, otherwise when we capture the graph the process @@ -254,4 +257,4 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: query_start += query_len end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") - assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) \ No newline at end of file From ef7348a173d055a15c3c597bcf615803748de245 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 6 May 2024 14:15:30 +0800 Subject: [PATCH 17/28] change random seed and fix CI --- tests/kernels/test_prefix_prefill.py | 2 +- tests/kernels/test_prefix_prefill_alibi.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index def06e5c7606d..5a5987e2242fa 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -206,4 +206,4 @@ def test_contexted_kv_attention( end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") output_ref = output_ref.reshape(output.shape) - assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) \ No newline at end of file + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py index d57f795e8b5bd..5e3cde564fc02 100644 --- a/tests/kernels/test_prefix_prefill_alibi.py +++ b/tests/kernels/test_prefix_prefill_alibi.py @@ -29,8 +29,8 @@ def test_contexted_kv_attention_alibi( dtype: torch.dtype, device: str, ) -> None: - # NOTE(DefTruth): The random seed here can not been set as the - # same one in test_prefix_prefill.py script to avoid illegal + # NOTE(DefTruth): The random seed here can not been set as the + # same one in test_prefix_prefill.py script to avoid illegal # memory access error in CI. random.seed(1) torch.manual_seed(1) @@ -257,4 +257,4 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: query_start += query_len end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") - assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) \ No newline at end of file + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) From d7242d055d9e542a1a22198bffa3668391457fcd Mon Sep 17 00:00:00 2001 From: DefTruth Date: Tue, 7 May 2024 04:52:25 +0000 Subject: [PATCH 18/28] update tests --- tests/kernels/test_prefix_prefill_alibi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py index 5e3cde564fc02..a2f8d545347e2 100644 --- a/tests/kernels/test_prefix_prefill_alibi.py +++ b/tests/kernels/test_prefix_prefill_alibi.py @@ -106,7 +106,6 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) - values = values[torch.randperm(cache_size)] block_table = values[:BS * max_block_per_request].view( BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) From d0259d8c26e06954be749a53efb9656696c2b136 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Tue, 7 May 2024 05:11:47 +0000 Subject: [PATCH 19/28] update tests --- tests/kernels/test_prefix_prefill_alibi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py index a2f8d545347e2..5e3cde564fc02 100644 --- a/tests/kernels/test_prefix_prefill_alibi.py +++ b/tests/kernels/test_prefix_prefill_alibi.py @@ -106,6 +106,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) + values = values[torch.randperm(cache_size)] block_table = values[:BS * max_block_per_request].view( BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) From 02d4ebeadfa9128b3eb4db32f7439a560995b238 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Tue, 7 May 2024 05:13:33 +0000 Subject: [PATCH 20/28] update tests --- vllm/attention/ops/prefix_prefill.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index a61a11fce4d01..997b25e887e30 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -536,8 +536,9 @@ def _fwd_kernel_alibi( offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -605,8 +606,9 @@ def _fwd_kernel_alibi( # -- compute qk ---- k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) From 51b86e9d73eb1a9b993b55e3ed104efa31d9fd14 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Tue, 7 May 2024 05:24:27 +0000 Subject: [PATCH 21/28] fix k load and update tests --- tests/kernels/test_prefix_prefill_alibi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py index 5e3cde564fc02..367e84a075e31 100644 --- a/tests/kernels/test_prefix_prefill_alibi.py +++ b/tests/kernels/test_prefix_prefill_alibi.py @@ -32,10 +32,10 @@ def test_contexted_kv_attention_alibi( # NOTE(DefTruth): The random seed here can not been set as the # same one in test_prefix_prefill.py script to avoid illegal # memory access error in CI. - random.seed(1) - torch.manual_seed(1) + random.seed(0) + torch.manual_seed(0) if torch.cuda.is_available(): - torch.cuda.manual_seed(1) + torch.cuda.manual_seed(0) torch.set_default_device(device) # Need this, otherwise when we capture the graph the process From 5b84d8c3c5f4628224b8d283075a5c23b82ba22a Mon Sep 17 00:00:00 2001 From: DefTruth Date: Tue, 7 May 2024 05:25:51 +0000 Subject: [PATCH 22/28] fix k load and update tests --- tests/kernels/test_prefix_prefill_alibi.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py index 367e84a075e31..da516cfcd7182 100644 --- a/tests/kernels/test_prefix_prefill_alibi.py +++ b/tests/kernels/test_prefix_prefill_alibi.py @@ -29,9 +29,6 @@ def test_contexted_kv_attention_alibi( dtype: torch.dtype, device: str, ) -> None: - # NOTE(DefTruth): The random seed here can not been set as the - # same one in test_prefix_prefill.py script to avoid illegal - # memory access error in CI. random.seed(0) torch.manual_seed(0) if torch.cuda.is_available(): @@ -186,7 +183,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: scale = float(1.0 / (head_size**0.5)) # NOTE: In order to reuse _make_alibi_bias function, - # We have to pad query tensor before MQA/GQA expanding, + # We have to pad query tensor before MQA/GQA expanding. if query.shape[0] != key.shape[0]: query_pad = torch.empty(sum(seq_lens), num_heads, From a61a58bdf4730ad858fb9d950d0a638df11d34bd Mon Sep 17 00:00:00 2001 From: DefTruth Date: Tue, 7 May 2024 06:02:13 +0000 Subject: [PATCH 23/28] merged alibi test into prefix prefill test script --- tests/kernels/test_prefix_prefill.py | 242 +++++++++++++++++++++ tests/kernels/test_prefix_prefill_alibi.py | 1 + 2 files changed, 243 insertions(+) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 5a5987e2242fa..3a62ceac992ed 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -207,3 +207,245 @@ def test_contexted_kv_attention( print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") output_ref = output_ref.reshape(output.shape) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) + + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_contexted_kv_attention_alibi( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + dtype: torch.dtype, + device: str, +) -> None: + random.seed(0) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + torch.set_default_device(device) + + # Need this, otherwise when we capture the graph the process + # for GPU 1 would run on both GPU0 and GPU1 and things would hang + # + # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 + torch.cuda.set_device(device) + + def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + import math + + # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + alibi_slopes = _get_alibi_slopes(num_heads).to(device) + + MAX_SEQ_LEN = 1024 + MAX_CTX_LEN = 1024 + BS = 10 + cache_size = 640 + block_size = 32 + max_block_per_request = 64 + query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] + num_kv_heads = num_heads // num_queries_per_kv + + num_tokens = sum(query_lens) + query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + query.uniform_(-1e-3, 1e-3) + output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + + kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) + kv.uniform_(-1e-3, 1e-3) + key, value = kv.unbind(dim=1) + + k_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + v_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + values = torch.arange(0, cache_size, dtype=torch.long) + values = values[torch.randperm(cache_size)] + block_table = values[:BS * max_block_per_request].view( + BS, max_block_per_request) + b_seq_len = torch.tensor(seq_lens, dtype=torch.long) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + dtype=torch.long), + dim=0) + max_input_len = MAX_SEQ_LEN + # copy kv to cache + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], + dtype=torch.long), + dim=0) + for i in range(BS): + for j in range(query_lens[i]): + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + + b_ctx_len[i] + j]) + cur_ctx = 0 + block_id = 0 + while cur_ctx < b_ctx_len[i]: + start_loc = b_seq_start_loc[i] + cur_ctx + if cur_ctx + block_size > b_ctx_len[i]: + end_loc = b_seq_start_loc[i] + b_ctx_len[i] + else: + end_loc = start_loc + block_size + start_slot = block_table[i, block_id] * block_size + end_slot = start_slot + end_loc - start_loc + k_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) + cur_ctx += block_size + block_id += 1 + # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] + # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] + k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, + 8).permute(0, 2, 3, 1, 4).contiguous() + # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] + # to V_cache[num_blocks, num_kv_heads, head_size, block_size] + v_cache = v_cache.view(-1, block_size, num_kv_heads, + head_size).permute(0, 2, 3, 1).contiguous() + + # Warm up the Triton kernel by calling it once before actually measuring + # generation time + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + start_time = time.time() + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + end_time = time.time() + print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + scale = float(1.0 / (head_size**0.5)) + + # NOTE(DefTruth): In order to reuse _make_alibi_bias function, + # we have to pad query tensor before MQA/GQA expanding. + if query.shape[0] != key.shape[0]: + query_pad = torch.empty(sum(seq_lens), + num_heads, + head_size, + dtype=dtype) + query_pad.uniform_(-1e-3, 1e-3) + seq_start = 0 + query_start = 0 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + query_pad[seq_start:seq_end, ...] = torch.cat([ + torch.zeros( + seq_len - query_len, num_heads, head_size, dtype=dtype), + query[query_start:query_end, ...] + ], + dim=0) + seq_start += seq_len + query_start += query_len + query = query_pad + + if num_kv_heads != num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # + # see also: vllm/model_executor/layers/attention.py + query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, + query.shape[-1]) + key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, + num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], num_kv_heads, + num_queries_per_kv, value.shape[-1]) + + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + from vllm.attention.backends.xformers import _make_alibi_bias + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) + output_ref = torch.empty_like(output) + seq_start = 0 + query_start = 0 + start_time = time.time() + # Attention with alibi slopes. + # FIXME(DefTruth): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + # modified from: vllm/attention/backends/xformers.py#L343 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + out = xops.memory_efficient_attention_forward(query[:, + seq_start:seq_end], + key[:, + seq_start:seq_end], + value[:, + seq_start:seq_end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) + out = out.view_as(query[:, seq_start:seq_end]).view( + seq_len, num_heads, head_size) + output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, + ...]) + seq_start += seq_len + query_start += query_len + torch.cuda.synchronize() + end_time = time.time() + print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py index da516cfcd7182..1abd812819565 100644 --- a/tests/kernels/test_prefix_prefill_alibi.py +++ b/tests/kernels/test_prefix_prefill_alibi.py @@ -252,6 +252,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: ...]) seq_start += seq_len query_start += query_len + torch.cuda.synchronize() end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) From 8705f30291522716ed638d2c5d727800718e6e81 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Tue, 7 May 2024 06:05:07 +0000 Subject: [PATCH 24/28] merged alibi test into prefix prefill test script --- tests/kernels/test_prefix_prefill_alibi.py | 258 --------------------- 1 file changed, 258 deletions(-) delete mode 100644 tests/kernels/test_prefix_prefill_alibi.py diff --git a/tests/kernels/test_prefix_prefill_alibi.py b/tests/kernels/test_prefix_prefill_alibi.py deleted file mode 100644 index 1abd812819565..0000000000000 --- a/tests/kernels/test_prefix_prefill_alibi.py +++ /dev/null @@ -1,258 +0,0 @@ -import random -import time - -import pytest -import torch -from xformers import ops as xops - -from vllm.attention.ops.prefix_prefill import context_attention_fwd - -NUM_HEADS = [64] -NUM_QUERIES_PER_KV = [1, 8, 64] -HEAD_SIZES = [128, 80, 96] -DTYPES = [torch.float16] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - - -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_contexted_kv_attention_alibi( - num_heads: int, - num_queries_per_kv: int, - head_size: int, - dtype: torch.dtype, - device: str, -) -> None: - random.seed(0) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed(0) - torch.set_default_device(device) - - # Need this, otherwise when we capture the graph the process - # for GPU 1 would run on both GPU0 and GPU1 and things would hang - # - # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 - torch.cuda.set_device(device) - - def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: - import math - - # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 - closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) - base = torch.tensor( - 2**(-(2**-(math.log2(closest_power_of_2) - 3))), - dtype=torch.float32, - ) - powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) - slopes = torch.pow(base, powers) - - if closest_power_of_2 != total_num_heads: - extra_base = torch.tensor( - 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), - dtype=torch.float32, - ) - num_remaining_heads = min(closest_power_of_2, - total_num_heads - closest_power_of_2) - extra_powers = torch.arange(start=1, - end=1 + 2 * num_remaining_heads, - step=2, - dtype=torch.int32) - slopes = torch.cat( - [slopes, torch.pow(extra_base, extra_powers)], dim=0) - return slopes - - alibi_slopes = _get_alibi_slopes(num_heads).to(device) - - MAX_SEQ_LEN = 1024 - MAX_CTX_LEN = 1024 - BS = 10 - cache_size = 640 - block_size = 32 - max_block_per_request = 64 - query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] - ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] - seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] - num_kv_heads = num_heads // num_queries_per_kv - - num_tokens = sum(query_lens) - query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - query.uniform_(-1e-3, 1e-3) - output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - - kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) - kv.uniform_(-1e-3, 1e-3) - key, value = kv.unbind(dim=1) - - k_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - v_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - values = torch.arange(0, cache_size, dtype=torch.long) - values = values[torch.randperm(cache_size)] - block_table = values[:BS * max_block_per_request].view( - BS, max_block_per_request) - b_seq_len = torch.tensor(seq_lens, dtype=torch.long) - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], - dtype=torch.long), - dim=0) - max_input_len = MAX_SEQ_LEN - # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], - dtype=torch.long), - dim=0) - for i in range(BS): - for j in range(query_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + - j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + - b_ctx_len[i] + j]) - cur_ctx = 0 - block_id = 0 - while cur_ctx < b_ctx_len[i]: - start_loc = b_seq_start_loc[i] + cur_ctx - if cur_ctx + block_size > b_ctx_len[i]: - end_loc = b_seq_start_loc[i] + b_ctx_len[i] - else: - end_loc = start_loc + block_size - start_slot = block_table[i, block_id] * block_size - end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) - cur_ctx += block_size - block_id += 1 - # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] - # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] - k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, - 8).permute(0, 2, 3, 1, 4).contiguous() - # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] - # to V_cache[num_blocks, num_kv_heads, head_size, block_size] - v_cache = v_cache.view(-1, block_size, num_kv_heads, - head_size).permute(0, 2, 3, 1).contiguous() - - # Warm up the Triton kernel by calling it once before actually measuring - # generation time - context_attention_fwd(query, - k, - v, - output, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - b_ctx_len, - max_input_len, - alibi_slopes=alibi_slopes) - torch.cuda.synchronize() - start_time = time.time() - context_attention_fwd(query, - k, - v, - output, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - b_ctx_len, - max_input_len, - alibi_slopes=alibi_slopes) - torch.cuda.synchronize() - end_time = time.time() - print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") - scale = float(1.0 / (head_size**0.5)) - - # NOTE: In order to reuse _make_alibi_bias function, - # We have to pad query tensor before MQA/GQA expanding. - if query.shape[0] != key.shape[0]: - query_pad = torch.empty(sum(seq_lens), - num_heads, - head_size, - dtype=dtype) - query_pad.uniform_(-1e-3, 1e-3) - seq_start = 0 - query_start = 0 - for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): - seq_end = seq_start + seq_len - query_end = query_start + query_len - query_pad[seq_start:seq_end, ...] = torch.cat([ - torch.zeros( - seq_len - query_len, num_heads, head_size, dtype=dtype), - query[query_start:query_end, ...] - ], - dim=0) - seq_start += seq_len - query_start += query_len - query = query_pad - - if num_kv_heads != num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # - # see also: vllm/model_executor/layers/attention.py - query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, - query.shape[-1]) - key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, - num_queries_per_kv, key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], num_kv_heads, - num_queries_per_kv, value.shape[-1]) - - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - - from vllm.attention.backends.xformers import _make_alibi_bias - attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) - output_ref = torch.empty_like(output) - seq_start = 0 - query_start = 0 - start_time = time.time() - # Attention with alibi slopes. - # FIXME(DefTruth): Because xformers does not support dynamic sequence - # lengths with custom attention bias, we process each prompt one by - # one. This is inefficient, especially when we have many short prompts. - # modified from: vllm/attention/backends/xformers.py#L343 - for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): - seq_end = seq_start + seq_len - query_end = query_start + query_len - out = xops.memory_efficient_attention_forward(query[:, - seq_start:seq_end], - key[:, - seq_start:seq_end], - value[:, - seq_start:seq_end], - attn_bias=attn_bias[i], - p=0.0, - scale=scale) - out = out.view_as(query[:, seq_start:seq_end]).view( - seq_len, num_heads, head_size) - output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, - ...]) - seq_start += seq_len - query_start += query_len - torch.cuda.synchronize() - end_time = time.time() - print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") - assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) From 319ca305f7888e7b3d61301a71b3251ae429c98e Mon Sep 17 00:00:00 2001 From: DefTruth Date: Wed, 8 May 2024 01:56:12 +0000 Subject: [PATCH 25/28] add more tests for small head sizes --- tests/kernels/test_prefix_prefill.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 3a62ceac992ed..e7450ad0b2910 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -1,5 +1,6 @@ import random import time +import math import pytest import torch @@ -7,10 +8,12 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask from vllm.attention.ops.prefix_prefill import context_attention_fwd +from vllm.attention.backends.xformers import _make_alibi_bias + NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] -HEAD_SIZES = [128, 96] +HEAD_SIZES = [128, 96, 80, 24] DTYPES = [torch.float16] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) @@ -235,8 +238,6 @@ def test_contexted_kv_attention_alibi( torch.cuda.set_device(device) def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: - import math - # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) base = torch.tensor( @@ -416,7 +417,6 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: key = key.unsqueeze(0) value = value.unsqueeze(0) - from vllm.attention.backends.xformers import _make_alibi_bias attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) output_ref = torch.empty_like(output) seq_start = 0 From b8d77e159ecab96642ac7e5db68d88389eab081d Mon Sep 17 00:00:00 2001 From: DefTruth Date: Wed, 8 May 2024 01:56:56 +0000 Subject: [PATCH 26/28] add more tests for small head sizes --- tests/kernels/test_prefix_prefill.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index e7450ad0b2910..4d12715139bac 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -1,15 +1,14 @@ +import math import random import time -import math import pytest import torch from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask -from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.backends.xformers import _make_alibi_bias - +from vllm.attention.ops.prefix_prefill import context_attention_fwd NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] From b3c387660849276b059dfbeeec1c6107b7f22518 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Wed, 8 May 2024 06:08:29 +0000 Subject: [PATCH 27/28] Invoke CI From 9691423ae2cda0fbf79661141ae0e6c77cb3ea13 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Wed, 8 May 2024 11:16:53 +0000 Subject: [PATCH 28/28] update tests for small head sizes --- tests/kernels/test_prefix_prefill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 4d12715139bac..99fda8364dc0e 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -12,7 +12,7 @@ NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] -HEAD_SIZES = [128, 96, 80, 24] +HEAD_SIZES = [128, 96, 24] DTYPES = [torch.float16] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)