Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
LucasWilkinson committed Jan 30, 2025
1 parent 8cfa2f4 commit 506c932
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 69 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import torch

from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
Expand All @@ -7,37 +8,48 @@ def cdiv(a, b):
return (a + b - 1) // b


def test_decode_attention(B, L, H_Q, H_KV, D, CACHE_SIZE, PAGE_SIZE):
@pytest.mark.parametrize("B", [3, 5])
@pytest.mark.parametrize("L", [1027, 1025])
@pytest.mark.parametrize("H_Q", [32])
@pytest.mark.parametrize("H_KV", [32, 8])
@pytest.mark.parametrize("D_QK", [128, 192, 576])
@pytest.mark.parametrize("D_V", [128, 512])
@pytest.mark.parametrize("CACHE_SIZE", [16384])
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
assert CACHE_SIZE % PAGE_SIZE == 0
dtype = torch.bfloat16
seq_len = L # This represents the number of tokens already in the sequence
sm_scale = 1.0 / (D**0.5)
sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 8

num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
req_to_page = torch.randint(0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda")
req_to_page = torch.randint(0,
CACHE_SIZE // PAGE_SIZE,
(B, num_pages_per_batch, 1),
device="cuda")
req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
1, 1, -1)
req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous()

# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")

# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer = torch.randn(CACHE_SIZE, H_KV, D, dtype=dtype, device="cuda")
v_buffer = torch.randn(CACHE_SIZE, H_KV, D, dtype=dtype, device="cuda")
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")

# o will have the same shape as q
o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")

b_req_idx = torch.arange(B, device="cuda")
b_seq_len = torch.full((B,), seq_len, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda")

attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D + 1),
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
Expand All @@ -56,24 +68,15 @@ def test_decode_attention(B, L, H_Q, H_KV, D, CACHE_SIZE, PAGE_SIZE):
)

# Page size can be larger than 1.
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D)
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D)
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)

o1 = torch.zeros_like(o)

b_seq_len = torch.full((B,), seq_len, device="cuda")

attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D + 1),
dtype=torch.float32,
device="cuda",
)

# Trick: Flatten the KV cache so that we use page_size = 1 inside the kernel.
decode_attention_fwd(
q,
k_buffer.flatten(0, 1),
v_buffer.flatten(0, 1),
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
Expand All @@ -82,14 +85,5 @@ def test_decode_attention(B, L, H_Q, H_KV, D, CACHE_SIZE, PAGE_SIZE):
sm_scale,
PAGE_SIZE,
)
print(torch.allclose(o, o1))
assert torch.allclose(o, o1)


if __name__ == "__main__":
# GQA
test_decode_attention(B=5, L=1027, H_Q=32, H_KV=8, D=128, CACHE_SIZE=16384, PAGE_SIZE=1)
test_decode_attention(B=5, L=1027, H_Q=32, H_KV=8, D=128, CACHE_SIZE=16384, PAGE_SIZE=16)
# MHA
test_decode_attention(B=3, L=1025, H_Q=32, H_KV=32, D=128, CACHE_SIZE=16384, PAGE_SIZE=1)
test_decode_attention(B=3, L=1025, H_Q=32, H_KV=32, D=128, CACHE_SIZE=16384, PAGE_SIZE=16)
assert torch.allclose(o, o1)
20 changes: 10 additions & 10 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,16 +280,16 @@ def _forward_decode(
def forward(
self,
layer: AttentionLayer,
hidden_states_or_cq: torch.Tensor, # query in unified attn
ckv_normed: torch.Tensor, # key in unified attn
hidden_states_or_q_c: torch.Tensor, # query in unified attn
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: MLAMetadataCommon,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for FlashInferMLAImpl")
"output is not yet supported for TritonMLAImpl")

is_decode = attn_metadata.decode_metadata is not None
is_prefill = attn_metadata.prefill_metadata is not None
Expand All @@ -302,14 +302,14 @@ def forward(
k_pe = k_pe.unsqueeze(1)

if is_decode:
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_cq)
q_pe = torch.matmul(hidden_states_or_cq, self.W_QR)\
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim)
q_pe, k_pe = \
self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe)
else:
assert is_prefill
q = self.q_proj(hidden_states_or_cq)[0]\
q = self.q_proj(hidden_states_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)

# TODO(lucas): there must be a nicer way to write this line
Expand All @@ -321,7 +321,7 @@ def forward(
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
ckv_normed,
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
Expand All @@ -330,7 +330,7 @@ def forward(
)

if attn_metadata.prefill_metadata is not None:
return self._forward_prefill(q, ckv_normed, k_pe, attn_metadata)
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata)

if attn_metadata.decode_metadata is not None:
return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata)
Expand All @@ -339,13 +339,13 @@ def forward(
def _forward_prefill_flash(
self,
q: torch.Tensor,
ckv_normed: torch.Tensor,
k_c_normed: torch.Tensor,
k_pe: torch.Tensor,
seq_start_loc: torch.Tensor,
max_prefill_seq_len: int,
) -> torch.Tensor:

kv_nope = self.kv_b_proj(ckv_normed)[0]\
kv_nope = self.kv_b_proj(k_c_normed)[0]\
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,11 +624,11 @@ def __init__(
def _forward_prefill(
self,
q: torch.Tensor,
ckv_normed: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: TritonMLAMetadata,
) -> torch.Tensor:
return self._forward_prefill_flash(q, ckv_normed, k_pe,
return self._forward_prefill_flash(q, kv_c_normed, k_pe,
attn_metadata.seq_start_loc,
attn_metadata.max_prefill_seq_len)

Expand Down
28 changes: 11 additions & 17 deletions vllm/attention/ops/triton_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _decode_att_m_fwd(
batch, head_num = q.shape[0], q.shape[1]

grid = (batch, head_num, NUM_KV_SPLITS)
kv_group_num = q.shape[1] // k_buffer.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2]

num_warps = 4 if kv_group_num == 1 else 2

Expand All @@ -202,10 +202,10 @@ def _decode_att_m_fwd(
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
k_buffer.stride(-2),
k_buffer.stride(-1),
v_buffer.stride(-2),
v_buffer.stride(-1),
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
Expand Down Expand Up @@ -405,7 +405,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_DV = triton.next_power_of_2(Lv)

batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2]

BLOCK_H = 16
NUM_KV_SPLITS = num_kv_splits
Expand Down Expand Up @@ -436,10 +436,10 @@ def _decode_grouped_att_m_fwd(
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
k_buffer.stride(-2),
k_buffer.stride(-1),
v_buffer.stride(-2),
v_buffer.stride(-1),
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
Expand Down Expand Up @@ -633,13 +633,7 @@ def decode_attention_fwd(
logit_cap=0.0,
):
assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[1]

if page_size > 1:
# Make the buffers look like page_size 1 since the original kernel only
# supported page size 1
k_buffer = k_buffer.flatten(0, 1)
v_buffer = v_buffer.flatten(0, 1)
kv_group_num = q.shape[1] // v_buffer.shape[-2]

if kv_group_num == 1:
# MHA
Expand Down
12 changes: 6 additions & 6 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,14 @@ def forward(
) -> torch.Tensor:
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_ckq = self.q_a_layernorm(ckq)
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_ckq = hidden_states
ckv_nope, ck_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
hidden_states_or_q_c = hidden_states
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
ckv_nope_normed = self.kv_a_layernorm(ckv_nope.contiguous())
return self.mla_attn(hidden_states_or_ckq, ckv_nope_normed, ck_pe,
kv_cache, attn_metadata)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
attn_metadata)


class DeepseekV2DecoderLayer(nn.Module):
Expand Down

0 comments on commit 506c932

Please sign in to comment.