diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c86559e62f94..22dd1b7ccea5 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3597,7 +3597,12 @@ def from_pretrained( ) else: config.quantization_config = quantization_config - hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=pre_quantized) + + hf_quantizer = AutoHfQuantizer.from_config( + config.quantization_config, + pre_quantized=pre_quantized, + ) + else: hf_quantizer = None @@ -4281,7 +4286,7 @@ def from_pretrained( dispatch_model(model, **device_map_kwargs) if hf_quantizer is not None: - hf_quantizer.postprocess_model(model) + hf_quantizer.postprocess_model(model, config=config) model.hf_quantizer = hf_quantizer if _adapter_model_path is not None: diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 38bebd2d8410..818072a0d916 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -173,13 +173,14 @@ def merge_quantization_configs( quantization_config = AutoQuantizationConfig.from_dict(quantization_config) if ( - isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config)) + isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config, CompressedTensorsConfig)) and quantization_config_from_args is not None ): # special case for GPTQ / AWQ / FbgemmFp8 config collision loading_attr_dict = quantization_config_from_args.get_loading_attributes() for attr, val in loading_attr_dict.items(): setattr(quantization_config, attr, val) + warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored." if warning_msg != "": diff --git a/src/transformers/quantizers/quantizer_awq.py b/src/transformers/quantizers/quantizer_awq.py index 0c14c236d260..7b81c93edf1f 100644 --- a/src/transformers/quantizers/quantizer_awq.py +++ b/src/transformers/quantizers/quantizer_awq.py @@ -111,7 +111,7 @@ def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwarg " Please double check your model architecture, or submit an issue on github if you think this is a bug." ) - def _process_model_after_weight_loading(self, model): + def _process_model_after_weight_loading(self, model, **kwargs): if self.quantization_config.do_fuse: from ..integrations import fuse_awq_modules diff --git a/src/transformers/quantizers/quantizer_compressed_tensors.py b/src/transformers/quantizers/quantizer_compressed_tensors.py index 61e940886d94..5064f2c019d7 100644 --- a/src/transformers/quantizers/quantizer_compressed_tensors.py +++ b/src/transformers/quantizers/quantizer_compressed_tensors.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import os + from ..utils import is_compressed_tensors_available, is_torch_available, logging -from ..utils.quantization_config import QuantizationConfigMixin +from ..utils.quantization_config import CompressedTensorsConfig from .base import HfQuantizer @@ -32,12 +35,13 @@ class CompressedTensorsHfQuantizer(HfQuantizer): requires_calibration = True required_packages = ["compressed_tensors"] - def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + def __init__(self, quantization_config: CompressedTensorsConfig, **kwargs): super().__init__(quantization_config, **kwargs) - from compressed_tensors.compressors import ModelCompressor self.compressor = ModelCompressor.from_compression_config(quantization_config) + self.run_compressed = quantization_config.run_compressed + self.quantization_config = quantization_config def validate_environment(self, *args, **kwargs): if not is_compressed_tensors_available(): @@ -63,20 +67,57 @@ def _process_model_before_weight_loading(self, model, **kwargs): from compressed_tensors.quantization import apply_quantization_config ct_quantization_config = self.compressor.quantization_config - apply_quantization_config(model, ct_quantization_config, run_compressed=True) - def _process_model_after_weight_loading(self, model, **kwargs) -> None: - pass + if self.run_compressed and self.is_quantization_compressed: + apply_quantization_config(model, ct_quantization_config, run_compressed=True) + elif not self.is_quantization_compressed: + apply_quantization_config(model, ct_quantization_config) + + def _process_model_after_weight_loading(self, model, **kwargs): + """Decompress loaded model if necessary - need for qat""" + + if (self.is_quantization_compressed and not self.run_compressed) or self.is_sparsification_compressed: + config = kwargs.get("config", None) + cache_path = config._name_or_path + + if not os.path.exists(cache_path): + from transformers.utils import cached_file + + config_file_path = cached_file(cache_path, "config.json") + cache_path = os.path.sep.join(config_file_path.split(os.path.sep)[:-1]) + + if self.is_quantization_compressed and not self.run_compressed: + from compressed_tensors.quantization import QuantizationStatus + + self.compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN + self.compressor.decompress(model_path=cache_path, model=model) @property - def is_trainable(self) -> bool: - """Models quantized using compressed tensors can be finetuned""" - return True + def is_quantization_compressed(self): + from compressed_tensors.quantization import QuantizationStatus + + return ( + self.quantization_config.quantization_config is not None + and self.quantization_config.quantization_config.quantization_status == QuantizationStatus.COMPRESSED + ) + + @property + def is_sparsification_compressed(self): + from compressed_tensors.config.base import CompressionFormat + + return ( + self.quantization_config.sparsity_config is not None + and self.quantization_config.sparsity_config.format != CompressionFormat.dense.value + ) @property + def is_trainable(self): + return True + def is_qat_trainable(self) -> bool: """Loaded Models can carry out quantization aware training""" - return True + # models need to be decompressed carry out qat + return not self.run_compressed or not self.is_quantization_compressed def is_serializable(self, safe_serialization=None) -> bool: """Models quantized using compressed tensors can be saved to disk""" diff --git a/src/transformers/quantizers/quantizer_quanto.py b/src/transformers/quantizers/quantizer_quanto.py index d91019dea152..230e8efe1506 100644 --- a/src/transformers/quantizers/quantizer_quanto.py +++ b/src/transformers/quantizers/quantizer_quanto.py @@ -197,7 +197,7 @@ def _process_model_before_weight_loading( ) model.config.quantization_config = self.quantization_config - def _process_model_after_weight_loading(self, model): + def _process_model_after_weight_loading(self, model, **kwargs): return model @property diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index e6c2dc1ce36b..10d2b184ef14 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -195,7 +195,7 @@ def create_quantized_param( module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) quantize_(module, self.quantization_config.get_apply_tensor_subclass()) - def _process_model_after_weight_loading(self, model): + def _process_model_after_weight_loading(self, model, **kwargs): """No process required for torchao quantized model""" return diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index bacbca94cd82..253cc4a06210 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1077,7 +1077,8 @@ class CompressedTensorsConfig(QuantizationConfigMixin): config_groups (`typing.Dict[str, typing.Union[ForwardRef('QuantizationScheme'), typing.List[str]]]`, *optional*): dictionary mapping group name to a quantization scheme definition format (`str`, *optional*, defaults to `"dense"`): - format the model is represented as + format the model is represented as. Set `run_compressed` True to execute model as the + compressed format if not `dense` quantization_status (`QuantizationStatus`, *optional*, defaults to `"initialized"`): status of model in the quantization lifecycle, ie 'initialized', 'calibration', 'frozen' kv_cache_scheme (`typing.Union[QuantizationArgs, NoneType]`, *optional*): @@ -1090,6 +1091,8 @@ class CompressedTensorsConfig(QuantizationConfigMixin): configuration for sparsity compression quant_method (`str`, *optional*, defaults to `"compressed-tensors"`): do not override, should be compressed-tensors + run_compressed (`bool`, *optional*, defaults to `True`): alter submodules (usually linear) in order to + emulate compressed model execution if True, otherwise use default submodule """ def __init__( @@ -1102,14 +1105,17 @@ def __init__( ignore: Optional[List[str]] = None, sparsity_config: Dict[str, Any] = None, quant_method: str = "compressed-tensors", + run_compressed: bool = True, **kwargs, ): - from compressed_tensors import QuantizationConfig from compressed_tensors.config import SparsityCompressionConfig + from compressed_tensors.quantization import QuantizationConfig self.quantization_config = None self.sparsity_config = None + self.run_compressed = run_compressed + # parse from dict to load nested QuantizationScheme objects if config_groups or kv_cache_scheme: self.quantization_config = QuantizationConfig.parse_obj( @@ -1121,6 +1127,7 @@ def __init__( "kv_cache_scheme": kv_cache_scheme, "global_compression_ratio": global_compression_ratio, "ignore": ignore, + "run_compressed": run_compressed, **kwargs, } ) @@ -1149,6 +1156,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): Returns: [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters. + """ if "quantization_config" in config_dict: @@ -1200,6 +1208,9 @@ def to_diff_dict(self) -> Dict[str, Any]: return serializable_config_dict + def get_loading_attributes(self): + return {"run_compressed": self.run_compressed} + @dataclass class FbgemmFp8Config(QuantizationConfigMixin): diff --git a/tests/quantization/compressed_tensor/test_load_sparse_model.py b/tests/quantization/compressed_tensor/test_load_sparse_model.py new file mode 100644 index 000000000000..8992cd3d9bd4 --- /dev/null +++ b/tests/quantization/compressed_tensor/test_load_sparse_model.py @@ -0,0 +1,80 @@ +import gc +import unittest + +from transformers import AutoModelForCausalLM +from transformers.testing_utils import require_compressed_tensors, require_torch +from transformers.utils import is_torch_available + + +if is_torch_available(): + import torch + + +@require_compressed_tensors +@require_torch +class CompressedTensorsTest(unittest.TestCase): + model_sparse_uncompressed = "horheynm/llama2.c_stories15M_pruned_50.2of4_uncompressed" + model_sparse_compressed = "horheynm/llama2.c_stories15M_pruned_50.2of4_compressed" + + prompt = "Paris is the capital of which country?" + + stubs = [model_sparse_uncompressed, model_sparse_compressed] + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def test_compressed_uncompressed_model_shapes(self): + """ + Check that the weights are the same between + uncompressed and compressed-decompressed model + Sparse compressed modules' weights are "packed" and shape/value will + differ + """ + + def _has_nested_attr(obj, attr_path): + attrs = attr_path.split(".") + for attr in attrs: + if not hasattr(obj, attr): + return None + obj = getattr(obj, attr) + return obj + + from compressed_tensors.quantization.utils import iter_named_leaf_modules + + uncompressed_model = AutoModelForCausalLM.from_pretrained( + self.model_sparse_uncompressed, + ) + + compressed_model_decompressed = AutoModelForCausalLM.from_pretrained( + self.model_sparse_compressed, + ) + + for name, submodule in iter_named_leaf_modules( + uncompressed_model, + ): + if comp_decomp_obj := _has_nested_attr(compressed_model_decompressed, name): + if hasattr(submodule, "weight"): + assert torch.equal(submodule.weight, comp_decomp_obj.weight) + + def test_run_compressed_outputs_match(self): + """Check that uncompressed and compressed-decompressed model outputs are the same""" + + from transformers import AutoTokenizer + + for stub in self.stubs: + tokenizer = AutoTokenizer.from_pretrained(stub) + input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids + + uncompressed_model = AutoModelForCausalLM.from_pretrained( + self.model_sparse_uncompressed, + ) + output_rc_true = uncompressed_model.generate(input_ids, max_new_tokens=100) + + compressed_model_decompressed = AutoModelForCausalLM.from_pretrained( + self.model_sparse_compressed, + ) + output_rc_false = compressed_model_decompressed.generate(input_ids, max_new_tokens=100) + + assert tokenizer.decode(output_rc_true[0]) == tokenizer.decode(output_rc_false[0]) diff --git a/tests/quantization/compressed_tensor/test_run_compressed_model.py b/tests/quantization/compressed_tensor/test_run_compressed_model.py new file mode 100644 index 000000000000..b168ca382cce --- /dev/null +++ b/tests/quantization/compressed_tensor/test_run_compressed_model.py @@ -0,0 +1,94 @@ +import gc +import unittest + +from transformers import AutoModelForCausalLM +from transformers.testing_utils import require_compressed_tensors, require_torch +from transformers.utils import is_torch_available + + +if is_torch_available(): + import torch + + +@require_compressed_tensors +@require_torch +class CompressedTensorsTest(unittest.TestCase): + tinyllama_w4a16 = "nm-testing/tinyllama-w4a16-compressed-hf-quantizer" + tinyllama_w8a8 = "nm-testing/tinyllama-w8a8-compressed-hf-quantizer" + + prompt = "Paris is the capital of which country?" + + stubs = [tinyllama_w4a16, tinyllama_w8a8] + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def test_default_run_compressed__True(self): + from compressed_tensors.linear.compressed_linear import CompressedLinear + from compressed_tensors.quantization.utils import iter_named_leaf_modules + + for stub in self.stubs: + model = AutoModelForCausalLM.from_pretrained( + stub, + ) + compressed_linear_counts = 0 + + for _, submodule in iter_named_leaf_modules( + model, + ): + if isinstance(submodule, CompressedLinear): + compressed_linear_counts += 1 + + # some linear models are not compressed - ex. lm_head + assert compressed_linear_counts > 0 + + def test_default_run_compressed__False(self): + from compressed_tensors.linear.compressed_linear import CompressedLinear + from compressed_tensors.quantization.utils import iter_named_leaf_modules + + from transformers.utils.quantization_config import CompressedTensorsConfig + + quantization_config = CompressedTensorsConfig(run_compressed=False) + + for stub in self.stubs: + model = AutoModelForCausalLM.from_pretrained( + stub, + quantization_config=quantization_config, + ) + compressed_linear_counts = 0 + + for _, submodule in iter_named_leaf_modules( + model, + ): + if isinstance(submodule, CompressedLinear): + compressed_linear_counts += 1 + + # No modules should be CompressedLinear + assert compressed_linear_counts == 0 + + def test_run_compressed_outputs_match(self): + """Check that run_compressed=True/False output are the same""" + + from transformers import AutoTokenizer + from transformers.utils.quantization_config import CompressedTensorsConfig + + quantization_config = CompressedTensorsConfig(run_compressed=False) + + for stub in self.stubs: + tokenizer = AutoTokenizer.from_pretrained(stub) + input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids + + model_run_compressed__True = AutoModelForCausalLM.from_pretrained( + stub, + ) + output_rc_true = model_run_compressed__True.generate(input_ids, max_new_tokens=100) + + model_run_compressed__False = AutoModelForCausalLM.from_pretrained( + stub, + quantization_config=quantization_config, + ) + output_rc_false = model_run_compressed__False.generate(input_ids, max_new_tokens=100) + + assert tokenizer.decode(output_rc_true[0]) == tokenizer.decode(output_rc_false[0])