Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Platform] Do not raise error if _Backend is not found #12023

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ def test_flash_attn(monkeypatch):


def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
"""Ignore the invalid env variable if it is set."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
get_attn_backend(16, torch.float16, None, 16, False)
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
backend = get_attn_backend(32, torch.float16, None, 16, False)
assert backend.get_name() == "FLASH_ATTN"

# when block size == 16, backend will fall back to XFORMERS
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() == "XFORMERS"
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from vllm.attention.backends.flash_attn import FlashAttentionBackend


class DummyAttentionBackend(FlashAttentionBackend):

@staticmethod
def get_name() -> str:
return "Dummy_Backend"
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@

class DummyPlatform(CudaPlatform):
device_name = "DummyDevice"

def get_attn_backend_cls(self, backend_name, head_size, dtype,
kv_cache_dtype, block_size, use_v1):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
14 changes: 14 additions & 0 deletions tests/plugins_tests/test_platform_plugins.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
import torch

from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import get_attn_backend
from vllm.utils import STR_INVALID_VAL


def test_platform_plugins():
# simulate workload by running an example
import runpy
Expand All @@ -14,3 +21,10 @@ def test_platform_plugins():
f"Expected DummyDevice, got {current_platform.device_name}, "
"possibly because current_platform is imported before the plugin"
f" is loaded. The first import:\n{_init_trace}")


def test_oot_attention_backend(monkeypatch):
# ignore the backend env variable if it is set
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert backend.get_name() == "Dummy_Backend"
8 changes: 4 additions & 4 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ def __init__(
kv_cache_dtype=None,
block_size=16,
is_attention_free=False)
attn_backend = backend_name_to_enum(attn_backend.get_name())
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
attn_backend = _Backend.XFORMERS
backend = backend_name_to_enum(attn_backend.get_name())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename here to make lint check happy

if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
backend = _Backend.XFORMERS

self.attn_backend = attn_backend if attn_backend in {
self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS
} else _Backend.TORCH_SDPA

Expand Down
20 changes: 11 additions & 9 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@
logger = init_logger(__name__)


def backend_name_to_enum(backend_name: str) -> _Backend:
assert backend_name is not None

backend_members = _Backend.__members__
if backend_name not in backend_members:
raise ValueError(f"Invalid attention backend '{backend_name}'. "
f"Available backends: {', '.join(backend_members)} "
"(case-sensitive).")
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
"""
Convert a string backend name to a _Backend enum value.

return _Backend[backend_name]
Returns:
* _Backend: enum value if backend_name is a valid in-tree type
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
loaded.
"""
assert backend_name is not None
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
None


def get_env_variable_attn_backend() -> Optional[_Backend]:
Expand Down
Loading