@@ -1022,7 +1022,6 @@ def forward(self, x):
1022
1022
1023
1023
@parameterized .expand (COMMON_DEVICE_DTYPE )
1024
1024
@unittest .skipIf (is_fbcode (), "'PlainAQTLayout' object has no attribute 'int_data'" )
1025
- @unittest .skipIf (TORCH_VERSION_AFTER_2_5 , "Can't save local lambda function for tensor subclass" )
1026
1025
@torch .no_grad ()
1027
1026
def test_save_load_dqtensors (self , device , dtype ):
1028
1027
if device == "cpu" :
@@ -1226,7 +1225,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
1226
1225
self .skipTest (f"bfloat16 requires sm80+" )
1227
1226
if m1 == 1 or m2 == 1 :
1228
1227
self .skipTest (f"Shape { (m1 , m2 , k , n )} requires sm80+" )
1229
- # This test fails on v0.4.0 and torch 2.4, so skipping for now.
1228
+ # This test fails on v0.4.0 and torch 2.4, so skipping for now.
1230
1229
if m1 == 1 or m2 == 1 and not TORCH_VERSION_AFTER_2_5 :
1231
1230
self .skipTest (f"Shape { (m1 , m2 , k , n )} requires torch version > 2.4" )
1232
1231
model = torch .nn .Sequential (
@@ -1299,7 +1298,7 @@ def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n):
1299
1298
self .skipTest (f"bfloat16 requires sm80+" )
1300
1299
if m1 == 1 or m2 == 1 :
1301
1300
self .skipTest (f"Shape { (m1 , m2 , k , n )} requires sm80+" )
1302
- # This test fails on v0.4.0 and torch 2.4, so skipping for now.
1301
+ # This test fails on v0.4.0 and torch 2.4, so skipping for now.
1303
1302
if m1 == 1 or m2 == 1 and not TORCH_VERSION_AFTER_2_5 :
1304
1303
self .skipTest (f"Shape { (m1 , m2 , k , n )} requires torch version > 2.4" )
1305
1304
0 commit comments