Skip to content

Commit 6411b23

Browse files
committed
Add Int4TilePackedTo4dTensor for int4 quantization and tile packed to 4d packing
This commit introduces Int4TilePackedTo4dTensor, a new tensor subclass for int4 weight-only quantization using tensor core tiled packing format. Key features: - Implements tensor core tiled packing for efficient computation on tensor cores - Supports PackingFormat.TILE_PACKED_TO_4D in Int4WeightOnlyConfig version 2 - Optimized for tinygemm int4mm kernel (_weight_int4pack_mm) - Includes comprehensive test suite The implementation follows the same pattern as other int4 tensor subclasses but uses a specialized packing format optimized for tensor core matrix multiplication performance. Changes: - Add Int4TilePackedTo4dTensor implementation - Update Int4WeightOnlyConfig version 2 to support TILE_PACKED_TO_4D packing format - Add TILE_PACKED_TO_4D to PackingFormat enum - Add comprehensive tests including serialization, different group sizes, and error conditions - Update __init__.py files to export new tensor class Test: python test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py
1 parent 9056c46 commit 6411b23

File tree

8 files changed

+616
-14
lines changed

8 files changed

+616
-14
lines changed
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import tempfile
8+
import unittest
9+
10+
import torch
11+
from torch.testing._internal.common_utils import (
12+
instantiate_parametrized_tests,
13+
parametrize,
14+
run_tests,
15+
)
16+
17+
from torchao.quantization import Int4WeightOnlyConfig, quantize_
18+
from torchao.quantization.quantize_.common.packing_format import PackingFormat
19+
from torchao.quantization.quantize_.workflows.int4.int4_tile_packed_to_4d_tensor import (
20+
Int4TilePackedTo4dTensor,
21+
)
22+
from torchao.quantization.utils import compute_error
23+
from torchao.testing.utils import TorchAOIntegrationTestCase
24+
from torchao.utils import is_sm_at_least_90
25+
26+
INT4_CONFIG = Int4WeightOnlyConfig(
27+
group_size=128,
28+
packing_format=PackingFormat.TILE_PACKED_TO_4D,
29+
version=2,
30+
)
31+
32+
33+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
34+
@unittest.skipIf(not is_sm_at_least_90(), "Need sm90+")
35+
class TestInt4TilePackedTo4dTensor(TorchAOIntegrationTestCase):
36+
def setUp(self):
37+
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
38+
39+
@parametrize(
40+
"sizes",
41+
[
42+
((128,), 256, 128),
43+
((32, 128), 512, 128),
44+
((2, 32, 128), 256, 128),
45+
],
46+
)
47+
def test_linear(self, sizes):
48+
config = INT4_CONFIG
49+
dtype = torch.bfloat16
50+
device = "cuda"
51+
52+
M, N, K = sizes
53+
input = torch.randn(*M, K, dtype=dtype, device=device)
54+
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
55+
56+
original = linear(input)
57+
quantize_(linear, config)
58+
quantized = linear(input)
59+
self.assertTrue(compute_error(original, quantized) > 20)
60+
61+
compiled_linear = torch.compile(linear)
62+
quantized_and_compiled = compiled_linear(input)
63+
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
64+
65+
def test_module_path(self):
66+
config = INT4_CONFIG
67+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
68+
quantize_(linear.cuda(), config)
69+
self.assertEqual(
70+
str(type(linear.weight)),
71+
"<class 'torchao.quantization.Int4TilePackedTo4dTensor'>",
72+
)
73+
74+
with tempfile.NamedTemporaryFile() as f:
75+
torch.save(linear.state_dict(), f)
76+
f.seek(0)
77+
state_dict = torch.load(f)
78+
self.assertEqual(
79+
str(type(state_dict["weight"])),
80+
"<class 'torchao.quantization.Int4TilePackedTo4dTensor'>",
81+
)
82+
83+
def test_slice(self):
84+
"""Note: we use multiples of 1024 for both in_features and out_features
85+
so that padding does not affect the weight after slicing
86+
"""
87+
config = INT4_CONFIG
88+
dtype = torch.bfloat16
89+
device = "cuda"
90+
91+
# Create a 2048x2048 linear layer for testing
92+
dummy = torch.nn.Linear(2048, 2048, bias=False, dtype=dtype, device=device)
93+
94+
# Create reference sliced linear layers
95+
dummy1 = torch.nn.Linear(2048, 1024, bias=False, dtype=dtype, device=device)
96+
dummy1.weight = torch.nn.Parameter(
97+
dummy.weight.narrow(0, 0, 1024), requires_grad=False
98+
)
99+
dummy2 = torch.nn.Linear(1024, 2048, dtype=dtype, device=device)
100+
dummy2.weight = torch.nn.Parameter(
101+
dummy.weight.narrow(1, 0, 1024), requires_grad=False
102+
)
103+
104+
# Quantize the main linear layer
105+
quantize_(dummy, config)
106+
107+
# Shape analysis for TilePackedTo4d format:
108+
# Original weight shape: (2048, 2048) -> no padding needed (already multiple of 1024)
109+
# n = 2048, k = 2048, inner_k_tiles = 8, group_size = 128
110+
#
111+
# qdata shape: [n/8, k/(inner_k_tiles*16), 32, inner_k_tiles/2]
112+
# = [2048/8, 2048/(8*16), 32, 8/2]
113+
# = [256, 16, 32, 4]
114+
#
115+
# scale_and_zero shape: [in_features/group_size, out_features, 2] (packed format)
116+
# = [2048/128, 2048, 2] = [16, 2048, 2]
117+
118+
# Test slicing along output dimension (dim=0: 2048 -> 1024)
119+
weight1 = dummy.weight.narrow(0, 0, 1024)
120+
121+
# qdata slicing: narrow from [256, 16, 32, 4] to [128, 16, 32, 4]
122+
# Calculation: 1024 out_features / 2048 total * 256 qdata_dim0 = 128
123+
expected_qdata_slice_0 = dummy.weight.qdata.narrow(0, 0, 128)
124+
self.assertEqual(weight1.qdata, expected_qdata_slice_0)
125+
126+
# scale_and_zero slicing: narrow from [16, 2048, 2] to [16, 1024, 2]
127+
# slicing 0th dim of qdata means we have to slice 1th dim of scale_and_zero
128+
expected_scale_zero_slice_0 = dummy.weight.scale_and_zero.narrow(1, 0, 1024)
129+
self.assertEqual(weight1.scale_and_zero, expected_scale_zero_slice_0)
130+
131+
# Test slicing along input dimension (dim=1: 2048 -> 1024)
132+
weight2 = dummy.weight.narrow(1, 0, 1024)
133+
134+
# qdata slicing: narrow from [256, 16, 32, 4] to [256, 8, 32, 4]
135+
# k = 2048
136+
# Calculation: 1024 in_features (1/2 of in_features) corresponds to 1/2 of qdata dimension 1
137+
# which is k / (inner_k_tiles * 16) / 2 = 2048 / (8 * 16) / 2 = 8
138+
expected_qdata_slice_1 = dummy.weight.qdata.narrow(1, 0, 8)
139+
self.assertEqual(weight2.qdata, expected_qdata_slice_1)
140+
141+
# scale_and_zero slicing: narrow from [16, 2048, 2] to [8, 2048, 2]
142+
expected_scale_zero_slice_1 = dummy.weight.scale_and_zero.narrow(0, 0, 8)
143+
self.assertEqual(weight2.scale_and_zero, expected_scale_zero_slice_1)
144+
145+
# Verify that sliced weights produce similar results to reference implementations
146+
input1 = torch.randn(2, 2048, dtype=dtype, device=device)
147+
res_ref1 = dummy1(input1)
148+
149+
# Create a new linear layer with the sliced weight
150+
test_linear1 = torch.nn.Linear(
151+
2048, 1024, bias=False, dtype=dtype, device=device
152+
)
153+
test_linear1.weight = torch.nn.Parameter(
154+
weight1.contiguous(), requires_grad=False
155+
)
156+
res1 = test_linear1(input1)
157+
self.assertGreater(compute_error(res_ref1, res1), 14)
158+
159+
input2 = torch.randn(2, 1024, dtype=dtype, device=device)
160+
res_ref2 = dummy2(input2)
161+
162+
# Create a new linear layer with the sliced weight
163+
test_linear2 = torch.nn.Linear(
164+
1024, 2048, bias=False, dtype=dtype, device=device
165+
)
166+
test_linear2.weight = torch.nn.Parameter(
167+
weight2.contiguous(), requires_grad=False
168+
)
169+
res2 = test_linear2(input2)
170+
self.assertGreater(compute_error(res_ref2, res2), 14)
171+
172+
def test_slice_preserves_aliasing(self):
173+
config = INT4_CONFIG
174+
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
175+
l.weight = torch.nn.Parameter(
176+
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
177+
)
178+
quantize_(l, config)
179+
param = l.weight
180+
param_data = param.data
181+
param_data = param_data.narrow(0, 0, 512)
182+
# Making sure the aliasing is preserved in sliced quantized Tensor
183+
assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr()
184+
assert (
185+
param.data.scale_and_zero.data_ptr() == param_data.scale_and_zero.data_ptr()
186+
)
187+
188+
def test_cant_initialize_in_cpu(self):
189+
config = INT4_CONFIG
190+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
191+
# make sure there is no cpu implementation of the packing op currently
192+
with self.assertRaisesRegex(
193+
NotImplementedError,
194+
"Could not run 'aten::_convert_weight_to_int4pack' with arguments from the 'CPU' backend. ",
195+
):
196+
quantize_(linear, config)
197+
198+
def test_to_device(self):
199+
# test calling to on the tensor that's already on the same device works
200+
config = INT4_CONFIG
201+
202+
for device in self.GPU_DEVICES:
203+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device)
204+
quantize_(linear, config)
205+
linear.to(device)
206+
207+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device)
208+
quantize_(linear, config)
209+
linear.to(device=device)
210+
211+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device)
212+
quantize_(linear, config)
213+
linear.to(device)
214+
215+
def test_slice_and_copy_similar_to_vllm(self):
216+
self._test_slice_and_copy_similar_to_vllm(INT4_CONFIG)
217+
218+
@parametrize("device", ["cuda"])
219+
@parametrize("dtype", [torch.bfloat16])
220+
def test_mm_int4wo(self, device, dtype):
221+
weight = torch.randn(512, 1024).to(device).to(dtype)
222+
weight = weight.t()
223+
224+
l = torch.nn.Linear(512, 1024).to(device).to(dtype)
225+
l.weight = torch.nn.Parameter(weight)
226+
quantize_(l, INT4_CONFIG)
227+
# weight shape: 1024 x 512
228+
weight = l.weight
229+
230+
input = torch.randn(1, 512, device=device, dtype=dtype)
231+
# make sure it runs
232+
torch.nn.functional.linear(input, weight)
233+
234+
@parametrize("group_size", [32, 64, 128])
235+
def test_different_group_sizes(self, group_size):
236+
"""Test with different group sizes"""
237+
dtype = torch.bfloat16
238+
device = "cuda"
239+
hp_tensor = torch.randn(256, 512, dtype=dtype, device=device)
240+
block_size = (1, group_size)
241+
242+
tensor = Int4TilePackedTo4dTensor.from_hp(hp_tensor, block_size)
243+
244+
self.assertEqual(tensor.shape, hp_tensor.shape)
245+
self.assertEqual(tensor.block_size, block_size)
246+
247+
def test_error_conditions(self):
248+
"""Test various error conditions"""
249+
dtype = torch.bfloat16
250+
device = "cuda"
251+
hp_tensor = torch.randn(128, 256, dtype=dtype, device=device)
252+
253+
# Test invalid block_size length
254+
with self.assertRaises(AssertionError):
255+
Int4TilePackedTo4dTensor.from_hp(
256+
hp_tensor, (64,)
257+
) # block_size length mismatch
258+
259+
# Test non-groupwise quantization
260+
with self.assertRaises(AssertionError):
261+
Int4TilePackedTo4dTensor.from_hp(
262+
hp_tensor, (2, 64)
263+
) # first element should be 1
264+
265+
266+
instantiate_parametrized_tests(TestInt4TilePackedTo4dTensor)
267+
268+
269+
if __name__ == "__main__":
270+
run_tests()

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
Int4MarlinSparseTensor,
9494
Int4PreshuffledTensor,
9595
Int4Tensor,
96+
Int4TilePackedTo4dTensor,
9697
IntxUnpackedToInt8Tensor,
9798
)
9899
from .smoothquant import (
@@ -163,6 +164,7 @@
163164
"Int4PreshuffledTensor",
164165
"Int4MarlinSparseTensor",
165166
"IntxUnpackedToInt8Tensor",
167+
"Int4TilePackedTo4dTensor",
166168
"Float8Tensor",
167169
# smooth quant - subject to change
168170
"get_scale",

torchao/quantization/quant_api.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
Int4MarlinSparseTensor,
7676
Int4PreshuffledTensor,
7777
Int4Tensor,
78+
Int4TilePackedTo4dTensor,
7879
IntxUnpackedToInt8Tensor,
7980
QuantizeTensorToFloat8Kwargs,
8081
)
@@ -1120,6 +1121,12 @@ def _int4_weight_only_quantize_tensor(weight, config):
11201121
block_size,
11211122
)
11221123
return new_weight
1124+
elif packing_format == PackingFormat.TILE_PACKED_TO_4D:
1125+
new_weight = Int4TilePackedTo4dTensor.from_hp(
1126+
weight,
1127+
block_size,
1128+
)
1129+
return new_weight
11231130
else:
11241131
raise ValueError(f"Unsupported packing format: {packing_format}")
11251132

@@ -1494,10 +1501,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
14941501
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
14951502
quantization + 2:4 sparsity to linear layers.
14961503
"""
1497-
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
1504+
warnings.warn(
1505+
"""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
14981506
14991507
from torchao.dtypes import SemiSparseLayout
1500-
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""")
1508+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()"""
1509+
)
15011510

15021511
return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
15031512

@@ -2073,7 +2082,10 @@ def __post_init__(self):
20732082
assert self.granularity.axis == 0, (
20742083
f"axis must be 0 with PerAxis, but got {self.granularity.axis}"
20752084
)
2076-
assert self.mapping_type in [MappingType.ASYMMETRIC, MappingType.SYMMETRIC], (
2085+
assert self.mapping_type in [
2086+
MappingType.ASYMMETRIC,
2087+
MappingType.SYMMETRIC,
2088+
], (
20772089
f"mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.mapping_type}"
20782090
)
20792091

torchao/quantization/quantize_/common/packing_format.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ class PackingFormat(str, Enum):
4141
"""
4242
UNPACKED_TO_INT8 = "unpacked_to_int8"
4343

44+
"""
45+
tile_packed_to_4d is referring to the format used by tinygemm kernels for int4 quantization
46+
"""
47+
TILE_PACKED_TO_4D = "tile_packed_to_4d"
48+
4449
"""
4550
Opaque packing format that's used for tensors that does not have a predefined packing format
4651
(that may be decided on hardware, tensor shape, library availability etc.) and it's not

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
from .intx.intx_unpacked_to_int8_tensor import (
1515
IntxUnpackedToInt8Tensor,
1616
)
17+
from .int4.int4_tile_packed_to_4d_tensor import Int4TilePackedTo4dTensor
1718

1819
__all__ = [
1920
"Int4Tensor",
2021
"Int4PreshuffledTensor",
2122
"Int4MarlinSparseTensor",
23+
"Int4TilePackedTo4dTensor",
2224
"Float8Tensor",
2325
"QuantizeTensorToFloat8Kwargs",
2426
"IntxUnpackedToInt8Tensor",
Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +0,0 @@
1-
from .int4_preshuffled_tensor import Int4PreshuffledTensor
2-
from .int4_tensor import Int4Tensor
3-
4-
__all__ = [
5-
"Int4PreshuffledTensor",
6-
"Int4Tensor",
7-
]

0 commit comments

Comments
 (0)