Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate ZeroPointDomain.NONE & None zero point domains #1556

Merged
merged 7 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import os
import unittest
from functools import partial

import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,6 +49,7 @@
quantize_,
)
from torchao.quantization.quant_primitives import (
MappingType,
dequantize_affine,
)
from torchao.quantization.smoothquant import (
Expand Down Expand Up @@ -102,6 +104,8 @@

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()


Expand All @@ -121,9 +125,18 @@ def _int8wo_groupwise_api(mod):
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)


def _int8da_int8w_api(mod):
def _int8da_int8w_api(
mod,
act_mapping_type=MappingType.SYMMETRIC,
):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
quantize_(
mod,
int8_dynamic_activation_int8_weight(
act_mapping_type=act_mapping_type,
),
set_inductor_config=False,
)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
Expand Down Expand Up @@ -962,25 +975,43 @@ def _test_lin_weight_subclass_api_impl(
mod[0].weight.tensor_impl.get_plain()

test = mod(x)

self.assertGreater(
SQNR(ref_f, test),
min_sqnr,
f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
)

mod_qc = torch.compile(mod, mode="max-autotune")
test_comp = mod_qc(x)
self.assertGreater(
SQNR(ref_f, test_comp),
min_sqnr,
f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
drisspg marked this conversation as resolved.
Show resolved Hide resolved
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
_int8da_int8w_api, device, 35, test_dtype=dtype
@parameterized.expand(
list(
itertools.product(
COMMON_DEVICES,
COMMON_DTYPES,
ACT_MAPPING_TYPES,
)
)
)
def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping):
if (
not TORCH_VERSION_AT_LEAST_2_5
and dtype in (torch.float16, torch.bfloat16)
and act_mapping is MappingType.ASYMMETRIC
and device == "cpu"
):
self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5")
api = partial(
_int8da_int8w_api,
act_mapping_type=act_mapping,
)
self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "broken in fbcode")
Expand Down
17 changes: 9 additions & 8 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)


Expand Down Expand Up @@ -74,7 +75,7 @@ def test_block_size_calc_success(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand All @@ -93,7 +94,7 @@ def test_block_size_calc_success(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
for example_input in example_inputs:
obs(example_input)
Expand All @@ -108,7 +109,7 @@ def test_block_size_row_errors(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand All @@ -127,7 +128,7 @@ def test_block_size_row_errors(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand Down Expand Up @@ -155,7 +156,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
if observe_weight:
weight_observer = AffineQuantizedMinMaxObserver(
Expand All @@ -165,7 +166,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
else:
weight_observer = None
Expand Down Expand Up @@ -199,7 +200,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
input_scale.item(),
max_val / max_fp8,
)
self.assertIsNotNone(input_zero_point)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a change of behavior when you change zero_point_domain for None to ZeroPointDomain.NONE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, input_zero_point would now be None. So, instead of removing that line, I now added self.assertIsNone(input_zero_point). Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so what is the meaning of zero_point_domain == None before?

Copy link
Contributor Author

@sanchitintel sanchitintel Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like some APIs/implementations were creating/expecting a None zero_point when zero_point_domain ZeroPointDomain.NONE or None was used, while choose_qparams_affine was not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jerryzh168, is it possible that some torchao users' code may be expecting a non-None zero_point with zero_point_domain ZeroPointDomain.NONE/None, making this change BC-breaking for them? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that most usages of this function are internal to torchao so that its okay to BC break, you can add the label just to be sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your advice, @drisspg!
Could you please help add such a label as GitHub isn't displaying an option to me for adding it? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a bc-breaking label, please also write a bc-breaking note similar to #1049

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again, @jerryzh168! I added a note & rebased the PR.

self.assertIsNone(input_zero_point)

if observe_weight:
weight_observer = linear.weight.weight_observer
Expand All @@ -210,7 +211,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
atol=5e-5,
rtol=0.0,
)
self.assertIsNotNone(weight_zero_point)
self.assertIsNone(weight_zero_point)
else:
self.assertIsNone(linear.weight.weight_observer)

Expand Down
53 changes: 51 additions & 2 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,55 @@ def test_fake_quantize_affine_cachemask(self):
torch.testing.assert_close(dequantized, fake_quantized)
torch.testing.assert_close(expected_mask, mask)

def test_none_zero_point_domain(self):
sanchitintel marked this conversation as resolved.
Show resolved Hide resolved
"""A None value for a ZeroPointDomain should not work, but ZeroPointDomain.NONE should"""
input = torch.randn(10, 256)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 128)
quant_min = None
quant_max = None
eps = 1e-6
scale_dtype = torch.float32
zero_point_dtype = torch.int64
try:
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=None,
)
except ValueError:
# This exception was expected
# Now test for ZeroPointDomain.NONE
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=ZeroPointDomain.NONE,
)
self.assertTrue(zero_point is None)
else:
# An exception should have been thrown for zero_point_domain None
self.assertTrue(
False,
msg="A runtime exception should have been thrown for zero_point_domain None",
)

@parameterized.expand(
[
(
Expand Down Expand Up @@ -890,7 +939,7 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
zero_point=None,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
expected_dequantized = dequantize_affine(
expected_quantized,
Expand All @@ -901,7 +950,7 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
zero_point=None,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)

self.assertTrue(torch.equal(expected_scale, scale))
Expand Down
20 changes: 11 additions & 9 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __new__(
dtype=None,
strides=None,
):
if zero_point_domain is None:
raise ValueError("please use ZeroPointDomain.NONE instead of None")
kwargs = {}
kwargs["device"] = tensor_impl.device
kwargs["layout"] = (
Expand Down Expand Up @@ -199,7 +201,7 @@ def from_hp_to_intx(
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
_layout: Layout = PlainLayout(),
use_hqq: bool = False,
):
Expand Down Expand Up @@ -258,8 +260,7 @@ def from_hp_to_intx(
zero_point_domain,
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
# TODO should probably consolidate ZeroPointDomain.NONE and None
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
if zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
data = quantize_affine(
input_float,
Expand Down Expand Up @@ -296,14 +297,15 @@ def from_hp_to_intx_static(
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
_layout: Layout = PlainLayout(),
):
"""Create an integer AffineQuantizedTensor from a high precision tensor using static parameters."""
if zero_point_domain is None:
raise ValueError("please use ZeroPointDomain.NONE instead of None")
elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None:
raise ValueError("zero_point should be None when zero_point_domain is NONE")
if target_dtype not in FP8_TYPES:
assert (
zero_point_domain is not None
), "zero_point_domain must be specified for non-fp8 types"
assert (
zero_point is not None
), "zero_point must be specified for non-fp8 types"
Expand Down Expand Up @@ -359,7 +361,7 @@ def from_hp_to_floatx(
scale_dtype=scale_dtype,
zero_point_dtype=None,
preserve_zero=True,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
_layout=_layout,
use_hqq=False,
)
Expand Down Expand Up @@ -387,7 +389,7 @@ def from_hp_to_floatx_static(
target_dtype=target_dtype,
quant_min=math.ceil(torch.finfo(target_dtype).min),
quant_max=math.ceil(torch.finfo(target_dtype).max),
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
_layout=_layout,
)
else:
Expand Down
4 changes: 3 additions & 1 deletion torchao/dtypes/uintx/marlin_qqq_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ def from_hp_to_intx(
block_size: Tuple[int, ...],
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
_layout: Optional[Layout] = None,
):
"""Converts a floating point tensor to a Marlin QQQ quantized tensor."""
if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
nbits = int(math.log2(quant_max - quant_min + 1))
Expand Down
5 changes: 3 additions & 2 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,12 @@ def __init__(
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
):
super().__init__()
assert granularity is not None, "granularity is None"

if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
self.mapping_type = mapping_type
self.target_dtype = target_dtype
self.granularity = granularity
Expand Down
5 changes: 5 additions & 0 deletions torchao/quantization/qat/affine_fake_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def forward(
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> "AffineFakeQuantizedTensor":
if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")

def apply_fake_quant_fn(t: torch.Tensor):
assert isinstance(t, AffineFakeQuantizedTensor)
qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
Expand Down Expand Up @@ -158,6 +161,8 @@ def from_float(
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
):
if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
return _ToAffineFakeQuantized.apply(
original_input,
mapping_type,
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def __init__(
group_size: Optional[int] = None,
is_symmetric: Optional[bool] = None,
):
if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
self.dtype = dtype
self.granularity = self._get_granularity(granularity, group_size)
self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric)
Expand Down
Loading
Loading