diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index a938e46e2e..ee2ffcf0a1 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -12,7 +12,7 @@ class TestAffineQuantized(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") + # @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") def test_tensor_core_layout_transpose(self): l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") t = l.weight diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index fd6fa89f15..d8b6d71a51 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -631,7 +631,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") - @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") + # @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): if dtype != torch.bfloat16: self.skipTest("Currently only supports bfloat16.") @@ -642,7 +642,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") - @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") + # @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest("Currently only supports bfloat16.") @@ -737,7 +737,7 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") - @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") + # @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") def test_int4_weight_only_quant_subclass(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -748,7 +748,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") - @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") + # @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -823,7 +823,7 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") - @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") + # @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") def test_int4_weight_only_quant_subclass_api(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -838,7 +838,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") - @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") + # @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1028,7 +1028,7 @@ def test_save_load_int8woqtensors(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch 2.3+.") - @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") + # @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") @torch.no_grad() def test_save_load_int4woqtensors(self, device, dtype): if dtype != torch.bfloat16: @@ -1488,7 +1488,7 @@ def test_get_model_size_autoquant(self, device, dtype): @parameterized.expand( list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) - @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") + # @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") def test_get_model_size_aqt(self, api, test_device, test_dtype): if test_dtype != torch.bfloat16: self.skipTest(f"{api} in {test_dtype} is not supported yet") diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 155a232c3e..99cad46a72 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -525,7 +525,7 @@ def test_quantized_tensor_subclass_8da4w(self): self.assertTrue(torch.equal(res, ref)) @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+") + # @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int4(self): # use 1024 so that we don't need padding diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 5f8680a509..8223b8bfda 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -29,6 +29,7 @@ from torchao.utils import ( TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4, + TORCH_VERSION_AFTER_2_5, is_fbcode, ) @@ -99,6 +100,8 @@ def _groupwise_affine_quantize_tensor_from_qparams( .to(torch.int32) .reshape_as(w) ) + if TORCH_VERSION_AFTER_2_5: + w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) return w_int4x8 @@ -500,7 +503,11 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): n_bit = 4 groupsize = 128 - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) + if TORCH_VERSION_AFTER_2_5: + input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize) + else: + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) self.assertTrue(torch.equal(w_bf16, w_bf16_ref)) diff --git a/test/test_ops.py b/test/test_ops.py index 3a5271d962..b0c3cefd54 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -95,20 +95,24 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES)) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+") +# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str) def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): N, K = shape assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0 t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") + if TORCH_VERSION_AFTER_2_5: + t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles) + if TORCH_VERSION_AFTER_2_5: + unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) assert torch.equal(t, unpacked) # TODO: Fix "test_aot_dispatch_dynamic" test failure @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+") +# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str) def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): test_utils = [ @@ -122,6 +126,8 @@ def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): test_utils.append("test_aot_dispatch_dynamic") t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") + if TORCH_VERSION_AFTER_2_5: + t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) opcheck( @@ -151,7 +157,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+") +# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size): n, k = shape @@ -210,7 +216,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in # This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+") +# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size): n, k = shape @@ -229,6 +235,9 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap # Unpack and dequantize unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles) + if TORCH_VERSION_AFTER_2_5: + unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) + dq_ao = groupwise_affine_dequantize_tensor_from_qparams( unpacked, scales, zeros, n_bit=4, groupsize=group_size ) @@ -264,13 +273,15 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap assert diff_op_ao < 1e-1 @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+") +# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size): n, k = shape device = "cuda" q = torch.randint(0, 16, shape, dtype=torch.int, device=device) + if TORCH_VERSION_AFTER_2_5: + q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles) q_groups = k // group_size scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index b71e48a3a0..4dd17c24e1 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -26,6 +26,7 @@ ) from typing import ClassVar from dataclasses import dataclass +from torchao.utils import TORCH_VERSION_AFTER_2_5 aten = torch.ops.aten @@ -245,7 +246,6 @@ def from_float( scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) - int_data = layout_type.post_process(int_data) layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) @@ -570,9 +570,12 @@ def from_plain( layout_type: LayoutType ): assert isinstance(layout_type, TensorCoreTiledLayoutType) - # assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype" - # packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles) - packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), layout_type.inner_k_tiles) + if TORCH_VERSION_AFTER_2_5: + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) + assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + else: + assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index 0c4ae45c61..1e8c5fc38c 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -12,6 +12,7 @@ from hqq.core.utils import * import torch.nn.functional as F +from torchao.utils import TORCH_VERSION_AFTER_2_5 class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): @@ -198,6 +199,8 @@ def hqq_quants_to_torch_quants( .reshape(shape) .contiguous() ) + if TORCH_VERSION_AFTER_2_5: + W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) # group_dequantize_tensor_from_qparams # W_r = W_q*scales + min_val diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index cb6acdc617..66a60c1f56 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -17,6 +17,7 @@ dequantize_affine, int_scaled_matmul, ) +from torchao.utils import TORCH_VERSION_AFTER_2_5 __all__ = [ "compute_error", @@ -349,6 +350,8 @@ def groupwise_affine_quantize_tensor_from_qparams( quant_max = 2 ** n_bit - 1 int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT) + if TORCH_VERSION_AFTER_2_5: + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) return int_data def groupwise_affine_dequantize_tensor_from_qparams( @@ -359,18 +362,26 @@ def groupwise_affine_dequantize_tensor_from_qparams( groupsize=128, ): assert groupsize > 1 - # needed for GPTQ single column dequantize - if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1: - groupsize = w_int4x8.shape[-1] - assert w_int4x8.shape[-1] % groupsize == 0 assert w_int4x8.dim() == 2 + if TORCH_VERSION_AFTER_2_5: + data = w_int4x8.to(torch.int32) + high_bits = data >> 4 + low_bits = data & 0x0F + w_int32 = torch.zeros((w_int4x8.shape[0], w_int4x8.shape[1] * 2), dtype=torch.int32, device=w_int4x8.device) + w_int32[::, ::2] = high_bits + w_int32[::, 1::2] = low_bits + else: + w_int32 = w_int4x8 + # needed for GPTQ single column dequantize + if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: + groupsize = w_int32.shape[-1] + assert w_int32.shape[-1] % groupsize == 0 block_size = (1, groupsize) input_dtype = torch.int32 quant_min = 0 quant_max = 2**n_bit - 1 - return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype) - + return dequantize_affine(w_int32, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype) def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize, dtype)