diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 397283a59b..457e3a060f 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -15,6 +15,9 @@ from torchao.dtypes import ( TensorCoreTiledLayoutType, ) +from torchao.quantization.prototype.qat.api import ( + ComposableQATQuantizer, +) from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, ) @@ -34,6 +37,9 @@ MappingType, ZeroPointDomain, ) +from torchao.quantization.unified import ( + TwoStepQuantizer, +) from torchao.quantization.utils import ( get_group_qparams_symmetric, get_groupwise_affine_qparams, @@ -626,6 +632,42 @@ def test_qat_4w_quantizer_module_swap(self): module_swap_out = module_swap_model(*x2) torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + class _MyQATQuantizer(TwoStepQuantizer): + """ + Dummy quantizer that attaches a certain value to each nn.Linear's + `_temp_quantizer_values` attribute. + """ + ATTR_NAME = "_temp_quantizer_values" + + def __init__(self, value: str): + self.value = value + + def _attach_value(self, module: torch.nn.Module): + if isinstance(module, torch.nn.Linear): + if not hasattr(module, self.ATTR_NAME): + setattr(module, self.ATTR_NAME, []) + getattr(module, self.ATTR_NAME).append(self.value) + + def prepare(self, model: torch.nn.Module): + model.apply(self._attach_value) + return model + + def convert(self, model: torch.nn.Module): + model.apply(self._attach_value) + return model + + def test_composable_qat_quantizer(self): + quantizer1 = self._MyQATQuantizer("quantizer1") + quantizer2 = self._MyQATQuantizer("quantizer2") + composable_quantizer = ComposableQATQuantizer([quantizer1, quantizer2]) + model = M() + model = composable_quantizer.prepare(model) + self.assertTrue(hasattr(model.linear1, self._MyQATQuantizer.ATTR_NAME)) + values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME) + self.assertEqual(values_list, ["quantizer1", "quantizer2"]) + composable_quantizer.convert(model) + values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME) + self.assertEqual(values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"]) if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/prototype/qat/__init__.py b/torchao/quantization/prototype/qat/__init__.py index c16b3ead44..9f8dd74e44 100644 --- a/torchao/quantization/prototype/qat/__init__.py +++ b/torchao/quantization/prototype/qat/__init__.py @@ -5,6 +5,7 @@ enable_8da4w_fake_quant, int4_weight_only_fake_quantize, int8_dynamic_activation_int4_weight_fake_quantize, + ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATQuantizer, ) @@ -20,6 +21,7 @@ "enable_8da4w_fake_quant", "int4_weight_only_fake_quantize", "int8_dynamic_activation_int4_weight_fake_quantize", + "ComposableQATQuantizer", "Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer", "Int8DynActInt4WeightQATLinear", diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index 2f3368ff1c..e1c5221e1e 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Optional +from typing import Any, List, Optional import torch import torch.nn.functional as F @@ -34,6 +34,44 @@ ) +class ComposableQATQuantizer(TwoStepQuantizer): + """ + Composable quantizer that users can use to apply multiple QAT quantizers easily. + Quantizers will be applied in the order they are specified in the constructor. + + Note: the quantizers provided must apply to different modules in the model, + e.g. nn.Linear and nn.Embedding, otherwise the behavior will be undefined. + + Example usage:: + + my_quantizer = ComposableQATQuantizer([ + QATQuantizer1(), + QATQuantizer2(), + QATQuantizer3(), + ]) + model = my_quantizer.prepare(model) + train(model) + model = my_quantizer.convert(model) + """ + + def __init__(self, quantizers: List[TwoStepQuantizer]): + self.quantizers = quantizers + + def prepare( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + for quantizer in self.quantizers: + model = quantizer.prepare(model) + return model + + def convert( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + for quantizer in self.quantizers: + model = quantizer.convert(model) + return model + + # ================= # | 8da4w QAT | # ================= @@ -44,7 +82,8 @@ def int8_dynamic_activation_int4_weight_fake_quantize(group_size=32): int4 per group weight symmetric fake quantization to linear. Please see :func:`~torchao.quantization.int8_dynamic_activation_int4_weight` for more details. - Example usage: + Example usage:: + from torchao.quantization import quantize_ quantize_(model, int8_dynamic_activation_int4_weight_fake_quantize(group_size=32)) """ @@ -151,7 +190,8 @@ def int4_weight_only_fake_quantize(group_size=128): Applies uint4 weight-only asymmetric per-group fake quantization to linear layers. Please see :func:`~torchao.quantization.int4_weight_only` for more details. - Example usage: + Example usage:: + from torchao.quantization import quantize_ quantize_(model, int4_weight_only_fake_quantize(group_size=32)) """ diff --git a/torchao/quantization/unified.py b/torchao/quantization/unified.py index 7da915dec7..1bd62b8979 100644 --- a/torchao/quantization/unified.py +++ b/torchao/quantization/unified.py @@ -1,5 +1,5 @@ import torch -from typing import Any +from typing import Any, List from abc import ABC, abstractmethod """ @@ -17,7 +17,6 @@ class Quantizer(ABC): def quantize( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: - pass @@ -27,11 +26,10 @@ class TwoStepQuantizer: def prepare( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: - pass + @abstractmethod def convert( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: - pass