Skip to content

Commit 0b18c16

Browse files
committed
Add Int4TensorCoreTilePackedTensor for tensor core tiled int4 quantization
This commit introduces Int4TensorCoreTilePackedTensor, a new tensor subclass for int4 weight-only quantization using tensor core tiled packing format. Key features: - Implements tensor core tiled packing for efficient computation on tensor cores - Uses tinygemm quantization path instead of HQQ for consistency - Supports PackingFormat.TENSOR_CORE_TILE_PACKED in Int4WeightOnlyConfig version 2 - Optimized for tinygemm int4mm kernel (_weight_int4pack_mm) - Includes comprehensive test suite The implementation follows the same pattern as other int4 tensor subclasses but uses a specialized packing format optimized for tensor core matrix multiplication performance. Changes: - Add Int4TensorCoreTilePackedTensor implementation - Update Int4WeightOnlyConfig version 2 to support TENSOR_CORE_TILE_PACKED packing format - Add TENSOR_CORE_TILE_PACKED to PackingFormat enum - Replace HQQ quantization with _quantize_affine_tinygemm for consistency - Add comprehensive tests including serialization, different group sizes, and error conditions - Update __init__.py files to export new tensor class
1 parent 69e71d9 commit 0b18c16

File tree

9 files changed

+923
-65
lines changed

9 files changed

+923
-65
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import tempfile
8+
import unittest
9+
10+
import torch
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
instantiate_parametrized_tests,
14+
parametrize,
15+
run_tests,
16+
)
17+
18+
from torchao.quantization import (
19+
Int4WeightOnlyConfig,
20+
quantize_,
21+
)
22+
from torchao.quantization.utils import compute_error
23+
from torchao.sparsity.sparse_api import apply_fake_sparsity
24+
from torchao.utils import (
25+
TORCH_VERSION_AT_LEAST_2_8,
26+
)
27+
28+
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
29+
group_size=128,
30+
packing_format="marlin_sparse",
31+
version=2,
32+
)
33+
34+
35+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
36+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
37+
class TestInt4MarlinSparseTensor(TestCase):
38+
def setUp(self):
39+
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
40+
41+
@parametrize("config", [BF16_ACT_CONFIG])
42+
@parametrize(
43+
"sizes",
44+
[
45+
((128,), 256, 128),
46+
((32, 128), 512, 128),
47+
((2, 32, 128), 256, 12),
48+
],
49+
)
50+
def test_linear(self, config, sizes):
51+
dtype = torch.float16
52+
device = "cuda"
53+
54+
M, N, K = sizes
55+
input = torch.randn(*M, K, dtype=dtype, device=device)
56+
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
57+
58+
apply_fake_sparsity(linear)
59+
original = linear(input)
60+
quantize_(linear, config)
61+
quantized = linear(input)
62+
self.assertTrue(compute_error(original, quantized) > 20)
63+
64+
compiled_linear = torch.compile(linear)
65+
quantized_and_compiled = compiled_linear(input)
66+
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
67+
68+
@unittest.skip("Fix later")
69+
@parametrize("config", [BF16_ACT_CONFIG])
70+
def test_to_device(self, config):
71+
for device in self.GPU_DEVICES:
72+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
73+
quantize_(linear, config)
74+
linear.to(device)
75+
76+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
77+
quantize_(linear, config)
78+
linear.to(device=device)
79+
80+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
81+
quantize_(linear, config)
82+
linear.to(device)
83+
84+
@parametrize("config", [BF16_ACT_CONFIG])
85+
def test_module_path(self, config):
86+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
87+
quantize_(linear.cuda(), config)
88+
self.assertEqual(
89+
str(type(linear.weight)),
90+
"<class 'torchao.quantization.Int4MarlinSparseTensor'>",
91+
)
92+
93+
with tempfile.NamedTemporaryFile() as f:
94+
torch.save(linear.state_dict(), f)
95+
f.seek(0)
96+
state_dict = torch.load(f)
97+
self.assertEqual(
98+
str(type(state_dict["weight"])),
99+
"<class 'torchao.quantization.Int4MarlinSparseTensor'>",
100+
)
101+
102+
103+
instantiate_parametrized_tests(TestInt4MarlinSparseTensor)
104+
105+
106+
if __name__ == "__main__":
107+
run_tests()
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import tempfile
8+
import unittest
9+
10+
import torch
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
instantiate_parametrized_tests,
14+
parametrize,
15+
run_tests,
16+
)
17+
18+
from torchao.quantization import Int4WeightOnlyConfig, quantize_
19+
from torchao.quantization.quantize_.common.packing_format import PackingFormat
20+
from torchao.quantization.quantize_.workflows.int4.int4_tensor_core_tile_packed_tensor import (
21+
Int4TensorCoreTilePackedTensor,
22+
)
23+
from torchao.quantization.utils import compute_error
24+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
25+
26+
TENSOR_CORE_TILED_CONFIG = Int4WeightOnlyConfig(
27+
group_size=128,
28+
packing_format=PackingFormat.TENSOR_CORE_TILE_PACKED,
29+
version=2,
30+
)
31+
32+
33+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Need pytorch 2.4+")
34+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
35+
class TestInt4TensorCoreTilePackedTensor(TestCase):
36+
def setUp(self):
37+
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
38+
39+
@parametrize("config", [TENSOR_CORE_TILED_CONFIG])
40+
@parametrize(
41+
"sizes",
42+
[
43+
((128,), 256, 128),
44+
((32, 128), 512, 128),
45+
((2, 32, 128), 256, 128),
46+
],
47+
)
48+
def test_linear(self, config, sizes):
49+
dtype = torch.bfloat16
50+
device = "cuda"
51+
52+
M, N, K = sizes
53+
input = torch.randn(*M, K, dtype=dtype, device=device)
54+
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
55+
56+
original = linear(input)
57+
quantize_(linear, config)
58+
quantized = linear(input)
59+
self.assertTrue(compute_error(original, quantized) > 1)
60+
61+
compiled_linear = torch.compile(linear)
62+
quantized_and_compiled = compiled_linear(input)
63+
self.assertTrue(compute_error(original, quantized_and_compiled) > 1)
64+
65+
def test_from_hp(self):
66+
"""Test creating Int4TensorCoreTilePackedTensor from high precision tensor"""
67+
dtype = torch.bfloat16
68+
device = "cuda"
69+
hp_tensor = torch.randn(256, 128, dtype=dtype, device=device)
70+
block_size = (1, 64)
71+
72+
tensor = Int4TensorCoreTilePackedTensor.from_hp(hp_tensor, block_size)
73+
74+
self.assertEqual(tensor.shape, hp_tensor.shape)
75+
self.assertEqual(tensor.block_size, block_size)
76+
self.assertEqual(tensor.device.type, device)
77+
self.assertEqual(tensor.dtype, dtype)
78+
79+
@parametrize("config", [TENSOR_CORE_TILED_CONFIG])
80+
def test_to_device(self, config):
81+
for device in self.GPU_DEVICES:
82+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
83+
quantize_(linear.cuda(), config)
84+
linear.to(device)
85+
86+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
87+
quantize_(linear.cuda(), config)
88+
linear.to(device=device)
89+
90+
@parametrize("config", [TENSOR_CORE_TILED_CONFIG])
91+
def test_module_path(self, config):
92+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
93+
quantize_(linear.cuda(), config)
94+
self.assertEqual(
95+
str(type(linear.weight)),
96+
"<class 'torchao.quantization.Int4TensorCoreTilePackedTensor'>",
97+
)
98+
99+
def test_serialization(self):
100+
"""Test saving and loading the tensor directly and via state_dict"""
101+
dtype = torch.bfloat16
102+
device = "cuda"
103+
hp_tensor = torch.randn(128, 256, dtype=dtype, device=device)
104+
block_size = (1, 64)
105+
106+
tensor = Int4TensorCoreTilePackedTensor.from_hp(hp_tensor, block_size)
107+
108+
# Test direct tensor serialization
109+
with tempfile.NamedTemporaryFile() as f:
110+
torch.save(tensor, f)
111+
f.seek(0)
112+
loaded_tensor = torch.load(f)
113+
114+
self.assertEqual(loaded_tensor.shape, tensor.shape)
115+
self.assertEqual(loaded_tensor.block_size, tensor.block_size)
116+
self.assertEqual(
117+
str(type(loaded_tensor)),
118+
"<class 'torchao.quantization.Int4TensorCoreTilePackedTensor'>",
119+
)
120+
121+
# Test state_dict serialization
122+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
123+
quantize_(linear.cuda(), TENSOR_CORE_TILED_CONFIG)
124+
125+
with tempfile.NamedTemporaryFile() as f:
126+
torch.save(linear.state_dict(), f)
127+
f.seek(0)
128+
state_dict = torch.load(f)
129+
self.assertEqual(
130+
str(type(state_dict["weight"])),
131+
"<class 'torchao.quantization.Int4TensorCoreTilePackedTensor'>",
132+
)
133+
134+
@parametrize("group_size", [32, 64, 128])
135+
def test_different_group_sizes(self, group_size):
136+
"""Test with different group sizes"""
137+
dtype = torch.bfloat16
138+
device = "cuda"
139+
hp_tensor = torch.randn(256, 512, dtype=dtype, device=device)
140+
block_size = (1, group_size)
141+
142+
tensor = Int4TensorCoreTilePackedTensor.from_hp(hp_tensor, block_size)
143+
144+
self.assertEqual(tensor.shape, hp_tensor.shape)
145+
self.assertEqual(tensor.block_size, block_size)
146+
147+
def test_error_conditions(self):
148+
"""Test various error conditions"""
149+
dtype = torch.bfloat16
150+
device = "cuda"
151+
hp_tensor = torch.randn(128, 256, dtype=dtype, device=device)
152+
153+
# Test invalid block_size length
154+
with self.assertRaises(AssertionError):
155+
Int4TensorCoreTilePackedTensor.from_hp(
156+
hp_tensor, (64,)
157+
) # block_size length mismatch
158+
159+
# Test non-groupwise quantization
160+
with self.assertRaises(AssertionError):
161+
Int4TensorCoreTilePackedTensor.from_hp(
162+
hp_tensor, (2, 64)
163+
) # first element should be 1
164+
165+
166+
instantiate_parametrized_tests(TestInt4TensorCoreTilePackedTensor)
167+
168+
169+
if __name__ == "__main__":
170+
run_tests()

torchao/quantization/__init__.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from torchao.kernel import (
2-
int_scaled_matmul,
3-
safe_int_mm,
4-
)
1+
from torchao.kernel import int_scaled_matmul, safe_int_mm
52

63
from .autoquant import (
74
ALL_AUTOQUANT_CLASS_LIST,
@@ -13,18 +10,8 @@
1310
OTHER_AUTOQUANT_CLASS_LIST,
1411
autoquant,
1512
)
16-
from .GPTQ import (
17-
Int4WeightOnlyGPTQQuantizer,
18-
MultiTensor,
19-
MultiTensorInputRecorder,
20-
)
21-
from .granularity import (
22-
PerAxis,
23-
PerGroup,
24-
PerRow,
25-
PerTensor,
26-
PerToken,
27-
)
13+
from .GPTQ import Int4WeightOnlyGPTQQuantizer, MultiTensor, MultiTensorInputRecorder
14+
from .granularity import PerAxis, PerGroup, PerRow, PerTensor, PerToken
2815
from .linear_activation_quantized_tensor import (
2916
LinearActivationQuantizedTensor,
3017
to_linear_activation_quantized,
@@ -37,10 +24,7 @@
3724
Int8DynActInt4WeightLinear,
3825
Int8DynActInt4WeightQuantizer,
3926
)
40-
from .observer import (
41-
AffineQuantizedMinMaxObserver,
42-
AffineQuantizedObserverBase,
43-
)
27+
from .observer import AffineQuantizedMinMaxObserver, AffineQuantizedObserverBase
4428
from .quant_api import (
4529
CutlassInt4PackedLayout,
4630
FbgemmConfig,
@@ -90,8 +74,10 @@
9074
)
9175
from .quantize_.workflows import (
9276
Float8Tensor,
77+
Int4MarlinSparseTensor,
9378
Int4PreshuffledTensor,
9479
Int4Tensor,
80+
Int4TensorCoreTilePackedTensor,
9581
)
9682
from .smoothquant import (
9783
SmoothFakeDynamicallyQuantizedLinear,
@@ -104,9 +90,7 @@
10490
from .subclass import * # noqa: F403
10591
from .transform_module import register_quantize_module_handler
10692
from .unified import Quantizer, TwoStepQuantizer
107-
from .utils import (
108-
compute_error,
109-
)
93+
from .utils import compute_error
11094
from .weight_only import WeightOnlyInt8QuantLinear
11195

11296
# TODO: remove after migration of APIs are done
@@ -159,6 +143,8 @@
159143
# tensor subclasses
160144
"Int4Tensor",
161145
"Int4PreshuffledTensor",
146+
"Int4MarlinSparseTensor",
147+
"Int4TensorCoreTilePackedTensor",
162148
"Float8Tensor",
163149
# smooth quant - subject to change
164150
"get_scale",

0 commit comments

Comments
 (0)