@@ -155,6 +155,7 @@ class TorchAOCompileTestCase(common_utils.TestCase):
155
155
# minimum sqnr for linear operation when the weight is quantized to low precision
156
156
# with the above setting
157
157
LINEAR_MIN_SQNR = 40
158
+ COMPILE_MIN_SQNR = 50
158
159
159
160
@common_utils .parametrize ("device" , COMMON_DEVICES )
160
161
@common_utils .parametrize ("dtype" , COMMON_DTYPES )
@@ -164,8 +165,11 @@ def test_input_output_tensor_subclass(self, device, dtype):
164
165
def f (tensor ):
165
166
return tensor
166
167
168
+ ref = f (lp_tensor )
167
169
f = torch .compile (f )
170
+ compiled = f (lp_tensor )
168
171
self .assertTrue (isinstance (f (lp_tensor ), self .TENSOR_SUBCLASS ))
172
+ self .assertEqual (ref .dequantize (), compiled .dequantize ())
169
173
170
174
@common_utils .parametrize ("device" , COMMON_DEVICES )
171
175
@common_utils .parametrize ("dtype" , COMMON_DTYPES )
@@ -175,8 +179,11 @@ def test_input_tensor_subclass(self, device, dtype):
175
179
def f (tensor ):
176
180
return tensor .dequantize ()
177
181
182
+ ref = f (lp_tensor )
178
183
f = torch .compile (f )
184
+ compiled = f (lp_tensor )
179
185
self .assertFalse (isinstance (f (lp_tensor ), self .TENSOR_SUBCLASS ))
186
+ self .assertEqual (ref , compiled )
180
187
181
188
@common_utils .parametrize ("device" , COMMON_DEVICES )
182
189
@common_utils .parametrize ("dtype" , COMMON_DTYPES )
@@ -185,8 +192,13 @@ def test_output_tensor_subclass(self, device, dtype):
185
192
def f (hp_tensor ):
186
193
return self .FACTORY_FN (hp_tensor , ** self .kwargs )
187
194
195
+ ref = f (hp_tensor )
188
196
f = torch .compile (f )
197
+ compiled = f (hp_tensor )
189
198
self .assertTrue (isinstance (f (hp_tensor ), self .TENSOR_SUBCLASS ))
199
+ # bfloat16 seems to result in much larger numerical differences
200
+ if dtype != torch .bfloat16 :
201
+ self .assertGreater (torchao .quantization .utils .compute_error (ref .dequantize (), compiled .dequantize ()), self .COMPILE_MIN_SQNR )
190
202
191
203
@common_utils .parametrize ("device" , COMMON_DEVICES )
192
204
@common_utils .parametrize ("dtype" , COMMON_DTYPES )
0 commit comments