Skip to content

Commit 8353c20

Browse files
committed
Add support for AQTStorage and PlainAQTStorage
Summary: Today `AffineQuantizedTensor` has hardcoded storage format of `int_data`, `scale`, `zero_point`. But this does not work if we want to support packed weight. In this PR, we added support to hide the storage details for `AffineQuantizedTensor` in a family of tensor subclasses, all should inherit from the base Storage type: `AQTStorage` (affine quantized tensor storage) This PR just added support for a plain storage tensor (`PlainAQTStorage`) that stores `int_data`, `scale` and `zero_point` tensors directly, in the next PR we'll also support storing packed weight (result of `torch.ops.aten._convert_weight_to_int4pack`) in a different type of `AQTStorage`. `AffineQuantizedTensor` will have the following: - storage_tensor: AQTStorage (can store data of different storage formats) - storage_layout: str (a string represents the type of storage_tensor we have, can be used in dispatch) Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags:
1 parent 90b5e17 commit 8353c20

File tree

2 files changed

+259
-53
lines changed

2 files changed

+259
-53
lines changed

test/quantization/test_quant_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def __init__(self, m=64, n=32, k=64):
106106
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
107107
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)
108108

109-
def example_inputs(self, batch_size=1):
110-
return (torch.randn(batch_size, self.linear1.in_features).to(torch.float),)
109+
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
110+
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
111111

112112
def forward(self, x):
113113
x = self.linear1(x)
@@ -482,10 +482,10 @@ def test_quantized_tensor_subclass_int4(self):
482482
# use 1024 so that we don't need padding
483483
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
484484
m_copy = copy.deepcopy(m)
485-
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))
485+
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
486486

487487
def apply_weight_quant(weight):
488-
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
488+
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled")
489489

490490
m = quantize(m, apply_weight_quant)
491491
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
@@ -562,7 +562,7 @@ def get_per_token_block_size(x):
562562
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
563563
m_copy = copy.deepcopy(m)
564564
# setting batch_size to 20 to be compatible with the kernel
565-
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20)))
565+
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
566566

567567
def apply_weight_quant(weight):
568568
block_size = get_weight_block_size(weight)

0 commit comments

Comments
 (0)