diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 7b869854d6..35a8d1f41e 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -437,7 +437,7 @@ class LLMChat { /*! * \brief Reload model, tokenizers and configurations from the specified model path. - * \param executable The module to reload. + * \param reload_lib The module to reload, it can either be a path to the library or a tvm Module. * \param model_path The path to search for models. * \param app_config_json The JSON string used to partially override the configuration loaded from * disk, default to empty string. diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 02625f4ef4..b2e0ec126c 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -532,14 +532,16 @@ def _get_lib_module_path( raise FileNotFoundError(err_msg) -def _convert_chat_config_to_json_str(chat_config: Optional[ChatConfig], conv_template: str) -> str: +def _convert_chat_config_to_json_str( + chat_config: Optional[ChatConfig], conv_template: Optional[str] +) -> str: """Convert user's input ChatConfig to a json string, omitting ``None`` fields. Parameters ---------- chat_config : Optional[ChatConfig] User's input. A partial ChatConfig for overriding ``mlc-chat-config.json``. - conv_template : str + conv_template : Optional[str] The ``conv_template`` that will be used after considering potential override. Returns @@ -591,7 +593,7 @@ def _convert_generation_config_to_json_str(generation_config: Optional[Generatio return json.dumps(asdict(generation_config)) -def _parse_device_str(device: str): +def _parse_device_str(device: str) -> (tvm.runtime.Device, str): """Parse the input device identifier into device name and id. Parameters @@ -603,11 +605,11 @@ def _parse_device_str(device: str): Returns ------- + dev : tvm.runtime.Device + The device. + device_name : str The name of the device. - - device_id : int - The id of the device, or 0 if not specified in the input. """ device_err_msg = ( f"Invalid device name: {device}. Please enter the device in the form " @@ -616,14 +618,32 @@ def _parse_device_str(device: str): ) device_args = device.split(":") if len(device_args) == 1: - return device_args[0], 0 + device_name, device_id = device_args[0], 0 elif len(device_args) == 2: - return device_args[0], int(device_args[1]) + device_name, device_id = device_args[0], int(device_args[1]) elif len(device_args) > 2: raise ValueError(device_err_msg) + if device_name == "cuda": + device = tvm.cuda(device_id) + elif device_name == "metal": + device = tvm.metal(device_id) + elif device_name == "vulkan": + device = tvm.vulkan(device_id) + elif device_name == "rocm": + device = tvm.rocm(device_id) + elif device_name == "opencl": + device = tvm.opencl(device_id) + elif device_name == "auto": + device, device_name = _detect_local_device(device_id) + logging.info(f"System automatically detected device: {device_name}") + else: + raise ValueError(device_err_msg) + + return device, device_name -def _detect_local_device(device_id: int = 0): + +def _detect_local_device(device_id: int = 0) -> (tvm.runtime.Device, str): """Automatically detect the local device if user does not specify. Parameters @@ -633,8 +653,11 @@ def _detect_local_device(device_id: int = 0): Returns ------ - dev : Device + dev : tvm.runtime.Device The local device. + + device_name : str + The name of the device. """ if tvm.metal().exist: return tvm.metal(device_id), "metal" @@ -715,34 +738,13 @@ def __init__( chat_config: Optional[ChatConfig] = None, model_lib_path: Optional[str] = None, ): - device_err_msg = ( - f"Invalid device name: {device}. Please enter the device in the form " - "'device_name:device_id' or 'device_name', where 'device_name' needs to be " - "one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'." - ) - - # 0. Retrieve device_name and device_id (if any, default 0) from device arg - device_name, device_id = _parse_device_str(device) - - # 1. Get self.device - if device_name == "cuda": - self.device = tvm.cuda(device_id) - elif device_name == "metal": - self.device = tvm.metal(device_id) - elif device_name == "vulkan": - self.device = tvm.vulkan(device_id) - elif device_name == "rocm": - self.device = tvm.rocm(device_id) - elif device_name == "opencl": - self.device = tvm.opencl(device_id) - elif device_name == "auto": - self.device, device_name = _detect_local_device(device_id) - logging.info(f"System automatically detected device: {device_name}") - else: - raise ValueError(device_err_msg) + # 0. Get device: + # Retrieve device_name and device_id (if any, default 0) from device arg + self.device, device_name = _parse_device_str(device) device_type = self.device.device_type + device_id = self.device.device_id - # 2. Populate chat module and their functions + # 1. Populate chat module and their functions fcreate_chat_mod = tvm.get_global_func("mlc.llm_chat_create") assert fcreate_chat_mod is not None chat_mod = fcreate_chat_mod(device_type, device_id) @@ -768,13 +770,13 @@ def __init__( self._get_role0_func = chat_mod["get_role0"] self._get_role1_func = chat_mod["get_role1"] - # 3. Look up model_path + # 2. Look up model_path self.model_path, self.config_file_path = _get_model_path(model) - # 4. Instantiate chat_config + # 3. Instantiate chat_config self.chat_config = _get_chat_config(self.config_file_path, chat_config) - # 5. Look up model library + # 4. Look up model library self.model_lib_path = _get_lib_module_path( model, self.model_path, @@ -784,7 +786,7 @@ def __init__( self.config_file_path, ) - # 6. Call reload + # 5. Call reload user_chat_config_json_str = _convert_chat_config_to_json_str( self.chat_config, self.chat_config.conv_template )