From 6a5a4e50fc9b6deb5b9878e833cb3a6e5ba77154 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Mon, 20 Jan 2025 17:43:21 +0800 Subject: [PATCH 1/2] [Misc] Pass `attention` to impl backend Signed-off-by: wangxiyuan --- vllm/attention/backends/abstract.py | 3 +-- vllm/attention/backends/blocksparse_attn.py | 11 +++++------ vllm/attention/backends/flash_attn.py | 9 ++++----- vllm/attention/backends/flashinfer.py | 15 +++++++-------- vllm/attention/backends/hpu_attn.py | 3 +-- vllm/attention/backends/ipex_attn.py | 17 ++++++++--------- vllm/attention/backends/pallas.py | 5 ++--- vllm/attention/backends/rocm_flash_attn.py | 19 +++++++++---------- vllm/attention/backends/torch_sdpa.py | 17 +++++++---------- vllm/attention/backends/xformers.py | 19 ++++++++----------- vllm/attention/layer.py | 8 +++----- vllm/v1/attention/backends/flash_attn.py | 9 ++++----- 12 files changed, 59 insertions(+), 76 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 737559bfe70ca..8ce50751d27f9 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -244,13 +244,12 @@ def __init__( @abstractmethod def forward( self, + layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 77cfa8490172b..603de740a6e7d 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -358,13 +358,12 @@ def __init__( def forward( self, + layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: BlocksparseFlashAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -401,8 +400,8 @@ def forward( value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) if prefill_meta := attn_metadata.prefill_metadata: @@ -439,8 +438,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, tp_rank=self.tp_rank, blocksparse_local_blocks=self.local_blocks, blocksparse_vert_stride=self.vert_stride, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 48b3e8d177ec9..3d566c3c6c5b7 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -634,13 +634,12 @@ def __init__( def forward( self, + layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -657,7 +656,7 @@ def forward( NOTE: It in-place updates the output tensor. """ # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert k_scale == 1.0 and v_scale == 1.0, ( + assert layer._k_scale == 1.0 and layer._v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") assert output is not None, "Output tensor must be provided." @@ -709,8 +708,8 @@ def forward( kv_cache[1], updated_slot_mapping.flatten(), # type: ignore[union-attr] kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) (num_prefill_query_tokens, num_prefill_kv_tokens, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 6ca75fabdfc38..03e0ae4487639 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -792,13 +792,12 @@ def __init__( def forward( self, + layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -826,8 +825,8 @@ def forward( kv_cache[:, 1], attn_metadata.slot_mapping.flatten(), kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 # to process the cache when the kv_cache_dtype is fp8 @@ -886,8 +885,8 @@ def forward( kv_cache, logits_soft_cap=logits_soft_cap, causal=True, - k_scale=k_scale, - v_scale=v_scale, + k_scale=layer._k_scale, + v_scale=layer._v_scale, window_left=window_left) if decode_meta := attn_metadata.decode_metadata: assert decode_meta is not None @@ -897,8 +896,8 @@ def forward( kv_cache, sm_scale=softmax_scale, logits_soft_cap=logits_soft_cap, - k_scale=k_scale, - v_scale=v_scale, + k_scale=layer._k_scale, + v_scale=layer._v_scale, window_left=window_left) if prefill_output is None and decode_output is not None: diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 94a461e0c8c29..ef35a324d82c5 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -152,13 +152,12 @@ def __init__( def forward( self, + layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: HPUAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index da1d307daa517..fdd419c3dcf8b 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -171,13 +171,12 @@ def split_kv_cache( def forward( self, + layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: IpexAttnMetadata, # type: ignore - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. @@ -193,7 +192,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert k_scale == 1.0 and v_scale == 1.0 + assert layer._k_scale == 1.0 and layer._v_scale == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -210,8 +209,8 @@ def forward( value_cache, attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) if attn_metadata.is_prompt: @@ -296,8 +295,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) else: # Run PagedAttention V2. @@ -329,8 +328,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 2ac492dd8ae54..bfca3c06c4942 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -150,13 +150,12 @@ def __init__( def forward( self, + layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], attn_metadata: PallasMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -173,7 +172,7 @@ def forward( Returns: shape = [batch_size, seq_len, num_heads * head_size] """ - assert k_scale == 1.0 and v_scale == 1.0 + assert layer._k_scale == 1.0 and layer._v_scale == 1.0 batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index a91a5af5c3d58..4c62b3aa25017 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -414,13 +414,12 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: def forward( self, + layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: ROCmFlashAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -458,8 +457,8 @@ def forward( value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) num_prefill_tokens = attn_metadata.num_prefill_tokens @@ -567,8 +566,8 @@ def forward( prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window[0], - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) if decode_meta := attn_metadata.decode_metadata: @@ -613,8 +612,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) else: output[num_prefill_tokens:] = PagedAttention.forward_decode( @@ -628,8 +627,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index ca1c4618615de..38fafd990e941 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -429,13 +429,12 @@ def __init__( def forward( self, + layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: TorchSDPAMetadata, # type: ignore - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -451,7 +450,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert k_scale == 1.0 and v_scale == 1.0 + assert layer._k_scale == 1.0 and layer._v_scale == 1.0 attn_type = self.attn_type if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): @@ -493,11 +492,9 @@ def forward( # Update self-attention KV cache (prefill/decode) updated_slot_mapping = attn_metadata.slot_mapping - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - updated_slot_mapping, - self.kv_cache_dtype, - k_scale, v_scale) + PagedAttention.write_to_paged_cache( + key, value, key_cache, value_cache, updated_slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale) if attn_type != AttentionType.ENCODER: # Decoder self-attention supports chunked prefill. @@ -571,8 +568,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 8c8ca8520a9db..a1cd2aefe015a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -412,13 +412,12 @@ def __init__( def forward( self, + layer: torch.nn.Module, query: torch.Tensor, key: Optional[torch.Tensor], value: Optional[torch.Tensor], kv_cache: torch.Tensor, attn_metadata: "XFormersMetadata", - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -524,11 +523,9 @@ def forward( # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory # profiling run. - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - updated_slot_mapping, - self.kv_cache_dtype, - k_scale, v_scale) + PagedAttention.write_to_paged_cache( + key, value, key_cache, value_cache, updated_slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale) (num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens) = \ get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) @@ -580,8 +577,8 @@ def forward( prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) assert output[:num_prefill_query_tokens].shape == out.shape output[:num_prefill_query_tokens] = out @@ -607,8 +604,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index e2403306950a3..c36f8d08eb4a7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -243,8 +243,7 @@ def unified_attention( attn_metadata = forward_context.attn_metadata self = forward_context.attn_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - return self.impl.forward(query, key, value, kv_cache, attn_metadata, - self._k_scale, self._v_scale) + return self.impl.forward(self, query, key, value, kv_cache, attn_metadata) def unified_attention_fake( @@ -276,13 +275,12 @@ def unified_attention_with_output( attn_metadata = forward_context.attn_metadata self = forward_context.attn_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - self.impl.forward(query, + self.impl.forward(self, + query, key, value, kv_cache, attn_metadata, - self._k_scale, - self._v_scale, output=output) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 7b0786261a6a6..fd36ea8d8806b 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -130,13 +130,12 @@ def __init__( def forward( self, + layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -151,7 +150,7 @@ def forward( shape = [num_tokens, num_heads * head_size] """ # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert k_scale == 1.0 and v_scale == 1.0, ( + assert layer._k_scale == 1.0 and layer._v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") assert output is not None, "Output tensor must be provided." @@ -183,8 +182,8 @@ def forward( value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, - k_scale, - v_scale, + layer._k_scale, + layer._v_scale, ) # Compute attention and update output up to `num_actual_tokens`. From 1cb9ccdb39d73f26039f81431666cb1ae6cf231f Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Mon, 20 Jan 2025 20:38:06 +0800 Subject: [PATCH 2/2] Add Attention Interface Signed-off-by: wangxiyuan --- vllm/attention/backends/abstract.py | 22 ++++++++++++++++++--- vllm/attention/backends/blocksparse_attn.py | 3 ++- vllm/attention/backends/flash_attn.py | 3 ++- vllm/attention/backends/flashinfer.py | 3 ++- vllm/attention/backends/hpu_attn.py | 3 ++- vllm/attention/backends/ipex_attn.py | 3 ++- vllm/attention/backends/pallas.py | 3 ++- vllm/attention/backends/rocm_flash_attn.py | 3 ++- vllm/attention/backends/torch_sdpa.py | 3 ++- vllm/attention/backends/xformers.py | 3 ++- 10 files changed, 37 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 8ce50751d27f9..e6ddca69bf01b 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import dataclass, fields -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, - Tuple, Type, TypeVar) +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, + Protocol, Set, Tuple, Type, TypeVar) import torch @@ -223,6 +223,22 @@ def build(self, seq_lens: List[int], query_lens: List[int], raise NotImplementedError +class AttentionLayer(Protocol): + + _k_scale: float + _v_scale: float + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + ... + + class AttentionImpl(ABC, Generic[T]): @abstractmethod @@ -244,7 +260,7 @@ def __init__( @abstractmethod def forward( self, - layer: torch.nn.Module, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 603de740a6e7d..9089db1126c94 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -4,6 +4,7 @@ import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import (CommonAttentionState, CommonMetadataBuilder) @@ -358,7 +359,7 @@ def __init__( def forward( self, - layer: torch.nn.Module, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 3d566c3c6c5b7..40250ef08b595 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -8,6 +8,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, AttentionType) @@ -634,7 +635,7 @@ def __init__( def forward( self, - layer: torch.nn.Module, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 03e0ae4487639..b9cd805e81b45 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -23,6 +23,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, AttentionState, AttentionType) @@ -792,7 +793,7 @@ def __init__( def forward( self, - layer: torch.nn.Module, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index ef35a324d82c5..80c132c0a8c05 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -11,6 +11,7 @@ from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, @@ -152,7 +153,7 @@ def __init__( def forward( self, - layer: torch.nn.Module, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index fdd419c3dcf8b..cd729a1c8b274 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -7,6 +7,7 @@ from vllm._ipex_ops import ipex_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.paged_attn import (PagedAttention, @@ -171,7 +172,7 @@ def split_kv_cache( def forward( self, - layer: torch.nn.Module, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index bfca3c06c4942..f5bf390df6afb 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -5,6 +5,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops. from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import CommonAttentionState @@ -150,7 +151,7 @@ def __init__( def forward( self, - layer: torch.nn.Module, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 4c62b3aa25017..e9f2808ff1674 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -7,6 +7,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import (CommonAttentionState, CommonMetadataBuilder) @@ -414,7 +415,7 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: def forward( self, - layer: torch.nn.Module, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 38fafd990e941..7cd2049f0c0a5 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,6 +7,7 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, AttentionType) @@ -429,7 +430,7 @@ def __init__( def forward( self, - layer: torch.nn.Module, + layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index a1cd2aefe015a..38e27434dab2c 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -10,6 +10,7 @@ LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import ( CommonAttentionState, CommonMetadataBuilder, @@ -412,7 +413,7 @@ def __init__( def forward( self, - layer: torch.nn.Module, + layer: AttentionLayer, query: torch.Tensor, key: Optional[torch.Tensor], value: Optional[torch.Tensor],