Skip to content

Commit 191383e

Browse files
authored
Migrate static quant ipex backend to 3.x API (#1596)
Signed-off-by: Cheng, Zixuan <[email protected]> Signed-off-by: chensuyue <[email protected]> Signed-off-by: yiliu30 <[email protected]>
1 parent 07f940c commit 191383e

File tree

14 files changed

+1015
-36
lines changed

14 files changed

+1015
-36
lines changed

.azure-pipelines/scripts/codeScan/pylint/pylint.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ apt-get install -y --no-install-recommends --fix-missing \
2222
pip install -r /neural-compressor/requirements.txt
2323
pip install cmake
2424

25-
pip install torch==1.12.0 \
25+
pip install torch \
2626
horovod \
2727
google \
2828
autograd \

neural_compressor/adaptor/pytorch.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -397,12 +397,12 @@ def _cfgs_to_fx_cfgs(op_cfgs, observer_type="post_training_static_quant"):
397397
for key, value in op_cfgs.items():
398398
if key == "default_qconfig":
399399
if version.release >= Version("1.13.0").release: # pragma: no cover
400-
fx_op_cfgs.set_global(value)
400+
fx_op_cfgs.set_global(value) # pylint: disable=E1101
401401
else:
402402
fx_op_cfgs[""] = value
403403
continue
404404
if version.release >= Version("1.13.0").release: # pragma: no cover
405-
fx_op_cfgs.set_module_name(key, value)
405+
fx_op_cfgs.set_module_name(key, value) # pylint: disable=E1101
406406
else:
407407
op_tuple = (key, value)
408408
op_tuple_cfg_list.append(op_tuple)
@@ -413,7 +413,7 @@ def _cfgs_to_fx_cfgs(op_cfgs, observer_type="post_training_static_quant"):
413413
from torch.ao.quantization import get_default_qconfig_mapping
414414

415415
for name, q_config in get_default_qconfig_mapping().to_dict()["object_type"]:
416-
fx_op_cfgs.set_object_type(name, q_config)
416+
fx_op_cfgs.set_object_type(name, q_config) # pylint: disable=E1101
417417

418418
return fx_op_cfgs
419419

@@ -3619,7 +3619,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
36193619
prepare_custom_config=self.prepare_custom_config_dict,
36203620
)
36213621
else:
3622-
q_model._model = prepare_qat_fx(
3622+
q_model._model = prepare_qat_fx( # pylint: disable=E1120,E1123
36233623
q_model._model, self.fx_op_cfgs, prepare_custom_config_dict=self.prepare_custom_config_dict
36243624
)
36253625
else:
@@ -3651,7 +3651,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
36513651
prepare_custom_config=self.prepare_custom_config_dict,
36523652
)
36533653
else:
3654-
q_model._model = prepare_fx(
3654+
q_model._model = prepare_fx( # pylint: disable=E1120,E1123
36553655
q_model._model, self.fx_op_cfgs, prepare_custom_config_dict=self.prepare_custom_config_dict
36563656
)
36573657
else:
@@ -3681,7 +3681,9 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
36813681
# pylint: disable=E1123
36823682
q_model._model = convert_fx(q_model._model, convert_custom_config=self.convert_custom_config_dict)
36833683
else:
3684-
q_model._model = convert_fx(q_model._model, convert_custom_config_dict=self.convert_custom_config_dict)
3684+
q_model._model = convert_fx( # pylint: disable=E1123
3685+
q_model._model, convert_custom_config_dict=self.convert_custom_config_dict
3686+
)
36853687
torch_utils.util.append_attr(q_model._model, tmp_model)
36863688
del tmp_model
36873689
gc.collect()
@@ -3830,7 +3832,7 @@ def _pre_hook_for_qat(self, dataloader=None):
38303832
),
38313833
)
38323834
else:
3833-
self.model._model = prepare_qat_fx(
3835+
self.model._model = prepare_qat_fx( # pylint: disable=E1120,E1123
38343836
self.model._model,
38353837
fx_op_cfgs,
38363838
prepare_custom_config_dict=(
@@ -3877,7 +3879,7 @@ def _post_hook_for_qat(self):
38773879
),
38783880
)
38793881
else:
3880-
self.model._model = convert_fx(
3882+
self.model._model = convert_fx( # pylint: disable=E1123
38813883
self.model._model,
38823884
convert_custom_config_dict=(
38833885
self.model.kwargs.get("convert_custom_config_dict", None)
@@ -4331,15 +4333,15 @@ def prepare_sub_graph(
43314333
# pragma: no cover
43324334
if is_qat:
43334335
module_pre = (
4334-
prepare_qat_fx(tmp_module, fx_sub_op_cfgs)
4336+
prepare_qat_fx(tmp_module, fx_sub_op_cfgs) # pylint: disable=E1120
43354337
if version <= Version("1.12.1")
43364338
else prepare_qat_fx(tmp_module, fx_sub_op_cfgs, example_inputs=example_inputs)
43374339
)
43384340
# pylint: disable=E1123
43394341
# pragma: no cover
43404342
else:
43414343
module_pre = (
4342-
prepare_fx(tmp_module, fx_sub_op_cfgs)
4344+
prepare_fx(tmp_module, fx_sub_op_cfgs) # pylint: disable=E1120
43434345
if version <= Version("1.12.1")
43444346
else prepare_fx(tmp_module, fx_sub_op_cfgs, example_inputs=example_inputs)
43454347
)
@@ -4433,7 +4435,9 @@ def fuse_fx_model(self, model, is_qat):
44334435
fused_model = _fuse_fx(graph_module, is_qat, fuse_custom_config=prepare_custom_config_dict)
44344436
elif self.version.release >= Version("1.11.0").release: # pragma: no cover
44354437
# pylint: disable=E1124
4436-
fused_model = _fuse_fx(graph_module, is_qat, fuse_custom_config_dict=prepare_custom_config_dict)
4438+
fused_model = _fuse_fx( # pylint: disable=E1123
4439+
graph_module, is_qat, fuse_custom_config_dict=prepare_custom_config_dict
4440+
)
44374441
else:
44384442
fused_model = _fuse_fx(graph_module, prepare_custom_config_dict)
44394443
except:

neural_compressor/adaptor/torch_utils/util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,7 @@ def output_hook(self, input, output):
821821
if adaptor.version.release >= Version("1.13.0").release: # pragma: no cover
822822
tmp_model = prepare_fx(tmp_model, fx_op_cfgs, example_inp)
823823
else:
824-
tmp_model = prepare_fx(
824+
tmp_model = prepare_fx( # pylint: disable=E1120
825825
tmp_model,
826826
fx_op_cfgs,
827827
)
@@ -877,7 +877,7 @@ def output_hook(self, input, output):
877877
if adaptor.version.release >= Version("1.13.0").release: # pragma: no cover
878878
tmp_model = prepare_fx(tmp_model, fx_op_cfgs, example_inp)
879879
else:
880-
tmp_model = prepare_fx(
880+
tmp_model = prepare_fx( # pylint: disable=E1120
881881
tmp_model,
882882
fx_op_cfgs,
883883
)
@@ -958,7 +958,7 @@ def output_hook(self, input, output):
958958
if adaptor.version.release >= Version("1.13.0").release: # pragma: no cover
959959
tmp_model = prepare_fx(tmp_model, fx_op_cfgs, example_inp)
960960
else:
961-
tmp_model = prepare_fx(
961+
tmp_model = prepare_fx( # pylint: disable=E1120
962962
tmp_model,
963963
fx_op_cfgs,
964964
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from .utility import *
17+
from .static_quant import static_quantize
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2024 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import json
19+
20+
from neural_compressor.torch.utils import get_ipex_version
21+
22+
try:
23+
import intel_extension_for_pytorch as ipex
24+
except:
25+
assert False, "Please install IPEX for static quantization."
26+
27+
import torch
28+
from packaging.version import Version
29+
30+
from .utility import (
31+
cfg_to_qconfig,
32+
dump_model_op_stats,
33+
get_quantizable_ops_recursively,
34+
ipex_config_path,
35+
simple_inference,
36+
)
37+
38+
ipex_ver = get_ipex_version()
39+
40+
41+
def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
42+
"""Execute the quantize process on the specified model.
43+
44+
Args:
45+
model: a float model to be quantized.
46+
tune_cfg: quantization config for ops.
47+
run_fn: a calibration function for calibrating the model.
48+
example_inputs: used to trace torch model.
49+
inplace: whether to carry out model transformations in-place.
50+
51+
Returns:
52+
A quantized model.
53+
"""
54+
model.eval()
55+
56+
if ipex_ver.release >= Version("1.12.0").release:
57+
# Check save_qconf_summary part is a workaround for IPEX bug.
58+
# Sometimes the prepared model from get_op_capablitiy loss this attribute
59+
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
60+
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
61+
62+
if ipex_ver.release >= Version("2.1").release:
63+
static_qconfig = ipex.quantization.default_static_qconfig_mapping
64+
else:
65+
static_qconfig = QConfig(
66+
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
67+
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
68+
)
69+
if isinstance(example_inputs, dict):
70+
model = ipex.quantization.prepare(
71+
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
72+
)
73+
else:
74+
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)
75+
76+
model.load_qconf_summary(qconf_summary=ipex_config_path)
77+
run_fn(model)
78+
model.save_qconf_summary(qconf_summary=ipex_config_path)
79+
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)
80+
81+
else: # pragma: no cover
82+
# for IPEX version < 1.12
83+
_, cfgs, default_cfgs, fuse_ops = get_quantizable_ops_recursively(model, example_inputs)
84+
qscheme = cfg_to_qconfig(tune_cfg, cfgs, default_cfgs, fuse_ops)
85+
ipex_conf = ipex.quantization.QuantConf(
86+
configure_file=ipex_config_path, qscheme=qscheme
87+
) # pylint: disable=E1101
88+
run_fn(model)
89+
ipex_conf.save(ipex_config_path)
90+
ipex_conf = ipex.quantization.QuantConf(ipex_config_path) # pylint: disable=E1101
91+
model = ipex.quantization.convert(model, ipex_conf, example_inputs, inplace=True) # pylint: disable=E1121
92+
93+
with open(ipex_config_path, "r") as f:
94+
model.tune_cfg = json.load(f)
95+
model.ipex_config_path = ipex_config_path
96+
if ipex_ver.release >= Version("1.12.0").release:
97+
dump_model_op_stats(tune_cfg)
98+
return model
99+
100+
101+
def _ipex_post_quant_process(model, example_inputs, inplace=False):
102+
"""Convert to a jit model.
103+
104+
Args:
105+
model: a prepared model.
106+
example_inputs: used to trace torch model.
107+
inplace: whether to carry out model transformations in-place.
108+
109+
Returns:
110+
A converted jit model.
111+
"""
112+
model = ipex.quantization.convert(model, inplace=inplace)
113+
with torch.no_grad():
114+
try:
115+
if isinstance(example_inputs, dict):
116+
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs)
117+
else:
118+
model = torch.jit.trace(model, example_inputs)
119+
model = torch.jit.freeze(model.eval())
120+
except:
121+
if isinstance(example_inputs, dict):
122+
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False)
123+
else:
124+
model = torch.jit.trace(model, example_inputs, strict=False)
125+
model = torch.jit.freeze(model.eval())
126+
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
127+
# At the 2nd run, the llga pass will be triggered and the model is turned into
128+
# an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
129+
simple_inference(model, example_inputs, iterations=2)
130+
return model

0 commit comments

Comments
 (0)