Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use int_scaled_matmul with int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC) #1402

Closed
wants to merge 7 commits into from
23 changes: 14 additions & 9 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()
MAPPING_TYPES = [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES, MAPPING_TYPES)).copy()

def _int8wo_api(mod):
if TORCH_VERSION_AT_LEAST_2_4:
Expand All @@ -125,9 +127,9 @@ def _int8wo_groupwise_api(mod):
group_size = 32
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)

def _int8da_int8w_api(mod):
def _int8da_int8w_api(mod, act_mapping_type=MappingType.SYMMETRIC):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
quantize_(mod, int8_dynamic_activation_int8_weight(act_mapping_type=act_mapping_type), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
Expand Down Expand Up @@ -860,7 +862,7 @@ def _test_lin_weight_subclass_api_impl(
test_device,
min_sqnr=35,
test_dtype=torch.bfloat16,
test_shape=(32, 64, 32)
test_shape=(32, 4096, 14336)
):
m, k, n = test_shape
x = torch.randn(m, k, device=test_device, dtype=test_dtype)
Expand All @@ -871,23 +873,26 @@ def _test_lin_weight_subclass_api_impl(
api(mod)

test = mod(x)

self.assertGreater(
SQNR(ref_f, test),
min_sqnr, f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}"
min_sqnr, f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}"
)

mod_qc = torch.compile(mod, mode="max-autotune")
test_comp = mod_qc(x)
self.assertGreater(
SQNR(ref_f, test_comp), min_sqnr,
f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}"
f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}"
)


@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
@parameterized.expand(list(itertools.product(COMMON_DEVICES, COMMON_DTYPES, MAPPING_TYPES)))
def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping):
from functools import partial
api = partial(_int8da_int8w_api, act_mapping_type=act_mapping)
self._test_lin_weight_subclass_api_impl(
_int8da_int8w_api, device, 35, test_dtype=dtype
api, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand Down
15 changes: 12 additions & 3 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
PlainAQTTensorImpl,
_linear_fp_act_int8_weight_check,
_linear_fp_act_int8_weight_impl,
_linear_int8_act_int8_weight_check,
_linear_int8_act_int8_weight_impl,
_linear_sym_int8_act_sym_int8_weight_check,
_linear_sym_int8_act_sym_int8_weight_impl,
_linear_asym_int8_act_sym_int8_weight_check,
_linear_asym_int8_act_sym_int8_weight_impl
)
from torchao.dtypes.uintx.semi_sparse_layout import (
_linear_int8_act_int8_weight_semi_structured_sparse_check,
Expand Down Expand Up @@ -110,7 +112,14 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias):
# so that these can be shared by F.linear, aten.mm, aten.addmm dispatches
def _register_aqt_quantized_linear_dispatches():
for dispatch_condition, impl in [
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
(
_linear_sym_int8_act_sym_int8_weight_check,
_linear_sym_int8_act_sym_int8_weight_impl
),
(
_linear_asym_int8_act_sym_int8_weight_check,
_linear_asym_int8_act_sym_int8_weight_impl
),
(
_linear_int8_act_int8_weight_semi_structured_sparse_check,
_linear_int8_act_int8_weight_semi_structured_sparse_impl,
Expand Down
59 changes: 57 additions & 2 deletions torchao/dtypes/uintx/plain_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):
return y


def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
def _linear_sym_int8_act_sym_int8_weight_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and _aqt_is_int8_reduced_range(input_tensor)
Expand All @@ -231,7 +231,7 @@ def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
)


def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
def _linear_sym_int8_act_sym_int8_weight_impl(input_tensor, weight_tensor, bias):
#
# 1. do the matrix form of dot(X_i, W_j)
#
Expand Down Expand Up @@ -266,3 +266,58 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
if bias is not None:
y += bias
return y


def _linear_asym_int8_act_sym_int8_weight_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and _aqt_is_int8(input_tensor)
# ZeroPointDomain.NONE did not work for weight in int8_dynamic_activation_int8_weight
# Uncommenting the next line works in eager mode, but Dynamo runs into a problem with it.
#and torch.equal(weight_tensor.tensor_impl.zero_point, torch.zeros_like(weight_tensor.tensor_impl.zero_point))
sanchitintel marked this conversation as resolved.
Show resolved Hide resolved
and isinstance(weight_tensor, AffineQuantizedTensor)
and input_tensor.dtype == weight_tensor.dtype
and isinstance(input_tensor._layout, PlainLayout)
and isinstance(weight_tensor._layout, PlainLayout)
)


def _linear_asym_int8_act_sym_int8_weight_impl(input_tensor, weight_tensor, bias):
#
# 1. do the matrix form of dot(X_i, W_j)
#
#
# 2. rescale the output and apply compensation for zero point of A
#
#
# in cases with large matrices, y_dot_int32 can grow sufficiently
# large that y_dot_int32 * a float16 scale is greater than the maximum
# value of a float 16, (which results in a value of inf even if multiplying
# by the other scale would bring it within the expected range)
x_vals_int8 = input_tensor.tensor_impl.int_data
x_zps = input_tensor.tensor_impl.zero_point.reshape(-1, 1)
x_scales = input_tensor.tensor_impl.scale.reshape(-1, 1)
w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t()
w_scales = weight_tensor.tensor_impl.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
x_scales_dtype = x_scales.dtype
# Cast fp16 scale to float to avoid overflow in int_scaled_matmul
intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype
y_dot_scaled = int_scaled_matmul(
tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)
)
y_dot_scaled = y_dot_scaled.to(x_scales_dtype) * w_scales

# Compute compensation
w_col_sum = weight_tensor.tensor_impl.int_data.contiguous().t().to(torch.float).sum(dim=0)
a_compensation = ((x_scales * w_scales) * x_zps.to(intermediate_dtype)) * w_col_sum
y = (y_dot_scaled - a_compensation).reshape(
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
)

# can downcast only at the very end
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y