Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Kernel] allow non-power-of-2 for prefix prefill with alibi #4573

Merged
merged 37 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
3edfc58
[Bugfix][Kernel] allow non-power-of-two head sizes in prefix prefill …
DefTruth Apr 30, 2024
2a8f222
remove un-need comments
DefTruth Apr 30, 2024
7c5ea9b
Update context_attention_fwd
DefTruth Apr 30, 2024
9a6b936
remove blanks
DefTruth May 1, 2024
caf2bdb
Merge branch 'vllm-project:main' into prefill-pow-2-alibi
DefTruth May 2, 2024
34f9fd8
Merge branch 'vllm-project:main' into prefill-pow-2-alibi
DefTruth May 3, 2024
2a536a4
format code
DefTruth May 3, 2024
62887d1
Merge branch 'vllm-project:main' into prefill-pow-2-alibi
DefTruth May 4, 2024
c62409c
add prefix prefill alibi test
DefTruth May 4, 2024
e6bdcaf
add prefix prefill alibi test
DefTruth May 4, 2024
eeafc24
Merge branch 'vllm-project:main' into prefill-pow-2-alibi
DefTruth May 5, 2024
6bb8823
add comments
DefTruth May 5, 2024
cf78a94
add comments
DefTruth May 5, 2024
6bcfb75
update prefix prefill alibi tests
DefTruth May 5, 2024
f00be26
move alibi test into prefix prefill tests
DefTruth May 5, 2024
590a6c7
move alibi test into prefix prefill tests
DefTruth May 5, 2024
1eccf6e
update prefix prefill tests
DefTruth May 5, 2024
15f3bf0
update prefix prefill tests
DefTruth May 5, 2024
f91e3b0
update prefix prefill tests
DefTruth May 5, 2024
2e89999
change random seed and fix CI
DefTruth May 6, 2024
ef7348a
change random seed and fix CI
DefTruth May 6, 2024
e54492f
Merge branch 'vllm-project:main' into prefill-pow-2-alibi
DefTruth May 6, 2024
d7242d0
update tests
DefTruth May 7, 2024
ee22f7a
Merge branch 'vllm-project:main' into prefill-pow-2-alibi
DefTruth May 7, 2024
d0259d8
update tests
DefTruth May 7, 2024
de405b9
Merge branch 'prefill-pow-2-alibi' of github.com:DefTruth/vllm into p…
DefTruth May 7, 2024
02d4ebe
update tests
DefTruth May 7, 2024
51b86e9
fix k load and update tests
DefTruth May 7, 2024
5b84d8c
fix k load and update tests
DefTruth May 7, 2024
a61a58b
merged alibi test into prefix prefill test script
DefTruth May 7, 2024
8705f30
merged alibi test into prefix prefill test script
DefTruth May 7, 2024
319ca30
add more tests for small head sizes
DefTruth May 8, 2024
b8d77e1
add more tests for small head sizes
DefTruth May 8, 2024
7aa82b1
Merge branch 'vllm-project:main' into prefill-pow-2-alibi
DefTruth May 8, 2024
b3c3876
Invoke CI
DefTruth May 8, 2024
866e35b
Merge branch 'prefill-pow-2-alibi' of github.com:DefTruth/vllm into p…
DefTruth May 8, 2024
9691423
update tests for small head sizes
DefTruth May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 242 additions & 1 deletion tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import random
import time

Expand All @@ -6,11 +7,12 @@
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask

from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd

NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64]
HEAD_SIZES = [128, 96]
HEAD_SIZES = [128, 96, 24]
DTYPES = [torch.float16]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
Expand Down Expand Up @@ -207,3 +209,242 @@ def test_contexted_kv_attention(
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.reshape(output.shape)
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_contexted_kv_attention_alibi(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
dtype: torch.dtype,
device: str,
) -> None:
random.seed(0)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed(0)
torch.set_default_device(device)

# Need this, otherwise when we capture the graph the process
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
#
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
torch.cuda.set_device(device)

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
# Fork from: vllm/vllm/model_executor/models/bloom.py#L44
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)

if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes

alibi_slopes = _get_alibi_slopes(num_heads).to(device)

MAX_SEQ_LEN = 1024
MAX_CTX_LEN = 1024
BS = 10
cache_size = 640
block_size = 32
max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv

num_tokens = sum(query_lens)
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
query.uniform_(-1e-3, 1e-3)
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)

kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1)

k_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
v_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
values = values[torch.randperm(cache_size)]
block_table = values[:BS * max_block_per_request].view(
BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
dtype=torch.long),
dim=0)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
dtype=torch.long),
dim=0)
for i in range(BS):
for j in range(query_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
b_ctx_len[i] + j])
cur_ctx = 0
block_id = 0
while cur_ctx < b_ctx_len[i]:
start_loc = b_seq_start_loc[i] + cur_ctx
if cur_ctx + block_size > b_ctx_len[i]:
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
else:
end_loc = start_loc + block_size
start_slot = block_table[i, block_id] * block_size
end_slot = start_slot + end_loc - start_loc
k_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc])
v_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc])
cur_ctx += block_size
block_id += 1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
8).permute(0, 2, 3, 1, 4).contiguous()
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous()

# Warm up the Triton kernel by calling it once before actually measuring
# generation time
context_attention_fwd(query,
k,
v,
output,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
alibi_slopes=alibi_slopes)
torch.cuda.synchronize()
start_time = time.time()
context_attention_fwd(query,
k,
v,
output,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
alibi_slopes=alibi_slopes)
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
scale = float(1.0 / (head_size**0.5))

# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# we have to pad query tensor before MQA/GQA expanding.
if query.shape[0] != key.shape[0]:
query_pad = torch.empty(sum(seq_lens),
num_heads,
head_size,
dtype=dtype)
query_pad.uniform_(-1e-3, 1e-3)
seq_start = 0
query_start = 0
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
query_pad[seq_start:seq_end, ...] = torch.cat([
torch.zeros(
seq_len - query_len, num_heads, head_size, dtype=dtype),
query[query_start:query_end, ...]
],
dim=0)
seq_start += seq_len
query_start += query_len
query = query_pad

if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])

query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)

attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output_ref = torch.empty_like(output)
seq_start = 0
query_start = 0
start_time = time.time()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
out = xops.memory_efficient_attention_forward(query[:,
seq_start:seq_end],
key[:,
seq_start:seq_end],
value[:,
seq_start:seq_end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
...])
seq_start += seq_len
query_start += query_len
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
41 changes: 25 additions & 16 deletions vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,8 @@ def _fwd_kernel_alibi(
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
BLOCK_N: tl.constexpr,
):
# attn_bias[]
Expand All @@ -493,21 +494,24 @@ def _fwd_kernel_alibi(

# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)

q = tl.load(
Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
dim_mask = tl.where(
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)

q = tl.load(Q + off_q,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)

# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)

alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(
Expand All @@ -532,8 +536,9 @@ def _fwd_kernel_alibi(
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
Expand Down Expand Up @@ -567,7 +572,8 @@ def _fwd_kernel_alibi(
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0)

p = p.to(v.dtype)
Expand Down Expand Up @@ -600,8 +606,9 @@ def _fwd_kernel_alibi(
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) <
cur_batch_seq_len - cur_batch_ctx_len,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) <
cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
Expand Down Expand Up @@ -637,8 +644,9 @@ def _fwd_kernel_alibi(
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) <
cur_batch_seq_len - cur_batch_ctx_len,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) <
cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)

p = p.to(v.dtype)
Expand All @@ -656,7 +664,8 @@ def _fwd_kernel_alibi(
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
return

@torch.inference_mode()
Expand Down Expand Up @@ -690,7 +699,6 @@ def context_attention_fwd(q,

num_warps = 8 if Lk <= 64 else 8
if alibi_slopes is not None:
assert Lk == Lk_padded
_fwd_kernel_alibi[grid](
q,
k,
Expand Down Expand Up @@ -735,6 +743,7 @@ def context_attention_fwd(q,
num_queries_per_kv=num_queries_per_kv,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
Expand Down
Loading