Skip to content

Commit a5e5f5f

Browse files
authored
Migrate SmoothQuant for IPEX to 3.x API (#1629)
Signed-off-by: Cheng, Zixuan <[email protected]> Signed-off-by: Lu, Yintong <[email protected]>
1 parent 3fa9ab1 commit a5e5f5f

File tree

11 files changed

+3157
-1982
lines changed

11 files changed

+3157
-1982
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from .utility import *
17+
from .smooth_quant import smooth_quantize
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2024 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import json
19+
20+
import torch
21+
22+
try:
23+
import intel_extension_for_pytorch as ipex
24+
except:
25+
assert False, "Please install IPEX for smooth quantization."
26+
27+
from packaging.version import Version
28+
29+
from .utility import (
30+
TorchSmoothQuant,
31+
cfg_to_qconfig,
32+
dump_model_op_stats,
33+
get_ipex_version,
34+
get_quantizable_ops_recursively,
35+
ipex_config_path,
36+
logger,
37+
simple_inference,
38+
update_sq_scale,
39+
)
40+
41+
ipex_ver = get_ipex_version()
42+
43+
44+
def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
45+
"""Execute the quantize process on the specified model.
46+
47+
Args:
48+
model: a float model to be quantized.
49+
tune_cfg: quantization config for ops.
50+
run_fn: a calibration function for calibrating the model.
51+
example_inputs: used to trace torch model.
52+
inplace: whether to carry out model transformations in-place.
53+
54+
Returns:
55+
A quantized model.
56+
"""
57+
assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant."
58+
59+
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(model, example_inputs)
60+
61+
# check smoothquant folding value
62+
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
63+
if "smooth_quant_args" in recipe_cfgs and "folding" in recipe_cfgs["smooth_quant_args"]:
64+
if recipe_cfgs["smooth_quant_args"]["folding"] is None:
65+
if ipex_ver.release < Version("2.1").release:
66+
folding = True
67+
else:
68+
folding = False
69+
else:
70+
folding = recipe_cfgs["smooth_quant_args"]["folding"]
71+
72+
# Note: we should make sure smoothquant is only executed once with inplacing fp32 model.
73+
if hasattr(model, "_smoothquant_optimized") and model._smoothquant_optimized:
74+
logger.info("The model is already optimized by SmoothQuant algorithm, skip it.")
75+
return model
76+
77+
sq = TorchSmoothQuant(model, dataloader=None, example_inputs=example_inputs, q_func=run_fn, record_max_info=True)
78+
model = sq.transform(
79+
alpha=recipe_cfgs["smooth_quant_args"]["alpha"],
80+
folding=folding,
81+
auto_alpha_args=recipe_cfgs["smooth_quant_args"]["auto_alpha_args"],
82+
scale_sharing=recipe_cfgs["smooth_quant_args"]["scale_sharing"],
83+
)
84+
85+
# Update model parameter when smoothquant folding = False
86+
if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and not folding:
87+
return qdq_quantize(
88+
model, tune_cfg, run_fn, example_inputs, inplace, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, sq
89+
)
90+
91+
# Update model parameter when smoothquant folding = True
92+
if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and folding:
93+
_apply_pre_optimization(model, tune_cfg, sq)
94+
model.eval()
95+
96+
# Check save_qconf_summary part is a workaround for IPEX bug.
97+
# Sometimes the prepared model from get_op_capablitiy loss this attribute
98+
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
99+
static_qconfig = ipex.quantization.default_static_qconfig_mapping
100+
if isinstance(example_inputs, dict):
101+
model = ipex.quantization.prepare(
102+
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
103+
)
104+
else:
105+
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)
106+
107+
model.load_qconf_summary(qconf_summary=ipex_config_path)
108+
run_fn(model)
109+
model.save_qconf_summary(qconf_summary=ipex_config_path)
110+
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)
111+
112+
# Recover model parameter when smoothquant folding = True
113+
if (
114+
recipe_cfgs
115+
and recipe_cfgs.get("smooth_quant", False)
116+
and recipe_cfgs["smooth_quant_args"]["folding"]
117+
and not inplace
118+
): # pragma: no cover
119+
_apply_pre_optimization(model, tune_cfg, sq, recover=True)
120+
121+
with open(ipex_config_path, "r") as f:
122+
model.tune_cfg = json.load(f)
123+
model.ipex_config_path = ipex_config_path
124+
dump_model_op_stats(tune_cfg)
125+
return model
126+
127+
128+
def qdq_quantize(
129+
model, tune_cfg, run_fn, example_inputs, inplace, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, sq
130+
):
131+
smoothquant_scale_info = sq.sq_scale_info
132+
sq_minmax_init = True if tune_cfg.get("act_algo", "kl") == "minmax" else False
133+
134+
# Check save_qconf_summary part is a workaround for IPEX bug.
135+
# Sometimes the prepared model from get_op_capablitiy loss this attribute
136+
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
137+
from torch.ao.quantization.observer import MinMaxObserver
138+
139+
if ipex_ver.release >= Version("2.1.1").release:
140+
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver)
141+
else:
142+
if sq_minmax_init:
143+
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
144+
alpha=0.5, act_observer=MinMaxObserver()
145+
)
146+
logger.warning(
147+
"The int8 model accuracy will be close to 0 with MinMaxobserver, "
148+
+ "the suggested IPEX version is higher or equal than 2.1.100+cpu."
149+
)
150+
else:
151+
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
152+
if isinstance(example_inputs, dict):
153+
model = ipex.quantization.prepare(
154+
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
155+
)
156+
else:
157+
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)
158+
159+
# The IPEX SmoothQuant observer can only use save/load_qconf_summary once.
160+
# The save_qconf_summary API will freeze the scale used in model and calibration won't work anymore.
161+
# The load_qconf_summary will overwrite the scales used in model but only work in the first call.
162+
# Here, we use INC collected scale for Linear and set normal observer instead of SQObserver \
163+
# to make sure calibration works for other ops, like add, bmm.
164+
cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, smooth_quant=True)
165+
update_sq_scale(ipex_config_path, smoothquant_scale_info)
166+
model.load_qconf_summary(qconf_summary=ipex_config_path)
167+
# real calibration for other operators
168+
try:
169+
# IPEX may raise an error on the second iteration.
170+
# OverflowError: cannot convert float infinity to integer
171+
run_fn(model)
172+
except:
173+
logger.warning(
174+
"The calibration failed when calibrating with ipex, "
175+
+ "using scale info from SmoothQuant for Linear and "
176+
+ "one iter calibration for other ops."
177+
)
178+
179+
if ipex_ver.release > Version("2.1.0").release:
180+
update_sq_scale(ipex_config_path, smoothquant_scale_info)
181+
model.load_qconf_summary(qconf_summary=ipex_config_path)
182+
_ipex_post_quant_process(model, example_inputs, inplace=inplace)
183+
184+
with open(ipex_config_path, "r") as f:
185+
model.tune_cfg = json.load(f)
186+
model.ipex_config_path = ipex_config_path
187+
dump_model_op_stats(tune_cfg)
188+
return model
189+
190+
191+
def _apply_pre_optimization(model, tune_cfg, sq, recover=False):
192+
sq_max_info = {}
193+
if sq.record_max_info:
194+
sq_max_info = sq.max_value_info
195+
if sq_max_info:
196+
tsq = TorchSmoothQuant(model, None)
197+
alpha = tune_cfg["recipe_cfgs"]["smooth_quant_args"]["alpha"]
198+
for op_name, info in sq_max_info.items():
199+
if alpha == "auto":
200+
alpha = info["alpha"]
201+
absorb_layer = op_name
202+
absorbed_layer = info["absorbed_layer"]
203+
input_minmax = info["input_minmax"]
204+
weight_max = info["weight_max"]
205+
if sq.weight_clip:
206+
weight_max = weight_max.clamp(min=1e-5)
207+
abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1]))
208+
input_power = torch.pow(abs_input_max, alpha)
209+
weight_power = torch.pow(weight_max, 1 - alpha)
210+
scale = torch.clip(input_power / weight_power, min=1e-5)
211+
with torch.no_grad():
212+
if recover:
213+
scale = 1.0 / scale
214+
for layer in absorbed_layer:
215+
tsq._scale_layer_weight(layer, scale)
216+
tsq._absorb_scales(absorb_layer, 1.0 / scale)
217+
logger.debug(f"Current smoothquant scale of {op_name} is {scale}, alpha is {alpha}")
218+
219+
220+
def _ipex_post_quant_process(model, example_inputs, inplace=False):
221+
"""Convert to a jit model.
222+
223+
Args:
224+
model: a prepared model.
225+
example_inputs: used to trace torch model.
226+
inplace: whether to carry out model transformations in-place.
227+
228+
Returns:
229+
A converted jit model.
230+
"""
231+
model = ipex.quantization.convert(model, inplace=inplace)
232+
with torch.no_grad():
233+
try:
234+
if isinstance(example_inputs, dict):
235+
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs)
236+
else:
237+
model = torch.jit.trace(model, example_inputs)
238+
model = torch.jit.freeze(model.eval())
239+
except:
240+
if isinstance(example_inputs, dict):
241+
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False)
242+
else:
243+
model = torch.jit.trace(model, example_inputs, strict=False)
244+
model = torch.jit.freeze(model.eval())
245+
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
246+
# At the 2nd run, the llga pass will be triggered and the model is turned into
247+
# an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
248+
simple_inference(model, example_inputs, iterations=2)
249+
return model

0 commit comments

Comments
 (0)