diff --git a/vllm/config.py b/vllm/config.py index 8b93bf4f0ed53..405787d64e374 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -156,6 +156,17 @@ def _verify_embedding_mode(self) -> None: self.embedding_mode = any( ModelRegistry.is_embedding_model(arch) for arch in architectures) + def _parse_quant_hf_config(self): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is None: + # SparseML uses a "compression_config" with a "quantization_config". + compression_cfg = getattr(self.hf_config, "compression_config", + None) + if compression_cfg is not None: + quant_cfg = compression_cfg.get("quantization_config", None) + + return quant_cfg + def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = ["gptq", "squeezellm"] @@ -163,12 +174,13 @@ def _verify_quantization(self) -> None: self.quantization = self.quantization.lower() # Parse quantization method from the HF model config, if available. - quant_cfg = getattr(self.hf_config, "quantization_config", None) + quant_cfg = self._parse_quant_hf_config() + if quant_cfg is not None: quant_method = quant_cfg.get("quant_method", "").lower() # Detect which checkpoint is it - for name, method in QUANTIZATION_METHODS.items(): + for _, method in QUANTIZATION_METHODS.items(): quantization_override = method.override_quantization_method( quant_cfg, self.quantization) if quantization_override: