diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 2eefcc4f30516..c0d8553c0df1a 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -2,7 +2,7 @@ import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Mapping, Optional, Type import torch from torch import nn @@ -59,6 +59,7 @@ def method_has_implemented_embedding( class QuantizationConfig(ABC): """Base class for quantization configs.""" + packed_modules_mapping: Mapping[str, List[str]] = dict() @abstractmethod def get_name(self) -> str: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 1a11b2419cc88..0e3258e4afba0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -83,7 +83,9 @@ def get_quant_method( # Check if the layer is skipped for quantization. # TODO (@robertgshaw2): support module names - if should_ignore_layer(prefix, ignore=self.ignore): + if should_ignore_layer(prefix, + ignore=self.ignore, + fused_mapping=self.packed_modules_mapping): return UnquantizedLinearMethod() if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) @@ -379,34 +381,29 @@ def get_scheme(self, # Will be empty for models with only sparsity weight_quant = input_quant = None - sparsity_scheme: Optional[SparsityCompressionConfig] = None if self.target_scheme_map: matched_target = find_matched_target( layer_name=layer_name, module=layer, - targets=self.target_scheme_map.keys()) + targets=self.target_scheme_map.keys(), + fused_mapping=self.packed_modules_mapping) scheme_dict = self.target_scheme_map[matched_target] weight_quant = scheme_dict.get("weights") input_quant = scheme_dict.get("input_activations") - if self.sparsity_scheme_map: - is_ignored = False - with suppress(ValueError): - is_ignored = find_matched_target( - layer_name=layer_name, - module=layer, - targets=self.sparsity_ignore_list) - - # if the layer is in the sparsity ignore list, - # we should not apply any sparsity scheme - - if not is_ignored: - matched_target = find_matched_target( - layer_name=layer_name, - module=layer, - targets=self.sparsity_scheme_map.keys()) - sparsity_scheme = self.sparsity_scheme_map.get(matched_target) + # Find the sparsity scheme of the layer + # assume that fused layers inerhit first component's sparsity scheme + sparsity_targets = (self.sparsity_scheme_map.keys() - + set(self.sparsity_ignore_list)) + sparsity_scheme: Optional[SparsityCompressionConfig] = None + with suppress(ValueError): + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=sparsity_targets, + fused_mapping=self.packed_modules_mapping) + sparsity_scheme = self.sparsity_scheme_map[matched_target] if self.supports_cutlass_24(weight_quant=weight_quant, input_quant=input_quant, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 4ea79531efe7e..85ae1d5cb7878 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 import re -from typing import Iterable, Optional +from types import MappingProxyType +from typing import Iterable, List, Mapping, Optional from compressed_tensors import CompressionFormat from torch.nn import Module -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - FUSED_LAYER_NAME_MAPPING) - def is_activation_quantization_format(format: str) -> bool: _ACTIVATION_QUANTIZATION_FORMATS = [ @@ -19,8 +17,11 @@ def is_activation_quantization_format(format: str) -> bool: return format in _ACTIVATION_QUANTIZATION_FORMATS -def should_ignore_layer(layer_name: Optional[str], - ignore: Iterable[str]) -> bool: +def should_ignore_layer( + layer_name: Optional[str], + ignore: Iterable[str] = tuple(), + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) +) -> bool: if layer_name is None: return False @@ -32,8 +33,8 @@ def should_ignore_layer(layer_name: Optional[str], # in the safetensors checkpoint. So, we convert the name # from the fused version to unfused + check to make sure that # each shard of the fused layer has the same scheme. - if proj_name in FUSED_LAYER_NAME_MAPPING and layer_name not in ignore: - shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] + if proj_name in fused_mapping and layer_name not in ignore: + shard_proj_names = fused_mapping[proj_name] # Convert fused_name --> [shard_names] shard_names = [ @@ -79,55 +80,12 @@ def check_equal_or_regex_match(layer_name: str, return False -def _handle_fused_layers(func): - """ - Decorator to handle fused layers by mapping vllm fused layer names - to their corresponding unfused layer names for quantization/pruning schemes. - """ - # fused_layer_name -> unfused_layer_name - fused_layer_map = { - "qkv_proj": "q_proj", - "gate_up_proj": "up_proj", - } - - def fused_layer_handler(layer_name: Optional[str], module: Module, - targets: Iterable[str]) -> Optional[str]: - """ - Wrapper function specifically designed to support the - find_matched_target function. - - It handles cases where the provided layer name corresponds to a - fused layer in vllm, mapping it to its equivalent unfused layer name - based on the predefined fused_layer_map. If the original layer name - raises a ValueError in the wrapped function, this handler - will attempt to resolve the issue by substituting with unfused - layer name. - - :param layer_name: Name of the layer, which may be fused. - :param module: An instance of torch.nn.Module. - :param targets: A list of target names or patterns to match. - :return: The result of the wrapped find_matched_target function with - the resolved layer name. - :raises ValueError: If the layer name cannot be resolved to a - valid target. - """ - try: - return func(layer_name, module, targets) - except ValueError: - if layer_name is None: - layer_name = "" - parent_name, fused_proj_name = layer_name.rsplit(".", 1) - unfused_proj_name = fused_layer_map.get(fused_proj_name, - fused_proj_name) - new_layer_name = f"{parent_name}.{unfused_proj_name}" - return func(new_layer_name, module, targets) - - return fused_layer_handler - - -@_handle_fused_layers -def find_matched_target(layer_name: Optional[str], module: Module, - targets: Iterable[str]) -> str: +def find_matched_target( + layer_name: Optional[str], + module: Module, + targets: Iterable[str], + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) +) -> str: """ Helper function to look up which "target" in the compressed-tensors config that a layer corresponds to. @@ -141,19 +99,25 @@ def find_matched_target(layer_name: Optional[str], module: Module, First, we try to match the layer_name with a target Second, we try to match the module's name with a target + Third, we try to map the layer_name to a list of fused module names. + *All* component module names must match in order for a match to be + successful. A successful match returns the first component target :param layer_name: layer name :param module: torch.nn.Module :param targets: list of targets to match the layer against + :param fused_mapping: map from fused layer names to its components + :param fused_strategy: either "all" or "any". If using "all", fused + layers match if "all" of its components match """ if layer_name is None: layer_name = "" - matched_target = (_find_first_match(layer_name, targets) - or _find_first_match(module.__class__.__name__, targets, - True) - or _match_fused_layer(layer_name, targets)) + matched_target = ( + _find_first_match(layer_name, targets) + or _find_first_match(module.__class__.__name__, targets, True) + or _match_fused_layer(layer_name, targets, fused_mapping)) if matched_target is None: raise ValueError( @@ -205,11 +169,19 @@ def _is_equal_or_regex_match(value: str, return False -def _match_fused_layer(layer_name: str, - target_layers: Iterable[str]) -> Optional[str]: +def _match_fused_layer( + layer_name: str, target_layers: Iterable[str], + fused_mapping: Mapping[str, List[str]]) -> Optional[str]: """ Match a fused layer name to its corresponding individual layer in - target_layers. + target_layers. Returns first value in fused_mapping which matches targets + + Implements an "all" matching strategy where a fused layer matches iff + "all" of its components match + + :param layer_name: layer name + :param target_layers: list of targets to match the layer against + :param fused_mapping: map from fused layer names to its components Examples: layer_name = "model.layers.0.self_attn.qkv_proj" @@ -217,27 +189,25 @@ def _match_fused_layer(layer_name: str, "model.layers.0.self_attn.k_proj", "model.layers.0.self_attn.v_proj"] """ - # Split into parent path and layer type - # e.g., "model.layers.0.self_attn" and "qkv_proj" - parent_path = ".".join(layer_name.split(".")[:-1]) - layer_type = layer_name.split(".")[-1] - - if layer_type not in FUSED_LAYER_NAME_MAPPING: + # find layer_name in mapping + fused = next((key for key in fused_mapping if layer_name.endswith(key)), + None) + if fused is None: return None - possible_layer_types = FUSED_LAYER_NAME_MAPPING[layer_type] - - # Look for a target layer that: - # 1. Has the same parent path - # 2. Ends with one of the possible individual layer types - for target in target_layers: - is_same_parent = parent_path in target - is_matching_type = any(type_suffix in target - for type_suffix in possible_layer_types) - - if is_same_parent and is_matching_type and all( - (f"{parent_path}.{type_suffix}" in target_layers) - for type_suffix in possible_layer_types): - return target + # expand path of unfused components + unfused_paths = [ + layer_name.replace(fused, unfused) for unfused in fused_mapping[fused] + ] - return None + # for each unfused component, find a match in targets + unfused_matches: List[Optional[str]] = [] + for unfused in unfused_paths: + for target in target_layers: + if _is_equal_or_regex_match(unfused, target): + unfused_matches.append(target) + break + else: + unfused_matches.append(None) + + return unfused_matches[0] if all(unfused_matches) else None diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 0451cf82b9971..ba123565a0ecc 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -18,8 +18,6 @@ QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8) from vllm.model_executor.layers.quantization.quark.utils import ( deep_compare, should_ignore_layer) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - FUSED_LAYER_NAME_MAPPING) from vllm.platforms import current_platform __all__ = ["QuarkLinearMethod"] @@ -58,7 +56,9 @@ def get_quant_method(self, layer: torch.nn.Module, # Check if the layer is skipped for quantization. exclude_layers = cast(List[str], self.quant_config.get("exclude")) - if should_ignore_layer(prefix, ignore=exclude_layers): + if should_ignore_layer(prefix, + ignore=exclude_layers, + fused_mapping=self.packed_modules_mapping): return UnquantizedLinearMethod() if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) @@ -201,8 +201,8 @@ def _find_matched_config(self, layer_name: str, module: torch.nn.Module) -> Dict[str, Any]: proj_name = layer_name.split(".")[-1] - if proj_name in FUSED_LAYER_NAME_MAPPING: - shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] + if proj_name in self.packed_modules_mapping: + shard_proj_names = self.packed_modules_mapping[proj_name] # Convert fused_name --> [shard_names] shard_names = [ diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index afb1d9d63e73a..17e0df021085a 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -1,10 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import re -from typing import Any, Iterable, Optional - -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - FUSED_LAYER_NAME_MAPPING) +from types import MappingProxyType +from typing import Any, Iterable, List, Mapping, Optional def deep_compare(dict1: Any, dict2: Any) -> bool: @@ -20,8 +18,11 @@ def deep_compare(dict1: Any, dict2: Any) -> bool: return dict1 == dict2 -def should_ignore_layer(layer_name: Optional[str], - ignore: Iterable[str]) -> bool: +def should_ignore_layer( + layer_name: Optional[str], + ignore: Iterable[str], + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) +) -> bool: if layer_name is None: return False @@ -33,8 +34,8 @@ def should_ignore_layer(layer_name: Optional[str], # in the safetensors checkpoint. So, we convert the name # from the fused version to unfused + check to make sure that # each shard of the fused layer has the same scheme. - if proj_name in FUSED_LAYER_NAME_MAPPING: - shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] + if proj_name in fused_mapping: + shard_proj_names = fused_mapping[proj_name] # Convert fused_name --> [shard_names] shard_names = [ diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 62484f62f6187..c7ce3a42c81f9 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """This file is used for /tests and /benchmarks""" -from typing import List, Optional, Tuple +from types import MappingProxyType +from typing import List, Mapping, Optional, Tuple import numpy import torch @@ -12,14 +13,6 @@ SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] -# Note: this is a hack. We should update each model to register the -# stacked params and get it from there instead in a future PR. -# fused_name: List[shard_name] -FUSED_LAYER_NAME_MAPPING = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] -} - # Normalize the group_shape to the full extent for any dims that are -1 def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int, @@ -178,14 +171,23 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor, return res.permute(inv_perm) -def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: +def is_layer_skipped( + prefix: str, + ignored_layers: List[str], + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) +) -> bool: # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj proj_name = prefix.split(".")[-1] - if proj_name in FUSED_LAYER_NAME_MAPPING: + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in fused_mapping: shard_prefixes = [ prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + for shard_proj_name in fused_mapping[proj_name] ] is_skipped = None diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 19e3bc6a259e9..2a2c2523b725d 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -43,6 +43,7 @@ TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, serialize_vllm_model, tensorizer_weights_iterator) from vllm.model_executor.model_loader.utils import (ParamMapping, + configure_quant_config, get_model_architecture, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( @@ -113,6 +114,9 @@ def _initialize_model( model_config = vllm_config.model_config model_class, _ = get_model_architecture(model_config) + if vllm_config.quant_config is not None: + configure_quant_config(vllm_config.quant_config, model_class) + signatures = inspect.signature(model_class.__init__) all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 7a82a695c5070..dc620d4984a77 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -11,6 +11,8 @@ from vllm.config import ModelConfig, ModelImpl from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.adapters import (as_classification_model, as_embedding_model, @@ -138,3 +140,23 @@ def get_sub_modules(self, if module_name.endswith(key): return key, value return None + + +def configure_quant_config(quant_config: QuantizationConfig, + model_class: Type[nn.Module]): + """ + Pass packed_modules_mapping by reference to quant_config so that + quant_config can properly match fused modules + + Note that model attributes are passed by reference to quant_config, + enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) + """ + packed_mapping = getattr(model_class, "packed_modules_mapping", None) + if packed_mapping is not None: + # pass packed_modules_mapping by reference to quant_config + quant_config.packed_modules_mapping = packed_mapping + else: + logger.warning( + "The model class %s has not defined `packed_modules_mapping`, " + "this may lead to incorrect mapping of quantized or ignored " + "modules", model_class.__name__) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index b81a9e917d455..a316486752590 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -265,12 +265,14 @@ def __init__( self.total_num_kv_heads, bias=config.add_bias_linear or config.add_qkv_bias, quant_config=quant_config, + prefix=f"{prefix}.query_key_value", ) self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=config.add_bias_linear, quant_config=quant_config, + prefix=f"{prefix}.dense", ) # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 @@ -327,6 +329,7 @@ def __init__( self, config: ChatGLMConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() @@ -338,6 +341,7 @@ def __init__( [config.ffn_hidden_size] * 2, bias=config.add_bias_linear, quant_config=quant_config, + prefix=f"{prefix}.dense_h_to_4h", ) self.activation_func = SiluAndMul() @@ -348,6 +352,7 @@ def __init__( config.hidden_size, bias=config.add_bias_linear, quant_config=quant_config, + prefix=f"{prefix}.dense_4h_to_h", ) def forward(self, hidden_states): @@ -396,7 +401,7 @@ def __init__( config.hidden_size, eps=config.layernorm_epsilon) # MLP - self.mlp = GLMMLP(config, quant_config) + self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp") def forward( self, @@ -507,7 +512,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embedding = VocabParallelEmbedding(config.padded_vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.embedding") self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num @@ -766,6 +772,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsMultiModal): # Ensure that the LoRA support check passes when the class is not # initialized, but set all these attributes to empty. + # These will be updated when an instance class is selected packed_modules_mapping = {} supported_lora_modules = [] embedding_modules = {} @@ -777,9 +784,18 @@ def __new__( prefix: str = "", ) -> None: config = vllm_config.model_config.hf_config + # Initialize VL - if hasattr(config, "vision_config"): - return ChatGLMV(vllm_config=vllm_config, prefix=prefix) + if hasattr(config, "vision_config"): # noqa: SIM108 + instance_cls = ChatGLMV # Initialize LLM else: - return ChatGLM(vllm_config=vllm_config, prefix=prefix) \ No newline at end of file + instance_cls = ChatGLM + + # quant_config references base class members, + # so update values before init is called + cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) + cls.supported_lora_modules += instance_cls.supported_lora_modules + cls.embedding_modules.update(instance_cls.embedding_modules) + cls.embedding_padding_modules += instance_cls.embedding_padding_modules + return instance_cls(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py index 4449eb8e8b143..2facd1353aef1 100644 --- a/vllm/model_executor/models/glm4_vision_encoder.py +++ b/vllm/model_executor/models/glm4_vision_encoder.py @@ -74,11 +74,13 @@ def __init__( self.head_dim, config.num_heads, quant_config=quant_config, + prefix=f"{prefix}.query_key_value", ) self.dense = RowParallelLinear( config.hidden_size, config.hidden_size, quant_config=quant_config, + prefix=f"{prefix}.dense", ) self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim, @@ -101,6 +103,7 @@ def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', ): super().__init__() self.config = config @@ -109,11 +112,13 @@ def __init__( config.hidden_size, config.intermediate_size, quant_config=quant_config, + prefix=f"{prefix}.fc1", ) self.fc2 = RowParallelLinear( config.intermediate_size, config.hidden_size, quant_config=quant_config, + prefix=f"{prefix}.fc2", ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -137,7 +142,9 @@ def __init__( self.attention = Attention(config, quant_config=quant_config, prefix=f"{prefix}.attention") - self.mlp = MLP(config, quant_config=quant_config) + self.mlp = MLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") self.post_attention_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -164,7 +171,7 @@ def __init__( self.layers = nn.ModuleList([ TransformerLayer(config, quant_config=quant_config, - prefix=f"{prefix}.layer.{layer_idx}") + prefix=f"{prefix}.layers.{layer_idx}") for layer_idx in range(config.num_hidden_layers) ]) @@ -181,6 +188,7 @@ def __init__( config, in_features, quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', ): """ The original implementation is the same as: @@ -222,7 +230,8 @@ def __init__( self.linear_proj = ReplicatedLinear(in_features, config.hidden_size, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.linear_proj") self.norm1 = nn.LayerNorm(config.hidden_size) self.act1 = nn.GELU() self.act2 = SiluAndMul() @@ -230,12 +239,15 @@ def __init__( self.merged_proj = MergedColumnParallelLinear( config.hidden_size, [config.ffn_hidden_size] * 2, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.merged_proj") - self.dense_4h_to_h = RowParallelLinear(config.ffn_hidden_size, - config.hidden_size, - bias=False, - quant_config=quant_config) + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.dense_4h_to_h") def forward(self, x): x, _ = self.linear_proj(x) @@ -262,7 +274,8 @@ def __init__( prefix=f"{prefix}.transformer") self.linear_proj = GLU(config, in_features=config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.linear_proj") self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 3d16d635b578a..20f3a3d1989be 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -1473,6 +1473,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): """ # Ensure that the LoRA support check passes when the class is not # initialized, but set all these attributes to empty. + # These will be updated when an instance class is selected packed_modules_mapping = {} supported_lora_modules = [] embedding_modules = {} @@ -1489,8 +1490,15 @@ def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""): version = str(config.version).split(".") version = tuple([int(x) for x in version]) # Dispatch class based on version - instance_class = _SUPPORT_VERSION.get(version) - if instance_class is None: + instance_cls = _SUPPORT_VERSION.get(version) + if instance_cls is None: raise ValueError( "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") - return instance_class(vllm_config=vllm_config, prefix=prefix) + + # quant_config references base class members, + # so update values before init is called + cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) + cls.supported_lora_modules += instance_cls.supported_lora_modules + cls.embedding_modules.update(instance_cls.embedding_modules) + cls.embedding_padding_modules += instance_cls.embedding_padding_modules + return instance_cls(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 327fad0f5702d..8970661243148 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -1135,6 +1135,7 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA): """ # Ensure that the LoRA support check passes when the class is not # initialized, but set all these attributes to empty. + # These will be updated when an instance class is selected packed_modules_mapping = {} supported_lora_modules = [] embedding_modules = {} @@ -1146,9 +1147,18 @@ def __new__( prefix: str = "", ) -> QWenBaseModel: config = vllm_config.model_config.hf_config + # Initialize VL - if hasattr(config, "visual"): - return QWenVL(vllm_config=vllm_config, prefix=prefix) + if hasattr(config, "visual"): # noqa: SIM108 + instance_cls = QWenVL # Initialize LLM else: - return QWenLLM(vllm_config=vllm_config, prefix=prefix) + instance_cls = QWenLLM + + # quant_config references base class members, + # so update values before init is called + cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) + cls.supported_lora_modules += instance_cls.supported_lora_modules + cls.embedding_modules.update(instance_cls.embedding_modules) + cls.embedding_padding_modules += instance_cls.embedding_padding_modules + return instance_cls(vllm_config=vllm_config, prefix=prefix)