diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 19e99c30f9..6165b1214c 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -158,18 +158,29 @@ class TorchAOCompileTestCase(common_utils.TestCase): @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) - def test_input_output(self, device, dtype): + def test_input_output_tensor_subclass(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) def f(tensor): - return tensor.t() + return tensor f = torch.compile(f) self.assertTrue(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS)) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) - def test_input_output(self, device, dtype): + def test_input_tensor_subclass(self, device, dtype): + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + def f(tensor): + return tensor.dequantize() + + f = torch.compile(f) + self.assertFalse(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS)) + + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_output_tensor_subclass(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) def f(hp_tensor): return self.FACTORY_FN(hp_tensor, **self.kwargs)