From 8be5980e78940ef5f40f814c96efcb62c8539fb5 Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 9 Apr 2024 23:10:40 -0700 Subject: [PATCH 1/3] done --- .../basic_correctness/test_basic_correctness.py | 4 ++++ vllm/attention/selector.py | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 97cff623c5e1d..6842f964f5cfc 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -14,6 +14,7 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("attn_backend", ["XFORMER", "FLASH"]) def test_models( hf_runner, vllm_runner, @@ -22,7 +23,10 @@ def test_models( dtype: str, max_tokens: int, enforce_eager: bool, + attn_backend: str, + monkeypatch, ) -> None: + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 4c699aed48d49..4f4dede91c6dc 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,4 +1,5 @@ import enum +import os from functools import lru_cache from typing import Type @@ -10,6 +11,8 @@ logger = init_logger(__name__) +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" + class _Backend(enum.Enum): FLASH_ATTN = enum.auto() @@ -75,4 +78,15 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: "Cannot use FlashAttention backend because the flash_attn package " "is not found. Please install it for better performance.") return _Backend.XFORMERS - return _Backend.FLASH_ATTN + + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + if backend_by_env_var == "XFORMER": + return _Backend.XFORMERS + elif backend_by_env_var == "FLASH": + return _Backend.FLASH_ATTN + elif backend_by_env_var is None: + # Default case. + return _Backend.FLASH_ATTN + else: + raise AssertionError( + f"{VLLM_ATTENTION_BACKEND}={backend_by_env_var} is not supported.") From 5162bfd2b31f041c03f0d44012a0175d6799b9df Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 10 Apr 2024 17:51:38 -0700 Subject: [PATCH 2/3] fixed --- tests/basic_correctness/test_basic_correctness.py | 6 ++++-- vllm/attention/selector.py | 13 ++++--------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 6842f964f5cfc..bd4c7ea3301be 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -4,6 +4,8 @@ """ import pytest +from vllm.attention.selector import VLLM_ATTENTION_BACKEND + MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -14,7 +16,7 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False, True]) -@pytest.mark.parametrize("attn_backend", ["XFORMER", "FLASH"]) +@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"]) def test_models( hf_runner, vllm_runner, @@ -26,7 +28,7 @@ def test_models( attn_backend: str, monkeypatch, ) -> None: - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend) hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 4f4dede91c6dc..8f378934614aa 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -80,13 +80,8 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: return _Backend.XFORMERS backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) - if backend_by_env_var == "XFORMER": - return _Backend.XFORMERS - elif backend_by_env_var == "FLASH": - return _Backend.FLASH_ATTN - elif backend_by_env_var is None: + if backend_by_env_var is not None: + return _Backend[backend_by_env_var] + # Default case. - return _Backend.FLASH_ATTN - else: - raise AssertionError( - f"{VLLM_ATTENTION_BACKEND}={backend_by_env_var} is not supported.") + return _Backend.FLASH_ATTN From bad59b048562d58667eeb22b3157ff62b960e579 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Wed, 10 Apr 2024 17:55:54 -0700 Subject: [PATCH 3/3] Apply suggestions from code review --- vllm/attention/selector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 8f378934614aa..554e802cd5513 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -83,5 +83,5 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: if backend_by_env_var is not None: return _Backend[backend_by_env_var] - # Default case. + # Default case. return _Backend.FLASH_ATTN