From f6285313434977891b81241cc51de4766be779c6 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Fri, 7 Feb 2025 00:46:13 +0800 Subject: [PATCH] [Kernel] Make rotary_embedding ops more flexible with input shape (#12777) --- csrc/pos_encoding_kernels.cu | 103 +++++++++++++++++++--- tests/kernels/test_pos_encoding.py | 31 +++++-- vllm/attention/backends/mla/utils.py | 25 +----- vllm/model_executor/models/deepseek_v2.py | 13 +-- 4 files changed, 115 insertions(+), 57 deletions(-) diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 97184a8735593..c085d31a3e9b1 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel( void rotary_embedding( torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or - // [num_tokens, num_heads * head_size] + // [num_tokens, num_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox) { - int64_t num_tokens = query.numel() / query.size(-1); + // num_tokens = batch_size * seq_len + int64_t num_tokens = positions.numel(); + int positions_ndim = positions.dim(); + + // Make sure num_tokens dim is consistent across positions, query, and key. + TORCH_CHECK( + positions_ndim == 1 || positions_ndim == 2, + "positions must have shape [num_tokens] or [batch_size, seq_len]"); + if (positions_ndim == 1) { + TORCH_CHECK( + query.size(0) == positions.size(0) && key.size(0) == positions.size(0), + "query, key and positions must have the same number of tokens"); + } + if (positions_ndim == 2) { + TORCH_CHECK( + query.size(0) == positions.size(0) && + key.size(0) == positions.size(0) && + query.size(1) == positions.size(1) && + key.size(1) == positions.size(1), + "query, key and positions must have the same batch_size and seq_len"); + } + + // Make sure head_size is valid for query and key + // hidden_size = num_heads * head_size + int query_hidden_size = query.numel() / num_tokens; + int key_hidden_size = key.numel() / num_tokens; + TORCH_CHECK(query_hidden_size % head_size == 0); + TORCH_CHECK(key_hidden_size % head_size == 0); + + // Make sure query and key have consistent number of heads + int num_heads = query_hidden_size / head_size; + int num_kv_heads = key_hidden_size / head_size; + TORCH_CHECK(num_heads % num_kv_heads == 0); + int rot_dim = cos_sin_cache.size(1); - int num_heads = query.size(-1) / head_size; - int num_kv_heads = key.size(-1) / head_size; - int64_t query_stride = query.stride(-2); - int64_t key_stride = key.stride(-2); + int seq_dim_idx = positions_ndim - 1; + int64_t query_stride = query.stride(seq_dim_idx); + int64_t key_stride = key.stride(seq_dim_idx); dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -165,19 +201,58 @@ and process in batched manner. void batched_rotary_embedding( torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or - // [num_tokens, num_heads * head_size] + // [num_tokens, num_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox, int64_t rot_dim, - torch::Tensor& cos_sin_cache_offsets // [num_tokens] + torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size] ) { + // num_tokens = batch_size * seq_len int64_t num_tokens = cos_sin_cache_offsets.size(0); - int num_heads = query.size(-1) / head_size; - int num_kv_heads = key.size(-1) / head_size; - int64_t query_stride = query.stride(-2); - int64_t key_stride = key.stride(-2); + TORCH_CHECK( + positions.size(0) == num_tokens || positions.numel() == num_tokens, + "positions must have the same num_tokens or batch_size as " + "cos_sin_cache_offsets"); + + int positions_ndim = positions.dim(); + // Make sure num_tokens dim is consistent across positions, query, and key. + TORCH_CHECK( + positions_ndim == 1 || positions_ndim == 2, + "positions must have shape [num_tokens] or [batch_size, seq_len]"); + if (positions_ndim == 1) { + TORCH_CHECK( + query.size(0) == positions.size(0) && key.size(0) == positions.size(0), + "query, key and positions must have the same number of tokens"); + } + if (positions_ndim == 2) { + TORCH_CHECK( + query.size(0) == positions.size(0) && + key.size(0) == positions.size(0) && + query.size(1) == positions.size(1) && + key.size(1) == positions.size(1), + "query, key and positions must have the same batch_size and seq_len"); + } + + // Make sure head_size is valid for query and key + int query_hidden_size = query.numel() / num_tokens; + int key_hidden_size = key.numel() / num_tokens; + TORCH_CHECK(query_hidden_size % head_size == 0); + TORCH_CHECK(key_hidden_size % head_size == 0); + + // Make sure query and key have concistent number of heads + int num_heads = query_hidden_size / head_size; + int num_kv_heads = key_hidden_size / head_size; + TORCH_CHECK(num_heads % num_kv_heads == 0); + + int seq_dim_idx = positions_ndim - 1; + int64_t query_stride = query.stride(seq_dim_idx); + int64_t key_stride = key.stride(seq_dim_idx); dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 5b7b0fda2be6a..af9bfd2f0f521 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from itertools import accumulate, product -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional import pytest import torch @@ -24,7 +24,21 @@ ] +def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, + head_size: int) -> tuple[int, ...]: + return (batch_size, seq_len, num_heads * head_size) + + +def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int, + head_size: int) -> tuple[int, ...]: + return (batch_size, seq_len, num_heads, head_size) + + +TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape] + + @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("seq_len", SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -36,6 +50,7 @@ @torch.inference_mode() def test_rotary_embedding( is_neox_style: bool, + tensor_shape_fn: Callable[[int, int, int, int], tuple[int]], batch_size: int, seq_len: int, num_heads: int, @@ -58,10 +73,8 @@ def test_rotary_embedding( rope = rope.to(dtype=dtype) positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=dtype) + query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) + query = torch.randn(query_shape, dtype=dtype) key = torch.randn_like(query) # NOTE(woosuk): The reference implementation should be executed first @@ -80,6 +93,7 @@ def test_rotary_embedding( @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("seq_len", SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -91,6 +105,7 @@ def test_rotary_embedding( @torch.inference_mode() def test_batched_rotary_embedding( is_neox_style: bool, + tensor_shape_fn: Callable[[int, int, int, int], tuple[int]], batch_size: int, seq_len: int, num_heads: int, @@ -113,10 +128,8 @@ def test_batched_rotary_embedding( rope = rope.to(dtype=dtype) positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=dtype) + query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) + query = torch.randn(query_shape, dtype=dtype) key = torch.randn_like(query) # NOTE(woosuk): The reference implementation should be executed first diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index e1285d1fad3c1..c22f7e92103b8 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -424,24 +424,6 @@ def _forward_decode( ) -> torch.Tensor: raise NotImplementedError - def apply_pure_rope( - self, - input_positions: torch.Tensor, - q_pe: torch.Tensor, - k_pe: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - seq_len = input_positions.size(0) - ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape - - q_pe, k_pe = self.rotary_emb( - input_positions, - q_pe.reshape(seq_len, -1), - k_pe.reshape(seq_len, -1), - ) - q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape) - - return q_pe, k_pe - def forward( self, layer: AttentionLayer, @@ -466,14 +448,13 @@ def forward( # Restore head dim (for rotary embedding) k_pe = k_pe.unsqueeze(1) assert hasattr(attn_metadata, "input_positions") - rope_fn = (self.rotary_emb - if self.use_yarn_rope else self.apply_pure_rope) if is_decode: q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c) q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\ .view(-1, self.num_heads, self.qk_rope_head_dim) - q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe) + q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, + k_pe) else: assert is_prefill q = self.q_proj(hidden_states_or_q_c)[0]\ @@ -481,7 +462,7 @@ def forward( # TODO(lucas): there must be a nicer way to write this line q[..., self.qk_nope_head_dim:], k_pe = \ - rope_fn( + self.rotary_emb( attn_metadata.input_positions, q[..., self.qk_nope_head_dim:], k_pe) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 773f5abe71dae..0c6f07ce7b112 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -257,9 +257,7 @@ def __init__( prefix=f"{prefix}.o_proj") if rope_scaling: rope_scaling["rope_type"] = 'deepseek_yarn' - self.use_normal_rope = False - else: - self.use_normal_rope = True + self.rotary_emb = get_rope(qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, @@ -309,17 +307,8 @@ def forward( k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = latent_cache[:, :, self.kv_lora_rank:] - if self.use_normal_rope: - seq_len = positions.size(0) - ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape - q_pe = q_pe.reshape(seq_len, -1) - k_pe = k_pe.reshape(seq_len, -1) - q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - if self.use_normal_rope: - q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape) - q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., :self.qk_nope_head_dim] = k_nope