Skip to content

Commit 53e6ee6

Browse files
authored
Support xpu for ipex static quant (#1916)
Signed-off-by: violetch24 <[email protected]>
1 parent a1cc618 commit 53e6ee6

File tree

5 files changed

+206
-55
lines changed

5 files changed

+206
-55
lines changed

neural_compressor/torch/algorithms/static_quant/save_load.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,16 @@ def save(model, output_dir="./saved_results"):
3232

3333
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
3434
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
35-
model.ori_save(qmodel_file_path)
36-
with open(qconfig_file_path, "w") as f:
37-
json.dump(model.tune_cfg, f, indent=4)
35+
device = next(model.parameters(), None).device.type if next(model.parameters(), None) else "cpu"
36+
if device == "cpu":
37+
model.ori_save(qmodel_file_path)
38+
with open(qconfig_file_path, "w") as f:
39+
json.dump(model.tune_cfg, f, indent=4)
40+
else: # pragma: no cover
41+
from neural_compressor.common.utils import save_config_mapping
42+
43+
torch.jit.save(model, qmodel_file_path)
44+
save_config_mapping(model.qconfig, qconfig_file_path)
3845

3946
logger.info("Save quantized model to {}.".format(qmodel_file_path))
4047
logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path))

neural_compressor/torch/algorithms/static_quant/static_quant.py

+69-39
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@
3333

3434
from neural_compressor.torch.algorithms import Quantizer
3535
from neural_compressor.torch.utils import logger
36+
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
3637

3738
from .utility import (
3839
CpuInfo,
3940
cfg_to_qconfig,
4041
dump_model_op_stats,
42+
generate_xpu_qconfig,
4143
get_ipex_version,
4244
get_quantizable_ops_recursively,
4345
ipex_config_path,
@@ -56,6 +58,7 @@ def __init__(self, quant_config: OrderedDict = {}):
5658
"""
5759
super().__init__(quant_config)
5860
self.user_cfg = OrderedDict()
61+
self.device = auto_detect_accelerator().current_device()
5962

6063
def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
6164
"""Prepares a given model for quantization.
@@ -70,43 +73,61 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
7073
"""
7174
assert example_inputs is not None, "Please provide example_inputs for static quantization."
7275

73-
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(
74-
model, example_inputs
75-
)
76-
# update json file in ipex_config_path; map ipex op_name to pt op_name
77-
self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
78-
model.eval()
76+
if self.device == "cpu":
77+
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(
78+
model, example_inputs
79+
)
80+
# update json file in ipex_config_path; map ipex op_name to pt op_name
81+
self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
82+
else: # pragma: no cover
83+
model = model.to("xpu")
7984

80-
use_bf16 = self.quant_config.get("use_bf16", None)
85+
model.eval()
8186

8287
# Check save_qconf_summary part is a workaround for IPEX bug.
83-
# Sometimes the prepared model from get_op_capablitiy loss this attribute
84-
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
85-
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
86-
87-
if ipex_ver.release >= Version("2.1").release:
88-
# HistogramObserver will cause a performance issue.
89-
# static_qconfig = ipex.quantization.default_static_qconfig_mapping
90-
qconfig = QConfig(
91-
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
92-
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
93-
)
94-
from torch.ao.quantization import QConfigMapping
95-
96-
static_qconfig = QConfigMapping().set_global(qconfig)
97-
else:
98-
static_qconfig = QConfig(
99-
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
100-
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
101-
)
102-
if isinstance(example_inputs, dict):
103-
model = ipex.quantization.prepare(
104-
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
105-
)
88+
# Sometimes the prepared model from get_op_capablitiy loss this attributes
89+
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover
90+
from torch.ao.quantization import HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver, QConfig
91+
92+
if self.device != "cpu": # pragma: no cover
93+
from torch.quantization.quantize_jit import prepare_jit
94+
95+
with torch.no_grad():
96+
modelJit = torch.jit.trace(model, example_inputs)
97+
qconfig = generate_xpu_qconfig(self.quant_config)
98+
model = prepare_jit(modelJit, qconfig, inplace)
10699
else:
107-
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)
100+
if ipex_ver.release >= Version("2.1").release:
101+
# HistogramObserver will cause a performance issue.
102+
# static_qconfig = ipex.quantization.default_static_qconfig_mapping
103+
qconfig = QConfig(
104+
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
105+
weight=PerChannelMinMaxObserver.with_args(
106+
dtype=torch.qint8, qscheme=torch.per_channel_symmetric
107+
),
108+
)
109+
from torch.ao.quantization import QConfigMapping
110+
111+
static_qconfig = QConfigMapping().set_global(qconfig)
112+
else: # pragma: no cover
113+
static_qconfig = QConfig(
114+
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
115+
weight=PerChannelMinMaxObserver.with_args(
116+
dtype=torch.qint8, qscheme=torch.per_channel_symmetric
117+
),
118+
)
119+
if isinstance(example_inputs, dict):
120+
model = ipex.quantization.prepare(
121+
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
122+
)
123+
else:
124+
model = ipex.quantization.prepare(
125+
model, static_qconfig, example_inputs=example_inputs, inplace=inplace
126+
)
127+
128+
if self.device == "cpu":
129+
model.load_qconf_summary(qconf_summary=ipex_config_path)
108130

109-
model.load_qconf_summary(qconf_summary=ipex_config_path)
110131
return model
111132

112133
def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
@@ -124,18 +145,27 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
124145

125146
from neural_compressor.torch.algorithms.static_quant import save
126147

127-
model.save_qconf_summary(qconf_summary=ipex_config_path)
128-
model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace)
148+
if self.device != "cpu": # pragma: no cover
149+
from torch.quantization.quantize_jit import convert_jit
129150

130-
with open(ipex_config_path, "r") as f:
131-
model.tune_cfg = json.load(f)
132-
model.ipex_config_path = ipex_config_path
151+
model = convert_jit(model, inplace)
152+
simple_inference(model, example_inputs, iterations=2)
153+
model.qconfig = self.quant_config["op"]
154+
dump_model_op_stats(model.qconfig)
155+
else:
156+
model.save_qconf_summary(qconf_summary=ipex_config_path)
157+
model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace)
133158

134-
dump_model_op_stats(self.user_cfg)
159+
with open(ipex_config_path, "r") as f:
160+
model.tune_cfg = json.load(f)
161+
model.ipex_config_path = ipex_config_path
162+
163+
dump_model_op_stats(self.user_cfg)
135164

136-
logger.info("Static quantization done.")
137165
model.ori_save = model.save
138166
model.save = MethodType(save, model)
167+
168+
logger.info("Static quantization done.")
139169
return model
140170

141171

neural_compressor/torch/algorithms/static_quant/utility.py

+41
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,47 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
163163
return cfgs, ori_user_cfg
164164

165165

166+
def generate_xpu_qconfig(tune_cfg): # pragma: no cover
167+
# qconfig observer & config constants for ipex-xpu
168+
from torch.ao.quantization import HistogramObserver, MinMaxObserver, QConfig
169+
170+
act_observer_minmax_asym = MinMaxObserver.with_args(quant_min=0, quant_max=127)
171+
act_observer_minmax_sym = MinMaxObserver.with_args(
172+
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127
173+
)
174+
act_observer_kl_asym = HistogramObserver.with_args(quant_min=0, quant_max=127)
175+
act_observer_kl_sym = HistogramObserver.with_args(
176+
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127
177+
)
178+
# no tuning for granularity due to tuning space
179+
weight_observer_minmax_sym = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
180+
181+
qconfig = {}
182+
user_cfg = copy.deepcopy(tune_cfg["op"])
183+
for _, cfg in user_cfg.items():
184+
act_algo = cfg["activation"]["algorithm"]
185+
act_sym = cfg["activation"]["scheme"]
186+
break
187+
188+
if act_algo == "minmax":
189+
if act_sym == "sym":
190+
activation = act_observer_minmax_sym
191+
else:
192+
activation = act_observer_minmax_asym
193+
else:
194+
if act_sym == "sym":
195+
activation = act_observer_kl_sym
196+
else:
197+
activation = act_observer_kl_asym
198+
199+
qconfig[""] = QConfig(activation=activation, weight=weight_observer_minmax_sym)
200+
201+
for (op_name, op_type), cfg in user_cfg.items():
202+
if cfg["weight"]["dtype"] == "fp32":
203+
qconfig[op_name] = None
204+
return qconfig
205+
206+
166207
def generate_activation_observer(
167208
scheme, algorithm, smooth_quant=False, smooth_quant_enable=False, alpha=0.5
168209
): # pragma: no cover

neural_compressor/torch/quantization/config.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,7 @@ def __init__(
10951095
act_algo: str = "minmax",
10961096
excluded_precisions: list = [],
10971097
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
1098+
model_info: Optional[List[Tuple[str, Callable]]] = None,
10981099
):
10991100
"""Init Static Quant Configs."""
11001101
super().__init__(white_list=white_list)
@@ -1107,6 +1108,7 @@ def __init__(
11071108
self.act_granularity = act_granularity
11081109
self.act_algo = act_algo
11091110
self.excluded_precisions = excluded_precisions
1111+
self.model_info = model_info
11101112
self._post_init()
11111113

11121114
@classmethod
@@ -1124,10 +1126,28 @@ def get_model_info_for_ipex(model: torch.nn.Module, example_inputs) -> List[Tupl
11241126
_, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
11251127
return model_info
11261128

1127-
@staticmethod
1128-
def get_model_info(model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]:
1129+
def get_model_info_for_ipex_xpu(self, model: torch.nn.Module) -> List[Tuple[str, Callable]]: # pragma: no cover
1130+
if self.model_info:
1131+
return self.model_info
1132+
else:
1133+
white_list = torch.quantization.quantization_mappings.get_default_qconfig_propagation_list()
1134+
filter_result = []
1135+
for op_name, module in model.named_modules():
1136+
if type(module) in white_list:
1137+
pair = (op_name, type(module).__name__)
1138+
filter_result.append(pair)
1139+
logger.debug(f"Get model info: {filter_result}")
1140+
self.model_info = filter_result
1141+
return filter_result
1142+
1143+
def get_model_info(self, model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]:
1144+
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
1145+
11291146
if is_ipex_imported():
1130-
return StaticQuantConfig.get_model_info_for_ipex(model, example_inputs)
1147+
if auto_detect_accelerator().current_device() == "cpu":
1148+
return StaticQuantConfig.get_model_info_for_ipex(model, example_inputs)
1149+
else:
1150+
return StaticQuantConfig.get_model_info_for_ipex_xpu(self, model)
11311151

11321152
def to_config_mapping(
11331153
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None

0 commit comments

Comments
 (0)