Skip to content

Commit 03e2c9b

Browse files
authored
Refactor rest of tinygemm quant primitive ops (#321)
Summary: This PR replaces the remaining tinygemm specific quant primitive ops with the general quant primitive ops that we want to use for everything, we could delete these ops in a separate PR if needed Test Plan: python test/quantization/test_quant_primitives.py -k test_get_groupwise_affine_qparams python test/quantization/test_quant_primitives.py -k test_groupwise_affine_quantize_tensor_from_qparams python test/quantization/test_quant_primitives.py -k test_groupwise_affine_dequantize_tensor_from_qparams accuracy: perf: no diff for generated code with `TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py`
1 parent 08fb8bf commit 03e2c9b

File tree

2 files changed

+142
-43
lines changed

2 files changed

+142
-43
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from torchao.quantization.quant_primitives import (
1212
get_group_qparams_symmetric,
1313
get_groupwise_affine_qparams,
14+
groupwise_affine_quantize_tensor_from_qparams,
15+
groupwise_affine_dequantize_tensor_from_qparams,
1416
quantize_affine,
1517
dequantize_affine,
1618
choose_qparams_affine,
@@ -38,6 +40,86 @@ def check_idempotent(self, fn, *args, **kwargs):
3840
self.assertTrue(torch.equal(output0, output1), f"Expected given function {fn} to be idempotent.")
3941
return output1
4042

43+
# Legacy tinygemm ops
44+
def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
45+
if groupsize > w.shape[-1]:
46+
groupsize = w.shape[-1]
47+
assert groupsize > 1
48+
assert w.shape[-1] % groupsize == 0
49+
assert w.dim() == 2
50+
51+
to_quant = w.reshape(-1, groupsize)
52+
# assert torch.isnan(to_quant).sum() == 0
53+
54+
max_val = to_quant.amax(dim=1, keepdim=True)
55+
min_val = to_quant.amin(dim=1, keepdim=True)
56+
max_int = 2**n_bit - 1
57+
scales = (max_val - min_val).clamp(min=1e-6) / max_int
58+
zeros = min_val + scales * (2 ** (n_bit - 1))
59+
return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to(
60+
dtype=dtype
61+
).reshape(w.shape[0], -1)
62+
63+
def _groupwise_affine_quantize_tensor_from_qparams(
64+
w,
65+
scales,
66+
zeros,
67+
n_bit=4,
68+
groupsize=128,
69+
):
70+
assert groupsize > 1
71+
# needed for GPTQ single column quantize
72+
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
73+
groupsize = w.shape[-1]
74+
75+
assert w.shape[-1] % groupsize == 0
76+
assert w.dim() == 2
77+
78+
to_quant = w.reshape(-1, groupsize)
79+
# assert torch.isnan(to_quant).sum() == 0
80+
81+
scales = scales.reshape(-1, 1)
82+
zeros = zeros.reshape(-1, 1)
83+
min_val = zeros - scales * (2 ** (n_bit - 1))
84+
max_int = 2**n_bit - 1
85+
min_int = 0
86+
w_int4x8 = (
87+
to_quant.sub(min_val)
88+
.div(scales)
89+
.round()
90+
.clamp_(min_int, max_int)
91+
.to(torch.int32)
92+
.reshape_as(w)
93+
)
94+
95+
return w_int4x8
96+
97+
def _groupwise_affine_dequantize_tensor_from_qparams(
98+
w_int4x8,
99+
scales,
100+
zeros,
101+
n_bit=4,
102+
groupsize=128,
103+
):
104+
assert groupsize > 1
105+
# needed for GPTQ single column dequantize
106+
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
107+
groupsize = w_int4x8.shape[-1]
108+
assert w_int4x8.shape[-1] % groupsize == 0
109+
assert w_int4x8.dim() == 2
110+
111+
w_int4x8_grouped = w_int4x8.reshape(-1, groupsize)
112+
scales = scales.reshape(-1, 1)
113+
zeros = zeros.reshape(-1, 1)
114+
115+
w_dq = (
116+
w_int4x8_grouped.sub(2 ** (n_bit - 1))
117+
.mul(scales)
118+
.add(zeros)
119+
.reshape_as(w_int4x8)
120+
)
121+
return w_dq
122+
41123

42124
class TestQuantPrimitives(unittest.TestCase):
43125
SEED = 123
@@ -356,12 +438,12 @@ def test_not_preserve_zero_not_supported(self):
356438
)
357439

358440

359-
def test_tinygemm_get_groupwise_affine_qparams(self):
441+
def test_get_groupwise_affine_qparams(self):
360442
from torchao.quantization.quant_primitives import ZeroPointDomain
361443

362444
input = torch.randn(10, 256)
363445
n_bit = 4
364-
scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)
446+
scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)
365447

366448
mapping_type = MappingType.ASYMMETRIC
367449
dtype = torch.int8
@@ -389,6 +471,29 @@ def test_tinygemm_get_groupwise_affine_qparams(self):
389471
self.assertTrue(torch.equal(scale, scale_ref))
390472
self.assertTrue(torch.equal(zero_point, zero_point_ref))
391473

474+
def test_groupwise_affine_quantize_tensor_from_qparams(self):
475+
input = torch.randn(10, 256)
476+
scales = torch.randn(10, 2)
477+
zeros = torch.randn(10, 2)
478+
n_bit = 4
479+
groupsize = 128
480+
481+
w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
482+
w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
483+
484+
self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref))
485+
486+
def test_groupwise_affine_dequantize_tensor_from_qparams(self):
487+
input = torch.randint(0, 15, (10, 256), dtype=torch.int32)
488+
scales = torch.randn(10, 2).bfloat16()
489+
zeros = torch.randn(10, 2).bfloat16()
490+
n_bit = 4
491+
groupsize = 128
492+
493+
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
494+
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
495+
496+
self.assertTrue(torch.equal(w_bf16, w_bf16_ref))
392497

393498
if __name__ == "__main__":
394499
unittest.main()

torchao/quantization/quant_primitives.py

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def dequantize_affine(
249249

250250
# TODO: validations
251251
# TODO: validate scale/zero_point dimensions are compatible with block_size
252-
assert input.dtype == input_dtype
252+
assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}"
253253
assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}"
254254
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
255255

@@ -644,22 +644,37 @@ def quant_int8_per_token_matmul(
644644

645645

646646
def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
647-
"""This is tinygemm specific, we'll keep this for now"""
648647
if groupsize > w.shape[-1]:
649648
groupsize = w.shape[-1]
650649
assert groupsize > 1
651650
assert w.shape[-1] % groupsize == 0
652651
assert w.dim() == 2
652+
assert n_bit <= 8, f"only n_bit smaller than 8 is supported, got: {n_bit}"
653653

654-
to_quant = w.reshape(-1, groupsize)
655-
# assert torch.isnan(to_quant).sum() == 0
654+
mapping_type = MappingType.ASYMMETRIC
655+
target_dtype = torch.int32
656+
block_size = (1, groupsize)
657+
quant_min = 0
658+
quant_max = 2**n_bit - 1
659+
eps = 1e-6
660+
scale_dtype = dtype
661+
zero_point_dtype = dtype
662+
663+
scale, zero_point = choose_qparams_affine(
664+
w,
665+
mapping_type,
666+
block_size,
667+
target_dtype,
668+
quant_min,
669+
quant_max,
670+
eps,
671+
scale_dtype=scale_dtype,
672+
zero_point_dtype=zero_point_dtype,
673+
preserve_zero=False,
674+
zero_point_domain=ZeroPointDomain.FLOAT
675+
)
656676

657-
max_val = to_quant.amax(dim=1, keepdim=True)
658-
min_val = to_quant.amin(dim=1, keepdim=True)
659-
max_int = 2**n_bit - 1
660-
scales = (max_val - min_val).clamp(min=1e-6) / max_int
661-
zeros = min_val + scales * (2 ** (n_bit - 1))
662-
return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to(
677+
return scale.to(dtype=dtype).reshape(w.shape[0], -1), zero_point.to(
663678
dtype=dtype
664679
).reshape(w.shape[0], -1)
665680

@@ -692,7 +707,6 @@ def groupwise_affine_quantize_tensor_from_qparams(
692707
n_bit=4,
693708
groupsize=128,
694709
):
695-
"""This is tinygemm specific, we'll keep this for now"""
696710
assert groupsize > 1
697711
# needed for GPTQ single column quantize
698712
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
@@ -701,25 +715,12 @@ def groupwise_affine_quantize_tensor_from_qparams(
701715
assert w.shape[-1] % groupsize == 0
702716
assert w.dim() == 2
703717

704-
to_quant = w.reshape(-1, groupsize)
705-
# assert torch.isnan(to_quant).sum() == 0
706-
707-
scales = scales.reshape(-1, 1)
708-
zeros = zeros.reshape(-1, 1)
709-
min_val = zeros - scales * (2 ** (n_bit - 1))
710-
max_int = 2**n_bit - 1
711-
min_int = 0
712-
w_int4x8 = (
713-
to_quant.sub(min_val)
714-
.div(scales)
715-
.round()
716-
.clamp_(min_int, max_int)
717-
.to(torch.int32)
718-
.reshape_as(w)
719-
)
720-
721-
return w_int4x8
718+
block_size = (1, groupsize)
719+
output_dtype = torch.int32
720+
quant_min = 0
721+
quant_max = 2 ** n_bit - 1
722722

723+
return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
723724

724725
def groupwise_affine_dequantize_tensor_from_qparams(
725726
w_int4x8,
@@ -728,25 +729,18 @@ def groupwise_affine_dequantize_tensor_from_qparams(
728729
n_bit=4,
729730
groupsize=128,
730731
):
731-
"""This is tinygemm specific, we'll keep this for now"""
732732
assert groupsize > 1
733733
# needed for GPTQ single column dequantize
734734
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
735735
groupsize = w_int4x8.shape[-1]
736736
assert w_int4x8.shape[-1] % groupsize == 0
737737
assert w_int4x8.dim() == 2
738738

739-
w_int4x8_grouped = w_int4x8.reshape(-1, groupsize)
740-
scales = scales.reshape(-1, 1)
741-
zeros = zeros.reshape(-1, 1)
742-
743-
w_dq = (
744-
w_int4x8_grouped.sub(2 ** (n_bit - 1))
745-
.mul(scales)
746-
.add(zeros)
747-
.reshape_as(w_int4x8)
748-
)
749-
return w_dq
739+
block_size = (1, groupsize)
740+
input_dtype = torch.int32
741+
quant_min = 0
742+
quant_max = 2**n_bit - 1
743+
return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype)
750744

751745

752746
def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):

0 commit comments

Comments
 (0)