|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +# |
| 4 | +# Copyright (c) 2024 Intel Corporation |
| 5 | +# |
| 6 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 7 | +# you may not use this file except in compliance with the License. |
| 8 | +# You may obtain a copy of the License at |
| 9 | +# |
| 10 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | +# |
| 12 | +# Unless required by applicable law or agreed to in writing, software |
| 13 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 14 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | +# See the License for the specific language governing permissions and |
| 16 | +# limitations under the License. |
| 17 | + |
| 18 | +import json |
| 19 | + |
| 20 | +import torch |
| 21 | + |
| 22 | +try: |
| 23 | + import intel_extension_for_pytorch as ipex |
| 24 | +except: |
| 25 | + assert False, "Please install IPEX for smooth quantization." |
| 26 | + |
| 27 | +from packaging.version import Version |
| 28 | + |
| 29 | +from .utility import ( |
| 30 | + TorchSmoothQuant, |
| 31 | + cfg_to_qconfig, |
| 32 | + dump_model_op_stats, |
| 33 | + get_ipex_version, |
| 34 | + get_quantizable_ops_recursively, |
| 35 | + ipex_config_path, |
| 36 | + logger, |
| 37 | + simple_inference, |
| 38 | + update_sq_scale, |
| 39 | +) |
| 40 | + |
| 41 | +ipex_ver = get_ipex_version() |
| 42 | + |
| 43 | + |
| 44 | +def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): |
| 45 | + """Execute the quantize process on the specified model. |
| 46 | +
|
| 47 | + Args: |
| 48 | + model: a float model to be quantized. |
| 49 | + tune_cfg: quantization config for ops. |
| 50 | + run_fn: a calibration function for calibrating the model. |
| 51 | + example_inputs: used to trace torch model. |
| 52 | + inplace: whether to carry out model transformations in-place. |
| 53 | +
|
| 54 | + Returns: |
| 55 | + A quantized model. |
| 56 | + """ |
| 57 | + assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant." |
| 58 | + |
| 59 | + _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(model, example_inputs) |
| 60 | + |
| 61 | + # check smoothquant folding value |
| 62 | + recipe_cfgs = tune_cfg.get("recipe_cfgs", None) |
| 63 | + if "smooth_quant_args" in recipe_cfgs and "folding" in recipe_cfgs["smooth_quant_args"]: |
| 64 | + if recipe_cfgs["smooth_quant_args"]["folding"] is None: |
| 65 | + if ipex_ver.release < Version("2.1").release: |
| 66 | + folding = True |
| 67 | + else: |
| 68 | + folding = False |
| 69 | + else: |
| 70 | + folding = recipe_cfgs["smooth_quant_args"]["folding"] |
| 71 | + |
| 72 | + # Note: we should make sure smoothquant is only executed once with inplacing fp32 model. |
| 73 | + if hasattr(model, "_smoothquant_optimized") and model._smoothquant_optimized: |
| 74 | + logger.info("The model is already optimized by SmoothQuant algorithm, skip it.") |
| 75 | + return model |
| 76 | + |
| 77 | + sq = TorchSmoothQuant(model, dataloader=None, example_inputs=example_inputs, q_func=run_fn, record_max_info=True) |
| 78 | + model = sq.transform( |
| 79 | + alpha=recipe_cfgs["smooth_quant_args"]["alpha"], |
| 80 | + folding=folding, |
| 81 | + auto_alpha_args=recipe_cfgs["smooth_quant_args"]["auto_alpha_args"], |
| 82 | + scale_sharing=recipe_cfgs["smooth_quant_args"]["scale_sharing"], |
| 83 | + ) |
| 84 | + |
| 85 | + # Update model parameter when smoothquant folding = False |
| 86 | + if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and not folding: |
| 87 | + return qdq_quantize( |
| 88 | + model, tune_cfg, run_fn, example_inputs, inplace, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, sq |
| 89 | + ) |
| 90 | + |
| 91 | + # Update model parameter when smoothquant folding = True |
| 92 | + if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and folding: |
| 93 | + _apply_pre_optimization(model, tune_cfg, sq) |
| 94 | + model.eval() |
| 95 | + |
| 96 | + # Check save_qconf_summary part is a workaround for IPEX bug. |
| 97 | + # Sometimes the prepared model from get_op_capablitiy loss this attribute |
| 98 | + if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): |
| 99 | + static_qconfig = ipex.quantization.default_static_qconfig_mapping |
| 100 | + if isinstance(example_inputs, dict): |
| 101 | + model = ipex.quantization.prepare( |
| 102 | + model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace |
| 103 | + ) |
| 104 | + else: |
| 105 | + model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace) |
| 106 | + |
| 107 | + model.load_qconf_summary(qconf_summary=ipex_config_path) |
| 108 | + run_fn(model) |
| 109 | + model.save_qconf_summary(qconf_summary=ipex_config_path) |
| 110 | + model = _ipex_post_quant_process(model, example_inputs, inplace=inplace) |
| 111 | + |
| 112 | + # Recover model parameter when smoothquant folding = True |
| 113 | + if ( |
| 114 | + recipe_cfgs |
| 115 | + and recipe_cfgs.get("smooth_quant", False) |
| 116 | + and recipe_cfgs["smooth_quant_args"]["folding"] |
| 117 | + and not inplace |
| 118 | + ): # pragma: no cover |
| 119 | + _apply_pre_optimization(model, tune_cfg, sq, recover=True) |
| 120 | + |
| 121 | + with open(ipex_config_path, "r") as f: |
| 122 | + model.tune_cfg = json.load(f) |
| 123 | + model.ipex_config_path = ipex_config_path |
| 124 | + dump_model_op_stats(tune_cfg) |
| 125 | + return model |
| 126 | + |
| 127 | + |
| 128 | +def qdq_quantize( |
| 129 | + model, tune_cfg, run_fn, example_inputs, inplace, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, sq |
| 130 | +): |
| 131 | + smoothquant_scale_info = sq.sq_scale_info |
| 132 | + sq_minmax_init = True if tune_cfg.get("act_algo", "kl") == "minmax" else False |
| 133 | + |
| 134 | + # Check save_qconf_summary part is a workaround for IPEX bug. |
| 135 | + # Sometimes the prepared model from get_op_capablitiy loss this attribute |
| 136 | + if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): |
| 137 | + from torch.ao.quantization.observer import MinMaxObserver |
| 138 | + |
| 139 | + if ipex_ver.release >= Version("2.1.1").release: |
| 140 | + static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver) |
| 141 | + else: |
| 142 | + if sq_minmax_init: |
| 143 | + static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping( |
| 144 | + alpha=0.5, act_observer=MinMaxObserver() |
| 145 | + ) |
| 146 | + logger.warning( |
| 147 | + "The int8 model accuracy will be close to 0 with MinMaxobserver, " |
| 148 | + + "the suggested IPEX version is higher or equal than 2.1.100+cpu." |
| 149 | + ) |
| 150 | + else: |
| 151 | + static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5) |
| 152 | + if isinstance(example_inputs, dict): |
| 153 | + model = ipex.quantization.prepare( |
| 154 | + model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace |
| 155 | + ) |
| 156 | + else: |
| 157 | + model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace) |
| 158 | + |
| 159 | + # The IPEX SmoothQuant observer can only use save/load_qconf_summary once. |
| 160 | + # The save_qconf_summary API will freeze the scale used in model and calibration won't work anymore. |
| 161 | + # The load_qconf_summary will overwrite the scales used in model but only work in the first call. |
| 162 | + # Here, we use INC collected scale for Linear and set normal observer instead of SQObserver \ |
| 163 | + # to make sure calibration works for other ops, like add, bmm. |
| 164 | + cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, smooth_quant=True) |
| 165 | + update_sq_scale(ipex_config_path, smoothquant_scale_info) |
| 166 | + model.load_qconf_summary(qconf_summary=ipex_config_path) |
| 167 | + # real calibration for other operators |
| 168 | + try: |
| 169 | + # IPEX may raise an error on the second iteration. |
| 170 | + # OverflowError: cannot convert float infinity to integer |
| 171 | + run_fn(model) |
| 172 | + except: |
| 173 | + logger.warning( |
| 174 | + "The calibration failed when calibrating with ipex, " |
| 175 | + + "using scale info from SmoothQuant for Linear and " |
| 176 | + + "one iter calibration for other ops." |
| 177 | + ) |
| 178 | + |
| 179 | + if ipex_ver.release > Version("2.1.0").release: |
| 180 | + update_sq_scale(ipex_config_path, smoothquant_scale_info) |
| 181 | + model.load_qconf_summary(qconf_summary=ipex_config_path) |
| 182 | + _ipex_post_quant_process(model, example_inputs, inplace=inplace) |
| 183 | + |
| 184 | + with open(ipex_config_path, "r") as f: |
| 185 | + model.tune_cfg = json.load(f) |
| 186 | + model.ipex_config_path = ipex_config_path |
| 187 | + dump_model_op_stats(tune_cfg) |
| 188 | + return model |
| 189 | + |
| 190 | + |
| 191 | +def _apply_pre_optimization(model, tune_cfg, sq, recover=False): |
| 192 | + sq_max_info = {} |
| 193 | + if sq.record_max_info: |
| 194 | + sq_max_info = sq.max_value_info |
| 195 | + if sq_max_info: |
| 196 | + tsq = TorchSmoothQuant(model, None) |
| 197 | + alpha = tune_cfg["recipe_cfgs"]["smooth_quant_args"]["alpha"] |
| 198 | + for op_name, info in sq_max_info.items(): |
| 199 | + if alpha == "auto": |
| 200 | + alpha = info["alpha"] |
| 201 | + absorb_layer = op_name |
| 202 | + absorbed_layer = info["absorbed_layer"] |
| 203 | + input_minmax = info["input_minmax"] |
| 204 | + weight_max = info["weight_max"] |
| 205 | + if sq.weight_clip: |
| 206 | + weight_max = weight_max.clamp(min=1e-5) |
| 207 | + abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1])) |
| 208 | + input_power = torch.pow(abs_input_max, alpha) |
| 209 | + weight_power = torch.pow(weight_max, 1 - alpha) |
| 210 | + scale = torch.clip(input_power / weight_power, min=1e-5) |
| 211 | + with torch.no_grad(): |
| 212 | + if recover: |
| 213 | + scale = 1.0 / scale |
| 214 | + for layer in absorbed_layer: |
| 215 | + tsq._scale_layer_weight(layer, scale) |
| 216 | + tsq._absorb_scales(absorb_layer, 1.0 / scale) |
| 217 | + logger.debug(f"Current smoothquant scale of {op_name} is {scale}, alpha is {alpha}") |
| 218 | + |
| 219 | + |
| 220 | +def _ipex_post_quant_process(model, example_inputs, inplace=False): |
| 221 | + """Convert to a jit model. |
| 222 | +
|
| 223 | + Args: |
| 224 | + model: a prepared model. |
| 225 | + example_inputs: used to trace torch model. |
| 226 | + inplace: whether to carry out model transformations in-place. |
| 227 | +
|
| 228 | + Returns: |
| 229 | + A converted jit model. |
| 230 | + """ |
| 231 | + model = ipex.quantization.convert(model, inplace=inplace) |
| 232 | + with torch.no_grad(): |
| 233 | + try: |
| 234 | + if isinstance(example_inputs, dict): |
| 235 | + model = torch.jit.trace(model, example_kwarg_inputs=example_inputs) |
| 236 | + else: |
| 237 | + model = torch.jit.trace(model, example_inputs) |
| 238 | + model = torch.jit.freeze(model.eval()) |
| 239 | + except: |
| 240 | + if isinstance(example_inputs, dict): |
| 241 | + model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False) |
| 242 | + else: |
| 243 | + model = torch.jit.trace(model, example_inputs, strict=False) |
| 244 | + model = torch.jit.freeze(model.eval()) |
| 245 | + # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile |
| 246 | + # At the 2nd run, the llga pass will be triggered and the model is turned into |
| 247 | + # an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph |
| 248 | + simple_inference(model, example_inputs, iterations=2) |
| 249 | + return model |
0 commit comments