diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 737559bfe70ca..e6ddca69bf01b 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import dataclass, fields -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, - Tuple, Type, TypeVar) +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, + Protocol, Set, Tuple, Type, TypeVar) import torch @@ -223,6 +223,22 @@ def build(self, seq_lens: List[int], query_lens: List[int], raise NotImplementedError +class AttentionLayer(Protocol): + + _k_scale: float + _v_scale: float + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + ... + + class AttentionImpl(ABC, Generic[T]): @abstractmethod @@ -244,13 +260,12 @@ def __init__( @abstractmethod def forward( self, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 77cfa8490172b..9089db1126c94 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -4,6 +4,7 @@ import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import (CommonAttentionState, CommonMetadataBuilder) @@ -358,13 +359,12 @@ def __init__( def forward( self, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: BlocksparseFlashAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -401,8 +401,8 @@ def forward( value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) if prefill_meta := attn_metadata.prefill_metadata: @@ -439,8 +439,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, tp_rank=self.tp_rank, blocksparse_local_blocks=self.local_blocks, blocksparse_vert_stride=self.vert_stride, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 48b3e8d177ec9..40250ef08b595 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -8,6 +8,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, AttentionType) @@ -634,13 +635,12 @@ def __init__( def forward( self, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -657,7 +657,7 @@ def forward( NOTE: It in-place updates the output tensor. """ # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert k_scale == 1.0 and v_scale == 1.0, ( + assert layer._k_scale == 1.0 and layer._v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") assert output is not None, "Output tensor must be provided." @@ -709,8 +709,8 @@ def forward( kv_cache[1], updated_slot_mapping.flatten(), # type: ignore[union-attr] kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) (num_prefill_query_tokens, num_prefill_kv_tokens, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index d51abb29f929e..73ba542fba6dd 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -27,6 +27,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, AttentionState, AttentionType) @@ -910,13 +911,12 @@ def __init__( def forward( self, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -944,8 +944,8 @@ def forward( kv_cache[:, 1], attn_metadata.slot_mapping.flatten(), kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 # to process the cache when the kv_cache_dtype is fp8 @@ -1009,8 +1009,8 @@ def forward( prefill_output = prefill_meta.prefill_wrapper.run( query, kv_cache, - k_scale=k_scale, - v_scale=v_scale, + k_scale=layer._k_scale, + v_scale=layer._v_scale, ) if decode_meta := attn_metadata.decode_metadata: assert decode_meta is not None @@ -1024,8 +1024,8 @@ def forward( decode_output = decode_meta.decode_wrapper.run( decode_query, kv_cache, - k_scale=k_scale, - v_scale=v_scale, + k_scale=layer._k_scale, + v_scale=layer._v_scale, ) if prefill_output is None and decode_output is not None: diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 94a461e0c8c29..80c132c0a8c05 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -11,6 +11,7 @@ from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, @@ -152,13 +153,12 @@ def __init__( def forward( self, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: HPUAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index da1d307daa517..cd729a1c8b274 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -7,6 +7,7 @@ from vllm._ipex_ops import ipex_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.paged_attn import (PagedAttention, @@ -171,13 +172,12 @@ def split_kv_cache( def forward( self, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: IpexAttnMetadata, # type: ignore - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. @@ -193,7 +193,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert k_scale == 1.0 and v_scale == 1.0 + assert layer._k_scale == 1.0 and layer._v_scale == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -210,8 +210,8 @@ def forward( value_cache, attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) if attn_metadata.is_prompt: @@ -296,8 +296,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) else: # Run PagedAttention V2. @@ -329,8 +329,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 2ac492dd8ae54..f5bf390df6afb 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -5,6 +5,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops. from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import CommonAttentionState @@ -150,13 +151,12 @@ def __init__( def forward( self, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], attn_metadata: PallasMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -173,7 +173,7 @@ def forward( Returns: shape = [batch_size, seq_len, num_heads * head_size] """ - assert k_scale == 1.0 and v_scale == 1.0 + assert layer._k_scale == 1.0 and layer._v_scale == 1.0 batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index a91a5af5c3d58..e9f2808ff1674 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -7,6 +7,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import (CommonAttentionState, CommonMetadataBuilder) @@ -414,13 +415,12 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: def forward( self, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: ROCmFlashAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -458,8 +458,8 @@ def forward( value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) num_prefill_tokens = attn_metadata.num_prefill_tokens @@ -567,8 +567,8 @@ def forward( prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window[0], - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) if decode_meta := attn_metadata.decode_metadata: @@ -613,8 +613,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) else: output[num_prefill_tokens:] = PagedAttention.forward_decode( @@ -628,8 +628,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index ca1c4618615de..7cd2049f0c0a5 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,6 +7,7 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, AttentionType) @@ -429,13 +430,12 @@ def __init__( def forward( self, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: TorchSDPAMetadata, # type: ignore - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -451,7 +451,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert k_scale == 1.0 and v_scale == 1.0 + assert layer._k_scale == 1.0 and layer._v_scale == 1.0 attn_type = self.attn_type if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): @@ -493,11 +493,9 @@ def forward( # Update self-attention KV cache (prefill/decode) updated_slot_mapping = attn_metadata.slot_mapping - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - updated_slot_mapping, - self.kv_cache_dtype, - k_scale, v_scale) + PagedAttention.write_to_paged_cache( + key, value, key_cache, value_cache, updated_slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale) if attn_type != AttentionType.ENCODER: # Decoder self-attention supports chunked prefill. @@ -571,8 +569,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 8c8ca8520a9db..38e27434dab2c 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -10,6 +10,7 @@ LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import ( CommonAttentionState, CommonMetadataBuilder, @@ -412,13 +413,12 @@ def __init__( def forward( self, + layer: AttentionLayer, query: torch.Tensor, key: Optional[torch.Tensor], value: Optional[torch.Tensor], kv_cache: torch.Tensor, attn_metadata: "XFormersMetadata", - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -524,11 +524,9 @@ def forward( # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory # profiling run. - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - updated_slot_mapping, - self.kv_cache_dtype, - k_scale, v_scale) + PagedAttention.write_to_paged_cache( + key, value, key_cache, value_cache, updated_slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale) (num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens) = \ get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) @@ -580,8 +578,8 @@ def forward( prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) assert output[:num_prefill_query_tokens].shape == out.shape output[:num_prefill_query_tokens] = out @@ -607,8 +605,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index e2403306950a3..c36f8d08eb4a7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -243,8 +243,7 @@ def unified_attention( attn_metadata = forward_context.attn_metadata self = forward_context.attn_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - return self.impl.forward(query, key, value, kv_cache, attn_metadata, - self._k_scale, self._v_scale) + return self.impl.forward(self, query, key, value, kv_cache, attn_metadata) def unified_attention_fake( @@ -276,13 +275,12 @@ def unified_attention_with_output( attn_metadata = forward_context.attn_metadata self = forward_context.attn_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - self.impl.forward(query, + self.impl.forward(self, + query, key, value, kv_cache, attn_metadata, - self._k_scale, - self._v_scale, output=output) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 7b0786261a6a6..fd36ea8d8806b 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -130,13 +130,12 @@ def __init__( def forward( self, + layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -151,7 +150,7 @@ def forward( shape = [num_tokens, num_heads * head_size] """ # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert k_scale == 1.0 and v_scale == 1.0, ( + assert layer._k_scale == 1.0 and layer._v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") assert output is not None, "Output tensor must be provided." @@ -183,8 +182,8 @@ def forward( value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) # Compute attention and update output up to `num_actual_tokens`.