98
98
99
99
def _int8wo_api (mod ):
100
100
if TORCH_VERSION_AFTER_2_4 :
101
- quantize (mod , int8_weight_only ())
101
+ quantize (mod , int8_weight_only (), set_inductor_config = False )
102
102
unwrap_tensor_subclass (mod )
103
103
else :
104
104
change_linear_weights_to_int8_woqtensors (mod )
105
105
106
106
def _int8da_int8w_api (mod ):
107
107
if TORCH_VERSION_AFTER_2_4 :
108
- quantize (mod , int8_dynamic_activation_int8_weight ())
108
+ quantize (mod , int8_dynamic_activation_int8_weight (), set_inductor_config = False )
109
109
unwrap_tensor_subclass (mod )
110
110
else :
111
111
change_linear_weights_to_int8_dqtensors (mod )
112
112
113
113
def _int4wo_api (mod ):
114
114
if TORCH_VERSION_AFTER_2_4 :
115
- quantize (mod , int4_weight_only ())
115
+ quantize (mod , int4_weight_only (), set_inductor_config = False )
116
116
unwrap_tensor_subclass (mod )
117
117
else :
118
118
change_linear_weights_to_int4_woqtensors (mod )
@@ -124,6 +124,13 @@ def _int4wo_api(mod):
124
124
_int4wo_api ,
125
125
]
126
126
127
+ def undo_recommended_configs ():
128
+ torch ._inductor .config .coordinate_descent_tuning = False
129
+ torch ._inductor .config .coordinate_descent_check_all_directions = False
130
+ torch ._inductor .config .force_fuse_int_mm_with_mul = False
131
+ torch ._inductor .config .fx_graph_cache = False
132
+ torch ._inductor .config .triton .unique_kernel_names = False
133
+ torch .set_float32_matmul_precision ("highest" )
127
134
128
135
def combine_parameters (a , b ):
129
136
new_tuples = []
@@ -689,6 +696,7 @@ def test_int8_dynamic_quant_subclass(self, device, dtype):
689
696
690
697
@parameterized .expand (COMMON_DEVICE_DTYPE )
691
698
def test_int8_weight_only_quant_subclass (self , device , dtype ):
699
+ undo_recommended_configs ()
692
700
self ._test_lin_weight_subclass_impl (
693
701
Int8WeightOnlyQuantizedLinearWeight .from_float , device , 40 , test_dtype = dtype
694
702
)
@@ -794,6 +802,7 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype):
794
802
@parameterized .expand (COMMON_DEVICE_DTYPE )
795
803
@unittest .skipIf (is_fbcode (), "broken in fbcode" )
796
804
def test_int8_weight_only_quant_subclass_api (self , device , dtype ):
805
+ undo_recommended_configs ()
797
806
self ._test_lin_weight_subclass_api_impl (
798
807
_int8wo_api , device , 40 , test_dtype = dtype
799
808
)
@@ -879,6 +888,7 @@ def test_weight_only_quant(self):
879
888
@torch .no_grad ()
880
889
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
881
890
def test_weight_only_quant_force_mixed_mm (self , device , dtype ):
891
+ undo_recommended_configs ()
882
892
if device != "cuda" :
883
893
self .skipTest (f"weight_only_quant_force_mixed_mm can't be constructed on { device } " )
884
894
if dtype == torch .bfloat16 and torch .cuda .get_device_capability () < (8 , 0 ):
@@ -907,6 +917,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
907
917
@parameterized .expand (COMMON_DEVICE_DTYPE )
908
918
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
909
919
def test_weight_only_quant_use_mixed_mm (self , device , dtype ):
920
+ undo_recommended_configs ()
910
921
if device != "cuda" :
911
922
self .skipTest (f"weight_only_quant_force_mixed_mm can't be constructed on { device } " )
912
923
if dtype == torch .bfloat16 and torch .cuda .get_device_capability () < (8 , 0 ):
@@ -1004,6 +1015,7 @@ def test_save_load_dqtensors(self, device, dtype):
1004
1015
@torch .no_grad ()
1005
1016
@unittest .skipIf (is_fbcode (), "broken in fbcode" )
1006
1017
def test_save_load_int8woqtensors (self , device , dtype ):
1018
+ undo_recommended_configs ()
1007
1019
self ._test_handle_save_load_meta_impl (_int8wo_api , device , test_dtype = dtype )
1008
1020
1009
1021
@parameterized .expand (COMMON_DEVICE_DTYPE )
@@ -1153,6 +1165,7 @@ class TestAutoQuant(unittest.TestCase):
1153
1165
]))
1154
1166
@unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "autoquant requires 2.3+." )
1155
1167
def test_autoquant_one_input (self , device , dtype , m , k , n ):
1168
+ undo_recommended_configs ()
1156
1169
print ("(m, k, n): " , (m , k , n ))
1157
1170
if device != "cuda" or not torch .cuda .is_available ():
1158
1171
self .skipTest (f"autoquant currently does not support { device } " )
@@ -1173,7 +1186,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
1173
1186
torch .nn .ReLU (),
1174
1187
).to (device ).to (dtype )
1175
1188
out = model (example_input )
1176
- torchao .autoquant (model )
1189
+ torchao .autoquant (model , set_inductor_config = False )
1177
1190
out2 = model (example_input )
1178
1191
sqnr = SQNR (out , out2 )
1179
1192
self .assertTrue (sqnr >= 30 )
@@ -1186,6 +1199,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
1186
1199
]))
1187
1200
@unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "autoquant requires 2.3+." )
1188
1201
def test_autoquant_compile (self , device , dtype , m1 , m2 , k , n ):
1202
+ undo_recommended_configs ()
1189
1203
if device != "cuda" or not torch .cuda .is_available ():
1190
1204
self .skipTest (f"autoquant currently does not support { device } " )
1191
1205
if torch .cuda .is_available () and torch .cuda .get_device_capability () < (8 , 0 ):
@@ -1202,7 +1216,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
1202
1216
example_input2 = torch .randn (m2 , k , device = device , dtype = dtype )
1203
1217
out = model (example_input )
1204
1218
1205
- mod = torchao .autoquant (torch .compile (model ), manual = True )
1219
+ mod = torchao .autoquant (torch .compile (model ), manual = True , set_inductor_config = False )
1206
1220
mod (example_input )
1207
1221
mod (example_input2 )
1208
1222
mod .finalize_autoquant ()
@@ -1214,6 +1228,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
1214
1228
@parameterized .expand (COMMON_DEVICE_DTYPE )
1215
1229
@unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "autoquant requires 2.3+." )
1216
1230
def test_autoquant_manual (self , device , dtype ):
1231
+ undo_recommended_configs ()
1217
1232
if device != "cuda" or not torch .cuda .is_available ():
1218
1233
self .skipTest (f"autoquant currently does not support { device } " )
1219
1234
if torch .cuda .is_available () and torch .cuda .get_device_capability () < (8 , 0 ):
@@ -1229,15 +1244,15 @@ def test_autoquant_manual(self, device, dtype):
1229
1244
example_input2 = torch .randn (m2 , k , device = device , dtype = dtype )
1230
1245
out = model (example_input )
1231
1246
1232
- mod = torchao .autoquant (torch .compile (model ), manual = True )
1247
+ mod = torchao .autoquant (torch .compile (model ), manual = True , set_inductor_config = False )
1233
1248
mod (example_input )
1234
1249
mod (example_input2 )
1235
1250
mod .finalize_autoquant ()
1236
1251
out2 = mod (example_input )
1237
1252
sqnr = SQNR (out , out2 )
1238
1253
self .assertTrue (sqnr >= 30 )
1239
1254
1240
- mod2 = torchao .autoquant (model , manual = True )
1255
+ mod2 = torchao .autoquant (model , manual = True , set_inductor_config = False )
1241
1256
mod2 (example_input )
1242
1257
mod2 (example_input2 )
1243
1258
mod2 .finalize_autoquant ()
@@ -1254,6 +1269,7 @@ def test_autoquant_manual(self, device, dtype):
1254
1269
]))
1255
1270
@unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "autoquant requires 2.3+." )
1256
1271
def test_autoquant_kwargs (self , device , dtype , m1 , m2 , k , n ):
1272
+ undo_recommended_configs ()
1257
1273
if device != "cuda" or not torch .cuda .is_available ():
1258
1274
self .skipTest (f"autoquant currently does not support { device } " )
1259
1275
if torch .cuda .is_available () and torch .cuda .get_device_capability () < (8 , 0 ):
@@ -1280,7 +1296,7 @@ def forward(self, x, y):
1280
1296
}
1281
1297
out = model (** example_input )
1282
1298
1283
- mod = torchao .autoquant (torch .compile (model ))
1299
+ mod = torchao .autoquant (torch .compile (model ), set_inductor_config = False )
1284
1300
mod (** example_input )
1285
1301
1286
1302
out2 = mod (** example_input )
@@ -1293,6 +1309,7 @@ def forward(self, x, y):
1293
1309
]))
1294
1310
@unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "autoquant requires 2.3+." )
1295
1311
def test_autoquant_double_access (self , device , dtype , m , k , n ):
1312
+ undo_recommended_configs ()
1296
1313
if device != "cuda" or not torch .cuda .is_available ():
1297
1314
self .skipTest (f"autoquant currently does not support { device } " )
1298
1315
if torch .cuda .is_available () and torch .cuda .get_device_capability () < (8 , 0 ):
@@ -1316,7 +1333,7 @@ def forward(self, x):
1316
1333
x_in = torch .randn (m , k , device = device , dtype = dtype )
1317
1334
model = DoubleAccess ().to (device ).to (dtype )
1318
1335
model (x_in )
1319
- torchao .autoquant (model )
1336
+ torchao .autoquant (model , set_inductor_config = False )
1320
1337
assert not isinstance (model .lin1 .weight .weight , AutoQuantizableLinearWeight )
1321
1338
model (x_in )
1322
1339
@@ -1443,7 +1460,7 @@ def test_get_model_size_autoquant(self, device, dtype):
1443
1460
qtensor_class_list = (
1444
1461
AQWeightOnlyQuantizedLinearWeight2 ,
1445
1462
)
1446
- mod = torchao .autoquant (torch .compile (model ), qtensor_class_list = qtensor_class_list )
1463
+ mod = torchao .autoquant (torch .compile (model ), qtensor_class_list = qtensor_class_list , set_inductor_config = False )
1447
1464
mod (example_input )
1448
1465
size2 = torchao .utils .get_model_size_in_bytes (mod )
1449
1466
self .assertTrue (size2 < size )
0 commit comments