Skip to content

Commit 1922aaf

Browse files
committed
Add Int4TensorCoreTilePackedTensor for tensor core tiled int4 quantization
This commit introduces Int4TensorCoreTilePackedTensor, a new tensor subclass for int4 weight-only quantization using tensor core tiled packing format. Key features: - Implements tensor core tiled packing for efficient computation on tensor cores - Uses tinygemm quantization path instead of HQQ for consistency - Supports PackingFormat.TENSOR_CORE_TILE_PACKED in Int4WeightOnlyConfig version 2 - Optimized for tinygemm int4mm kernel (_weight_int4pack_mm) - Includes comprehensive test suite The implementation follows the same pattern as other int4 tensor subclasses but uses a specialized packing format optimized for tensor core matrix multiplication performance. Changes: - Add Int4TensorCoreTilePackedTensor implementation - Update Int4WeightOnlyConfig version 2 to support TENSOR_CORE_TILE_PACKED packing format - Add TENSOR_CORE_TILE_PACKED to PackingFormat enum - Replace HQQ quantization with _quantize_affine_tinygemm for consistency - Add comprehensive tests including serialization, different group sizes, and error conditions - Update __init__.py files to export new tensor class Test: python test/quantization/quantize_/workflows/int4/test_int4_tensor_core_tile_packed_tensor.py
1 parent af2cf1e commit 1922aaf

File tree

8 files changed

+600
-73
lines changed

8 files changed

+600
-73
lines changed
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import tempfile
8+
import unittest
9+
10+
import torch
11+
from torch.testing._internal.common_utils import (
12+
instantiate_parametrized_tests,
13+
parametrize,
14+
run_tests,
15+
)
16+
17+
from torchao.quantization import Int4WeightOnlyConfig, quantize_
18+
from torchao.quantization.quantize_.common.packing_format import PackingFormat
19+
from torchao.quantization.quantize_.workflows.int4.int4_tensor_core_tile_packed_tensor import (
20+
Int4TensorCoreTilePackedTensor,
21+
)
22+
from torchao.quantization.utils import compute_error
23+
from torchao.testing.utils import TorchAOIntegrationTestCase
24+
from torchao.utils import is_sm_at_least_90
25+
26+
TENSOR_CORE_TILED_CONFIG = Int4WeightOnlyConfig(
27+
group_size=128,
28+
packing_format=PackingFormat.TENSOR_CORE_TILE_PACKED,
29+
version=2,
30+
)
31+
32+
33+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
34+
@unittest.skipIf(not is_sm_at_least_90(), "Need sm90+")
35+
class TestInt4TensorCoreTilePackedTensor(TorchAOIntegrationTestCase):
36+
def setUp(self):
37+
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
38+
39+
@parametrize(
40+
"sizes",
41+
[
42+
((128,), 256, 128),
43+
((32, 128), 512, 128),
44+
((2, 32, 128), 256, 128),
45+
],
46+
)
47+
def test_linear(self, sizes):
48+
config = TENSOR_CORE_TILED_CONFIG
49+
dtype = torch.bfloat16
50+
device = "cuda"
51+
52+
M, N, K = sizes
53+
input = torch.randn(*M, K, dtype=dtype, device=device)
54+
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
55+
56+
original = linear(input)
57+
quantize_(linear, config)
58+
quantized = linear(input)
59+
self.assertTrue(compute_error(original, quantized) > 20)
60+
61+
compiled_linear = torch.compile(linear)
62+
quantized_and_compiled = compiled_linear(input)
63+
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
64+
65+
def test_module_path(self):
66+
config = TENSOR_CORE_TILED_CONFIG
67+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
68+
quantize_(linear.cuda(), config)
69+
self.assertEqual(
70+
str(type(linear.weight)),
71+
"<class 'torchao.quantization.Int4TensorCoreTilePackedTensor'>",
72+
)
73+
74+
with tempfile.NamedTemporaryFile() as f:
75+
torch.save(linear.state_dict(), f)
76+
f.seek(0)
77+
state_dict = torch.load(f)
78+
self.assertEqual(
79+
str(type(state_dict["weight"])),
80+
"<class 'torchao.quantization.Int4TensorCoreTilePackedTensor'>",
81+
)
82+
83+
def test_slice(self):
84+
config = TENSOR_CORE_TILED_CONFIG
85+
dtype = torch.bfloat16
86+
device = "cuda"
87+
88+
# Create a 256x256 linear layer for testing
89+
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
90+
91+
# Create reference sliced linear layers
92+
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
93+
dummy1.weight = torch.nn.Parameter(
94+
dummy.weight.narrow(0, 0, 64), requires_grad=False
95+
)
96+
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
97+
dummy2.weight = torch.nn.Parameter(
98+
dummy.weight.narrow(1, 0, 128), requires_grad=False
99+
)
100+
101+
# Quantize the main linear layer
102+
quantize_(dummy, config)
103+
104+
# Shape analysis for tensor core tile packed format:
105+
# Original weight shape: (256, 256) -> after padding: (256, 1024)
106+
# n = 256, k = 1024, inner_k_tiles = 8, group_size = 128
107+
# inner_k_tiles = 8, group_size = 128
108+
#
109+
# qdata shape: [n/8, k/(inner_k_tiles*16), 32, inner_k_tiles/2]
110+
# = [256/8, 1024/(8*16), 32, 8/2]
111+
# = [32, 8, 32, 4]
112+
#
113+
# scale_and_zero shape: [in_features/group_size, out_features, 2] (packed format)
114+
# = [1024/128, 256, 2] = [8, 256, 2]
115+
116+
# Test slicing along output dimension (dim=0: 256 -> 64)
117+
weight1 = dummy.weight.narrow(0, 0, 64)
118+
119+
# qdata slicing: narrow from [32, 8, 32, 4] to [8, 8, 32, 4]
120+
# Calculation: 64 out_features / 256 total * 32 qdata_dim0 = 8
121+
expected_qdata_slice_0 = dummy.weight.qdata.narrow(0, 0, 8)
122+
self.assertEqual(weight1.qdata, expected_qdata_slice_0)
123+
124+
# scale_and_zero slicing: narrow from [8, 256, 2] to [8, 64, 2]
125+
# slicing 0th dim of qdata means we have to slice 1th dim of scale_and_zero
126+
expected_scale_zero_slice_0 = dummy.weight.scale_and_zero.narrow(1, 0, 64)
127+
self.assertEqual(weight1.scale_and_zero, expected_scale_zero_slice_0)
128+
129+
# Test slicing along input dimension (dim=1: 256 -> 128)
130+
weight2 = dummy.weight.narrow(1, 0, 128)
131+
132+
# qdata slicing: narrow from [32, 8, 32, 4] to [32, 4, 32, 4]
133+
# k = 1024
134+
# Calculation: 128 in_features (1/2 of in_features) corresponds to 1/2 of qdata dimension 1
135+
# which is k / (inner_k_tiles * 16) / 2 = 1024 / (8 * 16) / 2 = 4
136+
expected_qdata_slice_1 = dummy.weight.qdata.narrow(1, 0, 4)
137+
self.assertEqual(weight2.qdata, expected_qdata_slice_1)
138+
139+
# scale_and_zero slicing: narrow from [8, 256, 2] to [4, 256, 2]
140+
expected_scale_zero_slice_1 = dummy.weight.scale_and_zero.narrow(0, 0, 4)
141+
self.assertEqual(weight2.scale_and_zero, expected_scale_zero_slice_1)
142+
143+
# Verify that sliced weights produce similar results to reference implementations
144+
input1 = torch.randn(2, 256, dtype=dtype, device=device)
145+
res_ref1 = dummy1(input1)
146+
147+
# Create a new linear layer with the sliced weight
148+
test_linear1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
149+
test_linear1.weight = torch.nn.Parameter(
150+
weight1.contiguous(), requires_grad=False
151+
)
152+
res1 = test_linear1(input1)
153+
self.assertGreater(compute_error(res_ref1, res1), 15)
154+
155+
# input2 = torch.randn(2, 128, dtype=dtype, device=device)
156+
# res_ref2 = dummy2(input2)
157+
158+
# Create a new linear layer with the sliced weight
159+
# WIP
160+
# test_linear2 = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
161+
# test_linear2.weight = torch.nn.Parameter(
162+
# weight2.contiguous(), requires_grad=False
163+
# )
164+
# res2 = test_linear2(input2)
165+
# self.assertGreater(compute_error(res_ref2, res2), 15)
166+
167+
def test_slice_preserves_aliasing(self):
168+
config = TENSOR_CORE_TILED_CONFIG
169+
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
170+
l.weight = torch.nn.Parameter(
171+
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
172+
)
173+
quantize_(l, config)
174+
param = l.weight
175+
param_data = param.data
176+
param_data = param_data.narrow(0, 0, 512)
177+
# Making sure the aliasing is preserved in sliced quantized Tensor
178+
assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr()
179+
assert (
180+
param.data.scale_and_zero.data_ptr() == param_data.scale_and_zero.data_ptr()
181+
)
182+
183+
def test_slice_and_copy_similar_to_vllm(self):
184+
self._test_slice_and_copy_similar_to_vllm(TENSOR_CORE_TILED_CONFIG)
185+
186+
@parametrize("group_size", [32, 64, 128])
187+
def test_different_group_sizes(self, group_size):
188+
"""Test with different group sizes"""
189+
dtype = torch.bfloat16
190+
device = "cuda"
191+
hp_tensor = torch.randn(256, 512, dtype=dtype, device=device)
192+
block_size = (1, group_size)
193+
194+
tensor = Int4TensorCoreTilePackedTensor.from_hp(hp_tensor, block_size)
195+
196+
self.assertEqual(tensor.shape, hp_tensor.shape)
197+
self.assertEqual(tensor.block_size, block_size)
198+
199+
def test_error_conditions(self):
200+
"""Test various error conditions"""
201+
dtype = torch.bfloat16
202+
device = "cuda"
203+
hp_tensor = torch.randn(128, 256, dtype=dtype, device=device)
204+
205+
# Test invalid block_size length
206+
with self.assertRaises(AssertionError):
207+
Int4TensorCoreTilePackedTensor.from_hp(
208+
hp_tensor, (64,)
209+
) # block_size length mismatch
210+
211+
# Test non-groupwise quantization
212+
with self.assertRaises(AssertionError):
213+
Int4TensorCoreTilePackedTensor.from_hp(
214+
hp_tensor, (2, 64)
215+
) # first element should be 1
216+
217+
218+
instantiate_parametrized_tests(TestInt4TensorCoreTilePackedTensor)
219+
220+
221+
if __name__ == "__main__":
222+
run_tests()

torchao/quantization/__init__.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from torchao.kernel import (
2-
int_scaled_matmul,
3-
safe_int_mm,
4-
)
1+
from torchao.kernel import int_scaled_matmul, safe_int_mm
52

63
from .autoquant import (
74
ALL_AUTOQUANT_CLASS_LIST,
@@ -13,18 +10,8 @@
1310
OTHER_AUTOQUANT_CLASS_LIST,
1411
autoquant,
1512
)
16-
from .GPTQ import (
17-
Int4WeightOnlyGPTQQuantizer,
18-
MultiTensor,
19-
MultiTensorInputRecorder,
20-
)
21-
from .granularity import (
22-
PerAxis,
23-
PerGroup,
24-
PerRow,
25-
PerTensor,
26-
PerToken,
27-
)
13+
from .GPTQ import Int4WeightOnlyGPTQQuantizer, MultiTensor, MultiTensorInputRecorder
14+
from .granularity import PerAxis, PerGroup, PerRow, PerTensor, PerToken
2815
from .linear_activation_quantized_tensor import (
2916
LinearActivationQuantizedTensor,
3017
to_linear_activation_quantized,
@@ -37,10 +24,7 @@
3724
Int8DynActInt4WeightLinear,
3825
Int8DynActInt4WeightQuantizer,
3926
)
40-
from .observer import (
41-
AffineQuantizedMinMaxObserver,
42-
AffineQuantizedObserverBase,
43-
)
27+
from .observer import AffineQuantizedMinMaxObserver, AffineQuantizedObserverBase
4428
from .quant_api import (
4529
CutlassInt4PackedLayout,
4630
FbgemmConfig,
@@ -94,6 +78,7 @@
9478
Int4PreshuffledTensor,
9579
Int4Tensor,
9680
IntxUnpackedTensor,
81+
Int4TensorCoreTilePackedTensor,
9782
)
9883
from .smoothquant import (
9984
SmoothFakeDynamicallyQuantizedLinear,
@@ -106,9 +91,7 @@
10691
from .subclass import * # noqa: F403
10792
from .transform_module import register_quantize_module_handler
10893
from .unified import Quantizer, TwoStepQuantizer
109-
from .utils import (
110-
compute_error,
111-
)
94+
from .utils import compute_error
11295
from .weight_only import WeightOnlyInt8QuantLinear
11396

11497
# TODO: remove after migration of APIs are done
@@ -163,6 +146,7 @@
163146
"Int4PreshuffledTensor",
164147
"Int4MarlinSparseTensor",
165148
"IntxUnpackedTensor",
149+
"Int4TensorCoreTilePackedTensor",
166150
"Float8Tensor",
167151
# smooth quant - subject to change
168152
"get_scale",

torchao/quantization/quant_api.py

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,14 @@
6666
LinearActivationWeightObservedTensor,
6767
)
6868
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
69-
from torchao.quantization.quantize_.common import (
70-
KernelPreference,
71-
PackingFormat,
72-
)
69+
from torchao.quantization.quantize_.common import KernelPreference, PackingFormat
7370
from torchao.quantization.quantize_.workflows import (
7471
Float8Tensor,
7572
Int4MarlinSparseTensor,
7673
Int4PreshuffledTensor,
7774
Int4Tensor,
7875
IntxUnpackedTensor,
76+
Int4TensorCoreTilePackedTensor,
7977
QuantizeTensorToFloat8Kwargs,
8078
)
8179
from torchao.quantization.transform_module import (
@@ -93,35 +91,16 @@
9391
)
9492

9593
from .autoquant import AutoQuantizableLinearWeight, autoquant
96-
from .GPTQ import (
97-
Int4WeightOnlyGPTQQuantizer,
98-
)
99-
from .granularity import (
100-
Granularity,
101-
PerAxis,
102-
PerGroup,
103-
PerRow,
104-
PerTensor,
105-
)
94+
from .GPTQ import Int4WeightOnlyGPTQQuantizer
95+
from .granularity import Granularity, PerAxis, PerGroup, PerRow, PerTensor
10696
from .linear_activation_quantized_tensor import (
10797
LinearActivationQuantizedTensor,
10898
to_linear_activation_quantized,
10999
)
110-
from .linear_quant_modules import (
111-
Int4WeightOnlyQuantizer,
112-
Int8DynActInt4WeightQuantizer,
113-
)
114-
from .qat import (
115-
intx_quantization_aware_training,
116-
)
117-
from .quant_primitives import (
118-
_DTYPE_TO_QVALUE_BOUNDS,
119-
MappingType,
120-
ZeroPointDomain,
121-
)
122-
from .subclass import (
123-
QuantizedLinearWeightBase,
124-
)
100+
from .linear_quant_modules import Int4WeightOnlyQuantizer, Int8DynActInt4WeightQuantizer
101+
from .qat import intx_quantization_aware_training
102+
from .quant_primitives import _DTYPE_TO_QVALUE_BOUNDS, MappingType, ZeroPointDomain
103+
from .subclass import QuantizedLinearWeightBase
125104
from .unified import Quantizer, TwoStepQuantizer
126105
from .utils import _get_per_token_block_size
127106

@@ -1080,6 +1059,12 @@ def _int4_weight_only_quantize_tensor(weight, config):
10801059
block_size,
10811060
)
10821061
return new_weight
1062+
elif packing_format == PackingFormat.TENSOR_CORE_TILE_PACKED:
1063+
new_weight = Int4TensorCoreTilePackedTensor.from_hp(
1064+
weight,
1065+
block_size,
1066+
)
1067+
return new_weight
10831068
else:
10841069
raise ValueError(f"Unsupported packing format: {packing_format}")
10851070

@@ -1454,10 +1439,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
14541439
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
14551440
quantization + 2:4 sparsity to linear layers.
14561441
"""
1457-
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
1442+
warnings.warn(
1443+
"""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
14581444
14591445
from torchao.dtypes import SemiSparseLayout
1460-
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""")
1446+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()"""
1447+
)
14611448

14621449
return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
14631450

@@ -2007,7 +1994,10 @@ def __post_init__(self):
20071994
assert self.granularity.axis == 0, (
20081995
f"axis must be 0 with PerAxis, but got {self.granularity.axis}"
20091996
)
2010-
assert self.mapping_type in [MappingType.ASYMMETRIC, MappingType.SYMMETRIC], (
1997+
assert self.mapping_type in [
1998+
MappingType.ASYMMETRIC,
1999+
MappingType.SYMMETRIC,
2000+
], (
20112001
f"mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.mapping_type}"
20122002
)
20132003

0 commit comments

Comments
 (0)