Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate Logics for GPU Detection #1297

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 4 additions & 85 deletions python/mlc_chat/chat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import tvm
from tvm.runtime import disco # pylint: disable=unused-import

from mlc_chat.support.auto_device import detect_device

from . import base # pylint: disable=unused-import

if TYPE_CHECKING:
Expand Down Expand Up @@ -591,89 +593,6 @@ def _convert_generation_config_to_json_str(generation_config: Optional[Generatio
return json.dumps(asdict(generation_config))


def _parse_device_str(device: str) -> Tuple[tvm.runtime.Device, str]:
"""Parse the input device identifier into device name and id.

Parameters
----------
device : str
The device identifier to parse.
It can be "device_name" (e.g., "cuda") or
"device_name:device_id" (e.g., "cuda:1").

Returns
-------
dev : tvm.runtime.Device
The device.

device_name : str
The name of the device.
"""
device_err_msg = (
f"Invalid device name: {device}. Please enter the device in the form "
"'device_name:device_id' or 'device_name', where 'device_name' needs to be "
"one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'."
)
device_args = device.split(":")
if len(device_args) == 1:
device_name, device_id = device_args[0], 0
elif len(device_args) == 2:
device_name, device_id = device_args[0], int(device_args[1])
elif len(device_args) > 2:
raise ValueError(device_err_msg)

if device_name == "cuda":
device = tvm.cuda(device_id)
elif device_name == "metal":
device = tvm.metal(device_id)
elif device_name == "vulkan":
device = tvm.vulkan(device_id)
elif device_name == "rocm":
device = tvm.rocm(device_id)
elif device_name == "opencl":
device = tvm.opencl(device_id)
elif device_name == "auto":
device, device_name = _detect_local_device(device_id)
logging.info("System automatically detected device: %s", device_name)
else:
raise ValueError(device_err_msg)

return device, device_name


def _detect_local_device(device_id: int = 0) -> Tuple[tvm.runtime.Device, str]:
"""Automatically detect the local device if user does not specify.

Parameters
----------
device_id : int
The local device id.

Returns
------
dev : tvm.runtime.Device
The local device.

device_name : str
The name of the device.
"""
if tvm.metal().exist:
return tvm.metal(device_id), "metal"
if tvm.rocm().exist:
return tvm.rocm(device_id), "rocm"
if tvm.cuda().exist:
return tvm.cuda(device_id), "cuda"
if tvm.vulkan().exist:
return tvm.vulkan(device_id), "vulkan"
if tvm.opencl().exist:
return tvm.opencl(device_id), "opencl"
logging.info(
"None of the following device is detected: metal, rocm, cuda, vulkan, opencl. "
"Switch to llvm instead."
)
return tvm.cpu(device_id), "llvm"


class ChatModule: # pylint: disable=too-many-instance-attributes
r"""The ChatModule for MLC LLM.

Expand Down Expand Up @@ -738,7 +657,7 @@ def __init__(
):
# 0. Get device:
# Retrieve device_name and device_id (if any, default 0) from device arg
self.device, device_name = _parse_device_str(device)
self.device = detect_device(device)
device_type = self.device.device_type
device_id = self.device.device_id

Expand Down Expand Up @@ -780,7 +699,7 @@ def __init__(
self.model_path,
self.chat_config,
model_lib_path,
device_name,
self.device.MASK2STR[self.device.device_type],
self.config_file_path,
)

Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/cli/convert_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..support.argparse import ArgumentParser
from ..support.auto_config import detect_config, detect_model_type
from ..support.auto_target import detect_device
from ..support.auto_device import detect_device
from ..support.auto_weight import detect_weight


Expand Down
41 changes: 41 additions & 0 deletions python/mlc_chat/support/auto_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Automatic detection of the device available on the local machine."""
import logging

import tvm
from tvm.runtime import Device

from .style import bold, green, red

FOUND = green("Found")
NOT_FOUND = red("Not found")
AUTO_DETECT_DEVICES = ["cuda", "rocm", "metal", "vulkan", "opencl"]


logger = logging.getLogger(__name__)


def detect_device(device_hint: str) -> Device:
"""Detect locally available device from string hint."""
if device_hint == "auto":
device = None
for device_type in AUTO_DETECT_DEVICES:
cur_device = tvm.device(dev_type=device_type, dev_id=0)
if cur_device.exist:
logger.info("%s device: %s:0", FOUND, device_type)
if device is None:
device = cur_device
else:
logger.info("%s device: %s:0", NOT_FOUND, device_type)
if device is None:
logger.info("%s: No available device detected. Falling back to CPU", NOT_FOUND)
return tvm.device("cpu:0")
device_str = f"{tvm.runtime.Device.MASK2STR[device.device_type]}:{device.device_id}"
logger.info("Using device: %s. Use `--device` to override.", bold(device_str))
return device
try:
device = tvm.device(device_hint)
except Exception as err:
raise ValueError(f"Invalid device name: {device_hint}") from err
if not device.exist:
raise ValueError(f"Device is not found on your local environment: {device_hint}")
return device
32 changes: 1 addition & 31 deletions python/mlc_chat/support/auto_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import os
from typing import TYPE_CHECKING, Callable, Optional, Tuple

import tvm
from tvm import IRModule, relax
from tvm._ffi import get_global_func, register_func
from tvm.contrib import tar, xcode
from tvm.runtime import Device
from tvm.target import Target

from .auto_device import AUTO_DETECT_DEVICES
from .style import bold, green, red

if TYPE_CHECKING:
Expand Down Expand Up @@ -45,33 +44,6 @@ def detect_target_and_host(target_hint: str, host_hint: str = "auto") -> Tuple[T
return target, build_func


def detect_device(device_hint: str) -> Device:
"""Detect locally available device from string hint."""
if device_hint == "auto":
device = None
for device_type in AUTO_DETECT_DEVICES:
cur_device = tvm.device(dev_type=device_type, dev_id=0)
if cur_device.exist:
logger.info("%s device: %s:0", FOUND, device_type)
if device is None:
device = cur_device
else:
logger.info("%s device: %s:0", NOT_FOUND, device_type)
if device is None:
logger.info("%s: No available device detected. Falling back to CPU", NOT_FOUND)
return tvm.device("cpu:0")
device_str = f"{tvm.runtime.Device.MASK2STR[device.device_type]}:{device.device_id}"
logger.info("Using device: %s. Use `--device` to override.", bold(device_str))
return device
try:
device = tvm.device(device_hint)
except Exception as err:
raise ValueError(f"Invalid device name: {device_hint}") from err
if not device.exist:
raise ValueError(f"Device is not found on your local environment: {device_hint}")
return device


def _detect_target_gpu(hint: str) -> Tuple[Target, BuildFunc]:
if hint in ["iphone", "android", "webgpu", "mali", "opencl"]:
hint += ":generic"
Expand Down Expand Up @@ -287,8 +259,6 @@ def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument
return ptx


AUTO_DETECT_DEVICES = ["cuda", "rocm", "metal", "vulkan"]

PRESET = {
"iphone:generic": {
"target": {
Expand Down