14
14
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
15
# See the License for the specific language governing permissions and
16
16
# limitations under the License.
17
+ """The quantizer using SmoothQuant path."""
18
+
17
19
18
20
import json
19
21
import os
49
51
50
52
51
53
class SmoothQuantQuantizer (Quantizer ):
54
+ """SmoothQuantQuantizer Class."""
55
+
52
56
def __init__ (self , quant_config : OrderedDict = {}): # pragma: no cover
53
57
"""Init a SmoothQuantQuantizer object.
54
58
@@ -61,9 +65,9 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
61
65
"""Prepares a given model for quantization.
62
66
63
67
Args:
64
- model: A float model to be quantized .
65
- example_inputs: Used to trace torch model.
66
- inplace: Whether to carry out model transformations in-place. Defaults to True.
68
+ model (torch.nn.Module): raw fp32 model or prepared model .
69
+ example_inputs (tensor/tuple/dict): used to trace torch model.
70
+ inplace (bool, optional): whether to carry out model transformations in-place. Defaults to True.
67
71
68
72
Returns:
69
73
A prepared model.
@@ -128,9 +132,9 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
128
132
"""Converts a prepared model to a quantized model.
129
133
130
134
Args:
131
- model: The prepared model to be converted.
132
- example_inputs: Used to trace torch model.
133
- inplace: Whether to carry out model transformations in-place. Defaults to True.
135
+ model (QuantizationInterceptionModule): the prepared model to be converted.
136
+ example_inputs (tensor/tuple/dict): used to trace torch model.
137
+ inplace (bool, optional): whether to carry out model transformations in-place. Defaults to True.
134
138
135
139
Returns:
136
140
A quantized model.
@@ -153,14 +157,14 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
153
157
return model
154
158
155
159
def quantize (self , model , tune_cfg , run_fn , example_inputs , inplace = True , * args , ** kwargs ):
156
- """Execute the quantize process on the specified model.
160
+ """Executes the quantize process on the specified model.
157
161
158
162
Args:
159
- model: a float model to be quantized .
160
- tune_cfg: quantization config for ops.
161
- run_fn: a calibration function for calibrating the model.
162
- example_inputs: used to trace torch model.
163
- inplace: whether to carry out model transformations in-place.
163
+ model (torch.nn.Module): raw fp32 model or prepared model .
164
+ tune_cfg (OrderedDict) : quantization config for ops.
165
+ run_fn (Callable) : a calibration function for calibrating the model.
166
+ example_inputs (tensor/tuple/dict) : used to trace torch model.
167
+ inplace (bool, optional) : whether to carry out model transformations in-place. Defaults to True .
164
168
165
169
Returns:
166
170
A quantized model.
@@ -255,6 +259,22 @@ def quantize(self, model, tune_cfg, run_fn, example_inputs, inplace=True, *args,
255
259
def qdq_quantize (
256
260
model , tune_cfg , run_fn , example_inputs , inplace , cfgs , op_infos_from_cfgs , output_tensor_id_op_name , sq
257
261
):
262
+ """Executes the smooth quantize process.
263
+
264
+ Args:
265
+ model (torch.nn.Module): raw fp32 model or prepared model.
266
+ tune_cfg (OrderedDict): quantization config for ops.
267
+ run_fn (Callable): a calibration function for calibrating the model.
268
+ example_inputs (tensor/tuple/dict): used to trace torch model.
269
+ inplace (bool): whether to carry out model transformations in-place. Defaults to True.
270
+ cfgs (dict): configs loaded from ipex config path.
271
+ op_infos_from_cfgs (dict): dict containing configs that have been parsed for each op.
272
+ output_tensor_id_op_name (dict): dict containing op names corresponding to 'op_infos_from_cfgs'.
273
+ sq (TorchSmoothQuant): TorchSmoothQuant class containing sq infos.
274
+
275
+ Returns:
276
+ A quantized model.
277
+ """
258
278
smoothquant_scale_info = sq .sq_scale_info
259
279
sq_minmax_init = True if tune_cfg .get ("act_algo" , "kl" ) == "minmax" else False
260
280
@@ -325,6 +345,14 @@ def qdq_quantize(
325
345
326
346
327
347
def _apply_pre_optimization (model , tune_cfg , sq , recover = False ):
348
+ """Retrieves sq info to absorb the scale to the layer at output channel.
349
+
350
+ Args:
351
+ model (QuantizationInterceptionModule): a prepared model.
352
+ tune_cfg (OrderedDict): quantization config for ops.
353
+ sq (TorchSmoothQuant): TorchSmoothQuant class containing sq infos.
354
+ recover (bool, optional): whether to recover the scale. Defaults to False.
355
+ """
328
356
sq_max_info = {}
329
357
if sq .record_max_info :
330
358
sq_max_info = sq .max_value_info
@@ -354,13 +382,13 @@ def _apply_pre_optimization(model, tune_cfg, sq, recover=False):
354
382
355
383
356
384
def _ipex_post_quant_process (model , example_inputs , use_bf16 , inplace = False ):
357
- """Convert to a jit model.
385
+ """Converts to a jit model.
358
386
359
387
Args:
360
- model: a prepared model.
361
- example_inputs: used to trace torch model.
362
- use_bf16: whether to use bf16 for mixed precision.
363
- inplace: whether to carry out model transformations in-place.
388
+ model (QuantizationInterceptionModule) : a prepared model.
389
+ example_inputs (tensor/tuple/dict) : used to trace torch model.
390
+ use_bf16 (bool) : whether to use bf16 for mixed precision.
391
+ inplace (bool, optional) : whether to carry out model transformations in-place. Defaults to True .
364
392
365
393
Returns:
366
394
A converted jit model.
0 commit comments