diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index cc4a344bb8555..e213d533c7bb9 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -28,7 +28,7 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture, set_default_torch_dtype) from vllm.model_executor.model_loader.veturboio import ( - VeturboIOConfig, load_with_veturboio) + VeturboIOConfig, load_with_veturboio_into_model) from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, @@ -862,27 +862,30 @@ def load_model(self, *, model_config: ModelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: self._verify_config(model_config, parallel_config) - with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model_class = get_model_architecture(model_config)[0] - quant_config = _get_quantization_config( - model_config, self.load_config) - extra_kwargs = _get_model_initialization_kwargs( - model_class, lora_config, multimodal_config) - extra_kwargs["quant_config"] = quant_config - extra_kwargs["cache_config"] = cache_config - hf_weights_files, _ = self._prepare_weights( - model_config.model, model_config.revision) - + model = _initialize_model(model_config, self.load_config, + lora_config, multimodal_config, + cache_config) + + hf_weights_files, _ = self._prepare_weights(model_config.model, + model_config.revision) veturboio_config = copy.copy(self.veturboio_config) - veturboio_config.model_class = model_class - veturboio_config.hf_config = model_config.hf_config - veturboio_config.dtype = model_config.dtype veturboio_config.model_files = hf_weights_files veturboio_config.map_location = device_config.device_type - - model = load_with_veturboio(veturboio_config, **extra_kwargs) + load_with_veturboio_into_model(veturboio_config, model) + # # do quant method + # for _, module in model.named_modules(): + # quant_method = getattr(module, "quant_method", None) + # if quant_method is not None: + # # print(f">>>>>> {_} do quant_method {quant_method}") + # quant_method.process_weights_after_loading(module) + # # FIXME: Remove this after Mixtral is updated + # # to use quant_method. + # if hasattr(module, "process_weights_after_loading"): + # # print(f">>>>>> {_} has process_weights_after_loading") + # module.process_weights_after_loading() + torch.cuda.empty_cache() return model.eval() diff --git a/vllm/model_executor/model_loader/veturboio.py b/vllm/model_executor/model_loader/veturboio.py index ef182ea551744..c285cac838c33 100644 --- a/vllm/model_executor/model_loader/veturboio.py +++ b/vllm/model_executor/model_loader/veturboio.py @@ -1,3 +1,4 @@ +import gc import os import time from dataclasses import dataclass @@ -37,9 +38,6 @@ class VeturboIOConfig: use_pinmem: Optional[bool] = False use_direct_io: Optional[bool] = False use_cipher: Optional[bool] = False # not implemented yet - model_class: Optional[Type[torch.nn.Module]] = None - hf_config: Optional[PretrainedConfig] = None - dtype: Optional[Union[str, torch.dtype]] = None def _construct_veturboio_args(self) -> "VeturboIOArgs": veturboio_args = { @@ -69,8 +67,7 @@ def verify_with_model_config(self, model_config: "ModelConfig") -> None: class VeturboIOAgent: - def __init__(self, veturboio_config: VeturboIOConfig, - quant_config: QuantizationConfig, **extra_kwargs): + def __init__(self, veturboio_config: VeturboIOConfig): if veturboio_error_msg is not None: raise ImportError( "VeturboIO is not installed. Please install ImportError " @@ -80,45 +77,32 @@ def __init__(self, veturboio_config: VeturboIOConfig, self.veturboio_config = veturboio_config self.veturboio_args = ( self.veturboio_config._construct_veturboio_args()) - self.extra_kwargs = extra_kwargs - if extra_kwargs.get("quant_config", None) is not None: - self.quant_config = extra_kwargs["quant_config"] - else: - self.quant_config = quant_config - self.model = self._init_model() - - def _init_model(self): - assert self.veturboio_config.hf_config is not None - model_args = self.veturboio_config.hf_config - model_args.torch_dtype = self.veturboio_config.dtype - assert self.veturboio_config.model_class is not None - - return self.veturboio_config.model_class(config=model_args, - quant_config=self.quant_config, - **self.extra_kwargs) - - def deserialize(self): + + def deserialize(self, model): + assert isinstance(model, torch.nn.Module) start = time.perf_counter() for model_file in self.veturboio_config.model_files: + tensors_dict = veturboio.load(model_file, - helper=helper, - **self.veturboio_args.deserializer_params) - self.model.load_state_dict(tensors_dict, strict=False, assign=True) + helper=helper, + **self.veturboio_args.deserializer_params) + + model.load_state_dict(tensors_dict, strict=False, assign=True) del tensors_dict + # gc.collect() # do gc collect immediately torch.cuda.empty_cache() - + end = time.perf_counter() duration = end - start logger.info("Deserialized model in %0.2fs by VeturboIO", duration) - return self.model.eval() -def load_with_veturboio(veturboio_config: VeturboIOConfig, - **extra_kwargs) -> nn.Module: +def load_with_veturboio_into_model(veturboio_config: VeturboIOConfig, + model: nn.Module): assert veturboio_config.model_files is not None, ("model files can not be None, " "when load with veturboIO") - veturboio = VeturboIOAgent(veturboio_config, **extra_kwargs) - return veturboio.deserialize() + veturboio = VeturboIOAgent(veturboio_config) + return veturboio.deserialize(model) @dataclass