Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xulinhui committed Aug 8, 2024
1 parent d28795a commit 8a35916
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 49 deletions.
37 changes: 20 additions & 17 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from vllm.model_executor.model_loader.utils import (get_model_architecture,
set_default_torch_dtype)
from vllm.model_executor.model_loader.veturboio import (
VeturboIOConfig, load_with_veturboio)
VeturboIOConfig, load_with_veturboio_into_model)
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
Expand Down Expand Up @@ -862,27 +862,30 @@ def load_model(self, *, model_config: ModelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
self._verify_config(model_config, parallel_config)

with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0]
quant_config = _get_quantization_config(
model_config, self.load_config)
extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, multimodal_config)
extra_kwargs["quant_config"] = quant_config
extra_kwargs["cache_config"] = cache_config
hf_weights_files, _ = self._prepare_weights(
model_config.model, model_config.revision)

model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config)

hf_weights_files, _ = self._prepare_weights(model_config.model,
model_config.revision)
veturboio_config = copy.copy(self.veturboio_config)
veturboio_config.model_class = model_class
veturboio_config.hf_config = model_config.hf_config
veturboio_config.dtype = model_config.dtype
veturboio_config.model_files = hf_weights_files
veturboio_config.map_location = device_config.device_type

model = load_with_veturboio(veturboio_config, **extra_kwargs)
load_with_veturboio_into_model(veturboio_config, model)
# # do quant method
# for _, module in model.named_modules():
# quant_method = getattr(module, "quant_method", None)
# if quant_method is not None:
# # print(f">>>>>> {_} do quant_method {quant_method}")
# quant_method.process_weights_after_loading(module)
# # FIXME: Remove this after Mixtral is updated
# # to use quant_method.
# if hasattr(module, "process_weights_after_loading"):
# # print(f">>>>>> {_} has process_weights_after_loading")
# module.process_weights_after_loading()
torch.cuda.empty_cache()
return model.eval()


Expand Down
48 changes: 16 additions & 32 deletions vllm/model_executor/model_loader/veturboio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import os
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -37,9 +38,6 @@ class VeturboIOConfig:
use_pinmem: Optional[bool] = False
use_direct_io: Optional[bool] = False
use_cipher: Optional[bool] = False # not implemented yet
model_class: Optional[Type[torch.nn.Module]] = None
hf_config: Optional[PretrainedConfig] = None
dtype: Optional[Union[str, torch.dtype]] = None

def _construct_veturboio_args(self) -> "VeturboIOArgs":
veturboio_args = {
Expand Down Expand Up @@ -69,8 +67,7 @@ def verify_with_model_config(self, model_config: "ModelConfig") -> None:

class VeturboIOAgent:

def __init__(self, veturboio_config: VeturboIOConfig,
quant_config: QuantizationConfig, **extra_kwargs):
def __init__(self, veturboio_config: VeturboIOConfig):
if veturboio_error_msg is not None:
raise ImportError(
"VeturboIO is not installed. Please install ImportError "
Expand All @@ -80,45 +77,32 @@ def __init__(self, veturboio_config: VeturboIOConfig,
self.veturboio_config = veturboio_config
self.veturboio_args = (
self.veturboio_config._construct_veturboio_args())
self.extra_kwargs = extra_kwargs
if extra_kwargs.get("quant_config", None) is not None:
self.quant_config = extra_kwargs["quant_config"]
else:
self.quant_config = quant_config
self.model = self._init_model()

def _init_model(self):
assert self.veturboio_config.hf_config is not None
model_args = self.veturboio_config.hf_config
model_args.torch_dtype = self.veturboio_config.dtype
assert self.veturboio_config.model_class is not None

return self.veturboio_config.model_class(config=model_args,
quant_config=self.quant_config,
**self.extra_kwargs)

def deserialize(self):

def deserialize(self, model):
assert isinstance(model, torch.nn.Module)
start = time.perf_counter()
for model_file in self.veturboio_config.model_files:

tensors_dict = veturboio.load(model_file,
helper=helper,
**self.veturboio_args.deserializer_params)
self.model.load_state_dict(tensors_dict, strict=False, assign=True)
helper=helper,
**self.veturboio_args.deserializer_params)

model.load_state_dict(tensors_dict, strict=False, assign=True)
del tensors_dict
# gc.collect() # do gc collect immediately
torch.cuda.empty_cache()

end = time.perf_counter()
duration = end - start
logger.info("Deserialized model in %0.2fs by VeturboIO", duration)
return self.model.eval()


def load_with_veturboio(veturboio_config: VeturboIOConfig,
**extra_kwargs) -> nn.Module:
def load_with_veturboio_into_model(veturboio_config: VeturboIOConfig,
model: nn.Module):
assert veturboio_config.model_files is not None, ("model files can not be None, "
"when load with veturboIO")
veturboio = VeturboIOAgent(veturboio_config, **extra_kwargs)
return veturboio.deserialize()
veturboio = VeturboIOAgent(veturboio_config)
return veturboio.deserialize(model)


@dataclass
Expand Down

0 comments on commit 8a35916

Please sign in to comment.