From 6d5db058b10a346a7ef8e7122bf93d871997c7e2 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 24 Sep 2024 15:12:11 -0700 Subject: [PATCH] Add torchchat quantizer (#897) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/897 This diff adds a quantizer for the new torchao kernels that is similar to the Int8DynActInt4WeightQuantizer quantizer in torchchat (imported from from torchao.quantization.quant_api). See the draft torchchat PR (https://github.com/pytorch/torchchat/pull/1070/) for how this can integrate with torchchat's quantization API. I confirmed that models quantized with this are compatible with eager, compile, AOTI, and export to ExecuTorch in torchchat. They do not run on ExecuTorch because we still have not written an ExecuTorch kernel wrapper. jerryzh168 this does not use the new subclass API, and this is something I'd like to discuss further with you. I'll set up a sync with you this week, but I wanted to have some API on the table to ground the discussion. We do not currently have the required C++ methods implemented to support the new subclass API (e.g., we cannot unpack the packed weights from python; they are instead unpacked inline in the kernel). From a torchchat user's perspective, I do not think this is important, but I'd like to discuss further. Reviewed By: digantdesai Differential Revision: D62394341 --- .../kernels/cpu/aarch64/CMakeLists.txt | 8 +- .../examples/torch_custom_op/CMakeLists.txt | 8 +- .../torch_custom_op/build_custom_op.sh | 8 +- .../examples/torch_custom_op/run_custom_op.py | 71 ++--- ...est_int8_dyn_act_intx_weight_quantizer.py} | 55 ++-- .../torch_custom_op/torch_custom_op.py | 231 --------------- torchao/experimental/quant_api.py | 277 ++++++++++++++++++ 7 files changed, 342 insertions(+), 316 deletions(-) rename torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/{test_custom_op.py => test_int8_dyn_act_intx_weight_quantizer.py} (51%) delete mode 100644 torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py create mode 100644 torchao/experimental/quant_api.py diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index ecffb579c1..a13737d874 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -6,8 +6,8 @@ add_library( kernel_aarch64 - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp ) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt index 55bcdfbc23..10e44a79a8 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt @@ -13,12 +13,12 @@ set(CMAKE_BUILD_TYPE Release) add_compile_options("-Wall" "-Werror") include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) +message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") +include_directories(${TORCHAO_INCLUDE_DIRS}) -add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64) +add_subdirectory(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64) -include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake) +include(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/Utils.cmake) set(PLATFORM "ATEN" CACHE STRING "Choose platform surface: ATEN, EXECUTORCH") string(TOUPPER ${PLATFORM} PLATFORM_TO_UPPER) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh index 94cb9587c6..c657857fcc 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh @@ -6,14 +6,14 @@ # LICENSE file in the root directory of this source tree. SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../../.. +export TORCHAO_INCLUDE_DIRS=${SCRIPT_DIR}/../../../../../../.. export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" -export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples/torch_custom_op -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ +export CMAKE_OUT=/tmp/cmake-out/torchao +cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \ -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ -DPLATFORM="ATEN" \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \ + -S ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \ -B ${CMAKE_OUT} cmake --build ${CMAKE_OUT} diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py index 0b85583f76..62c0eb95ca 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py @@ -5,12 +5,18 @@ # LICENSE file in the root directory of this source tree. import copy +import glob + +import sys import torch -from torch_custom_op import ( - linear_a8sz_w_lowbit_reference_impl, - replace_linear_with_quantized_linear, -) + +sys.path.insert(0, "../../../../..") +from quant_api import Int8DynActIntxWeightQuantizer + +libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") +libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) +torch.ops.load_library(libs[0]) group_size = 256 m = 1 @@ -27,15 +33,15 @@ print("Quantizing random model") quantized_model = copy.deepcopy(model) -quantized_model = quantized_model.eval() -replace_linear_with_quantized_linear( - quantized_model, - kwargs={ - "group_size": group_size, - "nbit": nbit, - "has_weight_zeros": has_weight_zeros, - }, +quantizer = Int8DynActIntxWeightQuantizer( + device="cpu", + precision=torch.float32, + bitwidth=nbit, + groupsize=group_size, + has_weight_zeros=has_weight_zeros, ) +quantized_model = quantizer.quantize(quantized_model) +quantized_model = quantized_model.eval() print("Creating random activations") activations = torch.randn(m, k, dtype=torch.float32) @@ -58,44 +64,3 @@ print("Running AOTI") fn = torch._export.aot_load("/tmp/torch_custom_op_example_model.so", "cpu") fn(activations) - - -print("\nChecking correctness on layer 0") -linear = model[0] -quantized_linear = quantized_model[0] - -with torch.no_grad(): - result = quantized_linear(activations) - expected_result = linear_a8sz_w_lowbit_reference_impl( - linear.weight, activations, group_size, nbit, has_weight_zeros - ) - non_quantized_result = linear(activations) - - -# Check that entries in result match entries in expected_result -num_mismatch_at_low_tol = 0 -num_total = result.reshape(-1).shape[0] -for i in range(num_total): - actual_val = result.reshape(-1)[i] - expected_val = expected_result.reshape(-1)[i] - if not torch.allclose(actual_val, expected_val): - num_mismatch_at_low_tol += 1 - - # If results are not close at a relaxed tolerance, exit with failure - if not torch.allclose(actual_val, expected_val, atol=1e-6): - assert False, "Correctness check failed" - -# Assert at most 5% of entries are not close at a low tolerance -assert num_mismatch_at_low_tol / num_total <= 0.05, "Correctness check failed" -print( - "Correctness check passed. All results are close, and ", - (num_total - num_mismatch_at_low_tol), - "/", - num_total, - " entries are close at a low tolerance.", -) -print("Quantization errors:") -print("\tL1 error: ", torch.mean(torch.abs(result - non_quantized_result)).item()) -print("\tL2 error: ", torch.mean((result - non_quantized_result) ** 2).item()) -print("\tquantized_result[0:5]: ", result[0][0:5]) -print("\tnon_quantized_result[0:5]: ", non_quantized_result[0][0:5]) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py similarity index 51% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py rename to torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py index e4e108b901..8727217618 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py @@ -4,16 +4,27 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import copy + +import glob + +import sys import unittest import torch -from torch_custom_op import ( - linear_a8sz_w_lowbit_reference_impl, - replace_linear_with_quantized_linear, + +sys.path.insert(0, "../../../../..") +from quant_api import ( + _Int8DynActIntxWeightQuantizedLinearFallback, + Int8DynActIntxWeightQuantizer, ) -import copy -class TestTorchCustomOp(unittest.TestCase): +libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") +libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) +torch.ops.load_library(libs[0]) + + +class TestInt8DynActIntxWeightQuantizer(unittest.TestCase): def test_accuracy(self): group_size = 128 m = 1 @@ -22,24 +33,27 @@ def test_accuracy(self): activations = torch.randn(m, k, dtype=torch.float32) model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) - for nbit in [2, 3, 4, 5]: - for has_weight_zeros in [False, True]: + for nbit in [1, 2, 3, 4, 5, 6, 7]: + for has_weight_zeros in [True, False]: + print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") quantized_model = copy.deepcopy(model) - replace_linear_with_quantized_linear( - quantized_model, - kwargs={ - "group_size": group_size, - "nbit": nbit, - "has_weight_zeros": has_weight_zeros, - }, + quantizer = Int8DynActIntxWeightQuantizer( + device="cpu", + precision=torch.float32, + bitwidth=nbit, + groupsize=group_size, + has_weight_zeros=has_weight_zeros, ) + quantized_model = quantizer.quantize(quantized_model) with torch.no_grad(): result = quantized_model(activations) - expected_result = linear_a8sz_w_lowbit_reference_impl( - model[0].weight, activations, group_size, nbit, has_weight_zeros + reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback() + reference_impl.quantize_and_pack_weights( + model[0].weight, nbit, group_size, has_weight_zeros ) - + expected_result = reference_impl(activations) + num_mismatch_at_low_tol = 0 num_total = result.reshape(-1).shape[0] for i in range(num_total): @@ -50,7 +64,8 @@ def test_accuracy(self): num_mismatch_at_low_tol += 1 # Assert at most 5% of entries are not close at a low tolerance - self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) - -if __name__ == '__main__': + self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) + + +if __name__ == "__main__": unittest.main() diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py deleted file mode 100644 index 46117db15a..0000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional - -import torch -import torch.nn as nn - -import glob -libs = glob.glob("/tmp/cmake-out/torch_ao/examples/torch_custom_op/libtorch_custom_op.*") -libs = list(filter(lambda l:(l.endswith("so") or l.endswith("dylib")), libs)) -torch.ops.load_library(libs[0]) - -def quantize(vals: torch.Tensor, group_size: int, nbit: int, scale_only: bool): - assert nbit >= 2 and nbit <= 8 - qmin = -(1 << (nbit - 1)) - qmax = (1 << (nbit - 1)) - 1 - - n, k = vals.shape - vals = vals.reshape(-1, group_size) - vmins, _ = torch.min(vals, axis=1) - vmaxs, _ = torch.max(vals, axis=1) - group_scales = (vmaxs - vmins) / (qmax - qmin) - - if scale_only: - group_qvals = torch.round(vals / group_scales.reshape(-1, 1)) - else: - group_zeros = qmin - torch.round(vmins / group_scales) - group_qvals = torch.round( - group_zeros.reshape(-1, 1) + vals / group_scales.reshape(-1, 1) - ) - - group_qvals = torch.clip(group_qvals, qmin, qmax).reshape(n, k).to(torch.int8) - - if scale_only: - return group_qvals, group_scales - return group_qvals, group_scales, group_zeros - - -def linear_a8sz_w_lowbit_reference_impl( - weights, activations, group_size, nbit, has_weight_zeros -): - n, k = weights.shape - m, k = activations.shape - assert m == 1 - assert k % group_size == 0 - - if has_weight_zeros: - weight_qvals, weight_scales, weight_zeros = quantize( - weights, group_size, nbit, scale_only=False - ) - weights_dequantized = ( - weight_scales.reshape(-1, 1) - * (weight_qvals.reshape(-1, group_size) - weight_zeros.reshape(-1, 1)) - ).reshape(n, k) - else: - weight_qvals, weight_scales = quantize( - weights, group_size, nbit, scale_only=True - ) - weights_dequantized = ( - weight_scales.reshape(-1, 1) * (weight_qvals.reshape(-1, group_size)) - ).reshape(n, k) - - activation_qvals, activations_scales, activations_zeros = quantize( - activations, k, 8, False - ) - activations_dequantized = activations_scales * ( - activation_qvals - activations_zeros - ).reshape(m, k) - return torch.matmul(activations_dequantized, weights_dequantized.transpose(1, 0)) - - -class _quantized_linear(nn.Module): - def __init__( - self, - nbit, - has_weight_zeros, - pack_weight_op, - linear_op, - squeeze_unsqueeze_dim0=False, - ): - super().__init__() - self.squeeze_unsqueeze_dim0 = squeeze_unsqueeze_dim0 - self.nbit = nbit - - self._has_weight_zeros = has_weight_zeros - self._pack_weights_op = pack_weight_op - self._linear_op = linear_op - - def pack_weights(self, weight_qvals, weight_scales_and_zeros, group_size): - n, k = weight_qvals.shape - - # TODO(T200095131): convert self.n, self.k, self.group_size to - # int when supported by AOTI - self.n = torch.empty(n) - self.k = torch.empty(k) - self.group_size = torch.empty(group_size) - - if self._has_weight_zeros: - weight_scales, weight_zeros = weight_scales_and_zeros - self.packed_weights = self._pack_weights_op( - weight_qvals, weight_scales, weight_zeros, self.group_size - ) - else: - weight_scales = weight_scales_and_zeros - self.packed_weights = self._pack_weights_op( - weight_qvals, weight_scales, self.group_size - ) - - def forward(self, x): - if self.squeeze_unsqueeze_dim0: - x = x.squeeze(0) - - res = self._linear_op(self.packed_weights, self.n, self.k, self.group_size, x) - - if self.squeeze_unsqueeze_dim0: - res = res.unsqueeze(0) - return res - - -def replace_linear_with_quantized_linear(module: nn.Module, kwargs={}): - group_size = kwargs["group_size"] - nbit = kwargs["nbit"] - has_weight_zeros = kwargs["has_weight_zeros"] - squeeze_unsqueeze_dim0 = ( - kwargs["squeeze_unsqueeze_dim0"] - if "squeeze_unsqueeze_dim0" in kwargs - else False - ) - - for name, child in module.named_children(): - if isinstance(child, nn.Linear): - assert child.bias is None - - if not has_weight_zeros: - weight_qvals, weight_scales = quantize( - child.weight, group_size=group_size, nbit=nbit, scale_only=True - ) - weight_scales_and_zeros = weight_scales - else: - weight_qvals, weight_scales, weight_zeros = quantize( - child.weight, group_size=group_size, nbit=nbit, scale_only=False - ) - weight_scales_and_zeros = (weight_scales, weight_zeros.to(torch.int8)) - - qlinear = None - if nbit == 2: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w2sz, - linear_op=torch.ops.torchao._linear_a8sz_w2sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w2s, - linear_op=torch.ops.torchao._linear_a8sz_w2s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - elif nbit == 3: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w3sz, - linear_op=torch.ops.torchao._linear_a8sz_w3sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w3s, - linear_op=torch.ops.torchao._linear_a8sz_w3s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - elif nbit == 4: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w4sz, - linear_op=torch.ops.torchao._linear_a8sz_w4sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w4s, - linear_op=torch.ops.torchao._linear_a8sz_w4s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - elif nbit == 5: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w5sz, - linear_op=torch.ops.torchao._linear_a8sz_w5sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w5s, - linear_op=torch.ops.torchao._linear_a8sz_w5s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - raise ValueError( - f"Unsupported nbit ({nbit}) and has_weight_zeros ({has_weight_zeros}) combination" - ) - - assert qlinear is not None - setattr(module, name, qlinear) - getattr(module, name).pack_weights( - weight_qvals, - weight_scales_and_zeros, - group_size, - ) - else: - replace_linear_with_quantized_linear(child, kwargs) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py new file mode 100644 index 0000000000..c8886d3c75 --- /dev/null +++ b/torchao/experimental/quant_api.py @@ -0,0 +1,277 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +import torch.nn as nn + + +def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool): + assert nbit >= 1 and nbit <= 8 + qmin = -(1 << (nbit - 1)) + qmax = (1 << (nbit - 1)) - 1 + + n, k = vals.shape + vals = vals.reshape(-1, group_size) + vmins, _ = torch.min(vals, axis=1) + vmaxs, _ = torch.max(vals, axis=1) + group_scales = (vmaxs - vmins) / (qmax - qmin) + + if not has_weight_zeros: + group_qvals = torch.round(vals / group_scales.reshape(-1, 1)) + else: + group_zeros = qmin - torch.round(vmins / group_scales) + group_qvals = torch.round( + group_zeros.reshape(-1, 1) + vals / group_scales.reshape(-1, 1) + ) + + group_qvals = torch.clip(group_qvals, qmin, qmax).reshape(n, k).to(torch.int8) + + if not has_weight_zeros: + return group_qvals, group_scales + return group_qvals, (group_scales, group_zeros) + + +class _Int8DynActIntxWeightQuantizedLinearNative(nn.Module): + def __init__( + self, + pack_weight_op, + linear_op, + ): + super().__init__() + self._pack_weights_op = pack_weight_op + self._linear_op = linear_op + + def quantize_and_pack_weights(self, weights, nbit, group_size, has_weight_zeros): + self.nbit = nbit + self.group_size = group_size + self.has_weight_zeros = has_weight_zeros + + n, k = weights.shape + + # TODO(T200095131): convert self.n, self.k, self.group_size to + # int when supported by AOTI + self._n = torch.empty(n) + self._k = torch.empty(k) + self._group_size = torch.empty(self.group_size) + + weight_qvals, weight_scales_and_zeros = _quantize( + weights, self.group_size, self.nbit, self.has_weight_zeros + ) + if self.has_weight_zeros: + weight_scales, weight_zeros = weight_scales_and_zeros + self.packed_weights = self._pack_weights_op( + weight_qvals, + weight_scales, + weight_zeros.to(torch.int8), + self._group_size, + ) + else: + weight_scales = weight_scales_and_zeros + self.packed_weights = self._pack_weights_op( + weight_qvals, weight_scales, self._group_size + ) + + def forward(self, x): + if x.dim() == 2: + squeeze_dim0 = False + elif x.dim() == 3 and x.shape[0] == 1: + squeeze_dim0 = True + x = x.squeeze(0) + else: + assert False, "Unsupported tensor dimension in forward pass" + + res = self._linear_op( + self.packed_weights, self._n, self._k, self._group_size, x + ) + + if squeeze_dim0: + res = res.unsqueeze(0) + return res + + +# Python-based reference implementation of Int8DynActLowbitWeightQuantizedLinear +# It is arithmetically equivalent to Int8DynActLowbitWeightQuantizedLinear +# This is used to test Int8DynActLowbitWeightQuantizedLinear, and as a fallback when +# Int8DynActLowbitWeightQuantizedLinear is not available +class _Int8DynActIntxWeightQuantizedLinearFallback(nn.Module): + def __init__(self): + super().__init__() + + def quantize_and_pack_weights(self, weights, nbit, group_size, has_weight_zeros): + self.nbit = nbit + self.group_size = group_size + self.has_weight_zeros = has_weight_zeros + + self._n, self._k = weights.shape + assert self._k % group_size == 0, "group_size must divide k" + + self.weight_qvals, self.weight_scales_and_zeros = _quantize( + weights, self.group_size, self.nbit, self.has_weight_zeros + ) + + def forward(self, x): + if x.dim() == 2: + squeeze_dim0 = False + elif x.dim() == 3 and x.shape[0] == 1: + squeeze_dim0 = True + x = x.squeeze(0) + else: + assert False, "Unsupported tensor dimension in forward pass" + + n, k = self._n, self._k + m, k_ = x.shape + assert k_ == k + + if self.has_weight_zeros: + weight_scales, weight_zeros = self.weight_scales_and_zeros + weights_dequantized = ( + weight_scales.reshape(-1, 1) + * ( + self.weight_qvals.reshape(-1, self.group_size) + - weight_zeros.reshape(-1, 1) + ) + ).reshape(n, k) + else: + weight_scales = self.weight_scales_and_zeros + weights_dequantized = ( + weight_scales.reshape(-1, 1) + * (self.weight_qvals.reshape(-1, self.group_size)) + ).reshape(n, k) + + activation_qvals, activations_scales_and_zeros = _quantize( + x, group_size=k, nbit=8, has_weight_zeros=True + ) + assert activation_qvals.shape == (m, k) + + activations_scales, activations_zeros = activations_scales_and_zeros + activations_dequantized = ( + activations_scales.reshape(-1, 1) + * (activation_qvals.reshape(-1, k) - activations_zeros.reshape(-1, 1)) + ).reshape(m, k) + res = torch.matmul(activations_dequantized, weights_dequantized.transpose(1, 0)) + + if squeeze_dim0: + res = res.unsqueeze(0) + return res + + +def _get_quantized_linear_native(nbit, has_weight_zeros): + if nbit in [2, 3, 4, 5]: + wzp_suffix = "z" if has_weight_zeros else "" + return _Int8DynActIntxWeightQuantizedLinearNative( + pack_weight_op=getattr( + torch.ops.torchao, f"_pack_weights_a8sz_w{nbit}s{wzp_suffix}" + ), + linear_op=getattr(torch.ops.torchao, f"_linear_a8sz_w{nbit}s{wzp_suffix}"), + ) + raise NotImplementedError( + f"Not currently supported: nbit={nbit}, has_weight_zeros={has_weight_zeros}." + ) + + +def _replace_linear_with_quantized_linear(module: nn.Module, kwargs={}): + group_size = kwargs["group_size"] + nbit = kwargs["nbit"] + has_weight_zeros = kwargs["has_weight_zeros"] + + assert not isinstance(module, nn.Linear) + assert nbit >= 1 and nbit <= 7 + + for name, child in module.named_children(): + if not isinstance(child, nn.Linear): + _replace_linear_with_quantized_linear(child, kwargs) + else: + assert child.bias is None + qlinear = None + try: + qlinear = _get_quantized_linear_native( + nbit=nbit, + has_weight_zeros=has_weight_zeros, + ) + except Exception as e: + print( + f"Warning: _Int8DynActIntxWeightQuantizedLinearNative raised an exception during initialization: {e}\n" + + "Falling back to **slow** implementation _Int8DynActIntxWeightQuantizedLinearFallback." + ) + qlinear = _Int8DynActIntxWeightQuantizedLinearFallback() + + assert qlinear is not None + try: + setattr(module, name, qlinear) + getattr(module, name).quantize_and_pack_weights( + child.weight, nbit, group_size, has_weight_zeros + ) + except Exception as e: + if not isinstance(qlinear, _Int8DynActIntxWeightQuantizedLinearNative): + raise e + print( + f"Warning: _Int8DynActIntxWeightQuantizedLinearNative raised an exception during quantize_and_pack_weights: {e}\n" + + "Falling back to **slow** implementation _Int8DynActIntxWeightQuantizedLinearFallback." + ) + qlinear = _Int8DynActIntxWeightQuantizedLinearFallback() + setattr(module, name, qlinear) + getattr(module, name).quantize_and_pack_weights( + child.weight, nbit, group_size, has_weight_zeros + ) + + +class Int8DynActIntxWeightQuantizer: + def __init__( + self, + device, + precision, + *, + bitwidth: Optional[int] = None, + groupsize: Optional[int] = None, + has_weight_zeros: Optional[bool] = None, + ): + if device != "cpu": + raise NotImplementedError( + "Only device=cpu is currently supported in Int8DynActLowbitWeightQuantizer" + ) + else: + self.device = device + + if precision != torch.float32: + raise NotImplementedError( + "Only precision=torch.float32 is currently supported in Int8DynActLowbitWeightQuantizer" + ) + else: + self.precision = precision + + if bitwidth is None: + self.bitwidth = 4 + print(f"Warning: bitwidth not specified, defaulting to {self.bitwidth}.") + else: + self.bitwidth = bitwidth + + if groupsize is None: + self.groupsize = 128 + print(f"Warning: groupsize not specified, defaulting to {self.groupsize}.") + else: + self.groupsize = groupsize + + if has_weight_zeros is None: + self.has_weight_zeros = False + print( + f"Warning: has_weight_zeros not specified, defaulting to {self.has_weight_zeros}." + ) + else: + self.has_weight_zeros = has_weight_zeros + + def quantize(self, model: nn.Module) -> nn.Module: + model = model.to(self.device).to(self.precision) + _replace_linear_with_quantized_linear( + model, + kwargs={ + "group_size": self.groupsize, + "nbit": self.bitwidth, + "has_weight_zeros": self.has_weight_zeros, + }, + ) + return model