Skip to content

Commit

Permalink
fix: use veturboio to load model
Browse files Browse the repository at this point in the history
  • Loading branch information
xulinhui committed Aug 7, 2024
1 parent a4588d3 commit d28795a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
7 changes: 5 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,9 +825,11 @@ def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if isinstance(load_config.model_loader_extra_config, VeturboIOConfig):
self.veturboio_config = load_config.model_loader_extra_config
else:
elif isinstance(load_config.model_loader_extra_config, dict):
self.veturboio_config = VeturboIOConfig(
**load_config.model_loader_extra_config)
else:
self.veturboio_config = VeturboIOConfig()

def _verify_config(self, model_config: ModelConfig,
parallel_config: ParallelConfig):
Expand Down Expand Up @@ -870,14 +872,15 @@ def load_model(self, *, model_config: ModelConfig,
model_class, lora_config, multimodal_config)
extra_kwargs["quant_config"] = quant_config
extra_kwargs["cache_config"] = cache_config
_, hf_weights_files, _ = self._prepare_weights(
hf_weights_files, _ = self._prepare_weights(
model_config.model, model_config.revision)

veturboio_config = copy.copy(self.veturboio_config)
veturboio_config.model_class = model_class
veturboio_config.hf_config = model_config.hf_config
veturboio_config.dtype = model_config.dtype
veturboio_config.model_files = hf_weights_files
veturboio_config.map_location = device_config.device_type

model = load_with_veturboio(veturboio_config, **extra_kwargs)
return model.eval()
Expand Down
21 changes: 15 additions & 6 deletions vllm/model_executor/model_loader/veturboio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

try:
import veturboio
helper = veturboio.init_io_helper()
from veturboio.ops.load_utils import IOHelper
if IOHelper is None:
helper = None
else:
helper = veturboio.init_io_helper()
except ImportError as e:
veturboio_error_msg = str(e)

Expand All @@ -26,7 +30,7 @@

@dataclass
class VeturboIOConfig:
model_files: Tuple[str, os.PathLike]
model_files: Optional[Tuple[str, os.PathLike]] = None
map_location: Optional[str] = "cpu"
enable_fast_mode: Optional[bool] = True
num_thread: Optional[int] = 32
Expand All @@ -39,7 +43,7 @@ class VeturboIOConfig:

def _construct_veturboio_args(self) -> "VeturboIOArgs":
veturboio_args = {
"map_location": self.enable_fast_mode,
"map_location": self.map_location,
"enable_fast_mode": self.enable_fast_mode,
"num_thread": self.num_thread,
"use_pinmem": self.use_pinmem,
Expand Down Expand Up @@ -97,9 +101,12 @@ def deserialize(self):
start = time.perf_counter()
for model_file in self.veturboio_config.model_files:
tensors_dict = veturboio.load(model_file,
helper=helper,
**self.veturboio_args.deserializer_params)
self.model.load_state_dict(tensors_dict, assign=True)
helper=helper,
**self.veturboio_args.deserializer_params)
self.model.load_state_dict(tensors_dict, strict=False, assign=True)
del tensors_dict
torch.cuda.empty_cache()

end = time.perf_counter()
duration = end - start
logger.info("Deserialized model in %0.2fs by VeturboIO", duration)
Expand All @@ -108,6 +115,8 @@ def deserialize(self):

def load_with_veturboio(veturboio_config: VeturboIOConfig,
**extra_kwargs) -> nn.Module:
assert veturboio_config.model_files is not None, ("model files can not be None, "
"when load with veturboIO")
veturboio = VeturboIOAgent(veturboio_config, **extra_kwargs)
return veturboio.deserialize()

Expand Down

0 comments on commit d28795a

Please sign in to comment.