Skip to content

Commit

Permalink
[Doc] Consistent naming of attention backends (vllm-project#9498)
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Parnell <[email protected]>
  • Loading branch information
tdoublep authored Oct 21, 2024
1 parent 969d879 commit 6f47fef
Show file tree
Hide file tree
Showing 14 changed files with 23 additions and 19 deletions.
2 changes: 1 addition & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_supported_head_sizes() -> List[int]:

@staticmethod
def get_name() -> str:
return "flash-attn"
return "FLASH_ATTN"

@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class FlashInferBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "flashinfer"
return "FLASHINFER"

@staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]:
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class IpexAttnBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "ipex-attn"
return "IPEX"

@staticmethod
def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class OpenVINOAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "openvino"
return "OPENVINO"

@staticmethod
def get_impl_cls():
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

class PallasAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "PALLAS"

@staticmethod
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/placeholder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class PlaceholderAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "placeholder-attn"
return "NO_ATTENTION"

@staticmethod
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "rocm-flash-attn"
return "ROCM_FLASH"

@staticmethod
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TorchSDPABackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "torch-sdpa"
return "TORCH_SDPA"

@staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
Expand Down
12 changes: 6 additions & 6 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ def graph_capture_get_metadata_for_batch(
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
Expand All @@ -337,8 +337,8 @@ def get_graph_input_buffers(
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
Expand All @@ -356,8 +356,8 @@ def prepare_graph_input_buffers(
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class XFormersBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "xformers"
return "XFORMERS"

@staticmethod
def get_impl_cls() -> Type["XFormersImpl"]:
Expand Down
2 changes: 1 addition & 1 deletion vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
return False

# TODO: Add support for other attn backends
if self.attn_backend.get_name() != "flash-attn":
if self.attn_backend.get_name() != "FLASH_ATTN":
return False

# TODO: Add support for LORA
Expand Down
2 changes: 1 addition & 1 deletion vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def create_worker(

if not disable_mqa_scorer:
if scorer_worker.model_runner.attn_backend.get_name(
) != "flash-attn":
) != "FLASH_ATTN":
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,7 +1855,7 @@ def forward(
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True)

if self.backend_name != "placeholder-attn":
if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_(
attn_metadata.slot_mapping, non_blocking=True)

Expand Down
4 changes: 2 additions & 2 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

logger = init_logger(__name__)

MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["flash-attn"]
MULTI_STEP_ATTENTION_BACKENDS = ["FLASH_ATTN", "ROCM_FLASH", "FLASHINFER"]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"]

def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
-> List[str]:
Expand Down

0 comments on commit 6f47fef

Please sign in to comment.