Skip to content

Commit da45eb6

Browse files
committed
Add main tensor conversion API for packed tensors
Summary: Added `_convert_to_packed_tensor_based_on_current_hardware` to convert a tensor from the unpacked / plain version to a packed version This is to enable vllm for packed weights, vllm will do a slice for the quantized weight, but slicing is not always supported for all torchao tensor subclasses. So we want to first ship an plain / unpacked checkpoint and then convert to the packed version using this API Test Plan: pytest test/prototype/test_tensor_conversion.py Reviewers: Subscribers: Tasks: Tags:
1 parent 122b307 commit da45eb6

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed

test/prototype/test_tensor_conversion.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,18 @@
1313
StretchedUnifTorchaoQuantizer,
1414
)
1515
from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import Int8LutTensor
16-
from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64
17-
from torchao.quantization import MappingType
16+
from torchao.prototype.tensor_conversion.api import (
17+
_convert_model_for_aarch64,
18+
_convert_to_packed_tensor_based_on_current_hardware,
19+
)
20+
from torchao.quantization import (
21+
Int4PreshuffledTensor,
22+
Int4Tensor,
23+
MappingType,
24+
)
1825
from torchao.quantization.granularity import PerAxis, PerGroup
1926
from torchao.quantization.quant_api import (
27+
Int4WeightOnlyConfig,
2028
Int8DynamicActivationIntxWeightConfig,
2129
IntxWeightOnlyConfig,
2230
quantize_,
@@ -178,3 +186,21 @@ def test_aarch64_conversion(dtype, granularity, bit_width, lead_dim):
178186
assert ep.graph_module.code.count(line) == cnt, (
179187
f"expected {cnt} {line} in {ep.graph_module.code}"
180188
)
189+
190+
191+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA")
192+
def test_int4_tensor_conversion():
193+
m = torch.nn.Sequential(
194+
torch.nn.Linear(256, 512, dtype=torch.bfloat16, device="cuda")
195+
)
196+
quantize_(m, Int4WeightOnlyConfig(group_size=128))
197+
weight = m[0].weight
198+
assert isinstance(weight, Int4Tensor)
199+
example_inputs = (torch.randn(32, 256, dtype=torch.bfloat16, device="cuda"),)
200+
before_conversion = m(*example_inputs)
201+
m[0].weight = torch.nn.Parameter(
202+
_convert_to_packed_tensor_based_on_current_hardware(weight), requires_grad=False
203+
)
204+
after_conversion = m(*example_inputs)
205+
assert isinstance(m[0].weight, Int4PreshuffledTensor)
206+
assert torch.equal(before_conversion, after_conversion)

torchao/prototype/tensor_conversion/api.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77
import torch
88
import torch.nn as nn
99

10-
from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor
10+
# TODO: move the function to torchao.utils
11+
from torchao.dtypes.utils import is_device
12+
from torchao.quantization import (
13+
Int4PreshuffledTensor,
14+
Int4Tensor,
15+
IntxUnpackedToInt8Tensor,
16+
)
17+
from torchao.utils import TorchAOBaseTensor
1118

1219

1320
def _convert_linear_weight_to_int8_lut_tensor(module):
@@ -156,3 +163,19 @@ def _convert_model_for_aarch64(
156163
raise ValueError(f"Unexpected tensor_type={tensor_type}")
157164

158165
return model
166+
167+
168+
def _convert_to_packed_tensor_based_on_current_hardware(tensor: TorchAOBaseTensor):
169+
"""Convert a plain / unpacked torchao tensor to a packed one based on hardware
170+
171+
Goal is to have an optimized performance on current hardware, while also allow
172+
us to
173+
(1). distribute a single unpacked / plain format that can be used in multiple hardwares
174+
(2). support the vLLM use case, where we need to slice the weights for distributed
175+
inference. Since slice is not always supported in packed weight, we would like to first
176+
load plain / unpacked weight, slice it and then convert to packed weight to get the best
177+
inference speed
178+
"""
179+
if isinstance(tensor, Int4Tensor) and is_device("cuda", tensor.device):
180+
return Int4PreshuffledTensor.from_int4_tensor(tensor)
181+
return tensor

0 commit comments

Comments
 (0)