diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index 8ec15eb201..4567f3baef 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -200,6 +200,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): input_scale.item(), max_val / max_fp8, ) + self.assertIsNone(input_zero_point) if observe_weight: weight_observer = linear.weight.weight_observer @@ -210,6 +211,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): atol=5e-5, rtol=0.0, ) + self.assertIsNone(weight_zero_point) else: self.assertIsNone(linear.weight.weight_observer) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 00fe300864..9a97218077 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -838,8 +838,8 @@ def test_fake_quantize_affine_cachemask(self): torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) - # ZeroPointDomain.NONE should work def test_none_zero_point_domain(self): + """A None value for a ZeroPointDomain should not work, but ZeroPointDomain.NONE should""" input = torch.randn(10, 256) mapping_type = MappingType.SYMMETRIC dtype = torch.int8 @@ -849,20 +849,43 @@ def test_none_zero_point_domain(self): eps = 1e-6 scale_dtype = torch.float32 zero_point_dtype = torch.int64 - _, zero_point = choose_qparams_affine( - input, - mapping_type, - block_size, - dtype, - quant_min, - quant_max, - eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - preserve_zero=True, - zero_point_domain=ZeroPointDomain.NONE, - ) - self.assertTrue(zero_point is None) + try: + _, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=True, + zero_point_domain=None, + ) + except ValueError: + # This exception was expected + # Now test for ZeroPointDomain.NONE + _, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=True, + zero_point_domain=ZeroPointDomain.NONE, + ) + self.assertTrue(zero_point is None) + else: + # An exception should have been thrown for zero_point_domain None + self.assertTrue( + False, + msg="A runtime exception should have been thrown for zero_point_domain None", + ) if __name__ == "__main__": diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 9d3da97810..14ad158a52 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -85,6 +85,7 @@ def __new__( dtype=None, strides=None, ): + assert zero_point_domain is not None, "zero_point_domain must not be None" kwargs = {} kwargs["device"] = tensor_impl.device kwargs["layout"] = ( @@ -107,6 +108,7 @@ def __init__( dtype=None, strides=None, ): + assert zero_point_domain is not None, "zero_point_domain must not be None" self.tensor_impl = tensor_impl self.block_size = block_size self.quant_min = quant_min @@ -301,10 +303,8 @@ def from_hp_to_intx_static( zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), ): + assert zero_point_domain is not None, "zero_point_domain must not be None" if target_dtype not in FP8_TYPES: - assert ( - zero_point_domain is not None - ), "zero_point_domain must be specified for non-fp8 types" assert ( zero_point is not None ), "zero_point must be specified for non-fp8 types" diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index b75d959b41..d55cff1f61 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -58,6 +58,7 @@ def from_hp_to_intx( zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, _layout: Optional[Layout] = None, ): + assert zero_point_domain is not None, "zero_point_domain must not be None" original_shape = input_float.shape input_float = _layout.pre_process(input_float) nbits = int(math.log2(quant_max - quant_min + 1)) diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 06509c7b91..fe707f6ebe 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -108,7 +108,7 @@ def __init__( ): super().__init__() assert granularity is not None, "granularity is None" - + assert zero_point_domain is not None, "zero_point_domain must not be None" self.mapping_type = mapping_type self.target_dtype = target_dtype self.granularity = granularity diff --git a/torchao/quantization/qat/affine_fake_quantized_tensor.py b/torchao/quantization/qat/affine_fake_quantized_tensor.py index b84200ac9c..40776d89e4 100644 --- a/torchao/quantization/qat/affine_fake_quantized_tensor.py +++ b/torchao/quantization/qat/affine_fake_quantized_tensor.py @@ -41,6 +41,8 @@ def forward( preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> "AffineFakeQuantizedTensor": + assert zero_point_domain is not None, "zero_point_domain must not be None" + def apply_fake_quant_fn(t: torch.Tensor): assert isinstance(t, AffineFakeQuantizedTensor) qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) @@ -158,6 +160,7 @@ def from_float( preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ): + assert zero_point_domain is not None, "zero_point_domain must not be None" return _ToAffineFakeQuantized.apply( original_input, mapping_type, diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index cd3813291f..82faa713e0 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -96,6 +96,7 @@ def __init__( group_size: Optional[int] = None, is_symmetric: Optional[bool] = None, ): + assert zero_point_domain is not None, "zero_point_domain must not be None" self.dtype = dtype self.granularity = self._get_granularity(granularity, group_size) self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 949afc968f..bb35ece86e 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -316,6 +316,7 @@ def quantize_affine( Output: quantized tensor with requested dtype """ + assert zero_point_domain is not None, "zero_point_domain must not be None" return _quantize_affine( input, block_size, @@ -324,7 +325,7 @@ def quantize_affine( output_dtype, quant_min, quant_max, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, ) @@ -468,6 +469,7 @@ def dequantize_affine( Output: dequantized Tensor, with requested dtype or fp32 """ + assert zero_point_domain is not None, "zero_point_domain must not be None" return _dequantize_affine( input, block_size, @@ -476,7 +478,7 @@ def dequantize_affine( input_dtype, quant_min, quant_max, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, output_dtype=output_dtype, ) @@ -612,6 +614,7 @@ def fake_quantize_affine( value during quantization default is ZeroPointDomain.INT """ + assert zero_point_domain is not None, "zero_point_domain must not be None" (_, fq) = _do_fake_quantize_affine( input, block_size, @@ -654,6 +657,7 @@ def fake_quantize_affine_cachemask( ) """ + assert zero_point_domain is not None, "zero_point_domain must not be None" (q, dq) = _do_fake_quantize_affine( input, block_size, @@ -753,6 +757,8 @@ def choose_qparams_affine( Output: Tuple of scales and zero_points Tensor with requested dtype """ + if zero_point_domain is None: + raise ValueError("zero_point_domain must not be None") return _choose_qparams_affine( input, mapping_type.name, @@ -764,7 +770,7 @@ def choose_qparams_affine( scale_dtype, zero_point_dtype, preserve_zero, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, ) @@ -792,6 +798,8 @@ def choose_qparams_affine_with_min_max( difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val and then scale/zero_point, we pass in min_val/max_val directly """ + if zero_point_domain is None: + raise ValueError("zero_point_domain must not be None") return _choose_qparams_affine( None, mapping_type.name, @@ -803,7 +811,7 @@ def choose_qparams_affine_with_min_max( scale_dtype, zero_point_dtype, preserve_zero, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, min_val, max_val, ) @@ -910,10 +918,7 @@ def _choose_qparams_affine( raise ValueError( "preserve_zero == False is not supported for symmetric quantization" ) - if ( - zero_point_domain is not None - and zero_point_domain == ZeroPointDomain.FLOAT.name - ): + if zero_point_domain == ZeroPointDomain.FLOAT.name: # TODO INT should not be a valid ZeroPointDomain for symmetric quantization since # symmetric quant doesn't have a zero_point raise ValueError(