diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index e53ef03819..3634ac791f 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -27,7 +27,10 @@ get_groupwise_affine_qparams, groupwise_affine_quantize_tensor, ) -from torchao.utils import TORCH_VERSION_AFTER_2_4 +from torchao.utils import ( + TORCH_VERSION_AFTER_2_4, + TORCH_VERSION_AFTER_2_5, +) # TODO: put this in a common test utils file @@ -366,6 +369,8 @@ def _assert_close_4w(self, val, ref): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_primitives(self): n_bit = 4 group_size = 32 @@ -411,6 +416,8 @@ def test_qat_4w_primitives(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_linear(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear from torchao.quantization.GPTQ import WeightOnlyInt4Linear @@ -439,6 +446,8 @@ def test_qat_4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_quantizer(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer