Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Model runner refactoring 1/N] Refactor attn metadata term #4518

Merged
merged 15 commits into from
May 3, 2024
25 changes: 12 additions & 13 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def main(
version: str,
num_seqs: int,
context_len: int,
seq_len: int,
num_query_heads: int,
num_kv_heads: int,
head_size: int,
Expand Down Expand Up @@ -48,12 +48,12 @@ def main(
dtype=torch.float,
device=device)

context_lens = [context_len for _ in range(num_seqs)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
seq_lens = [seq_len for _ in range(num_seqs)]
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)

# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_seqs):
block_table = [
Expand All @@ -77,8 +77,7 @@ def main(
# Prepare for the paged attention kernel.
output = torch.empty_like(query)
if version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
PARTITION_SIZE)
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size),
dtype=output.dtype,
Expand Down Expand Up @@ -110,9 +109,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
num_kv_heads,
scale,
block_tables,
context_lens,
seq_lens,
block_size,
max_context_len,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
Expand All @@ -129,9 +128,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
num_kv_heads,
scale,
block_tables,
context_lens,
seq_lens,
block_size,
max_context_len,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
Expand Down Expand Up @@ -166,7 +165,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
choices=["v1", "v2"],
default="v2")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--context-len", type=int, default=4096)
parser.add_argument("--seq_len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64)
parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size",
Expand Down Expand Up @@ -199,7 +198,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
main(
version=args.version,
num_seqs=args.batch_size,
context_len=args.context_len,
seq_len=args.seq_len,
num_query_heads=args.num_query_heads,
num_kv_heads=args.num_kv_heads,
head_size=args.head_size,
Expand Down
76 changes: 38 additions & 38 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ __device__ void paged_attention_kernel(
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
Expand All @@ -115,23 +115,23 @@ __device__ void paged_attention_kernel(
const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int context_len = context_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
// No work to do. Terminate the thread block.
return;
}

const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;

// [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx;

// [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx;

constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
Expand Down Expand Up @@ -245,12 +245,12 @@ __device__ void paged_attention_kernel(
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;

if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= context_len;
const bool mask = token_idx >= seq_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
Expand Down Expand Up @@ -364,14 +364,14 @@ __device__ void paged_attention_kernel(
} else {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
}
if (block_idx == num_context_blocks - 1) {
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs.
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
accs[i] += dot(logits_vec, v_vec);
Expand Down Expand Up @@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel(
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
Expand All @@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel(
const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
}

Expand All @@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel(
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
Expand All @@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel(
const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride, kv_scale);
}

Expand All @@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel(
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) {
const int num_heads = gridDim.x;
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out.
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
Expand Down Expand Up @@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel(
num_kv_heads, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
Expand All @@ -639,8 +639,8 @@ void paged_attention_v1_launcher(
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0);
Expand All @@ -664,11 +664,11 @@ void paged_attention_v1_launcher(
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_context_len * sizeof(float);
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
Expand Down Expand Up @@ -715,8 +715,8 @@ void paged_attention_v1_launcher(
num_kv_heads, \
scale, \
block_tables, \
context_lens, \
max_context_len, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
kv_scale);

Expand Down Expand Up @@ -746,9 +746,9 @@ void paged_attention_v1(
int num_kv_heads, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
torch::Tensor& seq_lens, // [num_seqs]
int block_size,
int max_context_len,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale) {
Expand Down Expand Up @@ -790,7 +790,7 @@ void paged_attention_v1(
num_kv_heads, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
Expand All @@ -803,7 +803,7 @@ void paged_attention_v1(
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
context_lens_ptr, \
seq_lens_ptr, \
max_num_partitions);

template<
Expand All @@ -824,8 +824,8 @@ void paged_attention_v2_launcher(
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0);
Expand All @@ -852,10 +852,10 @@ void paged_attention_v2_launcher(
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);

Expand Down Expand Up @@ -909,8 +909,8 @@ void paged_attention_v2_launcher(
num_kv_heads, \
scale, \
block_tables, \
context_lens, \
max_context_len, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
kv_scale);

Expand Down Expand Up @@ -943,9 +943,9 @@ void paged_attention_v2(
int num_kv_heads, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
torch::Tensor& seq_lens, // [num_seqs]
int block_size,
int max_context_len,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale) {
Expand Down
Loading
Loading