Skip to content

Commit 6a4d064

Browse files
authored
Add composable QAT quantizer (#938)
Summary: This is a utility for users who wish to apply multiple QAT quantizers to their models. In the near future, we expect to add an embedding QAT quantizer that composes with the existing linear QAT quantizers. Test Plan: python test/quantization/test_qat.py -k test_composable_qat_quantizer
1 parent fbe97a0 commit 6a4d064

File tree

4 files changed

+89
-7
lines changed

4 files changed

+89
-7
lines changed

test/quantization/test_qat.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from torchao.dtypes import (
1616
TensorCoreTiledLayoutType,
1717
)
18+
from torchao.quantization.prototype.qat.api import (
19+
ComposableQATQuantizer,
20+
)
1821
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
1922
AffineFakeQuantizedTensor,
2023
)
@@ -34,6 +37,9 @@
3437
MappingType,
3538
ZeroPointDomain,
3639
)
40+
from torchao.quantization.unified import (
41+
TwoStepQuantizer,
42+
)
3743
from torchao.quantization.utils import (
3844
get_group_qparams_symmetric,
3945
get_groupwise_affine_qparams,
@@ -626,6 +632,42 @@ def test_qat_4w_quantizer_module_swap(self):
626632
module_swap_out = module_swap_model(*x2)
627633
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)
628634

635+
class _MyQATQuantizer(TwoStepQuantizer):
636+
"""
637+
Dummy quantizer that attaches a certain value to each nn.Linear's
638+
`_temp_quantizer_values` attribute.
639+
"""
640+
ATTR_NAME = "_temp_quantizer_values"
641+
642+
def __init__(self, value: str):
643+
self.value = value
644+
645+
def _attach_value(self, module: torch.nn.Module):
646+
if isinstance(module, torch.nn.Linear):
647+
if not hasattr(module, self.ATTR_NAME):
648+
setattr(module, self.ATTR_NAME, [])
649+
getattr(module, self.ATTR_NAME).append(self.value)
650+
651+
def prepare(self, model: torch.nn.Module):
652+
model.apply(self._attach_value)
653+
return model
654+
655+
def convert(self, model: torch.nn.Module):
656+
model.apply(self._attach_value)
657+
return model
658+
659+
def test_composable_qat_quantizer(self):
660+
quantizer1 = self._MyQATQuantizer("quantizer1")
661+
quantizer2 = self._MyQATQuantizer("quantizer2")
662+
composable_quantizer = ComposableQATQuantizer([quantizer1, quantizer2])
663+
model = M()
664+
model = composable_quantizer.prepare(model)
665+
self.assertTrue(hasattr(model.linear1, self._MyQATQuantizer.ATTR_NAME))
666+
values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME)
667+
self.assertEqual(values_list, ["quantizer1", "quantizer2"])
668+
composable_quantizer.convert(model)
669+
values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME)
670+
self.assertEqual(values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"])
629671

630672
if __name__ == "__main__":
631673
unittest.main()

torchao/quantization/prototype/qat/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
enable_8da4w_fake_quant,
66
int4_weight_only_fake_quantize,
77
int8_dynamic_activation_int4_weight_fake_quantize,
8+
ComposableQATQuantizer,
89
Int4WeightOnlyQATQuantizer,
910
Int8DynActInt4WeightQATQuantizer,
1011
)
@@ -20,6 +21,7 @@
2021
"enable_8da4w_fake_quant",
2122
"int4_weight_only_fake_quantize",
2223
"int8_dynamic_activation_int4_weight_fake_quantize",
24+
"ComposableQATQuantizer",
2325
"Int4WeightOnlyQATQuantizer",
2426
"Int8DynActInt4WeightQATQuantizer",
2527
"Int8DynActInt4WeightQATLinear",

torchao/quantization/prototype/qat/api.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Optional
7+
from typing import Any, List, Optional
88

99
import torch
1010
import torch.nn.functional as F
@@ -34,6 +34,44 @@
3434
)
3535

3636

37+
class ComposableQATQuantizer(TwoStepQuantizer):
38+
"""
39+
Composable quantizer that users can use to apply multiple QAT quantizers easily.
40+
Quantizers will be applied in the order they are specified in the constructor.
41+
42+
Note: the quantizers provided must apply to different modules in the model,
43+
e.g. nn.Linear and nn.Embedding, otherwise the behavior will be undefined.
44+
45+
Example usage::
46+
47+
my_quantizer = ComposableQATQuantizer([
48+
QATQuantizer1(),
49+
QATQuantizer2(),
50+
QATQuantizer3(),
51+
])
52+
model = my_quantizer.prepare(model)
53+
train(model)
54+
model = my_quantizer.convert(model)
55+
"""
56+
57+
def __init__(self, quantizers: List[TwoStepQuantizer]):
58+
self.quantizers = quantizers
59+
60+
def prepare(
61+
self, model: torch.nn.Module, *args: Any, **kwargs: Any
62+
) -> torch.nn.Module:
63+
for quantizer in self.quantizers:
64+
model = quantizer.prepare(model)
65+
return model
66+
67+
def convert(
68+
self, model: torch.nn.Module, *args: Any, **kwargs: Any
69+
) -> torch.nn.Module:
70+
for quantizer in self.quantizers:
71+
model = quantizer.convert(model)
72+
return model
73+
74+
3775
# =================
3876
# | 8da4w QAT |
3977
# =================
@@ -44,7 +82,8 @@ def int8_dynamic_activation_int4_weight_fake_quantize(group_size=32):
4482
int4 per group weight symmetric fake quantization to linear. Please see
4583
:func:`~torchao.quantization.int8_dynamic_activation_int4_weight` for more details.
4684
47-
Example usage:
85+
Example usage::
86+
4887
from torchao.quantization import quantize_
4988
quantize_(model, int8_dynamic_activation_int4_weight_fake_quantize(group_size=32))
5089
"""
@@ -151,7 +190,8 @@ def int4_weight_only_fake_quantize(group_size=128):
151190
Applies uint4 weight-only asymmetric per-group fake quantization to linear layers.
152191
Please see :func:`~torchao.quantization.int4_weight_only` for more details.
153192
154-
Example usage:
193+
Example usage::
194+
155195
from torchao.quantization import quantize_
156196
quantize_(model, int4_weight_only_fake_quantize(group_size=32))
157197
"""

torchao/quantization/unified.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from typing import Any
2+
from typing import Any, List
33
from abc import ABC, abstractmethod
44

55
"""
@@ -17,7 +17,6 @@ class Quantizer(ABC):
1717
def quantize(
1818
self, model: torch.nn.Module, *args: Any, **kwargs: Any
1919
) -> torch.nn.Module:
20-
2120
pass
2221

2322

@@ -27,11 +26,10 @@ class TwoStepQuantizer:
2726
def prepare(
2827
self, model: torch.nn.Module, *args: Any, **kwargs: Any
2928
) -> torch.nn.Module:
30-
3129
pass
3230

31+
@abstractmethod
3332
def convert(
3433
self, model: torch.nn.Module, *args: Any, **kwargs: Any
3534
) -> torch.nn.Module:
36-
3735
pass

0 commit comments

Comments
 (0)