Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][BugFix] Raise error when selected attn backend is not supported #13730

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,20 @@ def _cached_get_attn_backend(
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)

if selected_backend is not None and use_v1:
if (selected_backend in (_Backend.FLASH_ATTN, _Backend.FLASHINFER,
_Backend.XFORMERS)):
raise ValueError(
f"{selected_backend.name} is not compatible with vLLM V1. "
"Please either do `export VLLM_ATTENTION_BACKEND="
f"{_Backend.FLASH_ATTN_VLLM_V1.name}` or unset it to use "
"the default backend.")
elif selected_backend not in _Backend.get_v1_backends():
raise ValueError(
f"{selected_backend.name} attention backend is not compatible "
"with vLLM V1. Please use a different backend or unset the "
"VLLM_ATTENTION_BACKEND env variable.")

# get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class _Backend(enum.Enum):
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()

@classmethod
def get_v1_backends(cls) -> Tuple["_Backend", ...]:
return (cls.FLASH_ATTN_VLLM_V1, cls.ROCM_FLASH, cls.PALLAS_VLLM_V1)


class PlatformEnum(enum.Enum):
CUDA = enum.auto()
Expand Down