|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import tempfile |
| 8 | +import unittest |
| 9 | + |
| 10 | +import torch |
| 11 | +from torch.testing._internal.common_utils import ( |
| 12 | + instantiate_parametrized_tests, |
| 13 | + parametrize, |
| 14 | + run_tests, |
| 15 | +) |
| 16 | + |
| 17 | +from torchao.quantization import Int4WeightOnlyConfig, quantize_ |
| 18 | +from torchao.quantization.quantize_.common.packing_format import PackingFormat |
| 19 | +from torchao.quantization.quantize_.workflows.int4.int4_tensor_core_tile_packed_tensor import ( |
| 20 | + Int4TensorCoreTilePackedTensor, |
| 21 | +) |
| 22 | +from torchao.quantization.utils import compute_error |
| 23 | +from torchao.testing.utils import TorchAOIntegrationTestCase |
| 24 | +from torchao.utils import is_sm_at_least_90 |
| 25 | + |
| 26 | +TENSOR_CORE_TILED_CONFIG = Int4WeightOnlyConfig( |
| 27 | + group_size=128, |
| 28 | + packing_format=PackingFormat.TENSOR_CORE_TILE_PACKED, |
| 29 | + version=2, |
| 30 | +) |
| 31 | + |
| 32 | + |
| 33 | +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 34 | +@unittest.skipIf(not is_sm_at_least_90(), "Need sm90+") |
| 35 | +class TestInt4TensorCoreTilePackedTensor(TorchAOIntegrationTestCase): |
| 36 | + def setUp(self): |
| 37 | + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] |
| 38 | + |
| 39 | + @parametrize( |
| 40 | + "sizes", |
| 41 | + [ |
| 42 | + ((128,), 256, 128), |
| 43 | + ((32, 128), 512, 128), |
| 44 | + ((2, 32, 128), 256, 128), |
| 45 | + ], |
| 46 | + ) |
| 47 | + def test_linear(self, sizes): |
| 48 | + config = TENSOR_CORE_TILED_CONFIG |
| 49 | + dtype = torch.bfloat16 |
| 50 | + device = "cuda" |
| 51 | + |
| 52 | + M, N, K = sizes |
| 53 | + input = torch.randn(*M, K, dtype=dtype, device=device) |
| 54 | + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) |
| 55 | + |
| 56 | + original = linear(input) |
| 57 | + quantize_(linear, config) |
| 58 | + quantized = linear(input) |
| 59 | + self.assertTrue(compute_error(original, quantized) > 20) |
| 60 | + |
| 61 | + compiled_linear = torch.compile(linear) |
| 62 | + quantized_and_compiled = compiled_linear(input) |
| 63 | + self.assertTrue(compute_error(original, quantized_and_compiled) > 20) |
| 64 | + |
| 65 | + def test_module_path(self): |
| 66 | + config = TENSOR_CORE_TILED_CONFIG |
| 67 | + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) |
| 68 | + quantize_(linear.cuda(), config) |
| 69 | + self.assertEqual( |
| 70 | + str(type(linear.weight)), |
| 71 | + "<class 'torchao.quantization.Int4TensorCoreTilePackedTensor'>", |
| 72 | + ) |
| 73 | + |
| 74 | + with tempfile.NamedTemporaryFile() as f: |
| 75 | + torch.save(linear.state_dict(), f) |
| 76 | + f.seek(0) |
| 77 | + state_dict = torch.load(f) |
| 78 | + self.assertEqual( |
| 79 | + str(type(state_dict["weight"])), |
| 80 | + "<class 'torchao.quantization.Int4TensorCoreTilePackedTensor'>", |
| 81 | + ) |
| 82 | + |
| 83 | + def test_slice(self): |
| 84 | + config = TENSOR_CORE_TILED_CONFIG |
| 85 | + dtype = torch.bfloat16 |
| 86 | + device = "cuda" |
| 87 | + |
| 88 | + # Create a 256x256 linear layer for testing |
| 89 | + dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) |
| 90 | + |
| 91 | + # Create reference sliced linear layers |
| 92 | + dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) |
| 93 | + dummy1.weight = torch.nn.Parameter( |
| 94 | + dummy.weight.narrow(0, 0, 64), requires_grad=False |
| 95 | + ) |
| 96 | + dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) |
| 97 | + dummy2.weight = torch.nn.Parameter( |
| 98 | + dummy.weight.narrow(1, 0, 128), requires_grad=False |
| 99 | + ) |
| 100 | + |
| 101 | + # Quantize the main linear layer |
| 102 | + quantize_(dummy, config) |
| 103 | + |
| 104 | + # Shape analysis for tensor core tile packed format: |
| 105 | + # Original weight shape: (256, 256) -> after padding: (256, 1024) |
| 106 | + # n = 256, k = 1024, inner_k_tiles = 8, group_size = 128 |
| 107 | + # inner_k_tiles = 8, group_size = 128 |
| 108 | + # |
| 109 | + # qdata shape: [n/8, k/(inner_k_tiles*16), 32, inner_k_tiles/2] |
| 110 | + # = [256/8, 1024/(8*16), 32, 8/2] |
| 111 | + # = [32, 8, 32, 4] |
| 112 | + # |
| 113 | + # scale_and_zero shape: [in_features/group_size, out_features, 2] (packed format) |
| 114 | + # = [1024/128, 256, 2] = [8, 256, 2] |
| 115 | + |
| 116 | + # Test slicing along output dimension (dim=0: 256 -> 64) |
| 117 | + weight1 = dummy.weight.narrow(0, 0, 64) |
| 118 | + |
| 119 | + # qdata slicing: narrow from [32, 8, 32, 4] to [8, 8, 32, 4] |
| 120 | + # Calculation: 64 out_features / 256 total * 32 qdata_dim0 = 8 |
| 121 | + expected_qdata_slice_0 = dummy.weight.qdata.narrow(0, 0, 8) |
| 122 | + self.assertEqual(weight1.qdata, expected_qdata_slice_0) |
| 123 | + |
| 124 | + # scale_and_zero slicing: narrow from [8, 256, 2] to [8, 64, 2] |
| 125 | + # slicing 0th dim of qdata means we have to slice 1th dim of scale_and_zero |
| 126 | + expected_scale_zero_slice_0 = dummy.weight.scale_and_zero.narrow(1, 0, 64) |
| 127 | + self.assertEqual(weight1.scale_and_zero, expected_scale_zero_slice_0) |
| 128 | + |
| 129 | + # Test slicing along input dimension (dim=1: 256 -> 128) |
| 130 | + weight2 = dummy.weight.narrow(1, 0, 128) |
| 131 | + |
| 132 | + # qdata slicing: narrow from [32, 8, 32, 4] to [32, 4, 32, 4] |
| 133 | + # k = 1024 |
| 134 | + # Calculation: 128 in_features (1/2 of in_features) corresponds to 1/2 of qdata dimension 1 |
| 135 | + # which is k / (inner_k_tiles * 16) / 2 = 1024 / (8 * 16) / 2 = 4 |
| 136 | + expected_qdata_slice_1 = dummy.weight.qdata.narrow(1, 0, 4) |
| 137 | + self.assertEqual(weight2.qdata, expected_qdata_slice_1) |
| 138 | + |
| 139 | + # scale_and_zero slicing: narrow from [8, 256, 2] to [4, 256, 2] |
| 140 | + expected_scale_zero_slice_1 = dummy.weight.scale_and_zero.narrow(0, 0, 4) |
| 141 | + self.assertEqual(weight2.scale_and_zero, expected_scale_zero_slice_1) |
| 142 | + |
| 143 | + # Verify that sliced weights produce similar results to reference implementations |
| 144 | + input1 = torch.randn(2, 256, dtype=dtype, device=device) |
| 145 | + res_ref1 = dummy1(input1) |
| 146 | + |
| 147 | + # Create a new linear layer with the sliced weight |
| 148 | + test_linear1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) |
| 149 | + test_linear1.weight = torch.nn.Parameter( |
| 150 | + weight1.contiguous(), requires_grad=False |
| 151 | + ) |
| 152 | + res1 = test_linear1(input1) |
| 153 | + self.assertGreater(compute_error(res_ref1, res1), 15) |
| 154 | + |
| 155 | + # input2 = torch.randn(2, 128, dtype=dtype, device=device) |
| 156 | + # res_ref2 = dummy2(input2) |
| 157 | + |
| 158 | + # Create a new linear layer with the sliced weight |
| 159 | + # WIP |
| 160 | + # test_linear2 = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) |
| 161 | + # test_linear2.weight = torch.nn.Parameter( |
| 162 | + # weight2.contiguous(), requires_grad=False |
| 163 | + # ) |
| 164 | + # res2 = test_linear2(input2) |
| 165 | + # self.assertGreater(compute_error(res_ref2, res2), 15) |
| 166 | + |
| 167 | + def test_slice_preserves_aliasing(self): |
| 168 | + config = TENSOR_CORE_TILED_CONFIG |
| 169 | + l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) |
| 170 | + l.weight = torch.nn.Parameter( |
| 171 | + torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") |
| 172 | + ) |
| 173 | + quantize_(l, config) |
| 174 | + param = l.weight |
| 175 | + param_data = param.data |
| 176 | + param_data = param_data.narrow(0, 0, 512) |
| 177 | + # Making sure the aliasing is preserved in sliced quantized Tensor |
| 178 | + assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() |
| 179 | + assert ( |
| 180 | + param.data.scale_and_zero.data_ptr() == param_data.scale_and_zero.data_ptr() |
| 181 | + ) |
| 182 | + |
| 183 | + def test_slice_and_copy_similar_to_vllm(self): |
| 184 | + self._test_slice_and_copy_similar_to_vllm(TENSOR_CORE_TILED_CONFIG) |
| 185 | + |
| 186 | + @parametrize("group_size", [32, 64, 128]) |
| 187 | + def test_different_group_sizes(self, group_size): |
| 188 | + """Test with different group sizes""" |
| 189 | + dtype = torch.bfloat16 |
| 190 | + device = "cuda" |
| 191 | + hp_tensor = torch.randn(256, 512, dtype=dtype, device=device) |
| 192 | + block_size = (1, group_size) |
| 193 | + |
| 194 | + tensor = Int4TensorCoreTilePackedTensor.from_hp(hp_tensor, block_size) |
| 195 | + |
| 196 | + self.assertEqual(tensor.shape, hp_tensor.shape) |
| 197 | + self.assertEqual(tensor.block_size, block_size) |
| 198 | + |
| 199 | + def test_error_conditions(self): |
| 200 | + """Test various error conditions""" |
| 201 | + dtype = torch.bfloat16 |
| 202 | + device = "cuda" |
| 203 | + hp_tensor = torch.randn(128, 256, dtype=dtype, device=device) |
| 204 | + |
| 205 | + # Test invalid block_size length |
| 206 | + with self.assertRaises(AssertionError): |
| 207 | + Int4TensorCoreTilePackedTensor.from_hp( |
| 208 | + hp_tensor, (64,) |
| 209 | + ) # block_size length mismatch |
| 210 | + |
| 211 | + # Test non-groupwise quantization |
| 212 | + with self.assertRaises(AssertionError): |
| 213 | + Int4TensorCoreTilePackedTensor.from_hp( |
| 214 | + hp_tensor, (2, 64) |
| 215 | + ) # first element should be 1 |
| 216 | + |
| 217 | + |
| 218 | +instantiate_parametrized_tests(TestInt4TensorCoreTilePackedTensor) |
| 219 | + |
| 220 | + |
| 221 | +if __name__ == "__main__": |
| 222 | + run_tests() |
0 commit comments