Skip to content

Commit 8fa11a6

Browse files
authored
Fix int4pack_mm error (#517)
* Fix int4pack_mm error * fix CI * Fix CI * Fix CI * Fix CI * Fix CI
1 parent e5b705c commit 8fa11a6

File tree

8 files changed

+61
-26
lines changed

8 files changed

+61
-26
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
class TestAffineQuantized(TestCase):
1414
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
15-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
15+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
1616
def test_tensor_core_layout_transpose(self):
1717
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
1818
t = l.weight

test/integration/test_integration.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
631631

632632
@parameterized.expand(COMMON_DEVICE_DTYPE)
633633
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
634-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
634+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
635635
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
636636
if dtype != torch.bfloat16:
637637
self.skipTest("Currently only supports bfloat16.")
@@ -642,7 +642,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
642642

643643
@parameterized.expand(COMMON_DEVICE_DTYPE)
644644
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
645-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
645+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
646646
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
647647
if dtype != torch.bfloat16:
648648
self.skipTest("Currently only supports bfloat16.")
@@ -737,7 +737,7 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
737737

738738
@parameterized.expand(COMMON_DEVICE_DTYPE)
739739
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
740-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
740+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
741741
def test_int4_weight_only_quant_subclass(self, device, dtype):
742742
if dtype != torch.bfloat16:
743743
self.skipTest(f"Fails for {dtype}")
@@ -748,7 +748,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
748748

749749
@parameterized.expand(COMMON_DEVICE_DTYPE)
750750
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
751-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
751+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
752752
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
753753
if dtype != torch.bfloat16:
754754
self.skipTest(f"Fails for {dtype}")
@@ -823,7 +823,7 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):
823823

824824
@parameterized.expand(COMMON_DEVICE_DTYPE)
825825
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
826-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
826+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
827827
def test_int4_weight_only_quant_subclass_api(self, device, dtype):
828828
if dtype != torch.bfloat16:
829829
self.skipTest(f"Fails for {dtype}")
@@ -838,7 +838,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
838838

839839
@parameterized.expand(COMMON_DEVICE_DTYPE)
840840
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
841-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
841+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
842842
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
843843
if dtype != torch.bfloat16:
844844
self.skipTest(f"Fails for {dtype}")
@@ -1028,7 +1028,7 @@ def test_save_load_int8woqtensors(self, device, dtype):
10281028

10291029
@parameterized.expand(COMMON_DEVICE_DTYPE)
10301030
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch 2.3+.")
1031-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
1031+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
10321032
@torch.no_grad()
10331033
def test_save_load_int4woqtensors(self, device, dtype):
10341034
if dtype != torch.bfloat16:
@@ -1488,7 +1488,7 @@ def test_get_model_size_autoquant(self, device, dtype):
14881488
@parameterized.expand(
14891489
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
14901490
)
1491-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
1491+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
14921492
def test_get_model_size_aqt(self, api, test_device, test_dtype):
14931493
if test_dtype != torch.bfloat16:
14941494
self.skipTest(f"{api} in {test_dtype} is not supported yet")

test/quantization/test_quant_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ def test_quantized_tensor_subclass_8da4w(self):
525525
self.assertTrue(torch.equal(res, ref))
526526

527527
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
528-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
528+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
529529
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
530530
def test_quantized_tensor_subclass_int4(self):
531531
# use 1024 so that we don't need padding

test/quantization/test_quant_primitives.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torchao.utils import (
3030
TORCH_VERSION_AFTER_2_3,
3131
TORCH_VERSION_AFTER_2_4,
32+
TORCH_VERSION_AFTER_2_5,
3233
is_fbcode,
3334
)
3435

@@ -99,6 +100,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
99100
.to(torch.int32)
100101
.reshape_as(w)
101102
)
103+
if TORCH_VERSION_AFTER_2_5:
104+
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
102105

103106
return w_int4x8
104107

@@ -500,7 +503,11 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
500503
n_bit = 4
501504
groupsize = 128
502505

503-
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
506+
if TORCH_VERSION_AFTER_2_5:
507+
input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
508+
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize)
509+
else:
510+
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
504511
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
505512

506513
self.assertTrue(torch.equal(w_bf16, w_bf16_ref))

test/test_ops.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,20 +95,24 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
9595
TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES))
9696

9797
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
98-
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
98+
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
9999
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str)
100100
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
101101
N, K = shape
102102
assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0
103103

104104
t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
105+
if TORCH_VERSION_AFTER_2_5:
106+
t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8)
105107
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles)
106108
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles)
109+
if TORCH_VERSION_AFTER_2_5:
110+
unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8)
107111
assert torch.equal(t, unpacked)
108112

109113
# TODO: Fix "test_aot_dispatch_dynamic" test failure
110114
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
111-
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
115+
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
112116
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str)
113117
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
114118
test_utils = [
@@ -122,6 +126,8 @@ def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
122126
test_utils.append("test_aot_dispatch_dynamic")
123127

124128
t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
129+
if TORCH_VERSION_AFTER_2_5:
130+
t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8)
125131
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles)
126132

127133
opcheck(
@@ -151,7 +157,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
151157

152158

153159
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
154-
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
160+
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
155161
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
156162
def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size):
157163
n, k = shape
@@ -210,7 +216,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
210216

211217
# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
212218
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
213-
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
219+
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
214220
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
215221
def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size):
216222
n, k = shape
@@ -229,6 +235,9 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
229235

230236
# Unpack and dequantize
231237
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles)
238+
if TORCH_VERSION_AFTER_2_5:
239+
unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8)
240+
232241
dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
233242
unpacked, scales, zeros, n_bit=4, groupsize=group_size
234243
)
@@ -264,13 +273,15 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
264273
assert diff_op_ao < 1e-1
265274

266275
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
267-
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
276+
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
268277
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
269278
def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size):
270279
n, k = shape
271280
device = "cuda"
272281

273282
q = torch.randint(0, 16, shape, dtype=torch.int, device=device)
283+
if TORCH_VERSION_AFTER_2_5:
284+
q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
274285
packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles)
275286
q_groups = k // group_size
276287
scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from typing import ClassVar
2828
from dataclasses import dataclass
29+
from torchao.utils import TORCH_VERSION_AFTER_2_5
2930

3031
aten = torch.ops.aten
3132

@@ -245,7 +246,6 @@ def from_float(
245246

246247
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
247248
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
248-
249249
int_data = layout_type.post_process(int_data)
250250

251251
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
@@ -570,9 +570,12 @@ def from_plain(
570570
layout_type: LayoutType
571571
):
572572
assert isinstance(layout_type, TensorCoreTiledLayoutType)
573-
# assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
574-
# packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles)
575-
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), layout_type.inner_k_tiles)
573+
if TORCH_VERSION_AFTER_2_5:
574+
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
575+
assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype"
576+
else:
577+
assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype"
578+
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles)
576579
scale = scale.reshape(int_data.shape[0], -1)
577580
zero_point = zero_point.reshape(int_data.shape[0], -1)
578581
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)

torchao/prototype/hqq/hqq_tinygemm_linear.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from hqq.core.utils import *
1313

1414
import torch.nn.functional as F
15+
from torchao.utils import TORCH_VERSION_AFTER_2_5
1516

1617

1718
class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
@@ -198,6 +199,8 @@ def hqq_quants_to_torch_quants(
198199
.reshape(shape)
199200
.contiguous()
200201
)
202+
if TORCH_VERSION_AFTER_2_5:
203+
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
201204

202205
# group_dequantize_tensor_from_qparams
203206
# W_r = W_q*scales + min_val

torchao/quantization/utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
dequantize_affine,
1818
int_scaled_matmul,
1919
)
20+
from torchao.utils import TORCH_VERSION_AFTER_2_5
2021

2122
__all__ = [
2223
"compute_error",
@@ -349,6 +350,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
349350
quant_max = 2 ** n_bit - 1
350351

351352
int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
353+
if TORCH_VERSION_AFTER_2_5:
354+
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
352355
return int_data
353356

354357
def groupwise_affine_dequantize_tensor_from_qparams(
@@ -359,18 +362,26 @@ def groupwise_affine_dequantize_tensor_from_qparams(
359362
groupsize=128,
360363
):
361364
assert groupsize > 1
362-
# needed for GPTQ single column dequantize
363-
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
364-
groupsize = w_int4x8.shape[-1]
365-
assert w_int4x8.shape[-1] % groupsize == 0
366365
assert w_int4x8.dim() == 2
366+
if TORCH_VERSION_AFTER_2_5:
367+
data = w_int4x8.to(torch.int32)
368+
high_bits = data >> 4
369+
low_bits = data & 0x0F
370+
w_int32 = torch.zeros((w_int4x8.shape[0], w_int4x8.shape[1] * 2), dtype=torch.int32, device=w_int4x8.device)
371+
w_int32[::, ::2] = high_bits
372+
w_int32[::, 1::2] = low_bits
373+
else:
374+
w_int32 = w_int4x8
367375

376+
# needed for GPTQ single column dequantize
377+
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
378+
groupsize = w_int32.shape[-1]
379+
assert w_int32.shape[-1] % groupsize == 0
368380
block_size = (1, groupsize)
369381
input_dtype = torch.int32
370382
quant_min = 0
371383
quant_max = 2**n_bit - 1
372-
return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype)
373-
384+
return dequantize_affine(w_int32, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype)
374385

375386
def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
376387
scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize, dtype)

0 commit comments

Comments
 (0)