diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index ca7967c1ab0d2..e8b7d9988d546 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -99,6 +99,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: # Using default kv_scale kv_scale = 1.0 + sliding_window = None + for _ in range(num_iters): if version == "v1": ops.paged_attention_v1( @@ -113,6 +115,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, + sliding_window, kv_cache_dtype, kv_scale, ) @@ -132,6 +135,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, + sliding_window, kv_cache_dtype, kv_scale, ) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 41b337dd91d36..e378b38a96416 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -103,6 +103,7 @@ __device__ void paged_attention_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const int sliding_window, const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -129,6 +130,7 @@ __device__ void paged_attention_kernel( const int start_token_idx = start_block_idx * BLOCK_SIZE; const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int num_tokens = end_token_idx - start_token_idx; + const int skip_tokens = seq_len > sliding_window ? seq_len - sliding_window : 0; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS @@ -237,7 +239,7 @@ __device__ void paged_attention_kernel( if (thread_group_offset == 0) { // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. - const bool mask = token_idx >= seq_len; + const bool mask = token_idx >= seq_len || token_idx < skip_tokens; logits[token_idx - start_token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); @@ -437,6 +439,7 @@ __global__ void paged_attention_v1_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const int sliding_window, const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -444,7 +447,8 @@ __global__ void paged_attention_v1_kernel( paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); + max_num_blocks_per_seq, alibi_slopes, sliding_window, + q_stride, kv_block_stride, kv_head_stride, kv_scale); } // Grid: (num_heads, num_seqs, max_num_partitions). @@ -469,13 +473,14 @@ __global__ void paged_attention_v2_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const int sliding_window, const int q_stride, const int kv_block_stride, const int kv_head_stride, const float kv_scale) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, sliding_window, q_stride, kv_block_stride, kv_head_stride, kv_scale); } @@ -596,6 +601,7 @@ __global__ void paged_attention_v2_reduce_kernel( seq_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ + sliding_window, \ q_stride, \ kv_block_stride, \ kv_head_stride, \ @@ -619,6 +625,7 @@ void paged_attention_v1_launcher( torch::Tensor& seq_lens, int max_seq_len, const c10::optional& alibi_slopes, + int sliding_window, float kv_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); @@ -695,6 +702,7 @@ void paged_attention_v1_launcher( seq_lens, \ max_seq_len, \ alibi_slopes, \ + sliding_window, \ kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes @@ -727,6 +735,7 @@ void paged_attention_v1( int block_size, int max_seq_len, const c10::optional& alibi_slopes, + int sliding_window, const std::string& kv_cache_dtype, float kv_scale) { @@ -749,6 +758,7 @@ void paged_attention_v1( seq_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ + sliding_window, \ q_stride, \ kv_block_stride, \ kv_head_stride, \ @@ -783,6 +793,7 @@ void paged_attention_v2_launcher( torch::Tensor& seq_lens, int max_seq_len, const c10::optional& alibi_slopes, + int sliding_window, float kv_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); @@ -868,6 +879,7 @@ void paged_attention_v2_launcher( seq_lens, \ max_seq_len, \ alibi_slopes, \ + sliding_window, \ kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes @@ -903,6 +915,7 @@ void paged_attention_v2( int block_size, int max_seq_len, const c10::optional& alibi_slopes, + int sliding_window, const std::string& kv_cache_dtype, float kv_scale) { DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) diff --git a/csrc/ops.h b/csrc/ops.h index 9541adcb3de88..1f988acd132b9 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -14,6 +14,7 @@ void paged_attention_v1( int block_size, int max_seq_len, const c10::optional& alibi_slopes, + int sliding_window, const std::string& kv_cache_dtype, float kv_scale); @@ -32,6 +33,7 @@ void paged_attention_v2( int block_size, int max_seq_len, const c10::optional& alibi_slopes, + int sliding_window, const std::string& kv_cache_dtype, float kv_scale); diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index df0329f79d361..783438beb9d35 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -543,6 +543,7 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { // data type of the key and value cache. The FN is a macro that calls a function // with template. #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (sliding_window <= 0) sliding_window = INT32_MAX; \ if (KV_DTYPE == "auto") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 28496f187d466..e42c934a77c7d 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -175,6 +175,8 @@ def test_paged_attention( # Using default kv_scale kv_scale = 1.0 + sliding_window = None + # Call the paged attention kernel. output = torch.empty_like(query) if version == "v1": @@ -190,6 +192,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + sliding_window, kv_cache_dtype, kv_scale, ) @@ -221,6 +224,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + sliding_window, kv_cache_dtype, kv_scale, ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 42dedfdf76c4f..49fa653066feb 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -43,13 +43,16 @@ def paged_attention_v1( block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], kv_cache_dtype: str, kv_scale: float, ) -> None: + if sliding_window is None: + sliding_window = -1 vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, kv_scale) + sliding_window, kv_cache_dtype, kv_scale) def paged_attention_v2( @@ -67,14 +70,17 @@ def paged_attention_v2( block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], kv_cache_dtype: str, kv_scale: float, ) -> None: + if sliding_window is None: + sliding_window = -1 vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, - max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale) + max_seq_len, alibi_slopes, sliding_window, + kv_cache_dtype, kv_scale) # pos encoding ops diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4bad226512b69..1ef0022d9e44e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -268,6 +268,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + self.sliding_window[0], kv_scale, ) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 8fc1af1aa1e1c..07297653eb2e5 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -327,6 +327,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + self.sliding_window[0], kv_scale, ) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c29218dfd0cfc..3ea9f9670ce9f 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -199,6 +199,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + self.sliding_window, kv_scale, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 2a9150dea5875..e92523dd47101 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -266,6 +266,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + self.sliding_window, kv_scale, ) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 3c010b67b3120..e6618f854c7fe 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -90,6 +90,7 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], kv_scale: float, ) -> torch.Tensor: output = torch.empty_like(query) @@ -121,6 +122,7 @@ def forward_decode( block_size, max_seq_len, alibi_slopes, + sliding_window, kv_cache_dtype, kv_scale, ) @@ -153,6 +155,7 @@ def forward_decode( block_size, max_seq_len, alibi_slopes, + sliding_window, kv_cache_dtype, kv_scale, )