Skip to content

Commit a246d87

Browse files
authored
Add static quant to float8 (#787)
* Add static quant * Merged static_quant intx - floatx * Merged static_quant intx - floatx * Add assert for mapping type * Add assert for mapping type * Update intx_static to support floatx call
1 parent 740c6b3 commit a246d87

File tree

6 files changed

+139
-66
lines changed

6 files changed

+139
-66
lines changed

docs/source/api_ref_dtypes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ torchao.dtypes
1414
to_affine_quantized_intx
1515
to_affine_quantized_floatx
1616
to_affine_quantized_intx_static
17+
to_affine_quantized_floatx_static
1718
AffineQuantizedTensor
1819

1920
..

torchao/dtypes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# experimental, will be merged into floatx in the future
99
to_affine_quantized_fpx,
1010
to_affine_quantized_floatx,
11+
to_affine_quantized_floatx_static,
1112
LayoutType,
1213
PlainLayoutType,
1314
SemiSparseLayoutType,
@@ -25,6 +26,7 @@
2526
"to_affine_quantized_intx_static",
2627
"to_affine_quantized_fpx",
2728
"to_affine_quantized_floatx",
29+
"to_affine_quantized_floatx_static",
2830
"LayoutType",
2931
"PlainLayoutType",
3032
"SemiSparseLayoutType",

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,14 +292,17 @@ def from_hp_to_intx_static(
292292
cls,
293293
input_float: torch.Tensor,
294294
scale: torch.Tensor,
295-
zero_point: torch.Tensor,
295+
zero_point: Optional[torch.Tensor],
296296
block_size: Tuple[int, ...],
297297
target_dtype: torch.dtype,
298298
quant_min: Optional[int] = None,
299299
quant_max: Optional[int] = None,
300-
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
300+
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
301301
layout_type: LayoutType = PlainLayoutType(),
302302
):
303+
if target_dtype not in FP8_TYPES:
304+
assert zero_point_domain is not None, "zero_point_domain must be specified for non-fp8 types"
305+
assert zero_point is not None, "zero_point must be specified for non-fp8 types"
303306
original_shape = input_float.shape
304307
input_float = layout_type.pre_process(input_float)
305308

@@ -348,6 +351,31 @@ def from_hp_to_floatx(
348351
else:
349352
raise NotImplementedError(f"Unsupported dtype {target_dtype} for from_hp_to_floatx")
350353

354+
@classmethod
355+
def from_hp_to_floatx_static(
356+
cls,
357+
input_float: torch.Tensor,
358+
scale: torch.Tensor,
359+
block_size: Tuple[int, ...],
360+
target_dtype: torch.dtype,
361+
layout_type: LayoutType,
362+
):
363+
364+
if target_dtype in FP8_TYPES:
365+
return cls.from_hp_to_intx_static(
366+
input_float=input_float,
367+
scale=scale,
368+
zero_point=None,
369+
block_size=block_size,
370+
target_dtype=target_dtype,
371+
quant_min=math.ceil(torch.finfo(target_dtype).min),
372+
quant_max=math.ceil(torch.finfo(target_dtype).max),
373+
zero_point_domain=None,
374+
layout_type=layout_type,
375+
)
376+
else:
377+
raise NotImplementedError(f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static")
378+
351379
@classmethod
352380
def from_hp_to_fpx(
353381
cls,
@@ -1319,6 +1347,7 @@ def _(func, types, args, kwargs):
13191347
to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx
13201348
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
13211349
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
1350+
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
13221351
# experimental will be merged in to floatx
13231352
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx
13241353

torchao/quantization/quant_api.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
PlainLayoutType,
3030
AffineQuantizedTensor,
3131
SemiSparseLayoutType,
32-
to_affine_quantized_floatx,
33-
Float8AQTLayout,
3432
Float8LayoutType
3533
)
3634
from torchao.utils import (

torchao/quantization/quant_primitives.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,8 @@ def _choose_qparams_affine(
696696
"""
697697
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
698698
assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}"
699+
if target_dtype in FP8_TYPES:
700+
assert mapping_type == MappingType.SYMMETRIC.name, f"Only symmetric quantization is supported for FP8 types, got {mapping_type}"
699701

700702
if input is not None:
701703
if scale_dtype is None:

tutorials/calibration_flow/static_quant.py

Lines changed: 103 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
import torch.nn.functional as F
88
from torch import Tensor
9-
from torchao.dtypes import to_affine_quantized_intx_static
9+
from torchao.dtypes import (
10+
to_affine_quantized_intx_static,
11+
to_affine_quantized_floatx_static,
12+
Float8LayoutType,
13+
)
1014
from torchao.quantization.utils import compute_error
1115
from torchao.quantization import quantize_
1216
from torchao.quantization import to_linear_activation_quantized
@@ -18,6 +22,7 @@
1822
)
1923
from torchao.quantization.quant_primitives import (
2024
MappingType,
25+
FP8_TYPES,
2126
)
2227

2328

@@ -51,53 +56,81 @@ def replacement_fn(m):
5156

5257
# converting observed linear module to linear module with quantzied weights (and quantized activations)
5358
# with tensor subclasses
54-
def apply_static_quant(observed_linear):
55-
target_dtype = torch.uint8
56-
57-
# weight quantization
58-
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()
59-
def weight_quant_func(weight):
60-
block_size = (1, weight.shape[1])
61-
return to_affine_quantized_intx_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
62-
linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype)
63-
linear.weight = observed_linear.weight
64-
linear.bias = observed_linear.bias
65-
66-
linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False)
67-
68-
# activation quantization
69-
act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams()
70-
input_quant_func = lambda x: to_affine_quantized_intx_static(x, act_scale, act_zero_point, x.shape, target_dtype)
71-
linear.weight = torch.nn.Parameter(to_linear_activation_quantized(linear.weight, input_quant_func), requires_grad=False)
72-
73-
return linear
74-
59+
def apply_static_quant(target_dtype: torch.dtype):
60+
# target_dtype = torch.uint8
61+
def _apply_static_quant_to_linear(observed_linear):
62+
# weight quantization
63+
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()
64+
def weight_quant_func(weight):
65+
block_size = (1, weight.shape[1])
66+
if target_dtype == torch.uint8:
67+
return to_affine_quantized_intx_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
68+
elif target_dtype == torch.float8_e4m3fn:
69+
return to_affine_quantized_floatx_static(weight, weight_scale, block_size, target_dtype, Float8LayoutType(mm_config=None))
70+
else:
71+
raise ValueError(f"Unsupported target dtype {target_dtype}")
72+
linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype)
73+
linear.weight = observed_linear.weight
74+
linear.bias = observed_linear.bias
75+
76+
linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False)
77+
78+
# activation quantization
79+
act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams()
80+
if target_dtype == torch.uint8:
81+
input_quant_func = lambda x: to_affine_quantized_intx_static(x, act_scale, act_zero_point, x.shape, target_dtype)
82+
elif target_dtype == torch.float8_e4m3fn:
83+
input_quant_func = lambda x: to_affine_quantized_floatx_static(x, act_scale, x.shape, target_dtype, Float8LayoutType(mm_config=None))
84+
else:
85+
raise ValueError(f"Unsupported target dtype {target_dtype}")
86+
linear.weight = torch.nn.Parameter(to_linear_activation_quantized(linear.weight, input_quant_func), requires_grad=False)
87+
88+
return linear
89+
90+
return _apply_static_quant_to_linear
7591

7692
# alternative for converting observed linear module to quantized linear module
7793
class QuantizedLinear(torch.nn.Module):
78-
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, weight: torch.Tensor, bias: torch.Tensor):
94+
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, weight: torch.Tensor, bias: torch.Tensor, target_dtype: torch.dtype):
7995
super().__init__()
8096
self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
8197
weight_scale, weight_zero_point = weight_obs.calculate_qparams()
8298
assert weight.dim() == 2
8399
block_size = (1, weight.shape[1])
84-
target_dtype = torch.uint8
85-
self.qweight = to_affine_quantized_intx_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
100+
self.target_dtype = target_dtype
86101
self.bias = bias
102+
if self.target_dtype == torch.uint8:
103+
self.qweight = to_affine_quantized_intx_static(weight, weight_scale, weight_zero_point, block_size, self.target_dtype)
104+
elif self.target_dtype == torch.float8_e4m3fn:
105+
self.qweight = to_affine_quantized_floatx_static(weight, weight_scale, block_size, target_dtype, Float8LayoutType(mm_config=None))
106+
else:
107+
raise ValueError(f"Unsupported target dtype {self.target_dtype}")
87108

88109
def forward(self, input: Tensor):
89110
block_size = input.shape
90-
target_dtype = torch.uint8
91-
qinput = to_affine_quantized_intx_static(input, self.act_scale, self.act_zero_point, block_size, target_dtype)
111+
if self.target_dtype == torch.uint8:
112+
qinput = to_affine_quantized_intx_static(input, self.act_scale, self.act_zero_point, block_size, self.target_dtype)
113+
elif self.target_dtype == torch.float8_e4m3fn:
114+
qinput = to_affine_quantized_floatx_static(input, self.act_scale, block_size, self.target_dtype, Float8LayoutType(mm_config=None))
115+
else:
116+
raise ValueError(f"Unsupported target dtype {self.target_dtype}")
92117
return F.linear(qinput, self.qweight, self.bias)
93118

94119
@classmethod
95-
def from_observed(cls, observed_linear):
96-
quantized_linear = cls(observed_linear.in_features, observed_linear.out_features, observed_linear.act_obs, observed_linear.weight_obs, observed_linear.weight, observed_linear.bias)
120+
def from_observed(cls, observed_linear, target_dtype):
121+
quantized_linear = cls(observed_linear.in_features,
122+
observed_linear.out_features,
123+
observed_linear.act_obs,
124+
observed_linear.weight_obs,
125+
observed_linear.weight,
126+
observed_linear.bias,
127+
target_dtype)
97128
return quantized_linear
98129

99-
def apply_static_quant2(observed_linear):
100-
return QuantizedLinear.from_observed(observed_linear)
130+
def apply_static_quant2(target_dtype: torch.dtype):
131+
def _apply_static_quant2(observed_linear):
132+
return QuantizedLinear.from_observed(observed_linear, target_dtype)
133+
return _apply_static_quant2
101134

102135
class ToyLinearModel(torch.nn.Module):
103136
def __init__(self, m=64, n=32, k=64):
@@ -113,46 +146,54 @@ def forward(self, x):
113146
x = self.linear2(x)
114147
return x
115148

116-
torch.manual_seed(0)
117149

118-
dtype = torch.bfloat16
119-
m = ToyLinearModel().eval().to(dtype).to("cuda")
150+
def test_static_quant(target_dtype: torch.dtype, mapping_type: MappingType):
151+
print(f"Testing {target_dtype} static quantization:")
152+
torch.manual_seed(0)
153+
154+
dtype = torch.bfloat16
155+
m = ToyLinearModel().eval().to(dtype).to("cuda")
156+
157+
m_for_test = copy.deepcopy(m)
158+
159+
m_bf16 = copy.deepcopy(m)
160+
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
161+
print("example inputs shape:", example_inputs[0].shape)
120162

121-
m_for_test = copy.deepcopy(m)
163+
m_bf16 = torch.compile(m_bf16, mode='max-autotune')
122164

123-
m_bf16 = copy.deepcopy(m)
124-
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
125-
print("example inputs shape:", example_inputs[0].shape)
165+
act_obs = AffineQuantizedMinMaxObserver(mapping_type, target_dtype, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.float32)
166+
weight_obs = AffineQuantizedMinMaxObserver(mapping_type, target_dtype, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.float32)
126167

127-
m_bf16 = torch.compile(m_bf16, mode='max-autotune')
168+
before_quant = m(*example_inputs)
128169

129-
act_obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.int32)
130-
weight_obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.int32)
170+
insert_observers_(m, act_obs, weight_obs)
171+
# calibrating / training
172+
for _ in range(10):
173+
m(*example_inputs)
131174

132-
before_quant = m(*example_inputs)
175+
after_obs = m(*example_inputs)
133176

134-
insert_observers_(m, act_obs, weight_obs)
135-
# calibrating / training
136-
for _ in range(10):
137-
m(*example_inputs)
177+
m2 = copy.deepcopy(m)
138178

139-
after_obs = m(*example_inputs)
179+
is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)
140180

141-
m2 = copy.deepcopy(m)
181+
# quantized linear represented as an nn.Linear with modified tensor subclass weights
182+
# for both activation and weight quantization
183+
quantize_(m, apply_static_quant(target_dtype), is_observed_linear)
184+
print("quantized model (applying tensor subclass to weight):", m)
185+
after_quant = m(*example_inputs)
186+
assert compute_error(before_quant, after_quant) > 25
187+
print("test passed")
142188

143-
is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)
189+
# quantized linear as a standalone module
190+
quantize_(m2, apply_static_quant2(target_dtype), is_observed_linear)
191+
print("quantized model (quantized module):", m2)
192+
after_quant = m2(*example_inputs)
193+
assert compute_error(before_quant, after_quant) > 25
194+
print("test passed")
144195

145-
# quantized linear represented as an nn.Linear with modified tensor subclass weights
146-
# for both activation and weight quantization
147-
quantize_(m, apply_static_quant, is_observed_linear)
148-
print("quantized model (applying tensor subclass to weight):", m)
149-
after_quant = m(*example_inputs)
150-
assert compute_error(before_quant, after_quant) > 30
151-
print("test passed")
152196

153-
# quantized linear as a standalone module
154-
quantize_(m2, apply_static_quant2, is_observed_linear)
155-
print("quantized model (quantized module):", m2)
156-
after_quant = m2(*example_inputs)
157-
assert compute_error(before_quant, after_quant) > 30
158-
print("test passed")
197+
if __name__ == "__main__":
198+
test_static_quant(torch.uint8, MappingType.ASYMMETRIC)
199+
test_static_quant(torch.float8_e4m3fn, MappingType.SYMMETRIC)

0 commit comments

Comments
 (0)