From 5dba73cad2ab1d5b17312abaafa6a09431d4f3ed Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Sep 2024 20:47:24 +0000 Subject: [PATCH] initial commit --- .../modifiers/quantization/gptq/base.py | 21 +++++++++++++++- .../quantization/quantization/base.py | 25 +++++++++++++++++-- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index ebe826768..1807c9df4 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -8,11 +8,12 @@ freeze_module_quantization, ) from loguru import logger -from pydantic import Field +from pydantic import Field, field_validator from torch.nn import Module from llmcompressor.core.state import State from llmcompressor.modifiers import Modifier, ModifierFactory +from llmcompressor.modifiers.quantization import CONTIGUOUS_ACTIVATION_ORDERINGS from llmcompressor.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward @@ -110,6 +111,24 @@ class GPTQModifier(Modifier): compressible_layers_: Optional[List] = None quantization_modifier_: Any = None + @field_validator("config_groups", mode="after") + def validate_config_groups(self, config_groups) -> Optional[Dict[str, QuantizationScheme]]: + if config_groups is not None: + for value in config_groups.values(): + subfields = [value.weights, value.input_activations, value.output_activations] + for quant_args in subfields: + if quant_args is not None: + if quant_args.actorder in CONTIGUOUS_ACTIVATION_ORDERINGS: + if quant_args.contiguous_groups == False: + raise ValueError(f"Cannot set contiguous_groups={quant_args.contiguous_groups} for activation ordering {quant_args.actorder}") + quant_args.contiguous_groups = True + else: + if quant_args.contiguous_groups == True: + raise ValueError(f"Cannot set contiguous_groups={quant_args.contiguous_groups} for activation ordering {quant_args.actorder}") + quant_args.contiguous_groups = False + + return config_groups + def on_initialize_structure(self, state: State, **kwargs): """ Check the model's quantization state matches that expected by this modifier, diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index bcca9853b..c9aec908f 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -10,17 +10,19 @@ is_preset_scheme, preset_name_to_scheme, set_module_for_calibration, + ActivationOrdering, ) from compressed_tensors.quantization.observers.helpers import get_observer_token_count from loguru import logger -from pydantic import Field +from pydantic import Field, field_validator from torch.nn import Module from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward -__all__ = ["QuantizationModifier"] +__all__ = ["QuantizationModifier", "CONTIGUOUS_ACTIVATION_ORDERINGS"] +CONTIGUOUS_ACTIVATION_ORDERINGS = {ActivationOrdering.GROUP} class QuantizationModifier(Modifier): @@ -69,6 +71,25 @@ class QuantizationModifier(Modifier): calibration_dataloader_: Any = None calibration_function_: Any = None + @field_validator("config_groups", mode="after") + def validate_config_groups(self, config_groups) -> Optional[Dict[str, QuantizationScheme]]: + if config_groups is not None: + for value in config_groups.values(): + subfields = [value.weights, value.input_activations, value.output_activations] + for quant_args in subfields: + if quant_args is not None: + if quant_args.actorder in CONTIGUOUS_ACTIVATION_ORDERINGS: + if quant_args.contiguous_groups == False: + raise ValueError(f"Cannot set contiguous_groups={quant_args.contiguous_groups} for activation ordering {quant_args.actorder}") + quant_args.contiguous_groups = True + else: + if quant_args.contiguous_groups == True: + raise ValueError(f"Cannot set contiguous_groups={quant_args.contiguous_groups} for activation ordering {quant_args.actorder}") + quant_args.contiguous_groups = False + + return config_groups + + def on_initialize_structure(self, state: State, **kwargs): pass