25
25
26
26
27
27
def patch_module (mod , qconfig , mod_dict , patched_mod = None ):
28
+ """Replaces the module with patched module according to mod_dict.
29
+
30
+ Args:
31
+ mod (nn.module): The module that will be replaced with a patched module that quantize the inputs/outputs.
32
+ qconfig (ModuleExtraConfig): The quantization config object with the information how to quantize the inputs/outputs.
33
+ mod_dict (dict): dictionary from module name to its patched module.
34
+
35
+ Returns:
36
+ nn.module: The new patched module after patching.
37
+ """
28
38
parent = parent_child_mod_dict [mod ].parent
29
39
name = parent_child_mod_dict [mod ].name
30
40
if patched_mod is None :
@@ -33,6 +43,8 @@ def patch_module(mod, qconfig, mod_dict, patched_mod=None):
33
43
34
44
35
45
def apply_hf_hook (module ):
46
+ """Applies hf_hook on a given module so its weights will be loaded from disk to cpu and then we can quantize it.
47
+ """
36
48
if hasattr (module , "_hf_hook" ):
37
49
module ._hf_hook .pre_forward (module )
38
50
module ._hf_hook .detach_hook (module )
@@ -43,6 +55,12 @@ def apply_hf_hook(module):
43
55
44
56
45
57
def quantize_params (mod , mod_extra_config ):
58
+ """Quantizes the weights of the given module according to the quantization info from mod_extra_config.
59
+
60
+ Args:
61
+ mod (nn.module): The module that its weights will be quantized.
62
+ mod_extra_config (ModuleExtraConfig): The quantization config object with the information how to quantize the inputs/outputs.
63
+ """
46
64
for param_name in mod_extra_config .params :
47
65
quantizer = mod_extra_config .params [param_name ]
48
66
param = getattr (mod , param_name )
@@ -55,6 +73,15 @@ def quantize_params(mod, mod_extra_config):
55
73
56
74
57
75
def prepare_model (model , qconfig , mod_list , hp_dtype = torch .float ):
76
+ """Replaces the model submodules according to the mod_list with patched quantization modules.
77
+ Configures patched modules with the quantization/dequantization methods to apply on their input and output tensors.
78
+ Quantizes the model parameters as they are static.
79
+
80
+ Args:
81
+ model (nn.module): The model to quantize.
82
+ qconfig (dict): Dict that maps between patched module and its quantization info.
83
+ mod_list (list): The specific submodules that will be quantized in the model.
84
+ """
58
85
config = get_hqt_config (model )
59
86
patched_modules = []
60
87
patched_module_types = set ()
@@ -82,6 +109,12 @@ def prepare_model(model, qconfig, mod_list, hp_dtype=torch.float):
82
109
83
110
84
111
def quantize (model , mod_list ):
112
+ """Builds quantization config object that contains for each submodule its quantization functions as preparation for quantization.
113
+
114
+ Args:
115
+ model (nn.module): The model that will be quantized.
116
+ mod_list (list, optional): The specific modules that will be quantized in the model.
117
+ """
85
118
config = get_hqt_config (model )
86
119
generate_model_info (model )
87
120
hp_dtype = config .cfg ["hp_dtype" ]
0 commit comments