Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test util for basic tensor subclass functionalities #839

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import unittest
import functools
import copy
import torch
import torchao

from torch.testing._internal import common_utils
from torchao.dtypes import AffineQuantizedTensor
from torchao.dtypes import to_affine_quantized_intx
from torchao.quantization.quant_primitives import MappingType

"""
How to use:

import unittest
from torchao.testing.utils import TorchAOBasicTestCase, copy_tests
from torch.testing._internal import common_utils

# TODO: currently there is no way to set COMMON_DEVICES/COMMON_DTYPES
# we can figure out this a bit later

# change arguments
class MyTestCase(TorchAOBasicTestCase):
TENSOR_SUBCLASS = MyDTypeTensor
FACTOR_FN = to_my_dtype
kwargs = {"target_dtype": torch.uint8}
LINEAR_MIN_SQNR = 30

# copy the instantiated tests
copy_tests(TorchAOBasicTestCase, MyTestCase, "my_test_case")

if __name__ == "__main__":
unittest.main()
"""

# copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389
def copy_tests(
my_cls, other_cls, suffix, test_failures=None, xfail_prop=None
): # noqa: B902
for name, value in my_cls.__dict__.items():
if name.startswith("test_"):
# You cannot copy functions in Python, so we use closures here to
# create objects with different ids. Otherwise, unittest.skip
# would modify all methods sharing the same object id. Also, by
# using a default argument, we create a copy instead of a
# reference. Otherwise, we would lose access to the value.

@functools.wraps(value)
def new_test(self, value=value):
return value(self)

# Copy __dict__ which may contain test metadata
new_test.__dict__ = copy.deepcopy(value.__dict__)

if xfail_prop is not None and hasattr(value, xfail_prop):
new_test = unittest.expectedFailure(new_test)

tf = test_failures and test_failures.get(name)
if tf is not None and suffix in tf.suffixes:
skip_func = (
unittest.skip("Skipped!")
if tf.is_skip
else unittest.expectedFailure
)
new_test = skip_func(new_test)

setattr(other_cls, f"{name}_{suffix}", new_test)



class TorchAOBasicTestCase(common_utils.TestCase):
"""Basic test case for tensor subclasses
"""
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

TENSOR_SUBCLASS = AffineQuantizedTensor
FACTORY_FN = to_affine_quantized_intx
kwargs = {
"mapping_type": MappingType.ASYMMETRIC,
"block_size": (1, 32),
"target_dtype": torch.uint8,
}
# minimum sqnr for linear operation when the weight is quantized to low precision
# with the above setting
LINEAR_MIN_SQNR = 40

def test_flatten_unflatten(self):
hp_tensor = torch.randn(4, 128)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict}
outer_size = lp_tensor.size()
outer_stride = lp_tensor.stride()
reconstructed = self.TENSOR_SUBCLASS.__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one kind of annoying thing that makes __tensor_flatten__ and friends difficult to test in isolation is that we want their implementations to be side-effect-free / idempotent. This is kind of difficult to test by calling the fn directly, although usually you will find out if your implementation is wrong pretty quickly because usages of compile with your subclass will break.

This test suite seems like a good start (if the goal is "a test suite for all AO-specific quantization subclasses that conform to a specific API"). Do you want / plan to add more compile-related tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yeah I can add more compile related tests, although I'm not exactly sure what are all the scenarios I should be testing, I can also start with simple cases like just compile the quantized weights and run on linear op as well, but I'd like to hear what's your thoughts on rough categories of compile related tests we can add as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few that come to mind are:

(1) simple test where inputs + outputs of the function f are your subclass and you do some basic compute (return inp + inp)
(2) subclass constructor in the graph (e.g. inputs to f are plain tensors, output is a subclass)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, I can add these in a separate PR

self.assertEqual(lp_tensor.dequantize(), reconstructed.dequantize())

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_hp_tensor_device_dtype(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)

@common_utils.parametrize("device1", COMMON_DEVICES)
@common_utils.parametrize("device2", COMMON_DEVICES)
def test_device1_to_device2(self, device1, device2):
"""Note: this should be parametrized with device1 and device2
e.g. device1 = ["cpu", "cuda"], device2 = ["cpu", "cuda"]
"""
hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
lp_tensor.to(device=device2)

hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
lp_tensor.to(device2)

hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
lp_tensor.cuda()

hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
lp_tensor.cpu()

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_transpose(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
lp_tensor = lp_tensor.t()
self.assertEqual(lp_tensor.shape, (128, 4))

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_linear(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)

hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype)
hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor)
lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor)
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_linear_compile(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)

hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype)
hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor)
l = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype)
l.weight = torch.nn.Parameter(lp_tensor)
lp_res = torch.compile(l)(hp_act_tensor)
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)

common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)

if __name__ == "__main__":
unittest.main()
Loading