Skip to content

Commit 5459bb5

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Allow passing in calibration data to convert_pt2
Summary: As titled. The data should be passed in as a list of inputs, and will be used to calibrate the PTQ model. Differential Revision: D71289674
1 parent d6bc799 commit 5459bb5

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

backends/cadence/aot/compiler.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def convert_pt2(
5656
model: torch.nn.Module,
5757
inputs: tuple[object, ...],
5858
quantizer: CadenceQuantizer,
59+
calibration_data: Optional[list[tuple[object, ...]]] = None,
5960
dump_graphs: bool = False,
6061
) -> torch.fx.GraphModule:
6162
"""
@@ -64,6 +65,9 @@ def convert_pt2(
6465
fuse the model later, if applicable. If you do not expect that behavior,
6566
please use quantize_and_fuse_pt2 instead, which will instantiate a
6667
default quantizer for you if needed.
68+
If calibration data is provided, it will be used to calibrate the model. If
69+
not, the inputs will be used for calibration instead, which is useful for
70+
unit tests but should not be used for end-to-end use cases.
6771
Returns a GraphModule with the converted model.
6872
"""
6973

@@ -95,7 +99,12 @@ def convert_pt2(
9599
prepared_model = prepare_pt2e(model_gm, quantizer)
96100

97101
# Calibrate
98-
prepared_model(*inputs)
102+
# If no calibration data is provided, use the inputs
103+
if calibration_data is None:
104+
calibration_data = [inputs]
105+
106+
for samples in calibration_data:
107+
prepared_model(*samples)
99108

100109
# Convert
101110
converted_model = convert_pt2e(prepared_model)
@@ -136,10 +145,14 @@ def quantize_pt2(
136145
model: torch.nn.Module,
137146
inputs: tuple[object, ...],
138147
quantizer: Optional[CadenceQuantizer] = None,
148+
calibration_data: Optional[list[tuple[object, ...]]] = None,
139149
dump_graphs: bool = False,
140150
) -> torch.fx.GraphModule:
141151
"""
142152
Prepare, convert and fuse the model using the given quantizer.
153+
If calibration data is provided, it will be used to calibrate the model. If
154+
not, the inputs will be used for calibration instead, which is useful for
155+
unit tests but should not be used for end-to-end use cases.
143156
Returns a GraphModule with the quantized model.
144157
"""
145158
# Make the model inference mode by calling model.eval()
@@ -150,7 +163,9 @@ def quantize_pt2(
150163
quantizer = CadenceDefaultQuantizer()
151164

152165
# Get converted graph module
153-
converted_gm = convert_pt2(model, inputs, quantizer, dump_graphs)
166+
converted_gm = convert_pt2(
167+
model, inputs, quantizer, calibration_data, dump_graphs=dump_graphs
168+
)
154169

155170
# Get fused model
156171
fused_gm = fuse_pt2(converted_gm, quantizer)

0 commit comments

Comments
 (0)