Skip to content

Commit 5767aed

Browse files
authored
add docstring for torch.quantization and torch.utils (#1928)
Signed-off-by: xin3he <[email protected]>
1 parent f909bca commit 5767aed

File tree

5 files changed

+176
-5
lines changed

5 files changed

+176
-5
lines changed

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 143 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,16 @@ def rtn_entry(
6666
*args,
6767
**kwargs,
6868
) -> torch.nn.Module:
69-
"""The main entry to apply rtn quantization."""
69+
"""The main entry to apply rtn quantization.
70+
71+
Args:
72+
model (torch.nn.Module): raw fp32 model or prepared model.
73+
configs_mapping (Dict[Tuple[str, callable], RTNConfig]): per-op configuration.
74+
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.
75+
76+
Returns:
77+
torch.nn.Module: prepared model or quantized model.
78+
"""
7079
from neural_compressor.torch.algorithms.weight_only.rtn import RTNQuantizer
7180
from neural_compressor.torch.algorithms.weight_only.save_load import save
7281

@@ -115,6 +124,16 @@ def gptq_entry(
115124
*args,
116125
**kwargs,
117126
) -> torch.nn.Module:
127+
"""The main entry to apply gptq quantization.
128+
129+
Args:
130+
model (torch.nn.Module): raw fp32 model or prepared model.
131+
configs_mapping (Dict[Tuple[str, callable], GPTQConfig]): per-op configuration.
132+
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.
133+
134+
Returns:
135+
torch.nn.Module: prepared model or quantized model.
136+
"""
118137
logger.info("Quantize model with the GPTQ algorithm.")
119138
from neural_compressor.torch.algorithms.weight_only.gptq import GPTQuantizer
120139
from neural_compressor.torch.algorithms.weight_only.save_load import save
@@ -169,6 +188,16 @@ def static_quant_entry(
169188
*args,
170189
**kwargs,
171190
) -> torch.nn.Module:
191+
"""The main entry to apply static quantization, includes pt2e quantization and ipex quantization.
192+
193+
Args:
194+
model (torch.nn.Module): raw fp32 model or prepared model.
195+
configs_mapping (Dict[Tuple[str, callable], StaticQuantConfig]): per-op configuration.
196+
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.
197+
198+
Returns:
199+
torch.nn.Module: prepared model or quantized model.
200+
"""
172201
if not is_ipex_imported():
173202
return pt2e_static_quant_entry(model, configs_mapping, mode, *args, **kwargs)
174203
logger.info("Quantize model with the static quant algorithm.")
@@ -212,7 +241,23 @@ def static_quant_entry(
212241
###################### PT2E Dynamic Quant Algo Entry ##################################
213242
@register_algo(name=PT2E_DYNAMIC_QUANT)
214243
@torch.no_grad()
215-
def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
244+
def pt2e_dynamic_quant_entry(
245+
model: torch.nn.Module,
246+
configs_mapping,
247+
mode: Mode,
248+
*args,
249+
**kwargs,
250+
) -> torch.nn.Module:
251+
"""The main entry to apply pt2e dynamic quantization.
252+
253+
Args:
254+
model (torch.nn.Module): raw fp32 model or prepared model.
255+
configs_mapping: per-op configuration.
256+
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.
257+
258+
Returns:
259+
torch.nn.Module: prepared model or quantized model.
260+
"""
216261
logger.info("Quantize model with the PT2E static quant algorithm.")
217262
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
218263
from neural_compressor.torch.algorithms.pt2e_quant.save_load import save
@@ -235,7 +280,23 @@ def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode
235280
###################### PT2E Static Quant Algo Entry ##################################
236281
@register_algo(name=PT2E_STATIC_QUANT)
237282
@torch.no_grad()
238-
def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
283+
def pt2e_static_quant_entry(
284+
model: torch.nn.Module,
285+
configs_mapping,
286+
mode: Mode,
287+
*args,
288+
**kwargs,
289+
) -> torch.nn.Module:
290+
"""The main entry to apply pt2e static quantization.
291+
292+
Args:
293+
model (torch.nn.Module): raw fp32 model or prepared model.
294+
configs_mapping: per-op configuration.
295+
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.
296+
297+
Returns:
298+
torch.nn.Module: prepared model or quantized model.
299+
"""
239300
logger.info("Quantize model with the PT2E static quant algorithm.")
240301
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
241302
from neural_compressor.torch.algorithms.pt2e_quant.save_load import save
@@ -264,6 +325,16 @@ def smooth_quant_entry(
264325
*args,
265326
**kwargs,
266327
) -> torch.nn.Module:
328+
"""The main entry to apply smooth quantization.
329+
330+
Args:
331+
model (torch.nn.Module): raw fp32 model or prepared model.
332+
configs_mapping (Dict[Tuple[str, callable], SmoothQuantConfig]): per-op configuration.
333+
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.
334+
335+
Returns:
336+
torch.nn.Module: prepared model or quantized model.
337+
"""
267338
logger.info("Quantize model with the smooth quant algorithm.")
268339
from neural_compressor.torch.algorithms.smooth_quant import SmoothQuantQuantizer, TorchSmoothQuant
269340

@@ -323,6 +394,16 @@ def awq_quantize_entry(
323394
*args,
324395
**kwargs,
325396
) -> torch.nn.Module:
397+
"""The main entry to apply AWQ quantization.
398+
399+
Args:
400+
model (torch.nn.Module): raw fp32 model or prepared model.
401+
configs_mapping (Dict[Tuple[str, callable], AWQConfig]): per-op configuration.
402+
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.
403+
404+
Returns:
405+
torch.nn.Module: prepared model or quantized model.
406+
"""
326407
logger.info("Quantize model with the AWQ algorithm.")
327408
from neural_compressor.torch.algorithms.weight_only.awq import AWQQuantizer
328409
from neural_compressor.torch.algorithms.weight_only.save_load import save
@@ -391,8 +472,22 @@ def awq_quantize_entry(
391472
###################### TEQ Algo Entry ##################################
392473
@register_algo(name=TEQ)
393474
def teq_quantize_entry(
394-
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], TEQConfig], mode: Mode, *args, **kwargs
475+
model: torch.nn.Module,
476+
configs_mapping: Dict[Tuple[str, callable], TEQConfig],
477+
mode: Mode,
478+
*args,
479+
**kwargs,
395480
) -> torch.nn.Module:
481+
"""The main entry to apply TEQ quantization.
482+
483+
Args:
484+
model (torch.nn.Module): raw fp32 model or prepared model.
485+
configs_mapping (Dict[Tuple[str, callable], TEQConfig]): per-op configuration.
486+
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.
487+
488+
Returns:
489+
torch.nn.Module: prepared model or quantized model.
490+
"""
396491
from neural_compressor.torch.algorithms.weight_only.save_load import save
397492
from neural_compressor.torch.algorithms.weight_only.teq import TEQuantizer
398493

@@ -453,6 +548,16 @@ def autoround_quantize_entry(
453548
*args,
454549
**kwargs,
455550
) -> torch.nn.Module:
551+
"""The main entry to apply AutoRound quantization.
552+
553+
Args:
554+
model (torch.nn.Module): raw fp32 model or prepared model.
555+
configs_mapping (Dict[Tuple[str, callable], AutoRoundConfig]): per-op configuration.
556+
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.
557+
558+
Returns:
559+
torch.nn.Module: prepared model or quantized model.
560+
"""
456561
from neural_compressor.torch.algorithms.weight_only.autoround import AutoRoundQuantizer
457562
from neural_compressor.torch.algorithms.weight_only.save_load import save
458563

@@ -530,6 +635,16 @@ def hqq_entry(
530635
*args,
531636
**kwargs,
532637
) -> torch.nn.Module:
638+
"""The main entry to apply AutoRound quantization.
639+
640+
Args:
641+
model (torch.nn.Module): raw fp32 model or prepared model.
642+
configs_mapping (Dict[Tuple[str, callable], AutoRoundConfig]): per-op configuration.
643+
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.
644+
645+
Returns:
646+
torch.nn.Module: prepared model or quantized model.
647+
"""
533648
from neural_compressor.torch.algorithms.weight_only.hqq import HQQuantizer
534649
from neural_compressor.torch.algorithms.weight_only.save_load import save
535650

@@ -572,6 +687,16 @@ def mx_quant_entry(
572687
*args,
573688
**kwargs,
574689
) -> torch.nn.Module:
690+
"""The main entry to apply AutoRound quantization.
691+
692+
Args:
693+
model (torch.nn.Module): raw fp32 model or prepared model.
694+
configs_mapping (Dict[Tuple[str, callable], AutoRoundConfig]): per-op configuration.
695+
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.
696+
697+
Returns:
698+
torch.nn.Module: prepared model or quantized model.
699+
"""
575700
logger.info("Quantize model with the mx quant algorithm.")
576701
from neural_compressor.torch.algorithms.mx_quant.mx import MXQuantizer
577702

@@ -586,8 +711,21 @@ def mx_quant_entry(
586711
###################### Mixed Precision Algo Entry ##################################
587712
@register_algo(MIX_PRECISION)
588713
def mix_precision_entry(
589-
model: torch.nn.Module, configs_mapping: Dict[Tuple[str], MixPrecisionConfig], *args, **kwargs
714+
model: torch.nn.Module,
715+
configs_mapping: Dict[Tuple[str], MixPrecisionConfig],
716+
*args,
717+
**kwargs,
590718
) -> torch.nn.Module:
719+
"""The main entry to apply Mixed Precision.
720+
721+
Args:
722+
model (torch.nn.Module): raw fp32 model or prepared model.
723+
configs_mapping (Dict[Tuple[str, callable], MixPrecisionConfig]): per-op configuration.
724+
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.
725+
726+
Returns:
727+
torch.nn.Module: prepared model or quantized model.
728+
"""
591729
# only support fp16 and bf16 now, more types might be added later
592730
from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionConverter
593731

neural_compressor/torch/quantization/autotune.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,23 @@
3232

3333

3434
def get_rtn_double_quant_config_set() -> List[RTNConfig]:
35+
"""Generate RTN double quant config set.
36+
37+
Returns:
38+
List[RTNConfig]: a set of quant config
39+
"""
3540
rtn_double_quant_config_set = []
3641
for double_quant_type, double_quant_config in constants.DOUBLE_QUANT_CONFIGS.items():
3742
rtn_double_quant_config_set.append(RTNConfig.from_dict(double_quant_config))
3843
return rtn_double_quant_config_set
3944

4045

4146
def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
47+
"""Generate all quant config set.
48+
49+
Returns:
50+
Union[BaseConfig, List[BaseConfig]]: a set of quant config
51+
"""
4252
return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME)
4353

4454

neural_compressor/torch/quantization/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,16 @@
6767

6868

6969
class OperatorConfig(NamedTuple):
70+
"""OperatorConfig."""
71+
7072
config: BaseConfig
7173
operators: List[Union[str, Callable]]
7274
valid_func_list: List[Callable] = []
7375

7476

7577
class TorchBaseConfig(BaseConfig):
78+
"""Base config class for torch backend."""
79+
7680
# re-write func _get_op_name_op_type_config to fallback op_type with string
7781
# because there are some special op_types for IPEX backend: `Linear&Relu`, `Linear&add`, ...
7882
def _get_op_name_op_type_config(self):

neural_compressor/torch/utils/environ.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222

2323
################ Check imported sys.module first to decide behavior #################
2424
def is_ipex_imported() -> bool:
25+
"""Check whether intel_extension_for_pytorch is imported."""
2526
for name, _ in sys.modules.items():
2627
if name == "intel_extension_for_pytorch":
2728
return True
2829
return False
2930

3031

3132
def is_transformers_imported() -> bool:
33+
"""Check whether transformers is imported."""
3234
for name, _ in sys.modules.items():
3335
if name == "transformers":
3436
return True
@@ -37,6 +39,11 @@ def is_transformers_imported() -> bool:
3739

3840
################ Check available sys.module to decide behavior #################
3941
def is_package_available(package_name):
42+
"""Check if the package exists in the environment without importing.
43+
44+
Args:
45+
package_name (str): package name
46+
"""
4047
from importlib.util import find_spec
4148

4249
package_spec = find_spec(package_name)
@@ -52,6 +59,7 @@ def is_package_available(package_name):
5259

5360

5461
def is_hpex_available():
62+
"""Returns whether hpex is available."""
5563
return _hpex_available
5664

5765

@@ -63,10 +71,12 @@ def is_hpex_available():
6371

6472

6573
def is_ipex_available():
74+
"""Return whether ipex is available."""
6675
return _ipex_available
6776

6877

6978
def get_ipex_version():
79+
"""Return ipex version if ipex exists."""
7080
if is_ipex_available():
7181
try:
7282
import intel_extension_for_pytorch as ipex
@@ -84,6 +94,7 @@ def get_ipex_version():
8494

8595

8696
def get_torch_version():
97+
"""Return torch version if ipex exists."""
8798
try:
8899
torch_version = torch.__version__.split("+")[0]
89100
except ValueError as e: # pragma: no cover
@@ -96,6 +107,7 @@ def get_torch_version():
96107

97108

98109
def get_accelerator(device_name="auto"):
110+
"""Return the recommended accelerator based on device priority."""
99111
global accelerator # update the global accelerator when calling this func
100112
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
101113

@@ -109,6 +121,7 @@ def get_accelerator(device_name="auto"):
109121

110122
# for habana ease-of-use
111123
def device_synchronize(raw_func):
124+
"""Function decorator that calls accelerated.synchronize before and after a function call."""
112125
from functools import wraps
113126

114127
@wraps(raw_func)

neural_compressor/torch/utils/utility.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def set_module(model, op_name, new_module):
115115

116116

117117
def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) -> List[Tuple[str, str]]:
118+
"""Get model info according to white_module_list."""
118119
module_dict = dict(model.named_modules())
119120
filter_result = []
120121
filter_result_set = set()
@@ -129,6 +130,11 @@ def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) ->
129130

130131

131132
def get_double_quant_config_dict(double_quant_type="BNB_NF4"):
133+
"""Query config dict of double_quant according to double_quant_type.
134+
135+
Args:
136+
double_quant_type (str, optional): double_quant type. Defaults to "BNB_NF4".
137+
"""
132138
from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS
133139

134140
assert double_quant_type in DOUBLE_QUANT_CONFIGS, "Supported double quant configs: {}".format(

0 commit comments

Comments
 (0)