Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hardware][misc] introduce platform abstraction #6080

Merged
merged 5 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import torch

from vllm import _custom_ops as ops
from vllm.utils import get_device_capability_stateless
from vllm.platforms import current_platform

CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

capability = get_device_capability_stateless()
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]


Expand Down
4 changes: 2 additions & 2 deletions tests/quantization/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import torch

from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import get_device_capability_stateless
from vllm.platforms import current_platform


def is_quant_method_supported(quant_method: str) -> bool:
# Currently, all quantization methods require Nvidia or AMD GPUs
if not torch.cuda.is_available():
return False

capability = get_device_capability_stateless()
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
return (capability >=
QUANTIZATION_METHODS[quant_method].get_min_capability())
5 changes: 3 additions & 2 deletions vllm/attention/ops/blocksparse_attention/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import torch

from vllm.utils import get_device_capability_stateless, is_cpu, is_hip
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip

from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask)

IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
and get_device_capability_stateless()[0] >= 8)
and current_platform.get_device_capability()[0] >= 8)

if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import triton
import triton.language as tl

from vllm.utils import get_device_capability_stateless
from vllm.platforms import current_platform

if triton.__version__ >= "2.1.0":

Expand Down Expand Up @@ -685,7 +685,7 @@ def context_attention_fwd(q,
alibi_slopes=None,
sliding_window=None):

cap = get_device_capability_stateless()
cap = current_platform.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
Expand Down
4 changes: 2 additions & 2 deletions vllm/lora/punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import torch

from vllm import _custom_ops as ops
from vllm.utils import get_device_capability_stateless
from vllm.platforms import current_platform


def _check_punica_support():
if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
return

if get_device_capability_stateless() < (8, 0):
if current_platform.get_device_capability() < (8, 0):
raise ImportError(
"punica LoRA kernels require compute capability >= 8.0")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
find_first_name_or_class_match)
from vllm.utils import get_device_capability_stateless
from vllm.platforms import current_platform


class CompressedTensorsConfig(QuantizationConfig):
Expand Down Expand Up @@ -85,7 +85,7 @@ def get_config_filenames(cls) -> List[str]:
return []

def _check_gptq_and_marlin_can_run(self):
capability = get_device_capability_stateless()
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < 80:
raise RuntimeError("The quantization config is not supported for ",
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import get_device_capability_stateless, print_warning_once
from vllm.platforms import current_platform
from vllm.utils import print_warning_once

ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)


def cutlass_fp8_supported() -> bool:
capability = get_device_capability_stateless()
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]

return ops.cutlass_scaled_mm_supports_fp8(capability)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.utils import get_device_capability_stateless
from vllm.platforms import current_platform

logger = init_logger(__name__)

Expand Down Expand Up @@ -173,7 +173,7 @@ def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
return False

# If the capability of the device is too low, cannot convert.
major, minor = get_device_capability_stateless()
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if device_capability < cls.get_min_capability():
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
marlin_perm, marlin_scale_perm, marlin_scale_perm_single)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_pack_factor, quantize_weights, sort_weights)
from vllm.utils import get_device_capability_stateless
from vllm.platforms import current_platform

__cuda_arch = get_device_capability_stateless()
__cuda_arch = current_platform.get_device_capability()

MARLIN_TILE = 16

Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import get_device_capability_stateless, is_tpu
from vllm.platforms import current_platform
from vllm.utils import is_tpu

logger = init_logger(__name__)

Expand All @@ -46,7 +47,7 @@ def _get_quantization_config(
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
capability = get_device_capability_stateless()
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
Expand Down
18 changes: 18 additions & 0 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Optional

import torch

from .interface import Platform, PlatformEnum

current_platform: Optional[Platform]

if torch.version.cuda is not None:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif torch.version.hip is not None:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
else:
current_platform = None

__all__ = ['Platform', 'PlatformEnum', 'current_platform']
34 changes: 34 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

from functools import lru_cache, wraps
from typing import Tuple

import pynvml

from .interface import Platform, PlatformEnum


def with_nvml_context(fn):

@wraps(fn)
def wrapper(*args, **kwargs):
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()

return wrapper


class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA

@staticmethod
@lru_cache(maxsize=8)
@with_nvml_context
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
21 changes: 21 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import enum
from typing import Tuple


class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()


class Platform:
_enum: PlatformEnum

def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA

def is_rocm(self) -> bool:
return self._enum == PlatformEnum.ROCM

@staticmethod
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
raise NotImplementedError
15 changes: 15 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from functools import lru_cache
from typing import Tuple

import torch

from .interface import Platform, PlatformEnum


class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM

@staticmethod
@lru_cache(maxsize=8)
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return torch.cuda.get_device_capability(device_id)
7 changes: 0 additions & 7 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,13 +866,6 @@ def is_full_nvlink(device_ids: List[int]) -> bool:
return True


@lru_cache(maxsize=8)
@with_nvml_context
def get_device_capability_stateless(device_id: int = 0) -> Tuple[int, int]:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)


#From: https://stackoverflow.com/a/4104188/2749989
def run_once(f):

Expand Down
4 changes: 2 additions & 2 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_device_capability_stateless
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
Expand Down Expand Up @@ -333,7 +333,7 @@ def init_worker_distributed_environment(
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
compute_capability = get_device_capability_stateless()
compute_capability = current_platform.get_device_capability()
if compute_capability[0] < 8:
gpu_name = torch.cuda.get_device_name()
raise ValueError(
Expand Down
Loading