Skip to content

Commit

Permalink
[Kernel] Make rotary_embedding ops more flexible with input shape (vl…
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored and kerthcet committed Feb 21, 2025
1 parent b85c22c commit f628531
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 57 deletions.
103 changes: 89 additions & 14 deletions csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel(
void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1);
// num_tokens = batch_size * seq_len
int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim();

// Make sure num_tokens dim is consistent across positions, query, and key.
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}

// Make sure head_size is valid for query and key
// hidden_size = num_heads * head_size
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);

// Make sure query and key have consistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);

int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int64_t query_stride = query.stride(-2);
int64_t key_stride = key.stride(-2);
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx);

dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
Expand Down Expand Up @@ -165,19 +201,58 @@ and process in batched manner.
void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
) {
// num_tokens = batch_size * seq_len
int64_t num_tokens = cos_sin_cache_offsets.size(0);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int64_t query_stride = query.stride(-2);
int64_t key_stride = key.stride(-2);
TORCH_CHECK(
positions.size(0) == num_tokens || positions.numel() == num_tokens,
"positions must have the same num_tokens or batch_size as "
"cos_sin_cache_offsets");

int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key.
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}

// Make sure head_size is valid for query and key
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);

// Make sure query and key have concistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);

int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx);

dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
Expand Down
31 changes: 22 additions & 9 deletions tests/kernels/test_pos_encoding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from itertools import accumulate, product
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional

import pytest
import torch
Expand All @@ -24,7 +24,21 @@
]


def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads * head_size)


def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads, head_size)


TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]


@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
Expand All @@ -36,6 +50,7 @@
@torch.inference_mode()
def test_rotary_embedding(
is_neox_style: bool,
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
batch_size: int,
seq_len: int,
num_heads: int,
Expand All @@ -58,10 +73,8 @@ def test_rotary_embedding(
rope = rope.to(dtype=dtype)

positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query)

# NOTE(woosuk): The reference implementation should be executed first
Expand All @@ -80,6 +93,7 @@ def test_rotary_embedding(


@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
Expand All @@ -91,6 +105,7 @@ def test_rotary_embedding(
@torch.inference_mode()
def test_batched_rotary_embedding(
is_neox_style: bool,
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
batch_size: int,
seq_len: int,
num_heads: int,
Expand All @@ -113,10 +128,8 @@ def test_batched_rotary_embedding(
rope = rope.to(dtype=dtype)

positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query)

# NOTE(woosuk): The reference implementation should be executed first
Expand Down
25 changes: 3 additions & 22 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,24 +424,6 @@ def _forward_decode(
) -> torch.Tensor:
raise NotImplementedError

def apply_pure_rope(
self,
input_positions: torch.Tensor,
q_pe: torch.Tensor,
k_pe: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
seq_len = input_positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape

q_pe, k_pe = self.rotary_emb(
input_positions,
q_pe.reshape(seq_len, -1),
k_pe.reshape(seq_len, -1),
)
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)

return q_pe, k_pe

def forward(
self,
layer: AttentionLayer,
Expand All @@ -466,22 +448,21 @@ def forward(
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions")
rope_fn = (self.rotary_emb
if self.use_yarn_rope else self.apply_pure_rope)

if is_decode:
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim)
q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe)
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
k_pe)
else:
assert is_prefill
q = self.q_proj(hidden_states_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)

# TODO(lucas): there must be a nicer way to write this line
q[..., self.qk_nope_head_dim:], k_pe = \
rope_fn(
self.rotary_emb(
attn_metadata.input_positions,
q[..., self.qk_nope_head_dim:], k_pe)

Expand Down
13 changes: 1 addition & 12 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,7 @@ def __init__(
prefix=f"{prefix}.o_proj")
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.use_normal_rope = False
else:
self.use_normal_rope = True

self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
Expand Down Expand Up @@ -309,17 +307,8 @@ def forward(
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]

if self.use_normal_rope:
seq_len = positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(seq_len, -1)
k_pe = k_pe.reshape(seq_len, -1)

q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

if self.use_normal_rope:
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)

q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
Expand Down

0 comments on commit f628531

Please sign in to comment.