Skip to content

Commit f343336

Browse files
authored
Fix AOPerModuleConfig bug in skipping quantizing modules (#2135)
* Fix AOPerModuleConfig bug in skipping quantizing modules Summary: Previous logic is problematic in handling skipping modules (None setting), this PR fixes it. Test Plan: pytest test/quantization/test_quant_api.py -k test_ao_per_module_config_skip Reviewers: Subscribers: Tasks: Tags: * add IntxWeightOnlyConfig to torchao.quantization
1 parent 0810f57 commit f343336

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

test/quantization/test_quant_api.py

+12
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,18 @@ def test_ao_per_module_config_embedding_linear(self):
10051005
assert isinstance(model.emb.weight._layout, QDQLayout)
10061006
assert isinstance(model.linear.weight, LinearActivationQuantizedTensor)
10071007

1008+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1009+
def test_ao_per_module_config_skip(self):
1010+
config1 = Int4WeightOnlyConfig(group_size=32)
1011+
config = AOPerModuleConfig({"_default": config1, "linear2": None})
1012+
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
1013+
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
1014+
quantize_(model, config)
1015+
model(*example_inputs)
1016+
assert isinstance(model.linear1.weight, AffineQuantizedTensor)
1017+
assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout)
1018+
assert not isinstance(model.linear2.weight, AffineQuantizedTensor)
1019+
10081020

10091021
class TestMultiTensorFlow(TestCase):
10101022
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")

torchao/quantization/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
Int8DynamicActivationInt4WeightConfig,
5454
Int8DynamicActivationInt8WeightConfig,
5555
Int8WeightOnlyConfig,
56+
IntxWeightOnlyConfig,
5657
PlainLayout,
5758
TensorCoreTiledLayout,
5859
UIntXWeightOnlyConfig,
@@ -139,6 +140,7 @@
139140
"Float8StaticActivationFloat8WeightConfig",
140141
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
141142
"UIntXWeightOnlyConfig",
143+
"IntxWeightOnlyConfig",
142144
"FPXWeightOnlyConfig",
143145
"GemliteUIntXWeightOnlyConfig",
144146
"AOPerModuleConfig",

torchao/quantization/quant_api.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ def quantize_(
594594
595595
"""
596596
filter_fn = _is_linear if filter_fn is None else filter_fn
597+
597598
if isinstance(config, AOPerModuleConfig):
598599
_replace_with_custom_fn_if_matches_filter_with_name(
599600
model,
@@ -1975,18 +1976,19 @@ class AOPerModuleConfig(AOBaseConfig):
19751976
def _ao_per_module_config_handler(
19761977
module: torch.nn.Module, module_fqn: str, config: AOPerModuleConfig
19771978
):
1978-
c = config.module_fqn_to_config.get(module_fqn, None)
1979-
# Maybe: we can add module type specific config in the future, in needed
1980-
# fallback to use default if no module specific config is provided
1981-
default_c = config.module_fqn_to_config.get("_default", None)
1982-
if default_c is not None and c is None:
1983-
c = default_c
1979+
c = None
1980+
if module_fqn in config.module_fqn_to_config:
1981+
# Maybe: we can add module type specific config in the future, in needed
1982+
c = config.module_fqn_to_config[module_fqn]
1983+
else:
1984+
# fallback to use default if no module specific config is provided
1985+
c = config.module_fqn_to_config.get("_default", None)
19841986

19851987
if c is not None:
19861988
handler = _QUANTIZE_CONFIG_HANDLER[type(c)]
19871989
return handler(module, c)
19881990

1989-
return handler(module, c)
1991+
return module
19901992

19911993

19921994
if TORCH_VERSION_AT_LEAST_2_5:

0 commit comments

Comments
 (0)