From bdd1486486f740da854a594bb5b23552bbc0b6de Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 25 Sep 2024 12:23:34 -0700 Subject: [PATCH] Add torchchat quantizer (#897) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/897 This diff adds a quantizer for the new torchao kernels that is similar to the Int8DynActInt4WeightQuantizer quantizer in torchchat (imported from from torchao.quantization.quant_api). See the draft torchchat PR (https://github.com/pytorch/torchchat/pull/1070/) for how this can integrate with torchchat's quantization API. I confirmed that models quantized with this are compatible with eager, compile, AOTI, and export to ExecuTorch in torchchat. They do not run on ExecuTorch because we still have not written an ExecuTorch kernel wrapper. jerryzh168 this does not use the new subclass API, and this is something I'd like to discuss further with you. I'll set up a sync with you this week, but I wanted to have some API on the table to ground the discussion. We do not currently have the required C++ methods implemented to support the new subclass API (e.g., we cannot unpack the packed weights from python; they are instead unpacked inline in the kernel). From a torchchat user's perspective, I do not think this is important, but I'd like to discuss further. Differential Revision: D62394341 --- .../kernels/cpu/aarch64/CMakeLists.txt | 8 +- .../examples/torch_custom_op/CMakeLists.txt | 8 +- .../torch_custom_op/build_custom_op.sh | 8 +- .../examples/torch_custom_op/run_custom_op.py | 72 ++-- .../torch_custom_op/test_custom_op.py | 56 --- ...test_int8_dyn_act_intx_weight_quantizer.py | 79 +++++ .../torch_custom_op/torch_custom_op.py | 231 ------------- torchao/experimental/quant_api.py | 321 ++++++++++++++++++ 8 files changed, 432 insertions(+), 351 deletions(-) delete mode 100644 torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py create mode 100644 torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py delete mode 100644 torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py create mode 100644 torchao/experimental/quant_api.py diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index ecffb579c1..a13737d874 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -6,8 +6,8 @@ add_library( kernel_aarch64 - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp ) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt index 55bcdfbc23..10e44a79a8 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt @@ -13,12 +13,12 @@ set(CMAKE_BUILD_TYPE Release) add_compile_options("-Wall" "-Werror") include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) +message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") +include_directories(${TORCHAO_INCLUDE_DIRS}) -add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64) +add_subdirectory(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64) -include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake) +include(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/Utils.cmake) set(PLATFORM "ATEN" CACHE STRING "Choose platform surface: ATEN, EXECUTORCH") string(TOUPPER ${PLATFORM} PLATFORM_TO_UPPER) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh index 94cb9587c6..c657857fcc 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh @@ -6,14 +6,14 @@ # LICENSE file in the root directory of this source tree. SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../../.. +export TORCHAO_INCLUDE_DIRS=${SCRIPT_DIR}/../../../../../../.. export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" -export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples/torch_custom_op -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ +export CMAKE_OUT=/tmp/cmake-out/torchao +cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \ -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ -DPLATFORM="ATEN" \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \ + -S ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \ -B ${CMAKE_OUT} cmake --build ${CMAKE_OUT} diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py index 0b85583f76..e3d96df63c 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py @@ -5,12 +5,21 @@ # LICENSE file in the root directory of this source tree. import copy +import glob +import os + +import sys import torch -from torch_custom_op import ( - linear_a8sz_w_lowbit_reference_impl, - replace_linear_with_quantized_linear, + +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) ) +from quant_api import Int8DynActIntxWeightQuantizer + +libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") +libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) +torch.ops.load_library(libs[0]) group_size = 256 m = 1 @@ -27,15 +36,15 @@ print("Quantizing random model") quantized_model = copy.deepcopy(model) -quantized_model = quantized_model.eval() -replace_linear_with_quantized_linear( - quantized_model, - kwargs={ - "group_size": group_size, - "nbit": nbit, - "has_weight_zeros": has_weight_zeros, - }, +quantizer = Int8DynActIntxWeightQuantizer( + device="cpu", + precision=torch.float32, + bitwidth=nbit, + groupsize=group_size, + has_weight_zeros=has_weight_zeros, ) +quantized_model = quantizer.quantize(quantized_model) +quantized_model = quantized_model.eval() print("Creating random activations") activations = torch.randn(m, k, dtype=torch.float32) @@ -58,44 +67,3 @@ print("Running AOTI") fn = torch._export.aot_load("/tmp/torch_custom_op_example_model.so", "cpu") fn(activations) - - -print("\nChecking correctness on layer 0") -linear = model[0] -quantized_linear = quantized_model[0] - -with torch.no_grad(): - result = quantized_linear(activations) - expected_result = linear_a8sz_w_lowbit_reference_impl( - linear.weight, activations, group_size, nbit, has_weight_zeros - ) - non_quantized_result = linear(activations) - - -# Check that entries in result match entries in expected_result -num_mismatch_at_low_tol = 0 -num_total = result.reshape(-1).shape[0] -for i in range(num_total): - actual_val = result.reshape(-1)[i] - expected_val = expected_result.reshape(-1)[i] - if not torch.allclose(actual_val, expected_val): - num_mismatch_at_low_tol += 1 - - # If results are not close at a relaxed tolerance, exit with failure - if not torch.allclose(actual_val, expected_val, atol=1e-6): - assert False, "Correctness check failed" - -# Assert at most 5% of entries are not close at a low tolerance -assert num_mismatch_at_low_tol / num_total <= 0.05, "Correctness check failed" -print( - "Correctness check passed. All results are close, and ", - (num_total - num_mismatch_at_low_tol), - "/", - num_total, - " entries are close at a low tolerance.", -) -print("Quantization errors:") -print("\tL1 error: ", torch.mean(torch.abs(result - non_quantized_result)).item()) -print("\tL2 error: ", torch.mean((result - non_quantized_result) ** 2).item()) -print("\tquantized_result[0:5]: ", result[0][0:5]) -print("\tnon_quantized_result[0:5]: ", non_quantized_result[0][0:5]) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py deleted file mode 100644 index e4e108b901..0000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from torch_custom_op import ( - linear_a8sz_w_lowbit_reference_impl, - replace_linear_with_quantized_linear, -) -import copy - -class TestTorchCustomOp(unittest.TestCase): - def test_accuracy(self): - group_size = 128 - m = 1 - n = 1071 - k = 4096 - activations = torch.randn(m, k, dtype=torch.float32) - model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) - - for nbit in [2, 3, 4, 5]: - for has_weight_zeros in [False, True]: - quantized_model = copy.deepcopy(model) - replace_linear_with_quantized_linear( - quantized_model, - kwargs={ - "group_size": group_size, - "nbit": nbit, - "has_weight_zeros": has_weight_zeros, - }, - ) - - with torch.no_grad(): - result = quantized_model(activations) - expected_result = linear_a8sz_w_lowbit_reference_impl( - model[0].weight, activations, group_size, nbit, has_weight_zeros - ) - - num_mismatch_at_low_tol = 0 - num_total = result.reshape(-1).shape[0] - for i in range(num_total): - actual_val = result.reshape(-1)[i] - expected_val = expected_result.reshape(-1)[i] - self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) - if not torch.allclose(actual_val, expected_val): - num_mismatch_at_low_tol += 1 - - # Assert at most 5% of entries are not close at a low tolerance - self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) - -if __name__ == '__main__': - unittest.main() diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py new file mode 100644 index 0000000000..513088d2f0 --- /dev/null +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import glob +import os + +import sys +import unittest + +import torch + +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) +) +from quant_api import ( + _Int8DynActIntxWeightQuantizedLinearFallback, + Int8DynActIntxWeightQuantizer, +) + +libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") +libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) +if len(libs) == 0: + print( + "Could not find library lowbit_op_aten; please run `sh build_custom_op.sh` to build the library. A slow fallback kernel will be used instaed." + ) +else: + torch.ops.load_library(libs[0]) + + +class TestInt8DynActIntxWeightQuantizer(unittest.TestCase): + def test_accuracy(self): + group_size = 128 + m = 1 + n = 1071 + k = 4096 + activations = torch.randn(m, k, dtype=torch.float32) + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) + + for nbit in [1, 2, 3, 4, 5, 6, 7]: + for has_weight_zeros in [True, False]: + print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") + quantized_model = copy.deepcopy(model) + quantizer = Int8DynActIntxWeightQuantizer( + device="cpu", + precision=torch.float32, + bitwidth=nbit, + groupsize=group_size, + has_weight_zeros=has_weight_zeros, + ) + quantized_model = quantizer.quantize(quantized_model) + + with torch.no_grad(): + result = quantized_model(activations) + reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback() + reference_impl.quantize_and_pack_weights( + model[0].weight, nbit, group_size, has_weight_zeros + ) + expected_result = reference_impl(activations) + + num_mismatch_at_low_tol = 0 + num_total = result.reshape(-1).shape[0] + for i in range(num_total): + actual_val = result.reshape(-1)[i] + expected_val = expected_result.reshape(-1)[i] + self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) + if not torch.allclose(actual_val, expected_val): + num_mismatch_at_low_tol += 1 + + # Assert at most 5% of entries are not close at a low tolerance + self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py deleted file mode 100644 index 46117db15a..0000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional - -import torch -import torch.nn as nn - -import glob -libs = glob.glob("/tmp/cmake-out/torch_ao/examples/torch_custom_op/libtorch_custom_op.*") -libs = list(filter(lambda l:(l.endswith("so") or l.endswith("dylib")), libs)) -torch.ops.load_library(libs[0]) - -def quantize(vals: torch.Tensor, group_size: int, nbit: int, scale_only: bool): - assert nbit >= 2 and nbit <= 8 - qmin = -(1 << (nbit - 1)) - qmax = (1 << (nbit - 1)) - 1 - - n, k = vals.shape - vals = vals.reshape(-1, group_size) - vmins, _ = torch.min(vals, axis=1) - vmaxs, _ = torch.max(vals, axis=1) - group_scales = (vmaxs - vmins) / (qmax - qmin) - - if scale_only: - group_qvals = torch.round(vals / group_scales.reshape(-1, 1)) - else: - group_zeros = qmin - torch.round(vmins / group_scales) - group_qvals = torch.round( - group_zeros.reshape(-1, 1) + vals / group_scales.reshape(-1, 1) - ) - - group_qvals = torch.clip(group_qvals, qmin, qmax).reshape(n, k).to(torch.int8) - - if scale_only: - return group_qvals, group_scales - return group_qvals, group_scales, group_zeros - - -def linear_a8sz_w_lowbit_reference_impl( - weights, activations, group_size, nbit, has_weight_zeros -): - n, k = weights.shape - m, k = activations.shape - assert m == 1 - assert k % group_size == 0 - - if has_weight_zeros: - weight_qvals, weight_scales, weight_zeros = quantize( - weights, group_size, nbit, scale_only=False - ) - weights_dequantized = ( - weight_scales.reshape(-1, 1) - * (weight_qvals.reshape(-1, group_size) - weight_zeros.reshape(-1, 1)) - ).reshape(n, k) - else: - weight_qvals, weight_scales = quantize( - weights, group_size, nbit, scale_only=True - ) - weights_dequantized = ( - weight_scales.reshape(-1, 1) * (weight_qvals.reshape(-1, group_size)) - ).reshape(n, k) - - activation_qvals, activations_scales, activations_zeros = quantize( - activations, k, 8, False - ) - activations_dequantized = activations_scales * ( - activation_qvals - activations_zeros - ).reshape(m, k) - return torch.matmul(activations_dequantized, weights_dequantized.transpose(1, 0)) - - -class _quantized_linear(nn.Module): - def __init__( - self, - nbit, - has_weight_zeros, - pack_weight_op, - linear_op, - squeeze_unsqueeze_dim0=False, - ): - super().__init__() - self.squeeze_unsqueeze_dim0 = squeeze_unsqueeze_dim0 - self.nbit = nbit - - self._has_weight_zeros = has_weight_zeros - self._pack_weights_op = pack_weight_op - self._linear_op = linear_op - - def pack_weights(self, weight_qvals, weight_scales_and_zeros, group_size): - n, k = weight_qvals.shape - - # TODO(T200095131): convert self.n, self.k, self.group_size to - # int when supported by AOTI - self.n = torch.empty(n) - self.k = torch.empty(k) - self.group_size = torch.empty(group_size) - - if self._has_weight_zeros: - weight_scales, weight_zeros = weight_scales_and_zeros - self.packed_weights = self._pack_weights_op( - weight_qvals, weight_scales, weight_zeros, self.group_size - ) - else: - weight_scales = weight_scales_and_zeros - self.packed_weights = self._pack_weights_op( - weight_qvals, weight_scales, self.group_size - ) - - def forward(self, x): - if self.squeeze_unsqueeze_dim0: - x = x.squeeze(0) - - res = self._linear_op(self.packed_weights, self.n, self.k, self.group_size, x) - - if self.squeeze_unsqueeze_dim0: - res = res.unsqueeze(0) - return res - - -def replace_linear_with_quantized_linear(module: nn.Module, kwargs={}): - group_size = kwargs["group_size"] - nbit = kwargs["nbit"] - has_weight_zeros = kwargs["has_weight_zeros"] - squeeze_unsqueeze_dim0 = ( - kwargs["squeeze_unsqueeze_dim0"] - if "squeeze_unsqueeze_dim0" in kwargs - else False - ) - - for name, child in module.named_children(): - if isinstance(child, nn.Linear): - assert child.bias is None - - if not has_weight_zeros: - weight_qvals, weight_scales = quantize( - child.weight, group_size=group_size, nbit=nbit, scale_only=True - ) - weight_scales_and_zeros = weight_scales - else: - weight_qvals, weight_scales, weight_zeros = quantize( - child.weight, group_size=group_size, nbit=nbit, scale_only=False - ) - weight_scales_and_zeros = (weight_scales, weight_zeros.to(torch.int8)) - - qlinear = None - if nbit == 2: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w2sz, - linear_op=torch.ops.torchao._linear_a8sz_w2sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w2s, - linear_op=torch.ops.torchao._linear_a8sz_w2s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - elif nbit == 3: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w3sz, - linear_op=torch.ops.torchao._linear_a8sz_w3sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w3s, - linear_op=torch.ops.torchao._linear_a8sz_w3s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - elif nbit == 4: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w4sz, - linear_op=torch.ops.torchao._linear_a8sz_w4sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w4s, - linear_op=torch.ops.torchao._linear_a8sz_w4s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - elif nbit == 5: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w5sz, - linear_op=torch.ops.torchao._linear_a8sz_w5sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w5s, - linear_op=torch.ops.torchao._linear_a8sz_w5s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - raise ValueError( - f"Unsupported nbit ({nbit}) and has_weight_zeros ({has_weight_zeros}) combination" - ) - - assert qlinear is not None - setattr(module, name, qlinear) - getattr(module, name).pack_weights( - weight_qvals, - weight_scales_and_zeros, - group_size, - ) - else: - replace_linear_with_quantized_linear(child, kwargs) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py new file mode 100644 index 0000000000..26797bdb1c --- /dev/null +++ b/torchao/experimental/quant_api.py @@ -0,0 +1,321 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import torch +import torch.nn as nn +from torch.ao.quantization.fx._decomposed import ( + dequantize_per_channel_group, + quantize_per_channel_group, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +import sys + +handler = logging.StreamHandler(sys.stdout) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool): + assert nbit >= 1 and nbit <= 8 + qmin = -(1 << (nbit - 1)) + qmax = (1 << (nbit - 1)) - 1 + + n, k = vals.shape + vals = vals.reshape(-1, group_size) + vmins, _ = torch.min(vals, axis=1) + vmaxs, _ = torch.max(vals, axis=1) + group_scales = (vmaxs - vmins) / (qmax - qmin) + + if not has_weight_zeros: + group_zeros = torch.zeros_like(group_scales) + else: + group_zeros = qmin - torch.round(vmins / group_scales) + + vals = vals.reshape(n, k) + group_scales = group_scales.reshape(n, -1) + group_zeros = group_zeros.reshape(n, -1) + + group_qvals = quantize_per_channel_group( + input=vals, + scales=group_scales, + zero_points=group_zeros, + quant_min=qmin, + quant_max=qmax, + dtype=torch.int8, + group_size=group_size, + ) + + if not has_weight_zeros: + group_zeros = None + + return group_qvals, group_scales, group_zeros + + +class _Int8DynActIntxWeightQuantizedLinearNative(nn.Module): + def __init__( + self, + pack_weight_op, + linear_op, + ): + super().__init__() + self._pack_weights_op = pack_weight_op + self._linear_op = linear_op + + def quantize_and_pack_weights(self, weights, nbit, group_size, has_weight_zeros): + self.nbit = nbit + self.group_size = group_size + self.has_weight_zeros = has_weight_zeros + + n, k = weights.shape + + # TODO(T200095131): convert self.n, self.k, self.group_size to + # int when supported by AOTI + self._n = torch.empty(n, dtype=torch.int8) + self._k = torch.empty(k, dtype=torch.int8) + self._group_size = torch.empty(self.group_size, dtype=torch.int8) + + weight_qvals, weight_scales, weight_zeros = _quantize( + weights, self.group_size, self.nbit, self.has_weight_zeros + ) + if self.has_weight_zeros: + self.packed_weights = self._pack_weights_op( + weight_qvals, + weight_scales.reshape(-1), + weight_zeros.to(torch.int8).reshape(-1), + self._group_size, + ) + else: + self.packed_weights = self._pack_weights_op( + weight_qvals, weight_scales.reshape(-1), self._group_size + ) + + def forward(self, x): + assert x.dim() >= 2 + if x.dim() == 2: + return self._linear_op( + self.packed_weights, self._n, self._k, self._group_size, x + ) + + assert x.dim() >= 3 + lead_shape = x.shape[0:-2] + m, k = x.shape[-2], x.shape[-1] + n = self._n.shape[0] + x = x.reshape(-1, m, k) + + res = [ + self._linear_op( + self.packed_weights, self._n, self._k, self._group_size, x[i, :, :] + ) + for i in range(x.shape[0]) + ] + res = torch.stack(res) + res = res.reshape(*lead_shape, m, n) + return res + + +# Python-based reference implementation of Int8DynActLowbitWeightQuantizedLinear +# It is arithmetically equivalent to Int8DynActLowbitWeightQuantizedLinear +# This is used to test Int8DynActLowbitWeightQuantizedLinear, and as a fallback when +# Int8DynActLowbitWeightQuantizedLinear is not available +class _Int8DynActIntxWeightQuantizedLinearFallback(nn.Module): + def __init__(self): + super().__init__() + + def quantize_and_pack_weights(self, weights, nbit, group_size, has_weight_zeros): + self.nbit = nbit + self.group_size = group_size + self.has_weight_zeros = has_weight_zeros + + self._n, self._k = weights.shape + assert self._k % group_size == 0, "group_size must divide k" + + self.weight_qvals, self.weight_scales, self.weight_zeros = _quantize( + weights, self.group_size, self.nbit, self.has_weight_zeros + ) + + def _forward_2d(self, x): + assert x.dim() == 2 + + n, k = self._n, self._k + m, k_ = x.shape + assert k_ == k + + weights_dequantized = dequantize_per_channel_group( + w_int8=self.weight_qvals, + scales=self.weight_scales, + zero_points=( + self.weight_zeros + if self.has_weight_zeros + else torch.zeros_like(self.weight_scales) + ), + quant_min=None, # TODO: why is this an arg for this function + quant_max=None, # TODO: why is this an arg for this function + dtype=None, # TODO: why is this an arg for this function + group_size=self.group_size, + output_dtype=torch.float32, + ) + + activation_qvals, activation_scales, activation_zeros = _quantize( + x, group_size=k, nbit=8, has_weight_zeros=True + ) + activations_dequantized = dequantize_per_channel_group( + w_int8=activation_qvals, + scales=activation_scales, + zero_points=activation_zeros, + quant_min=None, # TODO: why is this an arg for this function + quant_max=None, # TODO: why is this an arg for this function + dtype=None, # TODO: why is this an arg for this function + group_size=k, + output_dtype=torch.float32, + ) + + res = torch.matmul(activations_dequantized, weights_dequantized.transpose(1, 0)) + return res + + def forward(self, x): + assert x.dim() >= 2 + if x.dim() == 2: + return self._forward_2d(x) + + assert x.dim() >= 3 + lead_shape = x.shape[0:-2] + m, k = x.shape[-2], x.shape[-1] + n = self._n + x = x.reshape(-1, m, k) + + res = [self._forward_2d(x[i, :, :]) for i in range(x.shape[0])] + res = torch.stack(res) + res = res.reshape(*lead_shape, m, n) + return res + + +def _maybe_get_quantized_linear_native(nbit, has_weight_zeros): + try: + if nbit in [2, 3, 4, 5]: + wzp_suffix = "z" if has_weight_zeros else "" + return _Int8DynActIntxWeightQuantizedLinearNative( + pack_weight_op=getattr( + torch.ops.torchao, f"_pack_weights_a8sz_w{nbit}s{wzp_suffix}" + ), + linear_op=getattr( + torch.ops.torchao, f"_linear_a8sz_w{nbit}s{wzp_suffix}" + ), + ) + else: + logger.warning( + f"_Int8DynActIntxWeightQuantizedLinearNative does not support: nbit={nbit}, has_weight_zeros={has_weight_zeros}." + ) + except Exception as e: + logger.warning( + f"_Int8DynActIntxWeightQuantizedLinearNative raised an exception during initialization: {e}" + ) + + logger.warning( + "Falling back to **slow** implementation _Int8DynActIntxWeightQuantizedLinearFallback." + ) + return _Int8DynActIntxWeightQuantizedLinearFallback() + + +def _replace_linear_with_quantized_linear(module: nn.Module, kwargs={}): + group_size = kwargs["group_size"] + nbit = kwargs["nbit"] + has_weight_zeros = kwargs["has_weight_zeros"] + + assert not isinstance(module, nn.Linear) + assert nbit >= 1 and nbit <= 7 + + for name, child in module.named_children(): + if not isinstance(child, nn.Linear): + _replace_linear_with_quantized_linear(child, kwargs) + else: + assert child.bias is None + qlinear = _maybe_get_quantized_linear_native( + nbit=nbit, has_weight_zeros=has_weight_zeros + ) + try: + # The packing function may raise some error from the C++ layer (e.g., if group_size is unsupported) + # so calling quantize_and_pack_weights can fail. In this case, we still switch to fallback + # implementation + setattr(module, name, qlinear) + getattr(module, name).quantize_and_pack_weights( + child.weight, nbit, group_size, has_weight_zeros + ) + except Exception as e: + if not isinstance(qlinear, _Int8DynActIntxWeightQuantizedLinearNative): + raise e + logger.warning( + "_Int8DynActIntxWeightQuantizedLinearNative raised an exception during quantize_and_pack_weights: {e}\n" + + "Falling back to **slow** implementation _Int8DynActIntxWeightQuantizedLinearFallback." + ) + qlinear = _Int8DynActIntxWeightQuantizedLinearFallback() + setattr(module, name, qlinear) + getattr(module, name).quantize_and_pack_weights( + child.weight, nbit, group_size, has_weight_zeros + ) + + +class Int8DynActIntxWeightQuantizer: + def __init__( + self, + device, + precision, + *, + bitwidth: Optional[int] = None, + groupsize: Optional[int] = None, + has_weight_zeros: Optional[bool] = None, + ): + if device != "cpu": + raise NotImplementedError( + "Only device=cpu is currently supported in Int8DynActLowbitWeightQuantizer" + ) + else: + self.device = device + + if precision != torch.float32: + raise NotImplementedError( + "Only precision=torch.float32 is currently supported in Int8DynActLowbitWeightQuantizer" + ) + else: + self.precision = precision + + if bitwidth is None: + self.bitwidth = 4 + logger.warning(f"bitwidth not specified, defaulting to {self.bitwidth}.") + else: + self.bitwidth = bitwidth + + if groupsize is None: + self.groupsize = 128 + logger.warning(f"groupsize not specified, defaulting to {self.groupsize}.") + else: + self.groupsize = groupsize + + if has_weight_zeros is None: + self.has_weight_zeros = False + logger.warning( + f"has_weight_zeros not specified, defaulting to {self.has_weight_zeros}." + ) + else: + self.has_weight_zeros = has_weight_zeros + + def quantize(self, model: nn.Module) -> nn.Module: + model = model.to(self.device).to(self.precision) + _replace_linear_with_quantized_linear( + model, + kwargs={ + "group_size": self.groupsize, + "nbit": self.bitwidth, + "has_weight_zeros": self.has_weight_zeros, + }, + ) + return model