diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 6a5ea8ef9d..8e047985c5 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -72,7 +72,7 @@ AQInt8WeightOnlyQuantizedLinearWeight2, AQInt8WeightOnlyQuantizedLinearWeight3, AutoQuantizableLinearWeight, - + AQFloat8WeightOnlyQuantizedLinearWeight, ) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os @@ -98,6 +98,7 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def _int8wo_api(mod): if TORCH_VERSION_AT_LEAST_2_4: @@ -744,6 +745,14 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype ) + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf(not is_H100, "Need H100 to run") + def test_aq_float8_weight_only_quant_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype + ) + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index fa4ca36d85..089add1d87 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -501,7 +501,6 @@ def from_float(cls, weight): # AQInt8WeightOnlyQuantizedLinearWeight3, # TODO this gets picked in places where it makes perf worse, why? AQInt8DynamicallyQuantizedLinearWeight, - AQFloat8WeightOnlyQuantizedLinearWeight, ] DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ @@ -510,6 +509,11 @@ def from_float(cls, weight): AQInt4G64WeightOnlyQuantizedLinearWeight ] +OTHER_AUTOQUANT_CLASS_LIST = [ + AQFloat8WeightOnlyQuantizedLinearWeight, +] + + def _change_linears_to_autoquantizable(model, **kwargs): """ Converts all linear weight tensors to the @@ -634,6 +638,8 @@ def autoquant( if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() + if qtensor_class_list in OTHER_AUTOQUANT_CLASS_LIST: + assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9), "float8 requires CUDA arch >= 8.9" # perform initial swap from linear weights # to AutoQuantizableLinearWeight