From 2c56991529fdcc31334755db7f6c2503129bdcc6 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 3 Aug 2024 13:36:14 +0800 Subject: [PATCH] [Model] Refactor and decouple weight loading logic for InternVL2 model (#7067) --- vllm/model_executor/models/intern_vit.py | 11 +++- vllm/model_executor/models/internvl.py | 82 ++++++++---------------- 2 files changed, 38 insertions(+), 55 deletions(-) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index c6c692deca2e1..54c933e3e4959 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -4,7 +4,7 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- -from typing import Optional +from typing import Iterable, Optional, Tuple import torch import torch.nn as nn @@ -16,6 +16,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader NORM2FN = { 'rms_norm': RMSNorm, @@ -268,3 +269,11 @@ def forward( encoder_outputs = self.encoder(inputs_embeds=hidden_states) return encoder_outputs + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index eabc283b1efdb..4749251271487 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -4,6 +4,7 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- +import itertools from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch @@ -414,58 +415,31 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - (".gate_up_proj", ".w1", 0), - (".gate_up_proj", ".w3", 1), - ] - params_dict = dict(self.named_parameters()) + def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], + prefix: str): for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if self.config.text_config.tie_word_embeddings \ - and "lm_head.weight" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - # We only do sharding for language model - # and not vision model for now. - if "vision_embed_tokens" in name and self.vision_embed_tokens: - continue - if weight_name not in name: - continue - param = params_dict[name.replace(weight_name, param_name)] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - if "wqkv" in name: - config = self.config.text_config - kv_groups = (config.num_attention_heads // - config.num_key_value_heads) - head_dim = config.hidden_size // config.num_attention_heads - loaded_weight = loaded_weight.view(-1, 2 + kv_groups, - head_dim, - loaded_weight.shape[-1]) - wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], - dim=1) - wq = wq.reshape(-1, wq.shape[-1]) - wk = wk.reshape(-1, wk.shape[-1]) - wv = wv.reshape(-1, wv.shape[-1]) - weight_loader = param.weight_loader - weight_loader(param, wq, 'q') - weight_loader(param, wk, 'k') - weight_loader(param, wv, 'v') - continue - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + name = name.split(".") + if prefix == name.pop(0): + name = ".".join(name) + yield name, loaded_weight + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # prepare weight iterators for components + vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + + # load vision encoder + vit_weights = self._filter_weights(vit_weights, "vision_model") + self.vision_model.load_weights(vit_weights) + + # load mlp projector + mlp_weights = self._filter_weights(mlp_weights, "mlp1") + mlp_params_dict = dict(self.mlp1.named_parameters()) + for name, loaded_weight in mlp_weights: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + llm_weights = self._filter_weights(llm_weights, "language_model") + self.language_model.load_weights(llm_weights)