diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 4ea5a14ffd69b..2646307674efb 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -1,3 +1,4 @@ +import math import random import time @@ -6,11 +7,12 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask +from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.ops.prefix_prefill import context_attention_fwd NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] -HEAD_SIZES = [128, 96] +HEAD_SIZES = [128, 96, 24] DTYPES = [torch.float16] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) @@ -211,3 +213,242 @@ def test_contexted_kv_attention( print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") output_ref = output_ref.reshape(output.shape) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) + + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_contexted_kv_attention_alibi( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + dtype: torch.dtype, + device: str, +) -> None: + random.seed(0) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + torch.set_default_device(device) + + # Need this, otherwise when we capture the graph the process + # for GPU 1 would run on both GPU0 and GPU1 and things would hang + # + # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 + torch.cuda.set_device(device) + + def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + alibi_slopes = _get_alibi_slopes(num_heads).to(device) + + MAX_SEQ_LEN = 1024 + MAX_CTX_LEN = 1024 + BS = 10 + cache_size = 640 + block_size = 32 + max_block_per_request = 64 + query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] + num_kv_heads = num_heads // num_queries_per_kv + + num_tokens = sum(query_lens) + query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + query.uniform_(-1e-3, 1e-3) + output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + + kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) + kv.uniform_(-1e-3, 1e-3) + key, value = kv.unbind(dim=1) + + k_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + v_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + values = torch.arange(0, cache_size, dtype=torch.long) + values = values[torch.randperm(cache_size)] + block_table = values[:BS * max_block_per_request].view( + BS, max_block_per_request) + b_seq_len = torch.tensor(seq_lens, dtype=torch.long) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + dtype=torch.long), + dim=0) + max_input_len = MAX_SEQ_LEN + # copy kv to cache + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], + dtype=torch.long), + dim=0) + for i in range(BS): + for j in range(query_lens[i]): + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + + b_ctx_len[i] + j]) + cur_ctx = 0 + block_id = 0 + while cur_ctx < b_ctx_len[i]: + start_loc = b_seq_start_loc[i] + cur_ctx + if cur_ctx + block_size > b_ctx_len[i]: + end_loc = b_seq_start_loc[i] + b_ctx_len[i] + else: + end_loc = start_loc + block_size + start_slot = block_table[i, block_id] * block_size + end_slot = start_slot + end_loc - start_loc + k_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) + cur_ctx += block_size + block_id += 1 + # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] + # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] + k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, + 8).permute(0, 2, 3, 1, 4).contiguous() + # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] + # to V_cache[num_blocks, num_kv_heads, head_size, block_size] + v_cache = v_cache.view(-1, block_size, num_kv_heads, + head_size).permute(0, 2, 3, 1).contiguous() + + # Warm up the Triton kernel by calling it once before actually measuring + # generation time + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + start_time = time.time() + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + end_time = time.time() + print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + scale = float(1.0 / (head_size**0.5)) + + # NOTE(DefTruth): In order to reuse _make_alibi_bias function, + # we have to pad query tensor before MQA/GQA expanding. + if query.shape[0] != key.shape[0]: + query_pad = torch.empty(sum(seq_lens), + num_heads, + head_size, + dtype=dtype) + query_pad.uniform_(-1e-3, 1e-3) + seq_start = 0 + query_start = 0 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + query_pad[seq_start:seq_end, ...] = torch.cat([ + torch.zeros( + seq_len - query_len, num_heads, head_size, dtype=dtype), + query[query_start:query_end, ...] + ], + dim=0) + seq_start += seq_len + query_start += query_len + query = query_pad + + if num_kv_heads != num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # + # see also: vllm/model_executor/layers/attention.py + query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, + query.shape[-1]) + key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, + num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], num_kv_heads, + num_queries_per_kv, value.shape[-1]) + + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) + output_ref = torch.empty_like(output) + seq_start = 0 + query_start = 0 + start_time = time.time() + # Attention with alibi slopes. + # FIXME(DefTruth): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + # modified from: vllm/attention/backends/xformers.py#L343 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + out = xops.memory_efficient_attention_forward(query[:, + seq_start:seq_end], + key[:, + seq_start:seq_end], + value[:, + seq_start:seq_end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) + out = out.view_as(query[:, seq_start:seq_end]).view( + seq_len, num_heads, head_size) + output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, + ...]) + seq_start += seq_len + query_start += query_len + torch.cuda.synchronize() + end_time = time.time() + print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 79878b26c5294..997b25e887e30 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -472,7 +472,8 @@ def _fwd_kernel_alibi( stride_v_cache_bl, num_queries_per_kv: int, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, ): # attn_bias[] @@ -493,21 +494,24 @@ def _fwd_kernel_alibi( # initialize offsets offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) - q = tl.load( - Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) alibi_slope = tl.load(Alibi_slopes + cur_head) alibi_start_q = tl.arange( @@ -532,8 +536,9 @@ def _fwd_kernel_alibi( offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -567,7 +572,8 @@ def _fwd_kernel_alibi( acc = acc * acc_scale[:, None] # update acc v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), other=0.0) p = p.to(v.dtype) @@ -600,8 +606,9 @@ def _fwd_kernel_alibi( # -- compute qk ---- k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -637,8 +644,9 @@ def _fwd_kernel_alibi( # update acc v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len), other=0.0) p = p.to(v.dtype) @@ -656,7 +664,8 @@ def _fwd_kernel_alibi( out_ptrs = Out + off_o tl.store(out_ptrs, acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) return @torch.inference_mode() @@ -690,7 +699,6 @@ def context_attention_fwd(q, num_warps = 8 if Lk <= 64 else 8 if alibi_slopes is not None: - assert Lk == Lk_padded _fwd_kernel_alibi[grid]( q, k, @@ -735,6 +743,7 @@ def context_attention_fwd(q, num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1,