Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 committed Sep 19, 2024
1 parent 72c4a43 commit a180182
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,29 @@ class TorchAOCompileTestCase(common_utils.TestCase):

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_input_output(self, device, dtype):
def test_input_output_tensor_subclass(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
def f(tensor):
return tensor.t()
return tensor

f = torch.compile(f)
self.assertTrue(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS))

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_input_output(self, device, dtype):
def test_input_tensor_subclass(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
def f(tensor):
return tensor.dequantize()

f = torch.compile(f)
self.assertFalse(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS))

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_output_tensor_subclass(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
def f(hp_tensor):
return self.FACTORY_FN(hp_tensor, **self.kwargs)
Expand Down

0 comments on commit a180182

Please sign in to comment.