Skip to content

Commit

Permalink
add pos_encoding impl (vllm-project#11)
Browse files Browse the repository at this point in the history
* add pos_encoding impl

* add benchmark and add open mp parallel
  • Loading branch information
jikunshang authored and bigPYJ1151 committed Sep 12, 2023
1 parent 06d9a3e commit d32add0
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 2 deletions.
49 changes: 49 additions & 0 deletions benchmarks/kernels/pos_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from threadpoolctl import threadpool_info
from pprint import pprint

import torch
from benchmark import KernelBenchmark
from vllm import pos_encoding_ops

class PosEncodingBench(KernelBenchmark):
def __init__(
self,
loop_time,
num_tokens: int,
num_heads: int,
head_size: int,
max_position: int,
rotary_dim: int,
dtype:torch.dtype,
device: torch.device) -> None:
super().__init__(loop_time)
base: int = 10000
self.positions = torch.randint(0, max_position, (num_tokens, ), device=device)
query = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device=device)
key = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device=device)
# Create the rotary embedding.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
self.head_size = head_size
self.cos_sin_cache = torch.cat((cos, sin), dim=-1)
self.cos_sin_cache = self.cos_sin_cache.to(dtype=dtype, device=device)
self.out_query = query.clone()
self.out_key = key.clone()

def _run(self):
for i in range(self.loop_time):
pos_encoding_ops.rotary_embedding_neox(self.positions, self.out_query, self.out_key, self.head_size, self.cos_sin_cache)

bench = PosEncodingBench(10, num_tokens=4096, num_heads=5, head_size=128, max_position=8192, rotary_dim=128, dtype=torch.float32, device=torch.device("cpu"))
bench.execute()

pprint(threadpool_info())
70 changes: 69 additions & 1 deletion csrc/pos_encoding.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,78 @@
#include <c10/util/Exception.h>
#include <torch/extension.h>

template <typename scalar_t>
void rotary_embedding_impl(const int64_t *__restrict__ positions, // [num_tokens]
scalar_t *__restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t *__restrict__ key, // [num_tokens, num_kv_heads, head_size]
const scalar_t *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int stride,
const int num_heads,
const int num_kv_heads,
const int head_size,
const int num_tokens) {
#pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;

const int embed_dim = rot_dim / 2;
const int nq = num_heads * embed_dim;
for (int i = 0; i < num_heads; ++i) {
const int head_idx = i;
const int token_head = token_idx * stride + head_idx * head_size;
for (int j = 0; j < embed_dim; ++j) {
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;

const int out_x = token_head + x_index;
const int out_y = token_head + y_index;

const scalar_t cos = *(cache_ptr + x_index);
const scalar_t sin = *(cache_ptr + y_index);

const scalar_t q_x = query[token_head + x_index];
const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin;

if (head_idx < num_kv_heads) {
const scalar_t k_x = key[token_head + x_index];
const scalar_t k_y = key[token_head + y_index];
key[out_x] = k_x * cos - k_y * sin;
key[out_y] = k_y * cos + k_x * sin;
}
}
}
}
}

void rotary_embedding_cpu(torch::Tensor &positions, torch::Tensor &query,
torch::Tensor &key, int head_size,
torch::Tensor &cos_sin_cache) {
TORCH_CHECK(false, "Unsupported rotary_embedding_neox on cpu.")
TORCH_CHECK(query.scalar_type() == c10::ScalarType::Float);

int num_tokens = query.size(0);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size;
int num_kv_heads = key.size(1) / head_size;
int stride = query.stride(0);
TORCH_CHECK(stride == key.stride(0));

AT_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding_impl", [&] {
rotary_embedding_impl( positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
stride,
num_heads,
num_kv_heads,
head_size,
num_tokens);
});
}


Expand Down
4 changes: 3 additions & 1 deletion tests/kernels/test_pos_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional as F

from vllm import pos_encoding_ops
import sys

IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
Expand Down Expand Up @@ -113,6 +114,7 @@ def test_rotary_embedding(
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
device: torch.device,
seed: int,
max_position: int = 8192,
base: int = 10000,
Expand Down Expand Up @@ -140,7 +142,7 @@ def test_rotary_embedding(
cos = freqs.cos()
sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1)
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device=device)

# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = query.clone()
Expand Down

0 comments on commit d32add0

Please sign in to comment.