From 352211bc50a12fba3c28e0f3cdf523f4a31e37a3 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Wed, 11 Dec 2024 00:13:42 -0800 Subject: [PATCH 1/7] Use int_scaled_matmul with asymmetrically quantized activation and symmetrically quantized weights --- test/integration/test_integration.py | 23 +++++--- torchao/dtypes/affine_quantized_tensor_ops.py | 15 ++++- torchao/dtypes/uintx/plain_layout.py | 59 ++++++++++++++++++- 3 files changed, 83 insertions(+), 14 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 6aae8b2e31..827a6ed2e1 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -105,7 +105,9 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] -COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() +MAPPING_TYPES = [MappingType.SYMMETRIC, MappingType.ASYMMETRIC] + +COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES, MAPPING_TYPES)).copy() def _int8wo_api(mod): if TORCH_VERSION_AT_LEAST_2_4: @@ -125,9 +127,9 @@ def _int8wo_groupwise_api(mod): group_size = 32 quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False) -def _int8da_int8w_api(mod): +def _int8da_int8w_api(mod, act_mapping_type=MappingType.SYMMETRIC): if TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) + quantize_(mod, int8_dynamic_activation_int8_weight(act_mapping_type=act_mapping_type), set_inductor_config=False) if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(mod) else: @@ -860,7 +862,7 @@ def _test_lin_weight_subclass_api_impl( test_device, min_sqnr=35, test_dtype=torch.bfloat16, - test_shape=(32, 64, 32) + test_shape=(32, 4096, 14336) ): m, k, n = test_shape x = torch.randn(m, k, device=test_device, dtype=test_dtype) @@ -871,23 +873,26 @@ def _test_lin_weight_subclass_api_impl( api(mod) test = mod(x) + self.assertGreater( SQNR(ref_f, test), - min_sqnr, f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}" + min_sqnr, f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}" ) mod_qc = torch.compile(mod, mode="max-autotune") test_comp = mod_qc(x) self.assertGreater( SQNR(ref_f, test_comp), min_sqnr, - f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}" + f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}" ) - @parameterized.expand(COMMON_DEVICE_DTYPE) - def test_int8_dynamic_quant_subclass_api(self, device, dtype): + @parameterized.expand(list(itertools.product(COMMON_DEVICES, COMMON_DTYPES, MAPPING_TYPES))) + def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): + from functools import partial + api = partial(_int8da_int8w_api, act_mapping_type=act_mapping) self._test_lin_weight_subclass_api_impl( - _int8da_int8w_api, device, 35, test_dtype=dtype + api, device, 35, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 8938e7472c..b2681a001e 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -32,8 +32,10 @@ PlainAQTTensorImpl, _linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl, - _linear_int8_act_int8_weight_check, - _linear_int8_act_int8_weight_impl, + _linear_sym_int8_act_sym_int8_weight_check, + _linear_sym_int8_act_sym_int8_weight_impl, + _linear_asym_int8_act_sym_int8_weight_check, + _linear_asym_int8_act_sym_int8_weight_impl ) from torchao.dtypes.uintx.semi_sparse_layout import ( _linear_int8_act_int8_weight_semi_structured_sparse_check, @@ -110,7 +112,14 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): # so that these can be shared by F.linear, aten.mm, aten.addmm dispatches def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ - (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), + ( + _linear_sym_int8_act_sym_int8_weight_check, + _linear_sym_int8_act_sym_int8_weight_impl + ), + ( + _linear_asym_int8_act_sym_int8_weight_check, + _linear_asym_int8_act_sym_int8_weight_impl + ), ( _linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl, diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index ed171634cd..c29574802e 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -220,7 +220,7 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): return y -def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): +def _linear_sym_int8_act_sym_int8_weight_check(input_tensor, weight_tensor, bias): return ( isinstance(input_tensor, AffineQuantizedTensor) and _aqt_is_int8_reduced_range(input_tensor) @@ -231,7 +231,7 @@ def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): ) -def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): +def _linear_sym_int8_act_sym_int8_weight_impl(input_tensor, weight_tensor, bias): # # 1. do the matrix form of dot(X_i, W_j) # @@ -266,3 +266,58 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): if bias is not None: y += bias return y + + +def _linear_asym_int8_act_sym_int8_weight_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8(input_tensor) + # ZeroPointDomain.NONE did not work for weight in int8_dynamic_activation_int8_weight + # Uncommenting the next line works in eager mode, but Dynamo runs into a problem with it. + #and torch.equal(weight_tensor.tensor_impl.zero_point, torch.zeros_like(weight_tensor.tensor_impl.zero_point)) + and isinstance(weight_tensor, AffineQuantizedTensor) + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, PlainLayout) + ) + + +def _linear_asym_int8_act_sym_int8_weight_impl(input_tensor, weight_tensor, bias): + # + # 1. do the matrix form of dot(X_i, W_j) + # + # + # 2. rescale the output and apply compensation for zero point of A + # + # + # in cases with large matrices, y_dot_int32 can grow sufficiently + # large that y_dot_int32 * a float16 scale is greater than the maximum + # value of a float 16, (which results in a value of inf even if multiplying + # by the other scale would bring it within the expected range) + x_vals_int8 = input_tensor.tensor_impl.int_data + x_zps = input_tensor.tensor_impl.zero_point.reshape(-1, 1) + x_scales = input_tensor.tensor_impl.scale.reshape(-1, 1) + w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() + w_scales = weight_tensor.tensor_impl.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + x_scales_dtype = x_scales.dtype + # Cast fp16 scale to float to avoid overflow in int_scaled_matmul + intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype + y_dot_scaled = int_scaled_matmul( + tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype) + ) + y_dot_scaled = y_dot_scaled.to(x_scales_dtype) * w_scales + + # Compute compensation + w_col_sum = weight_tensor.tensor_impl.int_data.contiguous().t().to(torch.float).sum(dim=0) + a_compensation = ((x_scales * w_scales) * x_zps.to(intermediate_dtype)) * w_col_sum + y = (y_dot_scaled - a_compensation).reshape( + *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] + ) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y From 082a565cc9120b68a36a001985c79f80299a8594 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Thu, 12 Dec 2024 15:19:19 -0800 Subject: [PATCH 2/7] Revise as per review suggestions --- test/integration/test_integration.py | 33 ++++++++++++++++------- torchao/dtypes/affine_quantized_tensor.py | 2 +- torchao/dtypes/uintx/plain_layout.py | 31 ++++++++++----------- torchao/quantization/quant_api.py | 5 +++- torchao/quantization/quant_primitives.py | 4 +-- 5 files changed, 44 insertions(+), 31 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 827a6ed2e1..27b6eecaa2 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -42,6 +42,7 @@ quantize_affine, dequantize_affine, MappingType, + ZeroPointDomain, ) from torchao.quantization.utils import ( dequantize_per_channel, @@ -105,9 +106,11 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] -MAPPING_TYPES = [MappingType.SYMMETRIC, MappingType.ASYMMETRIC] +ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC] -COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES, MAPPING_TYPES)).copy() +WEIGHT_ZERO_POINT_DOMAINS = [ZeroPointDomain.NONE, ZeroPointDomain.INT] + +COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() def _int8wo_api(mod): if TORCH_VERSION_AT_LEAST_2_4: @@ -127,9 +130,16 @@ def _int8wo_groupwise_api(mod): group_size = 32 quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False) -def _int8da_int8w_api(mod, act_mapping_type=MappingType.SYMMETRIC): +def _int8da_int8w_api(mod, act_mapping_type=MappingType.SYMMETRIC, weight_zero_point_domain=ZeroPointDomain.INT): if TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int8_dynamic_activation_int8_weight(act_mapping_type=act_mapping_type), set_inductor_config=False) + quantize_( + mod, + int8_dynamic_activation_int8_weight( + act_mapping_type=act_mapping_type, + weight_zp_domain=weight_zero_point_domain + ), + set_inductor_config=False + ) if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(mod) else: @@ -886,14 +896,17 @@ def _test_lin_weight_subclass_api_impl( f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}" ) - - @parameterized.expand(list(itertools.product(COMMON_DEVICES, COMMON_DTYPES, MAPPING_TYPES))) - def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): + @parameterized.expand( + list(itertools.product(COMMON_DEVICES, COMMON_DTYPES, ACT_MAPPING_TYPES, WEIGHT_ZERO_POINT_DOMAINS)) + ) + def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping, weight_zero_point_domain): from functools import partial - api = partial(_int8da_int8w_api, act_mapping_type=act_mapping) - self._test_lin_weight_subclass_api_impl( - api, device, 35, test_dtype=dtype + api = partial( + _int8da_int8w_api, + act_mapping_type=act_mapping, + weight_zero_point_domain=weight_zero_point_domain ) + self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "broken in fbcode") diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 7aca25ecc5..88d4a9dfde 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -251,7 +251,7 @@ def from_hp_to_intx( zero_point_domain, ) # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None - if zero_point_domain is None: + if zero_point_domain == ZeroPointDomain.NONE: zero_point = None data = quantize_affine( input_float, diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index c29574802e..24720b4f43 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -38,7 +38,7 @@ def __new__( cls, int_data: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: Optional[torch.Tensor], _layout: Layout, ): kwargs = {} @@ -55,7 +55,7 @@ def __init__( self, int_data: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: Optional[torch.Tensor], _layout: Layout, ): self.int_data = int_data @@ -64,7 +64,10 @@ def __init__( self._layout = _layout def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point"], [self._layout] + if self.zero_point is not None: + return ["int_data", "scale", "zero_point"], [self._layout] + else: + return ["int_data", "scale"], [self._layout] @classmethod def __tensor_unflatten__( @@ -73,7 +76,7 @@ def __tensor_unflatten__( int_data, scale, zero_point = ( tensor_data_dict["int_data"], tensor_data_dict["scale"], - tensor_data_dict["zero_point"], + tensor_data_dict.get("zero_point", None), ) (_layout,) = tensor_attributes return cls(int_data, scale, zero_point, _layout) @@ -83,7 +86,7 @@ def to(self, *args, **kwargs): return self.__class__( self.int_data.to(kwargs["device"]), self.scale.to(kwargs["device"]), - self.zero_point.to(kwargs["device"]), + self.zero_point.to(kwargs["device"]) if self.zero_point is not None else None, self._layout, ) @@ -91,7 +94,7 @@ def _apply_fn_to_data(self, fn): return self.__class__( fn(self.int_data), fn(self.scale), - fn(self.zero_point), + fn(self.zero_point) if self.zero_point is not None else None, self._layout, ) @@ -134,7 +137,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return PlainAQTTensorImpl( aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), - self.zero_point.view(-1), + self.zero_point.view(-1) if self.zero_point is not None else None, self._layout, ) else: @@ -148,7 +151,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): __torch_function__ = torch._C._disabled_torch_function_impl - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return self.int_data, self.scale, self.zero_point def get_layout(self) -> Layout: @@ -272,9 +275,7 @@ def _linear_asym_int8_act_sym_int8_weight_check(input_tensor, weight_tensor, bia return ( isinstance(input_tensor, AffineQuantizedTensor) and _aqt_is_int8(input_tensor) - # ZeroPointDomain.NONE did not work for weight in int8_dynamic_activation_int8_weight - # Uncommenting the next line works in eager mode, but Dynamo runs into a problem with it. - #and torch.equal(weight_tensor.tensor_impl.zero_point, torch.zeros_like(weight_tensor.tensor_impl.zero_point)) + and weight_tensor.zero_point_domain == ZeroPointDomain.NONE and isinstance(weight_tensor, AffineQuantizedTensor) and input_tensor.dtype == weight_tensor.dtype and isinstance(input_tensor._layout, PlainLayout) @@ -289,11 +290,6 @@ def _linear_asym_int8_act_sym_int8_weight_impl(input_tensor, weight_tensor, bias # # 2. rescale the output and apply compensation for zero point of A # - # - # in cases with large matrices, y_dot_int32 can grow sufficiently - # large that y_dot_int32 * a float16 scale is greater than the maximum - # value of a float 16, (which results in a value of inf even if multiplying - # by the other scale would bring it within the expected range) x_vals_int8 = input_tensor.tensor_impl.int_data x_zps = input_tensor.tensor_impl.zero_point.reshape(-1, 1) x_scales = input_tensor.tensor_impl.scale.reshape(-1, 1) @@ -309,8 +305,9 @@ def _linear_asym_int8_act_sym_int8_weight_impl(input_tensor, weight_tensor, bias y_dot_scaled = y_dot_scaled.to(x_scales_dtype) * w_scales # Compute compensation - w_col_sum = weight_tensor.tensor_impl.int_data.contiguous().t().to(torch.float).sum(dim=0) + w_col_sum = w_vals_int8_t.to(torch.float).sum(dim=0) a_compensation = ((x_scales * w_scales) * x_zps.to(intermediate_dtype)) * w_col_sum + y = (y_dot_scaled - a_compensation).reshape( *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 96ccb1889c..87d9f06f87 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -739,7 +739,9 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: def int8_dynamic_activation_int8_weight( - layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC + layout=PlainLayout(), + act_mapping_type=MappingType.SYMMETRIC, + weight_zp_domain=ZeroPointDomain.INT ): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight @@ -781,6 +783,7 @@ def get_weight_block_size(x): eps=eps, zero_point_dtype=zero_point_dtype, _layout=layout, + zero_point_domain=weight_zp_domain ) weight = to_linear_activation_quantized(weight, input_quant_func) return weight diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 37aa609b9b..804c6534ad 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -888,11 +888,11 @@ def _choose_qparams_affine( "preserve_zero == False is not supported for symmetric quantization" ) if ( - zero_point_domain is not None + zero_point_domain != ZeroPointDomain.NONE.name and zero_point_domain != ZeroPointDomain.INT.name ): raise ValueError( - "zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" + "Only ZeroPointDomain.NONE and ZeroPointDomain.INT are supported for symmetric quantization" ) scale = torch.clamp(scale, min=eps) zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) From a434dcac36e3aa26dbb9170282fa26003b16ca1e Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Thu, 12 Dec 2024 16:17:42 -0800 Subject: [PATCH 3/7] Add UT for ZeroPointDomain.NONE --- test/quantization/test_quant_primitives.py | 26 ++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index a3fef29fea..9cf32a0b1b 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -799,6 +799,32 @@ def test_fake_quantize_affine_cachemask(self): torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) + # ZeroPointDomain.NONE should work + def test_none_zero_point_domain(self): + input = torch.randn(10, 256) + n_bit = 8 + mapping_type = MappingType.SYMMETRIC + dtype = torch.int8 + block_size = (1, 128) + quant_min = None + quant_max = None + eps = 1e-6 + scale_dtype = torch.float32 + zero_point_dtype = torch.int64 + _, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=True, + zero_point_domain=ZeroPointDomain.NONE, + ) + self.assertTrue(torch.equal(zero_point, torch.zeros_like(zero_point))) if __name__ == "__main__": unittest.main() From 19ce3b8f4b9fcdfff8b377d85cb11b8d6e09c91c Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Thu, 12 Dec 2024 16:49:31 -0800 Subject: [PATCH 4/7] Support both None and ZeroPointDomain.NONE --- test/quantization/test_quant_primitives.py | 2 +- torchao/quantization/quant_primitives.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 9cf32a0b1b..17e2955322 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -824,7 +824,7 @@ def test_none_zero_point_domain(self): preserve_zero=True, zero_point_domain=ZeroPointDomain.NONE, ) - self.assertTrue(torch.equal(zero_point, torch.zeros_like(zero_point))) + self.assertTrue(zero_point is None) if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 804c6534ad..e0350f0142 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -889,13 +889,18 @@ def _choose_qparams_affine( ) if ( zero_point_domain != ZeroPointDomain.NONE.name + and zero_point_domain != None and zero_point_domain != ZeroPointDomain.INT.name ): raise ValueError( - "Only ZeroPointDomain.NONE and ZeroPointDomain.INT are supported for symmetric quantization" + "Except a None value for zero_point_domain, Only ZeroPointDomain.NONE and ZeroPointDomain.INT" + " are supported for symmetric quantization." ) + if zero_point_domain == ZeroPointDomain.NONE.name: + zero_point = None + else: + zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) scale = torch.clamp(scale, min=eps) - zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) else: assert mapping_type == MappingType.ASYMMETRIC.name scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) From cd51793510e61fd99393d217df3b081eccc02b28 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Thu, 12 Dec 2024 17:00:53 -0800 Subject: [PATCH 5/7] Use smaller input shapes in UT --- test/integration/test_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 27b6eecaa2..dbfc4aebe6 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -872,7 +872,7 @@ def _test_lin_weight_subclass_api_impl( test_device, min_sqnr=35, test_dtype=torch.bfloat16, - test_shape=(32, 4096, 14336) + test_shape=(32, 64, 32) ): m, k, n = test_shape x = torch.randn(m, k, device=test_device, dtype=test_dtype) From 9bc2f3e6cee596805e50ffc44c0196e8a4960e94 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Thu, 12 Dec 2024 17:18:55 -0800 Subject: [PATCH 6/7] Add support for both None and ZeroPointDomain.NONE --- torchao/dtypes/affine_quantized_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 88d4a9dfde..d30718853e 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -251,7 +251,7 @@ def from_hp_to_intx( zero_point_domain, ) # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None - if zero_point_domain == ZeroPointDomain.NONE: + if zero_point_domain == ZeroPointDomain.NONE or zero_point_domain is None: zero_point = None data = quantize_affine( input_float, From 9cc7ebd51145cbbbf3cd99ad21a83eb1c030bbdc Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Thu, 12 Dec 2024 20:40:39 -0800 Subject: [PATCH 7/7] Unify zero_point_domain None and ZeroPointDomain.NONE cases --- test/quantization/test_observer.py | 15 +++++++-------- torchao/dtypes/affine_quantized_tensor.py | 4 ++-- torchao/quantization/quant_primitives.py | 11 ----------- 3 files changed, 9 insertions(+), 21 deletions(-) diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index 0526ee01b2..8ec15eb201 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -21,6 +21,7 @@ ) from torchao.quantization.quant_primitives import ( MappingType, + ZeroPointDomain, ) @@ -74,7 +75,7 @@ def test_block_size_calc_success(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -93,7 +94,7 @@ def test_block_size_calc_success(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) for example_input in example_inputs: obs(example_input) @@ -108,7 +109,7 @@ def test_block_size_row_errors(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -127,7 +128,7 @@ def test_block_size_row_errors(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -155,7 +156,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) if observe_weight: weight_observer = AffineQuantizedMinMaxObserver( @@ -165,7 +166,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) else: weight_observer = None @@ -199,7 +200,6 @@ def test_linear_observer_tensor(self, observe_weight: bool): input_scale.item(), max_val / max_fp8, ) - self.assertIsNotNone(input_zero_point) if observe_weight: weight_observer = linear.weight.weight_observer @@ -210,7 +210,6 @@ def test_linear_observer_tensor(self, observe_weight: bool): atol=5e-5, rtol=0.0, ) - self.assertIsNotNone(weight_zero_point) else: self.assertIsNone(linear.weight.weight_observer) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index d30718853e..e6b1aa8db7 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -349,7 +349,7 @@ def from_hp_to_floatx( scale_dtype=scale_dtype, zero_point_dtype=None, preserve_zero=True, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, _layout=_layout, use_hqq=False, ) @@ -376,7 +376,7 @@ def from_hp_to_floatx_static( target_dtype=target_dtype, quant_min=math.ceil(torch.finfo(target_dtype).min), quant_max=math.ceil(torch.finfo(target_dtype).max), - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, _layout=_layout, ) else: diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index e0350f0142..69469bbfa8 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -533,16 +533,6 @@ def _dequantize_affine_no_dtype_check( ), "zero_point should be None when zero_point_domain is NONE" dequant = input.to(output_dtype) dequant = dequant * scale - elif zero_point_domain is None: - # This case handles dequantization for float8 we expect no zero point and no zero point domain - assert ( - zero_point is None - ), "zero_point should be None when zero_point_domain is None" - assert _is_float8_type( - input.dtype - ), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" - dequant = input.to(output_dtype) - dequant = dequant * scale else: assert ( zero_point_domain == ZeroPointDomain.FLOAT.name @@ -889,7 +879,6 @@ def _choose_qparams_affine( ) if ( zero_point_domain != ZeroPointDomain.NONE.name - and zero_point_domain != None and zero_point_domain != ZeroPointDomain.INT.name ): raise ValueError(