Skip to content

Commit

Permalink
add result check
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 committed Sep 26, 2024
1 parent a180182 commit 7d3ceb7
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class TorchAOCompileTestCase(common_utils.TestCase):
# minimum sqnr for linear operation when the weight is quantized to low precision
# with the above setting
LINEAR_MIN_SQNR = 40
COMPILE_MIN_SQNR = 50

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
Expand All @@ -164,8 +165,11 @@ def test_input_output_tensor_subclass(self, device, dtype):
def f(tensor):
return tensor

ref = f(lp_tensor)
f = torch.compile(f)
compiled = f(lp_tensor)
self.assertTrue(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS))
self.assertEqual(ref.dequantize(), compiled.dequantize())

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
Expand All @@ -175,8 +179,11 @@ def test_input_tensor_subclass(self, device, dtype):
def f(tensor):
return tensor.dequantize()

ref = f(lp_tensor)
f = torch.compile(f)
compiled = f(lp_tensor)
self.assertFalse(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS))
self.assertEqual(ref, compiled)

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
Expand All @@ -185,8 +192,13 @@ def test_output_tensor_subclass(self, device, dtype):
def f(hp_tensor):
return self.FACTORY_FN(hp_tensor, **self.kwargs)

ref = f(hp_tensor)
f = torch.compile(f)
compiled = f(hp_tensor)
self.assertTrue(isinstance(f(hp_tensor), self.TENSOR_SUBCLASS))
# bfloat16 seems to result in much larger numerical differences
if dtype != torch.bfloat16:
self.assertGreater(torchao.quantization.utils.compute_error(ref.dequantize(), compiled.dequantize()), self.COMPILE_MIN_SQNR)

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
Expand Down

0 comments on commit 7d3ceb7

Please sign in to comment.