diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 93b3f85c529..0a501963bfe 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -56,6 +56,7 @@ def convert_pt2( model: torch.nn.Module, inputs: tuple[object, ...], quantizer: CadenceQuantizer, + calibration_data: Optional[list[tuple[object, ...]]] = None, dump_graphs: bool = False, ) -> torch.fx.GraphModule: """ @@ -64,6 +65,9 @@ def convert_pt2( fuse the model later, if applicable. If you do not expect that behavior, please use quantize_and_fuse_pt2 instead, which will instantiate a default quantizer for you if needed. + If calibration data is provided, it will be used to calibrate the model. If + not, the inputs will be used for calibration instead, which is useful for + unit tests but should not be used for end-to-end use cases. Returns a GraphModule with the converted model. """ @@ -95,7 +99,12 @@ def convert_pt2( prepared_model = prepare_pt2e(model_gm, quantizer) # Calibrate - prepared_model(*inputs) + # If no calibration data is provided, use the inputs + if calibration_data is None: + calibration_data = [inputs] + + for samples in calibration_data: + prepared_model(*samples) # Convert converted_model = convert_pt2e(prepared_model) @@ -136,10 +145,14 @@ def quantize_pt2( model: torch.nn.Module, inputs: tuple[object, ...], quantizer: Optional[CadenceQuantizer] = None, + calibration_data: Optional[list[tuple[object, ...]]] = None, dump_graphs: bool = False, ) -> torch.fx.GraphModule: """ Prepare, convert and fuse the model using the given quantizer. + If calibration data is provided, it will be used to calibrate the model. If + not, the inputs will be used for calibration instead, which is useful for + unit tests but should not be used for end-to-end use cases. Returns a GraphModule with the quantized model. """ # Make the model inference mode by calling model.eval() @@ -150,7 +163,9 @@ def quantize_pt2( quantizer = CadenceDefaultQuantizer() # Get converted graph module - converted_gm = convert_pt2(model, inputs, quantizer, dump_graphs) + converted_gm = convert_pt2( + model, inputs, quantizer, calibration_data, dump_graphs=dump_graphs + ) # Get fused model fused_gm = fuse_pt2(converted_gm, quantizer)