Skip to content

Commit 9fa6fd7

Browse files
drisspgHDCharles
authored andcommitted
[StaticQuant] add a linear observer class and test (#807)
1 parent f4c8109 commit 9fa6fd7

File tree

6 files changed

+361
-11
lines changed

6 files changed

+361
-11
lines changed

ruff.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ include = [
88
"torchao/dtypes/nf4tensor.py",
99
"test/dtypes/test_nf4.py",
1010
"torchao/float8/float8_tensor.py",
11+
"torchao/quantization/linear_activation_weight_observer.py",
12+
"test/quantization/test_observer.py",
1113
]

test/quantization/test_observer.py

Lines changed: 108 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
import torch
3+
import torch.nn as nn
34
from torch.testing._internal.common_utils import TestCase
45
from torchao.quantization.observer import (
56
AffineQuantizedMinMaxObserver,
@@ -9,13 +10,23 @@
910
from torchao.quantization.quant_primitives import (
1011
MappingType,
1112
)
13+
from torchao.quantization.quant_api import (
14+
insert_observers_,
15+
)
16+
from torch.testing._internal import common_utils
1217
import unittest
18+
1319
# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao
1420
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
1521

22+
1623
class TestQuantFlow(TestCase):
1724
def _test_obs_helper(self, obs1, obs2):
18-
example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)]
25+
example_inputs = [
26+
torch.randn(10, 2048),
27+
torch.randn(10, 2048),
28+
torch.randn(10, 2048),
29+
]
1930
for example_input in example_inputs:
2031
obs1(example_input)
2132
obs2(example_input)
@@ -26,13 +37,29 @@ def _test_obs_helper(self, obs1, obs2):
2637
self.assertTrue(torch.allclose(zero_point1, zero_point2))
2738

2839
def test_min_max_per_tensor_affine(self):
29-
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
40+
obs = AffineQuantizedMinMaxObserver(
41+
MappingType.ASYMMETRIC,
42+
torch.uint8,
43+
granularity_type=PerTensor(),
44+
eps=torch.finfo(torch.float32).eps,
45+
scale_dtype=torch.float,
46+
zero_point_dtype=torch.int,
47+
)
3048
ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine)
3149
self._test_obs_helper(obs, ref_obs)
3250

3351
def test_min_max_per_channel_affine(self):
34-
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
35-
ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine)
52+
obs = AffineQuantizedMinMaxObserver(
53+
MappingType.ASYMMETRIC,
54+
torch.uint8,
55+
granularity_type=PerAxis(axis=0),
56+
eps=torch.finfo(torch.float32).eps,
57+
scale_dtype=torch.float,
58+
zero_point_dtype=torch.int,
59+
)
60+
ref_obs = PerChannelMinMaxObserver(
61+
dtype=torch.uint8, qscheme=torch.per_channel_affine
62+
)
3663
self._test_obs_helper(obs, ref_obs)
3764

3865
def test_block_size_calc_success(self):
@@ -109,5 +136,82 @@ def test_block_size_row_errors(self):
109136
obs(example_input)
110137

111138

139+
class TestLinearObserver(TestCase):
140+
@common_utils.parametrize("observe_weight", [True, False])
141+
def test_linear_observer_tensor(self, observe_weight: bool):
142+
# Create a simple linear layer
143+
in_features, out_features = 10, 5
144+
linear = nn.Linear(in_features, out_features)
145+
146+
# Create observers
147+
input_observer = AffineQuantizedMinMaxObserver(
148+
MappingType.SYMMETRIC,
149+
torch.float8_e4m3fn,
150+
granularity_type=PerTensor(),
151+
eps=torch.finfo(torch.float32).eps,
152+
scale_dtype=torch.float,
153+
zero_point_dtype=torch.int,
154+
zero_point_domain=None,
155+
)
156+
if observe_weight:
157+
weight_observer = AffineQuantizedMinMaxObserver(
158+
MappingType.SYMMETRIC,
159+
torch.float8_e4m3fn,
160+
granularity_type=PerTensor(),
161+
eps=torch.finfo(torch.float32).eps,
162+
scale_dtype=torch.float,
163+
zero_point_dtype=torch.int,
164+
zero_point_domain=None,
165+
)
166+
else:
167+
weight_observer = None
168+
169+
# Wrap the weight with LinearObserverTensor
170+
insert_observers_(linear, input_observer, weight_observer)
171+
172+
# Create some example inputs
173+
example_inputs = [torch.randn(5, in_features) for _ in range(3)]
174+
max_val = 42.1234
175+
min_val = -39.760
176+
big_tensor = torch.full((6, in_features), max_val)
177+
small_tensor = torch.full((40, in_features), min_val)
178+
example_inputs.extend([big_tensor, small_tensor])
179+
180+
# Run forward passes
181+
for example_input in example_inputs:
182+
_ = linear(example_input)
183+
184+
input_observer = linear.weight.input_observer
185+
186+
# Check that the observers have recorded statistics
187+
assert input_observer.min_val == min_val
188+
assert input_observer.max_val == max_val
189+
190+
# Calculate qparams and ensure they're not None
191+
input_scale, input_zero_point = input_observer.calculate_qparams()
192+
193+
max_fp8 = torch.finfo(torch.float8_e4m3fn).max
194+
self.assertEqual(
195+
input_scale.item(),
196+
max_val / max_fp8,
197+
)
198+
self.assertIsNotNone(input_zero_point)
199+
200+
if observe_weight:
201+
weight_observer = linear.weight.weight_observer
202+
weight_scale, weight_zero_point = weight_observer.calculate_qparams()
203+
torch.testing.assert_close(
204+
weight_scale,
205+
torch.max(linear.weight.original_weight_tensor) / max_fp8,
206+
atol=5e-5,
207+
rtol=0.0,
208+
)
209+
self.assertIsNotNone(weight_zero_point)
210+
else:
211+
self.assertIsNone(linear.weight.weight_observer)
212+
213+
214+
common_utils.instantiate_parametrized_tests(TestLinearObserver)
215+
112216
if __name__ == "__main__":
113217
unittest.main()
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import torch
2+
from typing import Callable, Optional, Dict
3+
from torch.utils._python_dispatch import return_and_correct_aliasing
4+
from torchao.utils import (
5+
TorchAOBaseTensor,
6+
TORCH_VERSION_AT_LEAST_2_5,
7+
)
8+
9+
from torchao.quantization.observer import AffineQuantizedObserverBase
10+
11+
__all__ = [
12+
"LinearActivationWeightObservedTensor",
13+
]
14+
15+
aten = torch.ops.aten
16+
Tensor = torch.Tensor
17+
18+
19+
class LinearActivationWeightObservedTensor(TorchAOBaseTensor):
20+
"""
21+
This subclass of Tensor is used in conjuction with a static calibration flow.
22+
The flow is broken up into 3 parts;
23+
1. Insert the LinearActivationWeightObservedTensor subclass into the model's nn.Linear layers
24+
2. Run the model with a calibration dataset, the observer will record the min/max of the input and weight
25+
3. quantize_ the model to static using the statistics recorded by the observer
26+
27+
This subclass wraps the original weight tensor on the nn.Linear layer. When forward is called, the observer
28+
will first calculat statistics on BOTH the input and weight, and then run the linear op.
29+
"""
30+
31+
original_weight_tensor: torch.Tensor
32+
input_observer: Optional[AffineQuantizedObserverBase]
33+
weight_observer: Optional[AffineQuantizedObserverBase]
34+
35+
def __new__(
36+
cls,
37+
original_weight_tensor: torch.Tensor,
38+
input_observer: Optional[AffineQuantizedObserverBase] = None,
39+
weight_observer: Optional[AffineQuantizedObserverBase] = None,
40+
):
41+
kwargs = {}
42+
dtype = original_weight_tensor.dtype
43+
kwargs["dtype"] = dtype
44+
kwargs["requires_grad"] = False
45+
kwargs["device"] = original_weight_tensor.device
46+
shape = original_weight_tensor.shape
47+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
48+
49+
def __init__(
50+
self,
51+
original_weight_tensor: torch.Tensor,
52+
input_observer: Optional[AffineQuantizedObserverBase] = None,
53+
weight_observer: Optional[AffineQuantizedObserverBase] = None,
54+
):
55+
self.original_weight_tensor = original_weight_tensor
56+
self.input_observer = input_observer
57+
self.weight_observer = weight_observer
58+
59+
def __repr__(self):
60+
return (
61+
f"LinearActivationWeightObservedTensor(\n"
62+
f"original_weight={self.original_weight_tensor}\n"
63+
f"input_observer={self.input_observer.__class__.__name__ if self.input_observer else None}\n"
64+
f"weight_observer={self.weight_observer.__class__.__name__ if self.weight_observer else None}\n)"
65+
)
66+
67+
def __tensor_flatten__(self):
68+
return ["original_weight_tensor"], [self.input_observer, self.weight_observer]
69+
70+
@classmethod
71+
def __tensor_unflatten__(
72+
cls,
73+
tensor_data_dict: Dict[str, Tensor],
74+
tensor_attributes,
75+
outer_size,
76+
outer_stride,
77+
):
78+
original_weight_tensor = tensor_data_dict["original_weight_tensor"]
79+
(input_observer, weight_observer) = tensor_attributes
80+
return cls(original_weight_tensor, input_observer, weight_observer)
81+
82+
@classmethod
83+
def from_float(
84+
cls,
85+
original_weight_tensor: Tensor,
86+
input_observer: Optional[AffineQuantizedObserverBase] = None,
87+
weight_observer: Optional[AffineQuantizedObserverBase] = None,
88+
):
89+
return cls(original_weight_tensor, input_observer, weight_observer)
90+
91+
def _apply_fn_to_data(self, fn: Callable):
92+
"""Applies a fn to the tensor component of the LinearActivationWeightObservedTensor"""
93+
return self.__class__(
94+
fn(self.original_weight_tensor),
95+
self.input_observer,
96+
self.weight_observer,
97+
)
98+
99+
def to(self, *args, **kwargs):
100+
kwargs = self._get_to_kwargs(*args, **kwargs)
101+
return self._apply_fn_to_data(lambda x: x.to(**kwargs))
102+
103+
104+
implements = LinearActivationWeightObservedTensor.implements
105+
106+
107+
@implements(torch.nn.functional.linear)
108+
def _(func, types, args, kwargs):
109+
input_tensor, weight_tensor, bias = (
110+
args[0],
111+
args[1],
112+
args[2] if len(args) > 2 else None,
113+
)
114+
if weight_tensor.input_observer is not None:
115+
input_tensor = weight_tensor.input_observer(input_tensor)
116+
if weight_tensor.weight_observer is not None:
117+
weight_tensor = weight_tensor.weight_observer(
118+
weight_tensor.original_weight_tensor
119+
)
120+
else:
121+
weight_tensor = weight_tensor.original_weight_tensor
122+
123+
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
124+
125+
126+
@implements(aten.detach.default)
127+
def _(func, types, args, kwargs):
128+
return return_and_correct_aliasing(
129+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
130+
)
131+
132+
133+
@implements(aten.clone.default)
134+
def _(func, types, args, kwargs):
135+
return return_and_correct_aliasing(
136+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
137+
)
138+
139+
140+
@implements(aten._to_copy.default)
141+
def _(func, types, args, kwargs):
142+
return return_and_correct_aliasing(
143+
func,
144+
args,
145+
kwargs,
146+
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
147+
)
148+
149+
150+
if TORCH_VERSION_AT_LEAST_2_5:
151+
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
152+
torch.serialization.add_safe_globals([LinearActivationWeightObservedTensor])

torchao/quantization/observer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88

99
from abc import ABCMeta, abstractmethod
1010
from dataclasses import dataclass
11-
from typing import Callable, List, Tuple, Optional, Any
11+
from typing import Tuple, Optional, Any
1212
from functools import partial
1313
import logging
14+
1415
logger = logging.getLogger(__name__)
1516

1617

@@ -52,6 +53,7 @@ class PerAxis(GranularityType):
5253
"""
5354
axis: int
5455

56+
5557
# borrowed from torch.ao.quantization.observer
5658
class _PartialWrapper:
5759
def __init__(self, p):
@@ -66,6 +68,7 @@ def __repr__(self):
6668
def with_args(self, *args, **kwargs):
6769
return _with_args(self, *args, **kwargs)
6870

71+
6972
def _with_args(cls_or_self, *args, **kwargs):
7073
r"""Wrapper that allows creation of class factories.
7174
@@ -103,8 +106,10 @@ def get_block_size(
103106
return tuple(block_size)
104107
raise ValueError(f"Unsupported GranularityType: {granularity_type}")
105108

109+
106110
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
107111

112+
108113
class AffineQuantizedObserverBase(ABC, torch.nn.Module):
109114
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
110115
@@ -114,9 +119,11 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module):
114119
Current supported granularity type are `PerTensor` and `PerAxis`
115120
other args: please see `:class:torchao.dtypes.AffineQuantizedTensor`
116121
"""
122+
117123
with_args = classmethod(_with_args)
118124

119-
def __init__(self,
125+
def __init__(
126+
self,
120127
mapping_type: MappingType,
121128
target_dtype: torch.dtype,
122129
granularity_type: GranularityType,
@@ -126,7 +133,7 @@ def __init__(self,
126133
scale_dtype: Optional[torch.dtype] = None,
127134
zero_point_dtype: Optional[torch.dtype] = None,
128135
preserve_zero: bool = True,
129-
zero_point_domain = ZeroPointDomain.INT,
136+
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
130137
):
131138
super().__init__()
132139
assert granularity_type is not None, "granularity_type is None"
@@ -144,7 +151,7 @@ def __init__(self,
144151

145152
@abstractmethod
146153
def forward(self, input: torch.Tensor) -> torch.Tensor:
147-
""" forward function should take the input tensor
154+
"""forward function should take the input tensor
148155
and updates internal stats and return the original input Tensor
149156
"""
150157
pass
@@ -156,6 +163,7 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
156163
"""
157164
pass
158165

166+
159167
class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase):
160168
def forward(self, input: torch.Tensor):
161169
if input.numel() == 0:
@@ -200,5 +208,5 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
200208
self.scale_dtype,
201209
self.zero_point_dtype,
202210
self.preserve_zero,
203-
self.zero_point_domain
211+
self.zero_point_domain,
204212
)

0 commit comments

Comments
 (0)