Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] sliding window support in paged_attention_v1/v2 kernels #4768

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
# Using default kv_scale
kv_scale = 1.0

sliding_window = None

for _ in range(num_iters):
if version == "v1":
ops.paged_attention_v1(
Expand All @@ -113,6 +115,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
block_size,
max_seq_len,
alibi_slopes,
sliding_window,
kv_cache_dtype,
kv_scale,
)
Expand All @@ -132,6 +135,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
block_size,
max_seq_len,
alibi_slopes,
sliding_window,
kv_cache_dtype,
kv_scale,
)
Expand Down
19 changes: 16 additions & 3 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ __device__ void paged_attention_kernel(
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int sliding_window,
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
Expand All @@ -129,6 +130,7 @@ __device__ void paged_attention_kernel(
const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx;
const int skip_tokens = seq_len > sliding_window ? seq_len - sliding_window : 0;

constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
Expand Down Expand Up @@ -237,7 +239,7 @@ __device__ void paged_attention_kernel(
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;
const bool mask = token_idx >= seq_len || token_idx < skip_tokens;
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
Expand Down Expand Up @@ -437,14 +439,16 @@ __global__ void paged_attention_v1_kernel(
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int sliding_window,
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
max_num_blocks_per_seq, alibi_slopes, sliding_window,
q_stride, kv_block_stride, kv_head_stride, kv_scale);
}

// Grid: (num_heads, num_seqs, max_num_partitions).
Expand All @@ -469,13 +473,14 @@ __global__ void paged_attention_v2_kernel(
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int sliding_window,
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, sliding_window,
q_stride, kv_block_stride, kv_head_stride, kv_scale);
}

Expand Down Expand Up @@ -596,6 +601,7 @@ __global__ void paged_attention_v2_reduce_kernel(
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
sliding_window, \
q_stride, \
kv_block_stride, \
kv_head_stride, \
Expand All @@ -619,6 +625,7 @@ void paged_attention_v1_launcher(
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
int sliding_window,
float kv_scale) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
Expand Down Expand Up @@ -695,6 +702,7 @@ void paged_attention_v1_launcher(
seq_lens, \
max_seq_len, \
alibi_slopes, \
sliding_window, \
kv_scale);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
Expand Down Expand Up @@ -727,6 +735,7 @@ void paged_attention_v1(
int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
int sliding_window,
const std::string& kv_cache_dtype,
float kv_scale) {

Expand All @@ -749,6 +758,7 @@ void paged_attention_v1(
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
sliding_window, \
q_stride, \
kv_block_stride, \
kv_head_stride, \
Expand Down Expand Up @@ -783,6 +793,7 @@ void paged_attention_v2_launcher(
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
int sliding_window,
float kv_scale) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
Expand Down Expand Up @@ -868,6 +879,7 @@ void paged_attention_v2_launcher(
seq_lens, \
max_seq_len, \
alibi_slopes, \
sliding_window, \
kv_scale);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
Expand Down Expand Up @@ -903,6 +915,7 @@ void paged_attention_v2(
int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
int sliding_window,
const std::string& kv_cache_dtype,
float kv_scale) {
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE)
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ void paged_attention_v1(
int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
int sliding_window,
const std::string& kv_cache_dtype,
float kv_scale);

Expand All @@ -32,6 +33,7 @@ void paged_attention_v2(
int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
int sliding_window,
const std::string& kv_cache_dtype,
float kv_scale);

Expand Down
1 change: 1 addition & 0 deletions csrc/quantization/fp8/amd/quant_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
// data type of the key and value cache. The FN is a macro that calls a function
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (sliding_window <= 0) sliding_window = INT32_MAX; \
if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
Expand Down
4 changes: 4 additions & 0 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def test_paged_attention(
# Using default kv_scale
kv_scale = 1.0

sliding_window = None

# Call the paged attention kernel.
output = torch.empty_like(query)
if version == "v1":
Expand All @@ -190,6 +192,7 @@ def test_paged_attention(
block_size,
max_seq_len,
alibi_slopes,
sliding_window,
kv_cache_dtype,
kv_scale,
)
Expand Down Expand Up @@ -221,6 +224,7 @@ def test_paged_attention(
block_size,
max_seq_len,
alibi_slopes,
sliding_window,
kv_cache_dtype,
kv_scale,
)
Expand Down
12 changes: 9 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ def paged_attention_v1(
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
kv_cache_dtype: str,
kv_scale: float,
) -> None:
if sliding_window is None:
sliding_window = -1
vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, kv_scale)
sliding_window, kv_cache_dtype, kv_scale)


def paged_attention_v2(
Expand All @@ -67,14 +70,17 @@ def paged_attention_v2(
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
kv_cache_dtype: str,
kv_scale: float,
) -> None:
if sliding_window is None:
sliding_window = -1
vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size,
max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale)
max_seq_len, alibi_slopes, sliding_window,
kv_cache_dtype, kv_scale)


# pos encoding ops
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
self.sliding_window[0],
kv_scale,
)

Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
self.sliding_window[0],
kv_scale,
)

Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
self.sliding_window,
kv_scale,
)

Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
self.sliding_window,
kv_scale,
)

Expand Down
3 changes: 3 additions & 0 deletions vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def forward_decode(
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
kv_scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)
Expand Down Expand Up @@ -121,6 +122,7 @@ def forward_decode(
block_size,
max_seq_len,
alibi_slopes,
sliding_window,
kv_cache_dtype,
kv_scale,
)
Expand Down Expand Up @@ -153,6 +155,7 @@ def forward_decode(
block_size,
max_seq_len,
alibi_slopes,
sliding_window,
kv_cache_dtype,
kv_scale,
)
Expand Down
Loading