From ec25547299d43c0ca22c813cd18d273022a5e0cf Mon Sep 17 00:00:00 2001 From: MitchellX Date: Tue, 21 Nov 2023 14:41:16 -0500 Subject: [PATCH 1/7] test --- xmc_test.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 xmc_test.py diff --git a/xmc_test.py b/xmc_test.py new file mode 100644 index 0000000000000..9daeafb9864cf --- /dev/null +++ b/xmc_test.py @@ -0,0 +1 @@ +test From 1b268233cdee7bcb47ab6c1513d87390bc81dd0c Mon Sep 17 00:00:00 2001 From: MitchellX Date: Mon, 27 Nov 2023 01:45:00 -0500 Subject: [PATCH 2/7] end_test --- xmc_test.py | 1 - 1 file changed, 1 deletion(-) delete mode 100644 xmc_test.py diff --git a/xmc_test.py b/xmc_test.py deleted file mode 100644 index 9daeafb9864cf..0000000000000 --- a/xmc_test.py +++ /dev/null @@ -1 +0,0 @@ -test From a74d372ddec6e2e6df733e06800a258fe16e71ba Mon Sep 17 00:00:00 2001 From: MitchellX Date: Tue, 28 Nov 2023 11:43:22 -0500 Subject: [PATCH 3/7] avoid multiple redefinition --- csrc/ops.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/csrc/ops.h b/csrc/ops.h index cfb18fbefd7a9..41cbecdd7ff22 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -1,3 +1,7 @@ +//avoid multiple redefinition +#ifndef OPS_H +#define OPS_H + #include void paged_attention_v1( @@ -73,3 +77,5 @@ void squeezellm_gemm( torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table); + +#endif // OPS_H From f14f82e6c36323506a6d82038f0746145156f9ae Mon Sep 17 00:00:00 2001 From: MitchellX Date: Mon, 4 Dec 2023 15:07:44 -0500 Subject: [PATCH 4/7] backup --- csrc/attention/attention_kernels_bk.cu | 872 +++++++++++++++++++++++++ 1 file changed, 872 insertions(+) create mode 100644 csrc/attention/attention_kernels_bk.cu diff --git a/csrc/attention/attention_kernels_bk.cu b/csrc/attention/attention_kernels_bk.cu new file mode 100644 index 0000000000000..78e8d8ecd6d41 --- /dev/null +++ b/csrc/attention/attention_kernels_bk.cu @@ -0,0 +1,872 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "attention_dtypes.h" +#include "attention_utils.cuh" + +#include + +#define WARP_SIZE 32 +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +namespace vllm { + +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +// TODO(woosuk): Merge the last two dimensions of the grid. +// Grid: (num_heads, num_seqs, max_num_partitions). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS, + int PARTITION_SIZE = 0> // Zero means no partitioning. +__device__ void paged_attention_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int max_num_partitions = gridDim.z; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int context_len = context_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int kv_head_idx = head_mapping[head_idx]; + const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread group + // fetch or compute 16 bytes at a time. + // For example, if the size of a thread group is 4 and the data type is half, + // then the vector size is 16 / (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... + // th vectors of the query, and so on. + // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(scalar_t); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by large numbers + // (e.g., kv_block_stride). + const int64_t physical_block_number = static_cast(block_table[block_idx]); + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th + // vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= context_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + + partition_idx; + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + scalar_t zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by large numbers + // (e.g., kv_block_stride). + const int64_t physical_block_number = static_cast(block_table[block_idx]); + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); + + const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec = *reinterpret_cast(v_ptr + offset); + if (block_idx == num_context_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens that are out of the context, + // we should explicitly zero out the values since they may contain NaNs. + // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for logits + // is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + +// Grid: (num_heads, num_seqs, 1). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void paged_attention_v1_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, + out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); +} + +// Grid: (num_heads, num_seqs, max_num_partitions). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS, + int PARTITION_SIZE> +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale, + block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, + q_stride, kv_block_stride, kv_head_stride); +} + +// Grid: (num_heads, num_seqs). +template< + typename scalar_t, + int HEAD_SIZE, + int NUM_THREADS, + int PARTITION_SIZE> +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = fmaxf(max_logit, l); + } + __syncthreads(); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + __syncthreads(); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + __syncthreads(); + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + } + from_float(out_ptr[i], acc); + } +} + +} // namespace vllm + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + cudaFuncSetAttribute( \ + vllm::paged_attention_v1_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); + +// TODO(woosuk): Tune NUM_THREADS. +template< + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void paged_attention_v1_launcher( + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int max_context_len, + const c10::optional& alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len + // Keep that in sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs, 1); + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V1(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V1(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V1(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V1(112); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V1(128); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V1(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + head_mapping, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len, \ + alibi_slopes); + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER(T, 32); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v1( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, + int max_context_len, + const c10::optional& alibi_slopes) { + if (query.dtype() == at::ScalarType::Float) { + CALL_V1_LAUNCHER_BLOCK_SIZE(float); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + context_lens_ptr, \ + max_num_partitions); + +template< + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128, + int PARTITION_SIZE = 512> +void paged_attention_v2_launcher( + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int max_context_len, + const c10::optional& alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + dim3 grid(num_heads, num_seqs, max_num_partitions); + int shared_mem_size = std::max(logits_size, outputs_size); + // For paged attention v2 reduce kernel. + dim3 reduce_grid(num_heads, num_seqs); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V2(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V2(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V2(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V2(112); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V2(128); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V2(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_launcher( \ + out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ + query, \ + key_cache, \ + value_cache, \ + head_mapping, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len, \ + alibi_slopes); + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER(T, 32); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v2( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, + int max_context_len, + const c10::optional& alibi_slopes) { + if (query.dtype() == at::ScalarType::Float) { + CALL_V2_LAUNCHER_BLOCK_SIZE(float); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP From 8062bdd13dc94ce9c3a84712f29e96c44941a2c5 Mon Sep 17 00:00:00 2001 From: MitchellX Date: Mon, 11 Dec 2023 10:54:05 -0500 Subject: [PATCH 5/7] add some notes --- csrc/attention/attention_kernels.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 78e8d8ecd6d41..b847763b67fc7 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -74,6 +74,8 @@ template< int BLOCK_SIZE, int NUM_THREADS, int PARTITION_SIZE = 0> // Zero means no partitioning. + +// xmc: replace the kernel with faster flashattention __device__ void paged_attention_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] From b8a727fb92ef5a3e0eaad35fdcadd420f878d5f4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 14 Dec 2023 17:33:00 +0000 Subject: [PATCH 6/7] Remove unnecessary change --- csrc/attention/attention_kernels.cu | 2 - csrc/attention/attention_kernels_bk.cu | 872 ------------------------- 2 files changed, 874 deletions(-) delete mode 100644 csrc/attention/attention_kernels_bk.cu diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index c81cc70c1ef1c..eff28d3dacd0e 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -82,8 +82,6 @@ template< int BLOCK_SIZE, int NUM_THREADS, int PARTITION_SIZE = 0> // Zero means no partitioning. - -// xmc: replace the kernel with faster flashattention __device__ void paged_attention_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] diff --git a/csrc/attention/attention_kernels_bk.cu b/csrc/attention/attention_kernels_bk.cu deleted file mode 100644 index 78e8d8ecd6d41..0000000000000 --- a/csrc/attention/attention_kernels_bk.cu +++ /dev/null @@ -1,872 +0,0 @@ -/* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * Copyright (c) 2023, The vLLM team. - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include - -#include "attention_dtypes.h" -#include "attention_utils.cuh" - -#include - -#define WARP_SIZE 32 -#define MAX(a, b) ((a) > (b) ? (a) : (b)) -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) - -namespace vllm { - -// Utility function for attention softmax. -template -inline __device__ float block_sum(float* red_smem, float sum) { - // Decompose the thread index into warp / lane. - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - - // Compute the sum per warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The warps compute the final sums. - if (lane < NUM_WARPS) { - sum = red_smem[lane]; - } - - // Parallel reduction inside the warp. -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); -} - -// TODO(woosuk): Merge the last two dimensions of the grid. -// Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - int PARTITION_SIZE = 0> // Zero means no partitioning. -__device__ void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int* __restrict__ head_mapping, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride) { - const int seq_idx = blockIdx.y; - const int partition_idx = blockIdx.z; - const int max_num_partitions = gridDim.z; - constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; - const int context_len = context_lens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { - // No work to do. Terminate the thread block. - return; - } - - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; - - // [start_block_idx, end_block_idx) is the range of blocks to process. - const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); - const int num_blocks = end_block_idx - start_block_idx; - - // [start_token_idx, end_token_idx) is the range of tokens to process. - const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); - const int num_tokens = end_token_idx - start_token_idx; - - constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS - assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int thread_idx = threadIdx.x; - const int warp_idx = thread_idx / WARP_SIZE; - const int lane = thread_idx % WARP_SIZE; - - const int head_idx = blockIdx.x; - const int num_heads = gridDim.x; - const int kv_head_idx = head_mapping[head_idx]; - const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; - - // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread group - // fetch or compute 16 bytes at a time. - // For example, if the size of a thread group is 4 and the data type is half, - // then the vector size is 16 / (4 * sizeof(half)) == 2. - constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); - using K_vec = typename Vec::Type; - using Q_vec = typename Vec::Type; - - constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; - constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; - - const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; - const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; - - // Load the query to registers. - // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... - // th vectors of the query, and so on. - // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. - const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; -#pragma unroll - for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { - const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); - } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs - - // Memory planning. - extern __shared__ char shared_mem[]; - // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. - float* logits = reinterpret_cast(shared_mem); - // Workspace for reduction. - __shared__ float red_smem[2 * NUM_WARPS]; - - // x == THREAD_GROUP_SIZE * VEC_SIZE - // Each thread group fetches x elements from the key at a time. - constexpr int x = 16 / sizeof(scalar_t); - float qk_max = -FLT_MAX; - - // Iterate over the key blocks. - // Each warp fetches a block of keys for each iteration. - // Each thread group in a warp fetches a key from the block, and computes - // dot product with the query. - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); - - // Load a key to registers. - // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th - // vectors of the key, and so on. - for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - K_vec k_vecs[NUM_VECS_PER_THREAD]; - -#pragma unroll - for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride - + physical_block_offset * x; - const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; - const int offset1 = (vec_idx * VEC_SIZE) / x; - const int offset2 = (vec_idx * VEC_SIZE) % x; - k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - } - - // Compute dot product. - // This includes a reduction across the threads in the same thread group. - float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); - // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; - - if (thread_group_offset == 0) { - // Store the partial reductions to shared memory. - // NOTE(woosuk): It is required to zero out the masked logits. - const bool mask = token_idx >= context_len; - logits[token_idx - start_token_idx] = mask ? 0.f : qk; - // Update the max value. - qk_max = mask ? qk_max : fmaxf(qk_max, qk); - } - } - } - - // Perform reduction across the threads in the same warp to get the - // max qk value for each "warp" (not across the thread block yet). - // The 0-th thread of each thread group already has its max qk value. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = qk_max; - } - __syncthreads(); - - // TODO(woosuk): Refactor this part. - // Get the max qk value for the sequence. - qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - // Broadcast the max qk value to all threads. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - // Get the sum of the exp values. - float exp_sum = 0.f; - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - float val = __expf(logits[i] - qk_max); - logits[i] = val; - exp_sum += val; - } - exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); - - // Compute softmax. - const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - logits[i] *= inv_sum; - } - __syncthreads(); - - // If partitioning is enabled, store the max logit and exp_sum. - if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; - *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; - *exp_sums_ptr = exp_sum; - } - - // Each thread will fetch 16 bytes from the value cache at a time. - constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); - using V_vec = typename Vec::Type; - using L_vec = typename Vec::Type; - using Float_L_vec = typename FloatVec::Type; - - constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; - constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); - - // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. - float accs[NUM_ROWS_PER_THREAD]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - accs[i] = 0.f; - } - - scalar_t zero_value; - zero(zero_value); - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); - const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); - - const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE) { - const int offset = row_idx * BLOCK_SIZE + physical_block_offset; - V_vec v_vec = *reinterpret_cast(v_ptr + offset); - if (block_idx == num_context_blocks - 1) { - // NOTE(woosuk): When v_vec contains the tokens that are out of the context, - // we should explicitly zero out the values since they may contain NaNs. - // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 - scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); -#pragma unroll - for (int j = 0; j < V_VEC_SIZE; j++) { - v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; - } - } - accs[i] += dot(logits_vec, v_vec); - } - } - } - - // Perform reduction within each warp. -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - float acc = accs[i]; -#pragma unroll - for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { - acc += __shfl_xor_sync(uint32_t(-1), acc, mask); - } - accs[i] = acc; - } - - // NOTE(woosuk): A barrier is required because the shared memory space for logits - // is reused for the output. - __syncthreads(); - - // Perform reduction across warps. - float* out_smem = reinterpret_cast(shared_mem); -#pragma unroll - for (int i = NUM_WARPS; i > 1; i /= 2) { - int mid = i / 2; - // Upper warps write to shared memory. - if (warp_idx >= mid && warp_idx < i) { - float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - dst[row_idx] = accs[i]; - } - } - } - __syncthreads(); - - // Lower warps update the output. - if (warp_idx < mid) { - const float* src = &out_smem[warp_idx * HEAD_SIZE]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - accs[i] += src[row_idx]; - } - } - } - __syncthreads(); - } - - // Write the final output. - if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - from_float(*(out_ptr + row_idx), accs[i]); - } - } - } -} - -// Grid: (num_heads, num_seqs, 1). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS> -__global__ void paged_attention_v1_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int* __restrict__ head_mapping, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride) { - paged_attention_kernel( - /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); -} - -// Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> -__global__ void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int* __restrict__ head_mapping, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride) { - paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale, - block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride); -} - -// Grid: (num_heads, num_seqs). -template< - typename scalar_t, - int HEAD_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> -__global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions) { - const int num_heads = gridDim.x; - const int head_idx = blockIdx.x; - const int seq_idx = blockIdx.y; - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); - if (num_partitions == 1) { - // No need to reduce. Only copy tmp_out to out. - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { - out_ptr[i] = tmp_out_ptr[i]; - } - // Terminate the thread block. - return; - } - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int warp_idx = threadIdx.x / WARP_SIZE; - const int lane = threadIdx.x % WARP_SIZE; - - // Size: 2 * num_partitions. - extern __shared__ char shared_mem[]; - // Workspace for reduction. - __shared__ float red_smem[2 * NUM_WARPS]; - - // Load max logits to shared memory. - float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; - float max_logit = -FLT_MAX; - for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { - const float l = max_logits_ptr[i]; - shared_max_logits[i] = l; - max_logit = fmaxf(max_logit, l); - } - __syncthreads(); - - // Get the global max logit. - // Reduce within the warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = max_logit; - } - __syncthreads(); - // Reduce across warps. - max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); - } - // Broadcast the max value to all threads. - max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); - - // Load rescaled exp sums to shared memory. - float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; - float global_exp_sum = 0.0f; - for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { - float l = shared_max_logits[i]; - float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); - global_exp_sum += rescaled_exp_sum; - shared_exp_sums[i] = rescaled_exp_sum; - } - __syncthreads(); - global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); - const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); - - // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; -#pragma unroll - for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { - float acc = 0.0f; - for (int j = 0; j < num_partitions; ++j) { - acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; - } - from_float(out_ptr[i], acc); - } -} - -} // namespace vllm - -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - cudaFuncSetAttribute( \ - vllm::paged_attention_v1_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ - out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - head_mapping_ptr, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride); - -// TODO(woosuk): Tune NUM_THREADS. -template< - typename T, - int BLOCK_SIZE, - int NUM_THREADS = 128> -void paged_attention_v1_launcher( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& head_mapping, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, - const c10::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); - assert(head_size % thread_group_size == 0); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_context_len * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); - // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len - // Keep that in sync with the logic here! - int shared_mem_size = std::max(logits_size, outputs_size); - - dim3 grid(num_heads, num_seqs, 1); - dim3 block(NUM_THREADS); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. - case 64: - LAUNCH_PAGED_ATTENTION_V1(64); - break; - case 80: - LAUNCH_PAGED_ATTENTION_V1(80); - break; - case 96: - LAUNCH_PAGED_ATTENTION_V1(96); - break; - case 112: - LAUNCH_PAGED_ATTENTION_V1(112); - break; - case 128: - LAUNCH_PAGED_ATTENTION_V1(128); - break; - case 256: - LAUNCH_PAGED_ATTENTION_V1(256); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - head_mapping, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len, \ - alibi_slopes); - -// NOTE(woosuk): To reduce the compilation time, we omitted block sizes -// 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER(T, 8); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER(T, 16); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER(T, 32); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } - -void paged_attention_v1( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& head_mapping, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] - int block_size, - int max_context_len, - const c10::optional& alibi_slopes) { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } -} - -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - head_mapping_ptr, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - context_lens_ptr, \ - max_num_partitions); - -template< - typename T, - int BLOCK_SIZE, - int NUM_THREADS = 128, - int PARTITION_SIZE = 512> -void paged_attention_v2_launcher( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& head_mapping, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, - const c10::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); - assert(head_size % thread_group_size == 0); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); - int logits_size = PARTITION_SIZE * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); - - // For paged attention v2 kernel. - dim3 grid(num_heads, num_seqs, max_num_partitions); - int shared_mem_size = std::max(logits_size, outputs_size); - // For paged attention v2 reduce kernel. - dim3 reduce_grid(num_heads, num_seqs); - int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); - - dim3 block(NUM_THREADS); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. - case 64: - LAUNCH_PAGED_ATTENTION_V2(64); - break; - case 80: - LAUNCH_PAGED_ATTENTION_V2(80); - break; - case 96: - LAUNCH_PAGED_ATTENTION_V2(96); - break; - case 112: - LAUNCH_PAGED_ATTENTION_V2(112); - break; - case 128: - LAUNCH_PAGED_ATTENTION_V2(128); - break; - case 256: - LAUNCH_PAGED_ATTENTION_V2(256); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - head_mapping, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len, \ - alibi_slopes); - -// NOTE(woosuk): To reduce the compilation time, we omitted block sizes -// 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER(T, 8); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER(T, 16); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER(T, 32); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } - -void paged_attention_v2( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& head_mapping, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] - int block_size, - int max_context_len, - const c10::optional& alibi_slopes) { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } -} - -#undef WARP_SIZE -#undef MAX -#undef MIN -#undef DIVIDE_ROUND_UP From 5f9ca44cfc487fb5cc708e7b28c4bba5fad39b02 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 14 Dec 2023 17:34:13 +0000 Subject: [PATCH 7/7] pragma once --- csrc/cache.h | 2 ++ csrc/cuda_utils.h | 2 ++ csrc/dispatch_utils.h | 2 ++ csrc/ops.h | 6 +----- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index da49d9103214b..b26faad2ca814 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -1,3 +1,5 @@ +#pragma once + #include #include diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 85cb199b9aa0c..69c96cef0d17e 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -1,3 +1,5 @@ +#pragma once + #include int get_device_attribute( diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 7c0c49d392a98..0ae9cd6415982 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -2,6 +2,8 @@ * Adapted from * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h */ +#pragma once + #include #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ diff --git a/csrc/ops.h b/csrc/ops.h index fb6da55fa71a5..5ded1c1286e4e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -1,6 +1,4 @@ -//avoid multiple redefinition -#ifndef OPS_H -#define OPS_H +#pragma once #include @@ -79,5 +77,3 @@ void squeezellm_gemm( torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table); - -#endif // OPS_H