Skip to content

Commit

Permalink
Add exceptions/asserts for None argument for zero_point_domain
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchitintel committed Jan 15, 2025
1 parent bbc8dcd commit cd7abef
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 27 deletions.
2 changes: 2 additions & 0 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
input_scale.item(),
max_val / max_fp8,
)
self.assertIsNone(input_zero_point)

if observe_weight:
weight_observer = linear.weight.weight_observer
Expand All @@ -210,6 +211,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
atol=5e-5,
rtol=0.0,
)
self.assertIsNone(weight_zero_point)
else:
self.assertIsNone(linear.weight.weight_observer)

Expand Down
53 changes: 38 additions & 15 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,8 +838,8 @@ def test_fake_quantize_affine_cachemask(self):
torch.testing.assert_close(dequantized, fake_quantized)
torch.testing.assert_close(expected_mask, mask)

# ZeroPointDomain.NONE should work
def test_none_zero_point_domain(self):
"""A None value for a ZeroPointDomain should not work, but ZeroPointDomain.NONE should"""
input = torch.randn(10, 256)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
Expand All @@ -849,20 +849,43 @@ def test_none_zero_point_domain(self):
eps = 1e-6
scale_dtype = torch.float32
zero_point_dtype = torch.int64
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=ZeroPointDomain.NONE,
)
self.assertTrue(zero_point is None)
try:
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=None,
)
except ValueError:
# This exception was expected
# Now test for ZeroPointDomain.NONE
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=ZeroPointDomain.NONE,
)
self.assertTrue(zero_point is None)
else:
# An exception should have been thrown for zero_point_domain None
self.assertTrue(
False,
msg="A runtime exception should have been thrown for zero_point_domain None",
)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __new__(
dtype=None,
strides=None,
):
assert zero_point_domain is not None, "zero_point_domain must not be None"
kwargs = {}
kwargs["device"] = tensor_impl.device
kwargs["layout"] = (
Expand All @@ -107,6 +108,7 @@ def __init__(
dtype=None,
strides=None,
):
assert zero_point_domain is not None, "zero_point_domain must not be None"
self.tensor_impl = tensor_impl
self.block_size = block_size
self.quant_min = quant_min
Expand Down Expand Up @@ -301,10 +303,8 @@ def from_hp_to_intx_static(
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Layout = PlainLayout(),
):
assert zero_point_domain is not None, "zero_point_domain must not be None"
if target_dtype not in FP8_TYPES:
assert (
zero_point_domain is not None
), "zero_point_domain must be specified for non-fp8 types"
assert (
zero_point is not None
), "zero_point must be specified for non-fp8 types"
Expand Down
1 change: 1 addition & 0 deletions torchao/dtypes/uintx/marlin_qqq_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def from_hp_to_intx(
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Optional[Layout] = None,
):
assert zero_point_domain is not None, "zero_point_domain must not be None"
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
nbits = int(math.log2(quant_max - quant_min + 1))
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
):
super().__init__()
assert granularity is not None, "granularity is None"

assert zero_point_domain is not None, "zero_point_domain must not be None"
self.mapping_type = mapping_type
self.target_dtype = target_dtype
self.granularity = granularity
Expand Down
3 changes: 3 additions & 0 deletions torchao/quantization/qat/affine_fake_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def forward(
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> "AffineFakeQuantizedTensor":
assert zero_point_domain is not None, "zero_point_domain must not be None"

def apply_fake_quant_fn(t: torch.Tensor):
assert isinstance(t, AffineFakeQuantizedTensor)
qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
Expand Down Expand Up @@ -158,6 +160,7 @@ def from_float(
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
):
assert zero_point_domain is not None, "zero_point_domain must not be None"
return _ToAffineFakeQuantized.apply(
original_input,
mapping_type,
Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
group_size: Optional[int] = None,
is_symmetric: Optional[bool] = None,
):
assert zero_point_domain is not None, "zero_point_domain must not be None"
self.dtype = dtype
self.granularity = self._get_granularity(granularity, group_size)
self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric)
Expand Down
21 changes: 13 additions & 8 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def quantize_affine(
Output:
quantized tensor with requested dtype
"""
assert zero_point_domain is not None, "zero_point_domain must not be None"
return _quantize_affine(
input,
block_size,
Expand All @@ -324,7 +325,7 @@ def quantize_affine(
output_dtype,
quant_min,
quant_max,
zero_point_domain.name if zero_point_domain is not None else None,
zero_point_domain.name,
)


Expand Down Expand Up @@ -468,6 +469,7 @@ def dequantize_affine(
Output:
dequantized Tensor, with requested dtype or fp32
"""
assert zero_point_domain is not None, "zero_point_domain must not be None"
return _dequantize_affine(
input,
block_size,
Expand All @@ -476,7 +478,7 @@ def dequantize_affine(
input_dtype,
quant_min,
quant_max,
zero_point_domain.name if zero_point_domain is not None else None,
zero_point_domain.name,
output_dtype=output_dtype,
)

Expand Down Expand Up @@ -612,6 +614,7 @@ def fake_quantize_affine(
value during quantization
default is ZeroPointDomain.INT
"""
assert zero_point_domain is not None, "zero_point_domain must not be None"
(_, fq) = _do_fake_quantize_affine(
input,
block_size,
Expand Down Expand Up @@ -654,6 +657,7 @@ def fake_quantize_affine_cachemask(
)
"""
assert zero_point_domain is not None, "zero_point_domain must not be None"
(q, dq) = _do_fake_quantize_affine(
input,
block_size,
Expand Down Expand Up @@ -753,6 +757,8 @@ def choose_qparams_affine(
Output:
Tuple of scales and zero_points Tensor with requested dtype
"""
if zero_point_domain is None:
raise ValueError("zero_point_domain must not be None")
return _choose_qparams_affine(
input,
mapping_type.name,
Expand All @@ -764,7 +770,7 @@ def choose_qparams_affine(
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain.name if zero_point_domain is not None else None,
zero_point_domain.name,
)


Expand Down Expand Up @@ -792,6 +798,8 @@ def choose_qparams_affine_with_min_max(
difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val
and then scale/zero_point, we pass in min_val/max_val directly
"""
if zero_point_domain is None:
raise ValueError("zero_point_domain must not be None")
return _choose_qparams_affine(
None,
mapping_type.name,
Expand All @@ -803,7 +811,7 @@ def choose_qparams_affine_with_min_max(
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain.name if zero_point_domain is not None else None,
zero_point_domain.name,
min_val,
max_val,
)
Expand Down Expand Up @@ -910,10 +918,7 @@ def _choose_qparams_affine(
raise ValueError(
"preserve_zero == False is not supported for symmetric quantization"
)
if (
zero_point_domain is not None
and zero_point_domain == ZeroPointDomain.FLOAT.name
):
if zero_point_domain == ZeroPointDomain.FLOAT.name:
# TODO INT should not be a valid ZeroPointDomain for symmetric quantization since
# symmetric quant doesn't have a zero_point
raise ValueError(
Expand Down

0 comments on commit cd7abef

Please sign in to comment.