Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Pass attention to impl backend #12218

Merged
merged 2 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar)
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
Protocol, Set, Tuple, Type, TypeVar)

import torch

Expand Down Expand Up @@ -223,6 +223,22 @@ def build(self, seq_lens: List[int], query_lens: List[int],
raise NotImplementedError


class AttentionLayer(Protocol):

_k_scale: float
_v_scale: float

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
...


class AttentionImpl(ABC, Generic[T]):

@abstractmethod
Expand All @@ -244,13 +260,12 @@ def __init__(
@abstractmethod
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
12 changes: 6 additions & 6 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
Expand Down Expand Up @@ -358,13 +359,12 @@ def __init__(

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down Expand Up @@ -401,8 +401,8 @@ def forward(
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

if prefill_meta := attn_metadata.prefill_metadata:
Expand Down Expand Up @@ -439,8 +439,8 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
tp_rank=self.tp_rank,
blocksparse_local_blocks=self.local_blocks,
blocksparse_vert_stride=self.vert_stride,
Expand Down
10 changes: 5 additions & 5 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
Expand Down Expand Up @@ -634,13 +635,12 @@ def __init__(

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Expand All @@ -657,7 +657,7 @@ def forward(
NOTE: It in-place updates the output tensor.
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

assert output is not None, "Output tensor must be provided."
Expand Down Expand Up @@ -709,8 +709,8 @@ def forward(
kv_cache[1],
updated_slot_mapping.flatten(), # type: ignore[union-attr]
kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

(num_prefill_query_tokens, num_prefill_kv_tokens,
Expand Down
16 changes: 8 additions & 8 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
Expand Down Expand Up @@ -792,13 +793,12 @@ def __init__(

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:

Expand Down Expand Up @@ -826,8 +826,8 @@ def forward(
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
Expand Down Expand Up @@ -886,8 +886,8 @@ def forward(
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=k_scale,
v_scale=v_scale,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
window_left=window_left)
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
Expand All @@ -897,8 +897,8 @@ def forward(
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
window_left=window_left)

if prefill_output is None and decode_output is not None:
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
Expand Down Expand Up @@ -152,13 +153,12 @@ def __init__(

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
18 changes: 9 additions & 9 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import (PagedAttention,
Expand Down Expand Up @@ -171,13 +172,12 @@ def split_kv_cache(

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Expand All @@ -193,7 +193,7 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
Expand All @@ -210,8 +210,8 @@ def forward(
value_cache,
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

if attn_metadata.is_prompt:
Expand Down Expand Up @@ -296,8 +296,8 @@ def forward(
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
else:
# Run PagedAttention V2.
Expand Down Expand Up @@ -329,8 +329,8 @@ def forward(
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

# Reshape the output tensor.
Expand Down
6 changes: 3 additions & 3 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch_xla.experimental.custom_kernel # Required to register custom ops.

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState

Expand Down Expand Up @@ -150,13 +151,12 @@ def __init__(

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Expand All @@ -173,7 +173,7 @@ def forward(
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
Expand Down
20 changes: 10 additions & 10 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
Expand Down Expand Up @@ -414,13 +415,12 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down Expand Up @@ -458,8 +458,8 @@ def forward(
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

num_prefill_tokens = attn_metadata.num_prefill_tokens
Expand Down Expand Up @@ -567,8 +567,8 @@ def forward(
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

if decode_meta := attn_metadata.decode_metadata:
Expand Down Expand Up @@ -613,8 +613,8 @@ def forward(
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
else:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
Expand All @@ -628,8 +628,8 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

# Reshape the output tensor.
Expand Down
Loading
Loading