From d32add049b2d8f1838fb9c7f863aed9b7d345892 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 10 Aug 2023 01:02:15 +0000 Subject: [PATCH] add pos_encoding impl (#11) * add pos_encoding impl * add benchmark and add open mp parallel --- benchmarks/kernels/pos_encoding.py | 49 +++++++++++++++++++++ csrc/pos_encoding.cpp | 70 +++++++++++++++++++++++++++++- tests/kernels/test_pos_encoding.py | 4 +- 3 files changed, 121 insertions(+), 2 deletions(-) create mode 100644 benchmarks/kernels/pos_encoding.py diff --git a/benchmarks/kernels/pos_encoding.py b/benchmarks/kernels/pos_encoding.py new file mode 100644 index 0000000000000..68bfb0993fdbe --- /dev/null +++ b/benchmarks/kernels/pos_encoding.py @@ -0,0 +1,49 @@ +from threadpoolctl import threadpool_info +from pprint import pprint + +import torch +from benchmark import KernelBenchmark +from vllm import pos_encoding_ops + +class PosEncodingBench(KernelBenchmark): + def __init__( + self, + loop_time, + num_tokens: int, + num_heads: int, + head_size: int, + max_position: int, + rotary_dim: int, + dtype:torch.dtype, + device: torch.device) -> None: + super().__init__(loop_time) + base: int = 10000 + self.positions = torch.randint(0, max_position, (num_tokens, ), device=device) + query = torch.randn(num_tokens, + num_heads * head_size, + dtype=dtype, + device=device) + key = torch.randn(num_tokens, + num_heads * head_size, + dtype=dtype, + device=device) + # Create the rotary embedding. + inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) + t = torch.arange(max_position).float() + freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) + cos = freqs.cos() + sin = freqs.sin() + self.head_size = head_size + self.cos_sin_cache = torch.cat((cos, sin), dim=-1) + self.cos_sin_cache = self.cos_sin_cache.to(dtype=dtype, device=device) + self.out_query = query.clone() + self.out_key = key.clone() + + def _run(self): + for i in range(self.loop_time): + pos_encoding_ops.rotary_embedding_neox(self.positions, self.out_query, self.out_key, self.head_size, self.cos_sin_cache) + +bench = PosEncodingBench(10, num_tokens=4096, num_heads=5, head_size=128, max_position=8192, rotary_dim=128, dtype=torch.float32, device=torch.device("cpu")) +bench.execute() + +pprint(threadpool_info()) diff --git a/csrc/pos_encoding.cpp b/csrc/pos_encoding.cpp index 2c58b8f05cb4b..1078d253cd6e2 100644 --- a/csrc/pos_encoding.cpp +++ b/csrc/pos_encoding.cpp @@ -1,10 +1,78 @@ #include #include +template +void rotary_embedding_impl(const int64_t *__restrict__ positions, // [num_tokens] + scalar_t *__restrict__ query, // [num_tokens, num_heads, head_size] + scalar_t *__restrict__ key, // [num_tokens, num_kv_heads, head_size] + const scalar_t *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, + const int stride, + const int num_heads, + const int num_kv_heads, + const int head_size, + const int num_tokens) { +#pragma omp parallel for + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + int64_t pos = positions[token_idx]; + const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + + const int embed_dim = rot_dim / 2; + const int nq = num_heads * embed_dim; + for (int i = 0; i < num_heads; ++i) { + const int head_idx = i; + const int token_head = token_idx * stride + head_idx * head_size; + for (int j = 0; j < embed_dim; ++j) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int out_x = token_head + x_index; + const int out_y = token_head + y_index; + + const scalar_t cos = *(cache_ptr + x_index); + const scalar_t sin = *(cache_ptr + y_index); + + const scalar_t q_x = query[token_head + x_index]; + const scalar_t q_y = query[token_head + y_index]; + query[out_x] = q_x * cos - q_y * sin; + query[out_y] = q_y * cos + q_x * sin; + + if (head_idx < num_kv_heads) { + const scalar_t k_x = key[token_head + x_index]; + const scalar_t k_y = key[token_head + y_index]; + key[out_x] = k_x * cos - k_y * sin; + key[out_y] = k_y * cos + k_x * sin; + } + } + } + } +} + void rotary_embedding_cpu(torch::Tensor &positions, torch::Tensor &query, torch::Tensor &key, int head_size, torch::Tensor &cos_sin_cache) { - TORCH_CHECK(false, "Unsupported rotary_embedding_neox on cpu.") + TORCH_CHECK(query.scalar_type() == c10::ScalarType::Float); + + int num_tokens = query.size(0); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(1) / head_size; + int num_kv_heads = key.size(1) / head_size; + int stride = query.stride(0); + TORCH_CHECK(stride == key.stride(0)); + + AT_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding_impl", [&] { + rotary_embedding_impl( positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + rot_dim, + stride, + num_heads, + num_kv_heads, + head_size, + num_tokens); + }); } diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 0d255900d4c11..5bf2936deb639 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from vllm import pos_encoding_ops +import sys IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -113,6 +114,7 @@ def test_rotary_embedding( head_size: int, rotary_dim: Optional[int], dtype: torch.dtype, + device: torch.device, seed: int, max_position: int = 8192, base: int = 10000, @@ -140,7 +142,7 @@ def test_rotary_embedding( cos = freqs.cos() sin = freqs.sin() cos_sin_cache = torch.cat((cos, sin), dim=-1) - cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device=device) # Run the kernel. The kernel is in-place, so we need to clone the inputs. out_query = query.clone()