Skip to content

Commit 33aa8e7

Browse files
committed
Refactor rest of tinygemm quant primitive ops
Summary: This PR replaces the remaining tinygemm specific quant primitive ops with the general quant primitive ops that we want to use for everything, we could delete these ops in a separate PR if needed Test Plan: python test/quantization/test_quant_primitives.py -k test_get_groupwise_affine_qparams python test/quantization/test_quant_primitives.py -k test_groupwise_affine_quantize_tensor_from_qparams python test/quantization/test_quant_primitives.py -k test_groupwise_affine_dequantize_tensor_from_qparams accuracy: perf: no diff for generated code with `TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py`
1 parent 338d87c commit 33aa8e7

File tree

2 files changed

+142
-43
lines changed

2 files changed

+142
-43
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from torchao.quantization.quant_primitives import (
1212
get_group_qparams_symmetric,
1313
get_groupwise_affine_qparams,
14+
groupwise_affine_quantize_tensor_from_qparams,
15+
groupwise_affine_dequantize_tensor_from_qparams,
1416
quantize_affine,
1517
dequantize_affine,
1618
choose_qparams_affine,
@@ -38,6 +40,86 @@ def check_idempotent(self, fn, *args, **kwargs):
3840
self.assertTrue(torch.equal(output0, output1), f"Expected given function {fn} to be idempotent.")
3941
return output1
4042

43+
# Legacy tinygemm ops
44+
def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
45+
if groupsize > w.shape[-1]:
46+
groupsize = w.shape[-1]
47+
assert groupsize > 1
48+
assert w.shape[-1] % groupsize == 0
49+
assert w.dim() == 2
50+
51+
to_quant = w.reshape(-1, groupsize)
52+
# assert torch.isnan(to_quant).sum() == 0
53+
54+
max_val = to_quant.amax(dim=1, keepdim=True)
55+
min_val = to_quant.amin(dim=1, keepdim=True)
56+
max_int = 2**n_bit - 1
57+
scales = (max_val - min_val).clamp(min=1e-6) / max_int
58+
zeros = min_val + scales * (2 ** (n_bit - 1))
59+
return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to(
60+
dtype=dtype
61+
).reshape(w.shape[0], -1)
62+
63+
def _groupwise_affine_quantize_tensor_from_qparams(
64+
w,
65+
scales,
66+
zeros,
67+
n_bit=4,
68+
groupsize=128,
69+
):
70+
assert groupsize > 1
71+
# needed for GPTQ single column quantize
72+
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
73+
groupsize = w.shape[-1]
74+
75+
assert w.shape[-1] % groupsize == 0
76+
assert w.dim() == 2
77+
78+
to_quant = w.reshape(-1, groupsize)
79+
# assert torch.isnan(to_quant).sum() == 0
80+
81+
scales = scales.reshape(-1, 1)
82+
zeros = zeros.reshape(-1, 1)
83+
min_val = zeros - scales * (2 ** (n_bit - 1))
84+
max_int = 2**n_bit - 1
85+
min_int = 0
86+
w_int4x8 = (
87+
to_quant.sub(min_val)
88+
.div(scales)
89+
.round()
90+
.clamp_(min_int, max_int)
91+
.to(torch.int32)
92+
.reshape_as(w)
93+
)
94+
95+
return w_int4x8
96+
97+
def _groupwise_affine_dequantize_tensor_from_qparams(
98+
w_int4x8,
99+
scales,
100+
zeros,
101+
n_bit=4,
102+
groupsize=128,
103+
):
104+
assert groupsize > 1
105+
# needed for GPTQ single column dequantize
106+
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
107+
groupsize = w_int4x8.shape[-1]
108+
assert w_int4x8.shape[-1] % groupsize == 0
109+
assert w_int4x8.dim() == 2
110+
111+
w_int4x8_grouped = w_int4x8.reshape(-1, groupsize)
112+
scales = scales.reshape(-1, 1)
113+
zeros = zeros.reshape(-1, 1)
114+
115+
w_dq = (
116+
w_int4x8_grouped.sub(2 ** (n_bit - 1))
117+
.mul(scales)
118+
.add(zeros)
119+
.reshape_as(w_int4x8)
120+
)
121+
return w_dq
122+
41123

42124
class TestQuantPrimitives(unittest.TestCase):
43125
SEED = 123
@@ -356,12 +438,12 @@ def test_not_preserve_zero_not_supported(self):
356438
)
357439

358440

359-
def test_tinygemm_get_groupwise_affine_qparams(self):
441+
def test_get_groupwise_affine_qparams(self):
360442
from torchao.quantization.quant_primitives import ZeroPointDomain
361443

362444
input = torch.randn(10, 256)
363445
n_bit = 4
364-
scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)
446+
scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)
365447

366448
mapping_type = MappingType.ASYMMETRIC
367449
dtype = torch.int8
@@ -389,6 +471,29 @@ def test_tinygemm_get_groupwise_affine_qparams(self):
389471
self.assertTrue(torch.equal(scale, scale_ref))
390472
self.assertTrue(torch.equal(zero_point, zero_point_ref))
391473

474+
def test_groupwise_affine_quantize_tensor_from_qparams(self):
475+
input = torch.randn(10, 256)
476+
scales = torch.randn(10, 2)
477+
zeros = torch.randn(10, 2)
478+
n_bit = 4
479+
groupsize = 128
480+
481+
w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
482+
w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
483+
484+
self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref))
485+
486+
def test_groupwise_affine_dequantize_tensor_from_qparams(self):
487+
input = torch.randint(0, 15, (10, 256), dtype=torch.int32)
488+
scales = torch.randn(10, 2).bfloat16()
489+
zeros = torch.randn(10, 2).bfloat16()
490+
n_bit = 4
491+
groupsize = 128
492+
493+
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
494+
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
495+
496+
self.assertTrue(torch.equal(w_bf16, w_bf16_ref))
392497

393498
if __name__ == "__main__":
394499
unittest.main()

torchao/quantization/quant_primitives.py

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def dequantize_affine(
252252

253253
# TODO: validations
254254
# TODO: validate scale/zero_point dimensions are compatible with block_size
255-
assert input.dtype == input_dtype
255+
assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}"
256256
assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}"
257257
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
258258

@@ -647,22 +647,37 @@ def quant_int8_per_token_matmul(
647647

648648

649649
def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
650-
"""This is tinygemm specific, we'll keep this for now"""
651650
if groupsize > w.shape[-1]:
652651
groupsize = w.shape[-1]
653652
assert groupsize > 1
654653
assert w.shape[-1] % groupsize == 0
655654
assert w.dim() == 2
655+
assert n_bit <= 8, f"only n_bit smaller than 8 is supported, got: {n_bit}"
656656

657-
to_quant = w.reshape(-1, groupsize)
658-
# assert torch.isnan(to_quant).sum() == 0
657+
mapping_type = MappingType.ASYMMETRIC
658+
target_dtype = torch.int32
659+
block_size = (1, groupsize)
660+
quant_min = 0
661+
quant_max = 2**n_bit - 1
662+
eps = 1e-6
663+
scale_dtype = dtype
664+
zero_point_dtype = dtype
659665

660-
max_val = to_quant.amax(dim=1, keepdim=True)
661-
min_val = to_quant.amin(dim=1, keepdim=True)
662-
max_int = 2**n_bit - 1
663-
scales = (max_val - min_val).clamp(min=1e-6) / max_int
664-
zeros = min_val + scales * (2 ** (n_bit - 1))
665-
return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to(
666+
scale, zero_point = choose_qparams_affine(
667+
w,
668+
mapping_type,
669+
block_size,
670+
target_dtype,
671+
quant_min,
672+
quant_max,
673+
eps,
674+
scale_dtype=scale_dtype,
675+
zero_point_dtype=zero_point_dtype,
676+
preserve_zero=False,
677+
zero_point_domain=ZeroPointDomain.FLOAT
678+
)
679+
680+
return scale.to(dtype=dtype).reshape(w.shape[0], -1), zero_point.to(
666681
dtype=dtype
667682
).reshape(w.shape[0], -1)
668683

@@ -695,7 +710,6 @@ def groupwise_affine_quantize_tensor_from_qparams(
695710
n_bit=4,
696711
groupsize=128,
697712
):
698-
"""This is tinygemm specific, we'll keep this for now"""
699713
assert groupsize > 1
700714
# needed for GPTQ single column quantize
701715
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
@@ -704,25 +718,12 @@ def groupwise_affine_quantize_tensor_from_qparams(
704718
assert w.shape[-1] % groupsize == 0
705719
assert w.dim() == 2
706720

707-
to_quant = w.reshape(-1, groupsize)
708-
# assert torch.isnan(to_quant).sum() == 0
709-
710-
scales = scales.reshape(-1, 1)
711-
zeros = zeros.reshape(-1, 1)
712-
min_val = zeros - scales * (2 ** (n_bit - 1))
713-
max_int = 2**n_bit - 1
714-
min_int = 0
715-
w_int4x8 = (
716-
to_quant.sub(min_val)
717-
.div(scales)
718-
.round()
719-
.clamp_(min_int, max_int)
720-
.to(torch.int32)
721-
.reshape_as(w)
722-
)
723-
724-
return w_int4x8
721+
block_size = (1, groupsize)
722+
output_dtype = torch.int32
723+
quant_min = 0
724+
quant_max = 2 ** n_bit - 1
725725

726+
return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
726727

727728
def groupwise_affine_dequantize_tensor_from_qparams(
728729
w_int4x8,
@@ -731,25 +732,18 @@ def groupwise_affine_dequantize_tensor_from_qparams(
731732
n_bit=4,
732733
groupsize=128,
733734
):
734-
"""This is tinygemm specific, we'll keep this for now"""
735735
assert groupsize > 1
736736
# needed for GPTQ single column dequantize
737737
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
738738
groupsize = w_int4x8.shape[-1]
739739
assert w_int4x8.shape[-1] % groupsize == 0
740740
assert w_int4x8.dim() == 2
741741

742-
w_int4x8_grouped = w_int4x8.reshape(-1, groupsize)
743-
scales = scales.reshape(-1, 1)
744-
zeros = zeros.reshape(-1, 1)
745-
746-
w_dq = (
747-
w_int4x8_grouped.sub(2 ** (n_bit - 1))
748-
.mul(scales)
749-
.add(zeros)
750-
.reshape_as(w_int4x8)
751-
)
752-
return w_dq
742+
block_size = (1, groupsize)
743+
input_dtype = torch.int32
744+
quant_min = 0
745+
quant_max = 2**n_bit - 1
746+
return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype)
753747

754748

755749
def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):

0 commit comments

Comments
 (0)