@@ -56,6 +56,7 @@ def convert_pt2(
56
56
model : torch .nn .Module ,
57
57
inputs : tuple [object , ...],
58
58
quantizer : CadenceQuantizer ,
59
+ calibration_data : Optional [list [tuple [object , ...]]] = None ,
59
60
dump_graphs : bool = False ,
60
61
) -> torch .fx .GraphModule :
61
62
"""
@@ -64,6 +65,9 @@ def convert_pt2(
64
65
fuse the model later, if applicable. If you do not expect that behavior,
65
66
please use quantize_and_fuse_pt2 instead, which will instantiate a
66
67
default quantizer for you if needed.
68
+ If calibration data is provided, it will be used to calibrate the model. If
69
+ not, the inputs will be used for calibration instead, which is useful for
70
+ unit tests but should not be used for end-to-end use cases.
67
71
Returns a GraphModule with the converted model.
68
72
"""
69
73
@@ -95,7 +99,12 @@ def convert_pt2(
95
99
prepared_model = prepare_pt2e (model_gm , quantizer )
96
100
97
101
# Calibrate
98
- prepared_model (* inputs )
102
+ # If no calibration data is provided, use the inputs
103
+ if calibration_data is None :
104
+ calibration_data = [inputs ]
105
+
106
+ for samples in calibration_data :
107
+ prepared_model (* samples )
99
108
100
109
# Convert
101
110
converted_model = convert_pt2e (prepared_model )
@@ -136,10 +145,14 @@ def quantize_pt2(
136
145
model : torch .nn .Module ,
137
146
inputs : tuple [object , ...],
138
147
quantizer : Optional [CadenceQuantizer ] = None ,
148
+ calibration_data : Optional [list [tuple [object , ...]]] = None ,
139
149
dump_graphs : bool = False ,
140
150
) -> torch .fx .GraphModule :
141
151
"""
142
152
Prepare, convert and fuse the model using the given quantizer.
153
+ If calibration data is provided, it will be used to calibrate the model. If
154
+ not, the inputs will be used for calibration instead, which is useful for
155
+ unit tests but should not be used for end-to-end use cases.
143
156
Returns a GraphModule with the quantized model.
144
157
"""
145
158
# Make the model inference mode by calling model.eval()
@@ -150,7 +163,9 @@ def quantize_pt2(
150
163
quantizer = CadenceDefaultQuantizer ()
151
164
152
165
# Get converted graph module
153
- converted_gm = convert_pt2 (model , inputs , quantizer , dump_graphs )
166
+ converted_gm = convert_pt2 (
167
+ model , inputs , quantizer , calibration_data , dump_graphs = dump_graphs
168
+ )
154
169
155
170
# Get fused model
156
171
fused_gm = fuse_pt2 (converted_gm , quantizer )
0 commit comments