-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add test util for basic tensor subclass functionalities #839
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import unittest | ||
import functools | ||
import copy | ||
import torch | ||
import torchao | ||
|
||
from torch.testing._internal import common_utils | ||
from torchao.dtypes import AffineQuantizedTensor | ||
from torchao.dtypes import to_affine_quantized_intx | ||
from torchao.quantization.quant_primitives import MappingType | ||
|
||
""" | ||
How to use: | ||
|
||
import unittest | ||
from torchao.testing.utils import TorchAOBasicTestCase, copy_tests | ||
from torch.testing._internal import common_utils | ||
|
||
# TODO: currently there is no way to set COMMON_DEVICES/COMMON_DTYPES | ||
# we can figure out this a bit later | ||
|
||
# change arguments | ||
class MyTestCase(TorchAOBasicTestCase): | ||
TENSOR_SUBCLASS = MyDTypeTensor | ||
FACTOR_FN = to_my_dtype | ||
kwargs = {"target_dtype": torch.uint8} | ||
LINEAR_MIN_SQNR = 30 | ||
|
||
# copy the instantiated tests | ||
copy_tests(TorchAOBasicTestCase, MyTestCase, "my_test_case") | ||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
""" | ||
|
||
# copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389 | ||
def copy_tests( | ||
my_cls, other_cls, suffix, test_failures=None, xfail_prop=None | ||
): # noqa: B902 | ||
for name, value in my_cls.__dict__.items(): | ||
if name.startswith("test_"): | ||
# You cannot copy functions in Python, so we use closures here to | ||
# create objects with different ids. Otherwise, unittest.skip | ||
# would modify all methods sharing the same object id. Also, by | ||
# using a default argument, we create a copy instead of a | ||
# reference. Otherwise, we would lose access to the value. | ||
|
||
@functools.wraps(value) | ||
def new_test(self, value=value): | ||
return value(self) | ||
|
||
# Copy __dict__ which may contain test metadata | ||
new_test.__dict__ = copy.deepcopy(value.__dict__) | ||
|
||
if xfail_prop is not None and hasattr(value, xfail_prop): | ||
new_test = unittest.expectedFailure(new_test) | ||
|
||
tf = test_failures and test_failures.get(name) | ||
if tf is not None and suffix in tf.suffixes: | ||
skip_func = ( | ||
unittest.skip("Skipped!") | ||
if tf.is_skip | ||
else unittest.expectedFailure | ||
) | ||
new_test = skip_func(new_test) | ||
|
||
setattr(other_cls, f"{name}_{suffix}", new_test) | ||
|
||
|
||
|
||
class TorchAOBasicTestCase(common_utils.TestCase): | ||
"""Basic test case for tensor subclasses | ||
""" | ||
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] | ||
|
||
TENSOR_SUBCLASS = AffineQuantizedTensor | ||
FACTORY_FN = to_affine_quantized_intx | ||
kwargs = { | ||
"mapping_type": MappingType.ASYMMETRIC, | ||
"block_size": (1, 32), | ||
"target_dtype": torch.uint8, | ||
} | ||
# minimum sqnr for linear operation when the weight is quantized to low precision | ||
# with the above setting | ||
LINEAR_MIN_SQNR = 40 | ||
|
||
def test_flatten_unflatten(self): | ||
hp_tensor = torch.randn(4, 128) | ||
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) | ||
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() | ||
tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict} | ||
outer_size = lp_tensor.size() | ||
outer_stride = lp_tensor.stride() | ||
reconstructed = self.TENSOR_SUBCLASS.__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) | ||
self.assertEqual(lp_tensor.dequantize(), reconstructed.dequantize()) | ||
|
||
@common_utils.parametrize("device", COMMON_DEVICES) | ||
@common_utils.parametrize("dtype", COMMON_DTYPES) | ||
def test_hp_tensor_device_dtype(self, device, dtype): | ||
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) | ||
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) | ||
|
||
@common_utils.parametrize("device1", COMMON_DEVICES) | ||
@common_utils.parametrize("device2", COMMON_DEVICES) | ||
def test_device1_to_device2(self, device1, device2): | ||
"""Note: this should be parametrized with device1 and device2 | ||
e.g. device1 = ["cpu", "cuda"], device2 = ["cpu", "cuda"] | ||
""" | ||
hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16) | ||
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) | ||
lp_tensor.to(device=device2) | ||
|
||
hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16) | ||
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) | ||
lp_tensor.to(device2) | ||
|
||
hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16) | ||
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) | ||
lp_tensor.cuda() | ||
|
||
hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16) | ||
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) | ||
lp_tensor.cpu() | ||
|
||
@common_utils.parametrize("device", COMMON_DEVICES) | ||
@common_utils.parametrize("dtype", COMMON_DTYPES) | ||
def test_transpose(self, device, dtype): | ||
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) | ||
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) | ||
lp_tensor = lp_tensor.t() | ||
self.assertEqual(lp_tensor.shape, (128, 4)) | ||
|
||
@common_utils.parametrize("device", COMMON_DEVICES) | ||
@common_utils.parametrize("dtype", COMMON_DTYPES) | ||
def test_linear(self, device, dtype): | ||
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) | ||
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) | ||
|
||
hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype) | ||
hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor) | ||
lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor) | ||
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) | ||
|
||
@common_utils.parametrize("device", COMMON_DEVICES) | ||
@common_utils.parametrize("dtype", COMMON_DTYPES) | ||
def test_linear_compile(self, device, dtype): | ||
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) | ||
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) | ||
|
||
hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype) | ||
hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor) | ||
l = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype) | ||
l.weight = torch.nn.Parameter(lp_tensor) | ||
lp_res = torch.compile(l)(hp_act_tensor) | ||
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) | ||
|
||
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one kind of annoying thing that makes
__tensor_flatten__
and friends difficult to test in isolation is that we want their implementations to be side-effect-free / idempotent. This is kind of difficult to test by calling the fn directly, although usually you will find out if your implementation is wrong pretty quickly because usages of compile with your subclass will break.This test suite seems like a good start (if the goal is "a test suite for all AO-specific quantization subclasses that conform to a specific API"). Do you want / plan to add more compile-related tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, yeah I can add more compile related tests, although I'm not exactly sure what are all the scenarios I should be testing, I can also start with simple cases like just compile the quantized weights and run on linear op as well, but I'd like to hear what's your thoughts on rough categories of compile related tests we can add as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few that come to mind are:
(1) simple test where inputs + outputs of the function
f
are your subclass and you do some basic compute (return inp + inp
)(2) subclass constructor in the graph (e.g. inputs to
f
are plain tensors, output is a subclass)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg, I can add these in a separate PR