|
13 | 13 | StretchedUnifTorchaoQuantizer, |
14 | 14 | ) |
15 | 15 | from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import Int8LutTensor |
16 | | -from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64 |
17 | | -from torchao.quantization import MappingType |
| 16 | +from torchao.prototype.tensor_conversion.api import ( |
| 17 | + _convert_model_for_aarch64, |
| 18 | + _convert_to_packed_tensor_based_on_current_hardware, |
| 19 | +) |
| 20 | +from torchao.quantization import ( |
| 21 | + Int4PreshuffledTensor, |
| 22 | + Int4Tensor, |
| 23 | + MappingType, |
| 24 | +) |
18 | 25 | from torchao.quantization.granularity import PerAxis, PerGroup |
19 | 26 | from torchao.quantization.quant_api import ( |
| 27 | + Int4WeightOnlyConfig, |
20 | 28 | Int8DynamicActivationIntxWeightConfig, |
21 | 29 | IntxWeightOnlyConfig, |
22 | 30 | quantize_, |
@@ -178,3 +186,21 @@ def test_aarch64_conversion(dtype, granularity, bit_width, lead_dim): |
178 | 186 | assert ep.graph_module.code.count(line) == cnt, ( |
179 | 187 | f"expected {cnt} {line} in {ep.graph_module.code}" |
180 | 188 | ) |
| 189 | + |
| 190 | + |
| 191 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA") |
| 192 | +def test_int4_tensor_conversion(): |
| 193 | + m = torch.nn.Sequential( |
| 194 | + torch.nn.Linear(256, 512, dtype=torch.bfloat16, device="cuda") |
| 195 | + ) |
| 196 | + quantize_(m, Int4WeightOnlyConfig(group_size=128)) |
| 197 | + weight = m[0].weight |
| 198 | + assert isinstance(weight, Int4Tensor) |
| 199 | + example_inputs = (torch.randn(32, 256, dtype=torch.bfloat16, device="cuda"),) |
| 200 | + before_conversion = m(*example_inputs) |
| 201 | + m[0].weight = torch.nn.Parameter( |
| 202 | + _convert_to_packed_tensor_based_on_current_hardware(weight), requires_grad=False |
| 203 | + ) |
| 204 | + after_conversion = m(*example_inputs) |
| 205 | + assert isinstance(m[0].weight, Int4PreshuffledTensor) |
| 206 | + assert torch.equal(before_conversion, after_conversion) |
0 commit comments