Skip to content

Commit 1ebf698

Browse files
violetch24violetch24
andauthored
add docstring for static quant and smooth quant (#1936)
* add docstring for static quant and smooth quant Signed-off-by: violetch24 <[email protected]> * format fix Signed-off-by: violetch24 <[email protected]> * update scan path Signed-off-by: violetch24 <[email protected]> * Update utility.py --------- Signed-off-by: violetch24 <[email protected]> Co-authored-by: violetch24 <[email protected]>
1 parent 296c5d4 commit 1ebf698

File tree

10 files changed

+783
-175
lines changed

10 files changed

+783
-175
lines changed

.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
/neural-compressor/neural_compressor/strategy
1616
/neural-compressor/neural_compressor/training.py
1717
/neural-compressor/neural_compressor/utils
18+
/neural-compressor/neural_compressor/torch/algorithms/static_quant
19+
/neural-compressor/neural_compressor/torch/algorithms/smooth_quant
1820
/neural_compressor/torch/algorithms/pt2e_quant
1921
/neural_compressor/torch/export
2022
/neural_compressor/common
21-
/neural_compressor/torch/algorithms/weight_only/hqq
23+
/neural_compressor/torch/algorithms/weight_only/hqq

neural_compressor/torch/algorithms/smooth_quant/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
"""The SmoothQuant-related modules."""
16+
1517

1618
from .utility import *
1719
from .smooth_quant import SmoothQuantQuantizer

neural_compressor/torch/algorithms/smooth_quant/save_load.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""Save and load the quantized model."""
15+
1416

1517
# pylint:disable=import-error
1618
import torch
@@ -32,7 +34,7 @@ def recover_model_from_json(model, json_file_path, example_inputs): # pragma: n
3234
example_inputs (tuple or torch.Tensor or dict): example inputs that will be passed to the ipex function.
3335
3436
Returns:
35-
(object): quantized model
37+
model (object): quantized model
3638
"""
3739
from torch.ao.quantization.observer import MinMaxObserver
3840

neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17+
"""The quantizer using SmoothQuant path."""
18+
1719

1820
import json
1921
import os
@@ -49,6 +51,8 @@
4951

5052

5153
class SmoothQuantQuantizer(Quantizer):
54+
"""SmoothQuantQuantizer Class."""
55+
5256
def __init__(self, quant_config: OrderedDict = {}): # pragma: no cover
5357
"""Init a SmoothQuantQuantizer object.
5458
@@ -61,9 +65,9 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
6165
"""Prepares a given model for quantization.
6266
6367
Args:
64-
model: A float model to be quantized.
65-
example_inputs: Used to trace torch model.
66-
inplace: Whether to carry out model transformations in-place. Defaults to True.
68+
model (torch.nn.Module): raw fp32 model or prepared model.
69+
example_inputs (tensor/tuple/dict): used to trace torch model.
70+
inplace (bool, optional): whether to carry out model transformations in-place. Defaults to True.
6771
6872
Returns:
6973
A prepared model.
@@ -128,9 +132,9 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
128132
"""Converts a prepared model to a quantized model.
129133
130134
Args:
131-
model: The prepared model to be converted.
132-
example_inputs: Used to trace torch model.
133-
inplace: Whether to carry out model transformations in-place. Defaults to True.
135+
model (QuantizationInterceptionModule): the prepared model to be converted.
136+
example_inputs (tensor/tuple/dict): used to trace torch model.
137+
inplace (bool, optional): whether to carry out model transformations in-place. Defaults to True.
134138
135139
Returns:
136140
A quantized model.
@@ -153,14 +157,14 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
153157
return model
154158

155159
def quantize(self, model, tune_cfg, run_fn, example_inputs, inplace=True, *args, **kwargs):
156-
"""Execute the quantize process on the specified model.
160+
"""Executes the quantize process on the specified model.
157161
158162
Args:
159-
model: a float model to be quantized.
160-
tune_cfg: quantization config for ops.
161-
run_fn: a calibration function for calibrating the model.
162-
example_inputs: used to trace torch model.
163-
inplace: whether to carry out model transformations in-place.
163+
model (torch.nn.Module): raw fp32 model or prepared model.
164+
tune_cfg (OrderedDict): quantization config for ops.
165+
run_fn (Callable): a calibration function for calibrating the model.
166+
example_inputs (tensor/tuple/dict): used to trace torch model.
167+
inplace (bool, optional): whether to carry out model transformations in-place. Defaults to True.
164168
165169
Returns:
166170
A quantized model.
@@ -255,6 +259,22 @@ def quantize(self, model, tune_cfg, run_fn, example_inputs, inplace=True, *args,
255259
def qdq_quantize(
256260
model, tune_cfg, run_fn, example_inputs, inplace, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, sq
257261
):
262+
"""Executes the smooth quantize process.
263+
264+
Args:
265+
model (torch.nn.Module): raw fp32 model or prepared model.
266+
tune_cfg (OrderedDict): quantization config for ops.
267+
run_fn (Callable): a calibration function for calibrating the model.
268+
example_inputs (tensor/tuple/dict): used to trace torch model.
269+
inplace (bool): whether to carry out model transformations in-place. Defaults to True.
270+
cfgs (dict): configs loaded from ipex config path.
271+
op_infos_from_cfgs (dict): dict containing configs that have been parsed for each op.
272+
output_tensor_id_op_name (dict): dict containing op names corresponding to 'op_infos_from_cfgs'.
273+
sq (TorchSmoothQuant): TorchSmoothQuant class containing sq infos.
274+
275+
Returns:
276+
A quantized model.
277+
"""
258278
smoothquant_scale_info = sq.sq_scale_info
259279
sq_minmax_init = True if tune_cfg.get("act_algo", "kl") == "minmax" else False
260280

@@ -325,6 +345,14 @@ def qdq_quantize(
325345

326346

327347
def _apply_pre_optimization(model, tune_cfg, sq, recover=False):
348+
"""Retrieves sq info to absorb the scale to the layer at output channel.
349+
350+
Args:
351+
model (QuantizationInterceptionModule): a prepared model.
352+
tune_cfg (OrderedDict): quantization config for ops.
353+
sq (TorchSmoothQuant): TorchSmoothQuant class containing sq infos.
354+
recover (bool, optional): whether to recover the scale. Defaults to False.
355+
"""
328356
sq_max_info = {}
329357
if sq.record_max_info:
330358
sq_max_info = sq.max_value_info
@@ -354,13 +382,13 @@ def _apply_pre_optimization(model, tune_cfg, sq, recover=False):
354382

355383

356384
def _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=False):
357-
"""Convert to a jit model.
385+
"""Converts to a jit model.
358386
359387
Args:
360-
model: a prepared model.
361-
example_inputs: used to trace torch model.
362-
use_bf16: whether to use bf16 for mixed precision.
363-
inplace: whether to carry out model transformations in-place.
388+
model (QuantizationInterceptionModule): a prepared model.
389+
example_inputs (tensor/tuple/dict): used to trace torch model.
390+
use_bf16 (bool): whether to use bf16 for mixed precision.
391+
inplace (bool, optional): whether to carry out model transformations in-place. Defaults to True.
364392
365393
Returns:
366394
A converted jit model.

0 commit comments

Comments
 (0)