Skip to content

Add composable QAT quantizer #938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from torchao.dtypes import (
TensorCoreTiledLayoutType,
)
from torchao.quantization.prototype.qat.api import (
ComposableQATQuantizer,
)
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)
Expand All @@ -34,6 +37,9 @@
MappingType,
ZeroPointDomain,
)
from torchao.quantization.unified import (
TwoStepQuantizer,
)
from torchao.quantization.utils import (
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
Expand Down Expand Up @@ -626,6 +632,42 @@ def test_qat_4w_quantizer_module_swap(self):
module_swap_out = module_swap_model(*x2)
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)

class _MyQATQuantizer(TwoStepQuantizer):
"""
Dummy quantizer that attaches a certain value to each nn.Linear's
`_temp_quantizer_values` attribute.
"""
ATTR_NAME = "_temp_quantizer_values"

def __init__(self, value: str):
self.value = value

def _attach_value(self, module: torch.nn.Module):
if isinstance(module, torch.nn.Linear):
if not hasattr(module, self.ATTR_NAME):
setattr(module, self.ATTR_NAME, [])
getattr(module, self.ATTR_NAME).append(self.value)

def prepare(self, model: torch.nn.Module):
model.apply(self._attach_value)
return model

def convert(self, model: torch.nn.Module):
model.apply(self._attach_value)
return model

def test_composable_qat_quantizer(self):
quantizer1 = self._MyQATQuantizer("quantizer1")
quantizer2 = self._MyQATQuantizer("quantizer2")
composable_quantizer = ComposableQATQuantizer([quantizer1, quantizer2])
model = M()
model = composable_quantizer.prepare(model)
self.assertTrue(hasattr(model.linear1, self._MyQATQuantizer.ATTR_NAME))
values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME)
self.assertEqual(values_list, ["quantizer1", "quantizer2"])
composable_quantizer.convert(model)
values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME)
self.assertEqual(values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"])

if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchao/quantization/prototype/qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
enable_8da4w_fake_quant,
int4_weight_only_fake_quantize,
int8_dynamic_activation_int4_weight_fake_quantize,
ComposableQATQuantizer,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
Expand All @@ -20,6 +21,7 @@
"enable_8da4w_fake_quant",
"int4_weight_only_fake_quantize",
"int8_dynamic_activation_int4_weight_fake_quantize",
"ComposableQATQuantizer",
"Int4WeightOnlyQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
"Int8DynActInt4WeightQATLinear",
Expand Down
46 changes: 43 additions & 3 deletions torchao/quantization/prototype/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional
from typing import Any, List, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -34,6 +34,44 @@
)


class ComposableQATQuantizer(TwoStepQuantizer):
"""
Composable quantizer that users can use to apply multiple QAT quantizers easily.
Quantizers will be applied in the order they are specified in the constructor.

Note: the quantizers provided must apply to different modules in the model,
e.g. nn.Linear and nn.Embedding, otherwise the behavior will be undefined.

Example usage::

my_quantizer = ComposableQATQuantizer([
QATQuantizer1(),
QATQuantizer2(),
QATQuantizer3(),
])
model = my_quantizer.prepare(model)
train(model)
model = my_quantizer.convert(model)
"""

def __init__(self, quantizers: List[TwoStepQuantizer]):
self.quantizers = quantizers

def prepare(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
for quantizer in self.quantizers:
model = quantizer.prepare(model)
return model

def convert(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
for quantizer in self.quantizers:
model = quantizer.convert(model)
Copy link
Contributor

@jerryzh168 jerryzh168 Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the order be reversed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline: the quantizers are supposed to apply to different modules in the same model. Added some docs to clarify instead

return model


# =================
# | 8da4w QAT |
# =================
Expand All @@ -44,7 +82,8 @@ def int8_dynamic_activation_int4_weight_fake_quantize(group_size=32):
int4 per group weight symmetric fake quantization to linear. Please see
:func:`~torchao.quantization.int8_dynamic_activation_int4_weight` for more details.

Example usage:
Example usage::

from torchao.quantization import quantize_
quantize_(model, int8_dynamic_activation_int4_weight_fake_quantize(group_size=32))
"""
Expand Down Expand Up @@ -151,7 +190,8 @@ def int4_weight_only_fake_quantize(group_size=128):
Applies uint4 weight-only asymmetric per-group fake quantization to linear layers.
Please see :func:`~torchao.quantization.int4_weight_only` for more details.

Example usage:
Example usage::

from torchao.quantization import quantize_
quantize_(model, int4_weight_only_fake_quantize(group_size=32))
"""
Expand Down
6 changes: 2 additions & 4 deletions torchao/quantization/unified.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from typing import Any
from typing import Any, List
from abc import ABC, abstractmethod

"""
Expand All @@ -17,7 +17,6 @@ class Quantizer(ABC):
def quantize(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:

pass


Expand All @@ -27,11 +26,10 @@ class TwoStepQuantizer:
def prepare(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:

pass

@abstractmethod
def convert(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:

pass
Loading