From d6e4a130b028f42a7f413d99eb91a4395fa7a04a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 26 Feb 2024 15:00:54 -0800 Subject: [PATCH] [Minor] Remove gather_cached_kv kernel (#3043) --- csrc/cache.h | 7 -- csrc/cache_kernels.cu | 161 ------------------------------------------ csrc/pybind.cpp | 4 -- 3 files changed, 172 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 21c71830f7942..765e231abd26f 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -23,13 +23,6 @@ void reshape_and_cache( torch::Tensor& slot_mapping, const std::string& kv_cache_dtype); -void gather_cached_kv( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping); - // Just for unittest void convert_fp8_e5m2( torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index ceb7347d94670..7254010b8e3a9 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -269,167 +269,6 @@ void reshape_and_cache( namespace vllm { -// Grid: (num_blocks, block_size). -template -__global__ void gather_cached_kv_kernel( - scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size] - scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size] - const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x) { - const int token_idx = blockIdx.x; - const int slot_idx = slot_mapping[token_idx]; - const int block_idx = slot_idx / block_size; - const int block_offset = slot_idx % block_size; - - const int num_tokens = num_heads * head_size; - for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { - const int tgt_key_idx = token_idx * key_stride + i; - const int tgt_value_idx = token_idx * value_stride + i; - - const int head_idx = i / head_size; - const int head_offset = i % head_size; - const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension - const int x_offset = head_offset % x; - - const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int src_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; - - key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]); - value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]); - } -} - -template -__global__ void gather_cached_kv_kernel_optimized( - scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size] - scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size] - const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int *__restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x) -{ - const int token_idx = blockIdx.x; - const int slot_idx = slot_mapping[token_idx]; - const int block_idx = slot_idx / block_size; - const int block_offset = slot_idx % block_size; - - const int dim = num_heads * head_size; - assert(dim % 4 == 0); // this is true for known use cases - const int unroll_factor = 4; - const int unrolled_dim = dim / unroll_factor; - - for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x) - { - int tgt_key_indices[unroll_factor]; - int tgt_value_indices[unroll_factor]; - int src_key_indices[unroll_factor]; - int src_value_indices[unroll_factor]; - scalar_t keys_to_store[unroll_factor]; - scalar_t values_to_store[unroll_factor]; - - #pragma unroll - for (int j = 0; j < unroll_factor; ++j) - { - int index = i + j * unrolled_dim; - - const int tgt_key_idx = token_idx * key_stride + index; - const int tgt_value_idx = token_idx * value_stride + index; - - const int head_idx = index / head_size; - const int head_offset = index % head_size; - const int x_idx = head_offset / x; - const int x_offset = head_offset % x; - - const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int src_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; - - tgt_key_indices[j] = tgt_key_idx; - tgt_value_indices[j] = tgt_value_idx; - src_key_indices[j] = src_key_idx; - src_value_indices[j] = src_value_idx; - - keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]); - values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]); - } - - #pragma unroll - for (int j = 0; j < unroll_factor; ++j) - { - key[tgt_key_indices[j]] = keys_to_store[j]; - value[tgt_value_indices[j]] = values_to_store[j]; - } - } -} - -} // namespace vllm - -void gather_cached_kv( - torch::Tensor& key, // [out] [num_tokens, num_heads, head_size] - torch::Tensor& value, // [out] [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size] - torch::Tensor& slot_mapping) // [in] [num_tokens] -{ - int num_tokens = key.size(0); - int num_heads = key.size(1); - int head_size = key.size(2); - int block_size = key_cache.size(3); - int x = key_cache.size(4); - - int key_stride = key.stride(0); - int value_stride = value.stride(0); - - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * head_size, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( - key.scalar_type(), - "gather_cached_kv_kernel_optimized", - [&] { - vllm::gather_cached_kv_kernel_optimized<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - slot_mapping.data_ptr(), - key_stride, - value_stride, - num_heads, - head_size, - block_size, - x); - }); -} - -namespace vllm { - template __global__ void convert_fp8_e5m2_kernel( const Tin* __restrict__ src_cache, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 24c22020131e8..5d062bb5700bc 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -79,10 +79,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); - cache_ops.def( - "gather_cached_kv", - &gather_cached_kv, - "Gather key and value from the cache into contiguous QKV tensors"); cache_ops.def( "convert_fp8_e5m2", &convert_fp8_e5m2,