diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index d4fbdd6b13..aa9415e51b 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -4,19 +4,14 @@ PerChannelSymmetricWeightUInt4Tensor, ) import unittest -from unittest import TestCase, main from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer from torch._export import capture_pre_autograd_graph -from torch._export import dynamic_dim from torch.testing._internal.common_quantization import ( NodeSpec as ns, QuantizationTestCase, ) -from torchao.quantization.utils import ( - compute_error, -) from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, ) @@ -30,7 +25,6 @@ QuantizationAnnotation, ) import copy -from packaging import version def _apply_weight_only_uint4_quant(model): @@ -229,4 +223,4 @@ def forward(self, x): ) if __name__ == "__main__": - main() + unittest.main() diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index f635fb712a..2286f3856f 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -81,6 +81,7 @@ from torchao.utils import ( TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4, + TORCH_VERSION_AFTER_2_5, unwrap_tensor_subclass, is_fbcode, benchmark_model @@ -734,6 +735,7 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") def test_int4_weight_only_quant_subclass(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -744,6 +746,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1020,7 +1023,8 @@ def test_save_load_int8woqtensors(self, device, dtype): self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch 2.3+.") + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") @torch.no_grad() def test_save_load_int4woqtensors(self, device, dtype): if dtype != torch.bfloat16: @@ -1500,7 +1504,7 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype): class TestBenchmarkModel(unittest.TestCase): - + class ToyLinearModel(torch.nn.Module): def __init__(self, m=64, n=32, k=64): super().__init__() diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 8d949fdf84..a0d2c2d0e6 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -44,6 +44,7 @@ from torchao.utils import ( TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4, + TORCH_VERSION_AFTER_2_5, ) from pathlib import Path from torchao._models.llama.tokenizer import get_tokenizer @@ -522,6 +523,7 @@ def test_quantized_tensor_subclass_8da4w(self): self.assertTrue(torch.equal(res, ref)) @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int4(self): # use 1024 so that we don't need padding diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 3cde983e9c..9c1e5e9bde 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -461,6 +461,8 @@ def __tensor_unflatten__( @classmethod def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8): + # assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype" + # packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles) packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), inner_k_tiles) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index cab8f9b622..2cfee6025a 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -348,7 +348,8 @@ def groupwise_affine_quantize_tensor_from_qparams( quant_min = 0 quant_max = 2 ** n_bit - 1 - return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT) + int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT) + return int_data def groupwise_affine_dequantize_tensor_from_qparams( w_int4x8,