From 6cc94bd5c88aacd00526ddbb46fea2162ee6b33b Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 8 Nov 2024 22:36:46 -0500 Subject: [PATCH] [Bugfix] Ignore GPTQ quantization of Qwen2-VL visual module (#10169) Signed-off-by: mgoin Signed-off-by: Jee Jee Li --- vllm/model_executor/models/qwen2_vl.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 8073c5f4b2fd2..8dd75c9ee7e7b 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -51,7 +51,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization import (GPTQConfig, + GPTQMarlinConfig, + QuantizationConfig) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -982,7 +984,7 @@ def __init__(self, self.visual = Qwen2VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, + quant_config=self._maybe_ignore_quant_config(quant_config), prefix="visual", ) @@ -1008,6 +1010,14 @@ def __init__(self, make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid vision encoder sections for some models. + # See: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4 + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config + def _validate_and_reshape_mm_tensor(self, mm_input: Union[torch.Tensor, List[torch.Tensor]],