Skip to content

Commit f08f6de

Browse files
committed
Add support for AQTStorage and PlainAQTStorage
Summary: Today `AffineQuantizedTensor` has hardcoded storage format of `int_data`, `scale`, `zero_point`. But this does not work if we want to support packed weight. In this PR, we added support to hide the storage details for `AffineQuantizedTensor` in a family of tensor subclasses, all should inherit from the base Storage type: `AQTStorage` (affine quantized tensor storage) This PR just added support for a plain storage tensor (`PlainAQTStorage`) that stores `int_data`, `scale` and `zero_point` tensors directly, in the next PR we'll also support storing packed weight (result of `torch.ops.aten._convert_weight_to_int4pack`) in a different type of `AQTStorage`. `AffineQuantizedTensor` will have the following: - storage_tensor: AQTStorage (can store data of different storage formats) - storage_layout: str (a string represents the type of storage_tensor we have, can be used in dispatch) Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags:
1 parent 42c2376 commit f08f6de

File tree

3 files changed

+259
-53
lines changed

3 files changed

+259
-53
lines changed

test/quantization/test_quant_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def __init__(self, m=64, n=32, k=64):
110110
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
111111
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)
112112

113-
def example_inputs(self, batch_size=1):
114-
return (torch.randn(batch_size, self.linear1.in_features).to(torch.float),)
113+
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
114+
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
115115

116116
def forward(self, x):
117117
x = self.linear1(x)
@@ -450,7 +450,7 @@ def test_quantized_tensor_subclass_int4(self):
450450
# use 1024 so that we don't need padding
451451
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
452452
m_copy = copy.deepcopy(m)
453-
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))
453+
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
454454

455455
groupsize = 32
456456
m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize))
@@ -496,7 +496,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
496496
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
497497
m_copy = copy.deepcopy(m)
498498
# setting batch_size to 20 to be compatible with the kernel
499-
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20)))
499+
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
500500
m = quantize(m, get_apply_int8dyn_quant())
501501

502502
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)

0 commit comments

Comments
 (0)