Skip to content

Commit ad0625b

Browse files
committed
[SW-189684] Add description to functions in HQT
Change-Id: Id5822a21abd1f60f28999574c2ca0e89acc70bf6
1 parent 7bf9521 commit ad0625b

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/measure.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,16 @@
2828

2929

3030
def patch_module_measure(mod, mconfig, mod_dict):
31+
"""Replaces the module with patched module according to mconfig.
32+
33+
Args:
34+
mod (nn.module): The module that will be replaced with patched module that measures the inputs.
35+
mconfig (e.g. MaxAbsObserver/MaxAbsPerChannelObserver): The observer object that will measure the parameters.
36+
mod_dict (dict): dictionary from module name to its patched module.
37+
38+
Returns:
39+
nn.module: The new module after patching.
40+
"""
3141
parent = parent_child_mod_dict[mod].parent
3242
name = parent_child_mod_dict[mod].name
3343
patched_mod = mod_dict[mod.__class__.__name__].patched_module(mod, mconfig, name)
@@ -72,6 +82,12 @@ def init_measure_object(mod, name, observer_class, mod_type, skip_measure_output
7282

7383

7484
def prepare_model(model, mod_list=None):
85+
"""Defines the observer class and modules for measurement as preparation.
86+
87+
Args:
88+
model (nn.module): The model that will be measured.
89+
mod_list (list, optional): The specific submodules that will be measured in the model. Defaults to None.
90+
"""
7591
config = get_hqt_config(model).cfg
7692
observer_class = observer_types[config["observer"]]
7793
if (config["shape_file"] is not None) and (observer_class != ShapeObserver):
@@ -85,6 +101,16 @@ def prepare_model(model, mod_list=None):
85101

86102

87103
def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=None):
104+
"""Replace the submodules of the model that appear in mod_list with a patched submodule that uses the given observer_class
105+
so the submodule will preform measurement on inputs/outputs in forward stage.
106+
Weights measurement is done during model preparation as they are static.
107+
108+
Args:
109+
model (nn.module): The model that will be measured.
110+
mod_list (list): The specific submodules that will be measured in the model.
111+
observer_class (e.g. MaxAbsObserver/MaxAbsPerChannelObserver): The observer type that will measure the weights.
112+
d_shapes (dict, optional): Defaults to None.
113+
"""
88114
top_level_config = get_hqt_config(model)
89115
config = top_level_config.cfg
90116
skip_outputs_measurements = config["measure_exclude"] & (MeasureExclude.OUTPUT | MeasureExclude.ALL)

neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@
2525

2626

2727
def patch_module(mod, qconfig, mod_dict, patched_mod=None):
28+
"""Replaces the module with patched module according to mod_dict.
29+
30+
Args:
31+
mod (nn.module): The module that will be replaced with a patched module that quantize the inputs/outputs.
32+
qconfig (ModuleExtraConfig): The quantization config object with the information how to quantize the inputs/outputs.
33+
mod_dict (dict): dictionary from module name to its patched module.
34+
35+
Returns:
36+
nn.module: The new patched module after patching.
37+
"""
2838
parent = parent_child_mod_dict[mod].parent
2939
name = parent_child_mod_dict[mod].name
3040
if patched_mod is None:
@@ -33,6 +43,8 @@ def patch_module(mod, qconfig, mod_dict, patched_mod=None):
3343

3444

3545
def apply_hf_hook(module):
46+
"""Applies hf_hook on a given module so its weights will be loaded from disk to cpu and then we can quantize it.
47+
"""
3648
if hasattr(module, "_hf_hook"):
3749
module._hf_hook.pre_forward(module)
3850
module._hf_hook.detach_hook(module)
@@ -43,6 +55,12 @@ def apply_hf_hook(module):
4355

4456

4557
def quantize_params(mod, mod_extra_config):
58+
"""Quantizes the weights of the given module according to the quantization info from mod_extra_config.
59+
60+
Args:
61+
mod (nn.module): The module that its weights will be quantized.
62+
mod_extra_config (ModuleExtraConfig): The quantization config object with the information how to quantize the inputs/outputs.
63+
"""
4664
for param_name in mod_extra_config.params:
4765
quantizer = mod_extra_config.params[param_name]
4866
param = getattr(mod, param_name)
@@ -55,6 +73,15 @@ def quantize_params(mod, mod_extra_config):
5573

5674

5775
def prepare_model(model, qconfig, mod_list, hp_dtype=torch.float):
76+
"""Replaces the model submodules according to the mod_list with patched quantization modules.
77+
Configures patched modules with the quantization/dequantization methods to apply on their input and output tensors.
78+
Quantizes the model parameters as they are static.
79+
80+
Args:
81+
model (nn.module): The model to quantize.
82+
qconfig (dict): Dict that maps between patched module and its quantization info.
83+
mod_list (list): The specific submodules that will be quantized in the model.
84+
"""
5885
config = get_hqt_config(model)
5986
patched_modules = []
6087
patched_module_types = set()
@@ -82,6 +109,12 @@ def prepare_model(model, qconfig, mod_list, hp_dtype=torch.float):
82109

83110

84111
def quantize(model, mod_list):
112+
"""Builds quantization config object that contains for each submodule its quantization functions as preparation for quantization.
113+
114+
Args:
115+
model (nn.module): The model that will be quantized.
116+
mod_list (list, optional): The specific modules that will be quantized in the model.
117+
"""
85118
config = get_hqt_config(model)
86119
generate_model_info(model)
87120
hp_dtype = config.cfg["hp_dtype"]

neural_compressor/torch/algorithms/fp8_quant/_core/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def is_substr(substr_list, target):
4242

4343

4444
def prepare_model(model):
45+
"""Receives the parent module to quantize.
46+
Replaces its submodules with patched submodules that perform calibration and quantization.
47+
Returns the patched parent module that can perform calibration or quantization according to the configuration.
48+
49+
Args:
50+
model (nn.module): The model that will be measured/quantized.
51+
"""
4552
config = get_hqt_config(model)
4653
update_mod_dict(config)
4754
allowlist = set(config.cfg["mod_dict"].keys())

0 commit comments

Comments
 (0)