Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[platform] Add verify_quantization in platform. #10757

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 1 addition & 27 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,17 +393,11 @@ def _parse_quant_hf_config(self):

def _verify_quantization(self) -> None:
supported_quantization = QUANTIZATION_METHODS
rocm_supported_quantization = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8", "gguf"
]
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
"compressed-tensors", "experts_int8"
]
tpu_supported_quantization = ["tpu_int8"]
neuron_supported_quantization = ["neuron_quant"]
if self.quantization is not None:
self.quantization = self.quantization.lower()

Expand Down Expand Up @@ -438,32 +432,12 @@ def _verify_quantization(self) -> None:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
if current_platform.is_rocm(
) and self.quantization not in rocm_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
if current_platform.is_tpu(
) and self.quantization not in tpu_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in TPU Backend.")
current_platform.verify_quantization(self.quantization)
if self.quantization not in optimized_quantization_methods:
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.", self.quantization)
if (self.quantization == "awq" and current_platform.is_rocm()
and not envs.VLLM_USE_TRITON_AWQ):
logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True
if current_platform.is_neuron(
) and self.quantization not in neuron_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in Neuron Backend.")

def _verify_cuda_graph(self) -> None:
if self.max_seq_len_to_capture is None:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

class CpuPlatform(Platform):
_enum = PlatformEnum.CPU
device_name: str = "cpu"
device_type: str = "cpu"
dispatch_key: str = "CPU"

Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:

class CudaPlatformBase(Platform):
_enum = PlatformEnum.CUDA
device_name: str = "cuda"
device_type: str = "cuda"
dispatch_key: str = "CUDA"

Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

class HpuPlatform(Platform):
_enum = PlatformEnum.HPU
device_name: str = "hpu"
device_type: str = "hpu"
dispatch_key: str = "HPU"

Expand Down
13 changes: 13 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@ def to_int(self) -> int:

class Platform:
_enum: PlatformEnum
device_name: str
device_type: str
# available dispatch keys:
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
# use "CPU" as a fallback for platforms not registered in PyTorch
dispatch_key: str = "CPU"
supported_quantization: list[str] = []

def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA
Expand Down Expand Up @@ -171,6 +173,17 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"""
pass

@classmethod
def verify_quantization(cls, quant: str) -> None:
"""
Verify whether the quantization is supported by the current platform.
"""
if cls.supported_quantization and \
quant not in cls.supported_quantization:
raise ValueError(
f"{quant} quantization is currently not supported in "
f"{cls.device_name}.")


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
2 changes: 2 additions & 0 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

class NeuronPlatform(Platform):
_enum = PlatformEnum.NEURON
device_name: str = "neuron"
device_type: str = "neuron"
supported_quantization: list[str] = ["neuron_quant"]

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO
device_name: str = "openvino"
device_type: str = "openvino"
dispatch_key: str = "CPU"

Expand Down
15 changes: 15 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

import vllm.envs as envs
from vllm.logger import init_logger

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
Expand Down Expand Up @@ -35,8 +36,13 @@

class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
device_name: str = "rocm"
device_type: str = "cuda"
dispatch_key: str = "CUDA"
supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8", "gguf"
]

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
Expand Down Expand Up @@ -79,3 +85,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

@classmethod
def verify_quantization(cls, quant: str) -> None:
super().verify_quantization(quant)
if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True
2 changes: 2 additions & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
device_name: str = "tpu"
device_type: str = "tpu"
dispatch_key: str = "XLA"
supported_quantization: list[str] = ["tpu_int8"]

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

class XPUPlatform(Platform):
_enum = PlatformEnum.XPU
device_name: str = "xpu"
device_type: str = "xpu"
dispatch_key: str = "XPU"

Expand Down