Skip to content

Allow passing in calibration data to convert_pt2 #9791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def convert_pt2(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: CadenceQuantizer,
calibration_data: Optional[list[tuple[object, ...]]] = None,
dump_graphs: bool = False,
) -> torch.fx.GraphModule:
"""
Expand All @@ -64,6 +65,9 @@ def convert_pt2(
fuse the model later, if applicable. If you do not expect that behavior,
please use quantize_and_fuse_pt2 instead, which will instantiate a
default quantizer for you if needed.
If calibration data is provided, it will be used to calibrate the model. If
not, the inputs will be used for calibration instead, which is useful for
unit tests but should not be used for end-to-end use cases.
Returns a GraphModule with the converted model.
"""

Expand Down Expand Up @@ -95,7 +99,12 @@ def convert_pt2(
prepared_model = prepare_pt2e(model_gm, quantizer)

# Calibrate
prepared_model(*inputs)
# If no calibration data is provided, use the inputs
if calibration_data is None:
calibration_data = [inputs]

for samples in calibration_data:
prepared_model(*samples)

# Convert
converted_model = convert_pt2e(prepared_model)
Expand Down Expand Up @@ -136,10 +145,14 @@ def quantize_pt2(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: Optional[CadenceQuantizer] = None,
calibration_data: Optional[list[tuple[object, ...]]] = None,
dump_graphs: bool = False,
) -> torch.fx.GraphModule:
"""
Prepare, convert and fuse the model using the given quantizer.
If calibration data is provided, it will be used to calibrate the model. If
not, the inputs will be used for calibration instead, which is useful for
unit tests but should not be used for end-to-end use cases.
Returns a GraphModule with the quantized model.
"""
# Make the model inference mode by calling model.eval()
Expand All @@ -150,7 +163,9 @@ def quantize_pt2(
quantizer = CadenceDefaultQuantizer()

# Get converted graph module
converted_gm = convert_pt2(model, inputs, quantizer, dump_graphs)
converted_gm = convert_pt2(
model, inputs, quantizer, calibration_data, dump_graphs=dump_graphs
)

# Get fused model
fused_gm = fuse_pt2(converted_gm, quantizer)
Expand Down
Loading