From 52d37bfb30962688ec21ffc7d41b5bc6f5b90e82 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 6 Sep 2024 16:23:32 -0700 Subject: [PATCH 1/3] Add test util for basic tensor subclass functionalities Summary: This is a small starting point for testing low precision tensor subclass functionalities we can add more test cases for training, tensor parallel, FSDP in the future right now it tests: - tensor flatten/unflatten - constructing low precision tensor with different device/dtype - move tensor subclass from device1 to device2 - transpose works - linear works (weight only quantization with the low precision tensor) It can be extended with new tensor subclasses or test cases by overriding the class variables: e.g. ``` class MyTensorSubclassTest(TorchAOBasicTestCase): COMMON_DEVICES = ["cpu", "cuda"] COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] TENSOR_SUBCLASS = LUTQuantizedTensor FACTORY_FN = to_lut_quantized_intx kwargs = { "target_dtype": torch.uint8, } # minimum sqnr for linear operation when the weight is quantized to low precision # with the above setting LINEAR_MIN_SQNR = 40 ``` Test Plan: python test/utils.py Reviewers: Subscribers: Tasks: Tags: --- torchao/testing/utils.py | 148 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 torchao/testing/utils.py diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py new file mode 100644 index 0000000000..99e751af99 --- /dev/null +++ b/torchao/testing/utils.py @@ -0,0 +1,148 @@ +import unittest +import functools +import copy +import torch +import torchao + +from torch._inductor.test_case import TestCase +from torch.testing._internal import common_utils +from torchao.dtypes import AffineQuantizedTensor +from torchao.dtypes import to_affine_quantized_intx +from torchao.quantization.quant_primitives import MappingType + +""" +How to use: + +import unittest +from torchao.testing.utils import TorchAOBasicTestCase, copy_tests +from torch.testing._internal import common_utils + +# TODO: currently there is no way to set COMMON_DEVICES/COMMON_DTYPES +# we can figure out this a bit later + +# change arguments +class MyTestCase(TorchAOBasicTestCase): + FACTOR_FN = to_my_dtype + kwargs = {...} + LINEAR_MIN_SQNR = 30 + +# copy the instantiated tests +copy_tests(TorchAOBasicTestCase, MyTestCase, "my_test_case") + +if __name__ == "__main__": + unittest.main() +""" + +# copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389 +def copy_tests( + my_cls, other_cls, suffix, test_failures=None, xfail_prop=None +): # noqa: B902 + for name, value in my_cls.__dict__.items(): + if name.startswith("test_"): + # You cannot copy functions in Python, so we use closures here to + # create objects with different ids. Otherwise, unittest.skip + # would modify all methods sharing the same object id. Also, by + # using a default argument, we create a copy instead of a + # reference. Otherwise, we would lose access to the value. + + @functools.wraps(value) + def new_test(self, value=value): + return value(self) + + # Copy __dict__ which may contain test metadata + new_test.__dict__ = copy.deepcopy(value.__dict__) + + if xfail_prop is not None and hasattr(value, xfail_prop): + new_test = unittest.expectedFailure(new_test) + + tf = test_failures and test_failures.get(name) + if tf is not None and suffix in tf.suffixes: + skip_func = ( + unittest.skip("Skipped!") + if tf.is_skip + else unittest.expectedFailure + ) + new_test = skip_func(new_test) + + setattr(other_cls, f"{name}_{suffix}", new_test) + + + +class TorchAOBasicTestCase(TestCase): + """Basic test case for tensor subclasses + """ + COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + + TENSOR_SUBCLASS = AffineQuantizedTensor + FACTORY_FN = to_affine_quantized_intx + kwargs = { + "mapping_type": MappingType.ASYMMETRIC, + "block_size": (1, 32), + "target_dtype": torch.uint8, + } + # minimum sqnr for linear operation when the weight is quantized to low precision + # with the above setting + LINEAR_MIN_SQNR = 40 + + def test_flatten_unflatten(self): + hp_tensor = torch.randn(4, 128) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() + tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict} + outer_size = lp_tensor.size() + outer_stride = lp_tensor.stride() + reconstructed = self.TENSOR_SUBCLASS.__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) + self.assertEqual(lp_tensor.dequantize(), reconstructed.dequantize()) + + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_hp_tensor_device_dtype(self, device, dtype): + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + + @common_utils.parametrize("device1", COMMON_DEVICES) + @common_utils.parametrize("device2", COMMON_DEVICES) + def test_device1_to_device2(self, device1, device2): + """Note: this should be parametrized with device1 and device2 + e.g. device1 = ["cpu", "cuda"], device2 = ["cpu", "cuda"] + """ + hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + lp_tensor.to(device=device2) + + hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + lp_tensor.to(device2) + + hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + lp_tensor.cuda() + + hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + lp_tensor.cpu() + + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_transpose(self, device, dtype): + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + lp_tensor = lp_tensor.t() + self.assertEqual(lp_tensor.shape, (128, 4)) + + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_linear(self, device, dtype): + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + + hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype) + hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor) + lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor) + self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + +common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) + +if __name__ == "__main__": + unittest.main() From a9907bbcbd3a37a7b02e241535e34cc317f4b76a Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 6 Sep 2024 17:55:45 -0700 Subject: [PATCH 2/3] minor fix --- torchao/testing/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 99e751af99..5f6db95eb4 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -22,8 +22,9 @@ # change arguments class MyTestCase(TorchAOBasicTestCase): + TENSOR_SUBCLASS = MyDTypeTensor FACTOR_FN = to_my_dtype - kwargs = {...} + kwargs = {"target_dtype": torch.uint8} LINEAR_MIN_SQNR = 30 # copy the instantiated tests From 75d9c909b5eca55764d01843853b424ab401a66e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 9 Sep 2024 15:20:59 -0700 Subject: [PATCH 3/3] don't use inductor TestCase --- torchao/testing/utils.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 5f6db95eb4..a6c5bf7e0a 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -4,7 +4,6 @@ import torch import torchao -from torch._inductor.test_case import TestCase from torch.testing._internal import common_utils from torchao.dtypes import AffineQuantizedTensor from torchao.dtypes import to_affine_quantized_intx @@ -69,7 +68,7 @@ def new_test(self, value=value): -class TorchAOBasicTestCase(TestCase): +class TorchAOBasicTestCase(common_utils.TestCase): """Basic test case for tensor subclasses """ COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) @@ -143,6 +142,19 @@ def test_linear(self, device, dtype): lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor) self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_linear_compile(self, device, dtype): + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + + hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype) + hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor) + l = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype) + l.weight = torch.nn.Parameter(lp_tensor) + lp_res = torch.compile(l)(hp_act_tensor) + self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) if __name__ == "__main__":