Skip to content

Commit 7d3ceb7

Browse files
committed
add result check
1 parent a180182 commit 7d3ceb7

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

torchao/testing/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ class TorchAOCompileTestCase(common_utils.TestCase):
155155
# minimum sqnr for linear operation when the weight is quantized to low precision
156156
# with the above setting
157157
LINEAR_MIN_SQNR = 40
158+
COMPILE_MIN_SQNR = 50
158159

159160
@common_utils.parametrize("device", COMMON_DEVICES)
160161
@common_utils.parametrize("dtype", COMMON_DTYPES)
@@ -164,8 +165,11 @@ def test_input_output_tensor_subclass(self, device, dtype):
164165
def f(tensor):
165166
return tensor
166167

168+
ref = f(lp_tensor)
167169
f = torch.compile(f)
170+
compiled = f(lp_tensor)
168171
self.assertTrue(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS))
172+
self.assertEqual(ref.dequantize(), compiled.dequantize())
169173

170174
@common_utils.parametrize("device", COMMON_DEVICES)
171175
@common_utils.parametrize("dtype", COMMON_DTYPES)
@@ -175,8 +179,11 @@ def test_input_tensor_subclass(self, device, dtype):
175179
def f(tensor):
176180
return tensor.dequantize()
177181

182+
ref = f(lp_tensor)
178183
f = torch.compile(f)
184+
compiled = f(lp_tensor)
179185
self.assertFalse(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS))
186+
self.assertEqual(ref, compiled)
180187

181188
@common_utils.parametrize("device", COMMON_DEVICES)
182189
@common_utils.parametrize("dtype", COMMON_DTYPES)
@@ -185,8 +192,13 @@ def test_output_tensor_subclass(self, device, dtype):
185192
def f(hp_tensor):
186193
return self.FACTORY_FN(hp_tensor, **self.kwargs)
187194

195+
ref = f(hp_tensor)
188196
f = torch.compile(f)
197+
compiled = f(hp_tensor)
189198
self.assertTrue(isinstance(f(hp_tensor), self.TENSOR_SUBCLASS))
199+
# bfloat16 seems to result in much larger numerical differences
200+
if dtype != torch.bfloat16:
201+
self.assertGreater(torchao.quantization.utils.compute_error(ref.dequantize(), compiled.dequantize()), self.COMPILE_MIN_SQNR)
190202

191203
@common_utils.parametrize("device", COMMON_DEVICES)
192204
@common_utils.parametrize("dtype", COMMON_DTYPES)

0 commit comments

Comments
 (0)