From dc8c21a53a3c1811a23aad6d099e86810ed2ca15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Mon, 3 Mar 2025 15:38:38 +0100 Subject: [PATCH 1/3] Arm backend: Move ReplaceScalarTensorWithFullPass to transforms The pass is general and can be used by multiple backends. The aten.scalar_tensor is replaced by a aten.full which is already supported by Arm backend. Adds new method to Arm tester for getting output as the nn.module in the unit test does not take any input. The output is then manually compared within the unit test. Change-Id: I2bf211a2ce561d53e8a6cf683fdbda58e675938e --- backends/arm/_passes/arm_pass_manager.py | 6 +- .../tosa_supported_operators.py | 1 + backends/arm/test/ops/test_scalar_tensor.py | 137 ++++++++++++++++++ backends/arm/test/tester/arm_tester.py | 54 +++++++ backends/cadence/aot/replace_ops.py | 35 +---- .../replace_scalar_tensor_with_full.py | 42 ++++++ 6 files changed, 245 insertions(+), 30 deletions(-) create mode 100644 backends/arm/test/ops/test_scalar_tensor.py create mode 100644 backends/transforms/replace_scalar_tensor_with_full.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index f8a4a40648f..8f6ec71d8e7 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -78,7 +78,9 @@ UnsqueezeScalarPlaceholdersPass, ) from executorch.backends.arm.tosa_specification import TosaSpecification - +from executorch.backends.transforms.replace_scalar_tensor_with_full import ( + ReplaceScalarTensorWithFullPass, +) from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, ) @@ -133,6 +135,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul return self._transform(exported_program.graph_module) def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: + self.add_pass(ReplaceScalarTensorWithFullPass()) self.add_pass(ReplaceScalarWithTensorArgPass()) self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) @@ -194,4 +197,5 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeSoftmaxesPass()) self.add_pass(ConvertMinMaxPass()) + self.add_pass(ReplaceScalarTensorWithFullPass()) return self._transform(graph_module) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 7a9ce29ff52..393ecbce562 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -171,6 +171,7 @@ def is_node_supported( exir_ops.edge.aten.constant_pad_nd.default, exir_ops.edge.aten.amax.default, exir_ops.edge.aten.amin.default, + torch.ops.aten.scalar_tensor.default, ] return supported diff --git a/backends/arm/test/ops/test_scalar_tensor.py b/backends/arm/test/ops/test_scalar_tensor.py new file mode 100644 index 00000000000..770ebc1b4ec --- /dev/null +++ b/backends/arm/test/ops/test_scalar_tensor.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.xnnpack.test.tester.tester import Quantize +from parameterized import parameterized + + +float_test_data_suite = [ + # (test_name, scalar input, scalar input type,) + ( + "scalar_tensor_float_1", + 3.7, + torch.float32, + ), + ( + "scalar_tensor_float_2", + 66, + torch.float32, + ), +] + +int_test_data_suite = [ + # (test_name, scalar input, scalar input type,) + ( + "scalar_tensor_int32", + 33, + torch.int32, + ), + ( + "scalar_tensor_int8", + 8, + torch.int8, + ), + ( + "scalar_tensor_int16", + 16 * 16 * 16, + torch.int16, + ), +] + + +class ScalarTensor(torch.nn.Module): + def __init__(self, scalar, dtype=torch.float32): + super().__init__() + self.scalar = scalar + self.dtype = dtype + + def forward(self): + return torch.scalar_tensor(self.scalar, dtype=self.dtype) + + +class TestScalarTensor(unittest.TestCase): + + def _test_scalar_tensor_tosa_MI_pipeline( + self, module: torch.nn.Module, expected_output + ): + test_outputs = [] + in_data = () + + ( + ArmTester( + module, + example_inputs=in_data, + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80+MI", + ), + ) + .export() + .check_count({"torch.ops.aten.scalar_tensor.default": 1}) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_get_output(test_outputs, inputs=in_data) + ) + self._verify_output(test_outputs, expected_output) + + def _test_scalar_tensor_tosa_BI_pipeline( + self, module: torch.nn.Module, expected_output + ): + test_outputs = [] + in_data = () + tosa_spec = TosaSpecification.create_from_string("TOSA-0.80+BI") + compile_spec = common.get_tosa_compile_spec(tosa_spec) + quantizer = TOSAQuantizer(tosa_spec).set_io(get_symmetric_quantization_config()) + + ( + ArmTester( + module, + example_inputs=in_data, + compile_spec=compile_spec, + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.full.default": 1}) # Already replaced + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_get_output(test_outputs, inputs=in_data) + ) + self._verify_output(test_outputs, expected_output) + + def _verify_output(self, test_outputs, expected_output): + out_data = torch.squeeze(test_outputs[0][0]) + assert out_data == expected_output + assert out_data.dtype == expected_output.dtype + + @parameterized.expand(int_test_data_suite + float_test_data_suite) + def test_scalar_tensor_tosa_MI( # Note TOSA MI supports all types + self, test_name: str, scalar_value, scalar_type + ): + scalar = scalar_value + dtype = scalar_type + self._test_scalar_tensor_tosa_MI_pipeline( + ScalarTensor(scalar, dtype), torch.scalar_tensor(scalar, dtype=dtype) + ) + + @parameterized.expand(float_test_data_suite) + def test_scalar_tensor_tosa_BI(self, test_name: str, scalar_value, scalar_type): + scalar = scalar_value + dtype = scalar_type + self._test_scalar_tensor_tosa_BI_pipeline( + ScalarTensor(scalar, dtype), torch.scalar_tensor(scalar, dtype=dtype) + ) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index a6da2accd1d..dbfb59a199a 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -338,6 +338,60 @@ def serialize( def is_quantized(self) -> bool: return self.stages[self.stage_name(tester.Quantize)] is not None + def run_method_and_get_output( + self, + test_outputs: List, + inputs: Optional[Tuple[torch.Tensor]] = None, + stage: Optional[str] = None, + num_runs=1, + ): + """ + Returns the run_artifact output of 'stage'. This output is returned as parameter of type List. + Returns self to allow the function to be run in a test chain. + + Args: + stage: (Optional[str]): The name of the stage to compare. + The default is the latest run stage. + test_output: All output results. + inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data. + The default is random data. + """ + edge_stage = self.stages[self.stage_name(tester.ToEdge)] + if edge_stage is None: + edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)] + assert ( + edge_stage is not None + ), "To get outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run." + + stage = stage or self.cur + test_stage = self.stages[stage] + + exported_program = self.stages[self.stage_name(tester.Export)].artifact + output_nodes = get_output_nodes(exported_program) + output_qparams = get_output_quantization_params(output_nodes) + + quantization_scales = [] + for node in output_qparams: + quantization_scales.append(getattr(output_qparams[node], "scale", None)) + + # Loop inputs and get outputs of the test stage. + for run_iteration in range(num_runs): + reference_input = inputs if inputs else next(self.generate_random_inputs()) + + input_shapes = [ + generated_input.shape if hasattr(generated_input, "shape") else (1,) + for generated_input in reference_input + ] + input_shape_str = ", ".join([str(list(i)) for i in input_shapes]) + logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}") + + test_output, _ = pytree.tree_flatten( + test_stage.run_artifact(reference_input) + ) + test_outputs.append(test_output) + + return self + def run_method_and_compare_outputs( self, inputs: Optional[Tuple[torch.Tensor]] = None, diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index f91fb26ddc8..5d4a237bc18 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -38,6 +38,9 @@ ) from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass from executorch.backends.cadence.aot.utils import get_edge_overload_packet +from executorch.backends.transforms.replace_scalar_tensor_with_full import ( + ReplaceScalarTensorWithFullPass, +) from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, ) @@ -1722,35 +1725,9 @@ def call_operator(self, op, args, kwargs, meta): register_cadence_pass(CadencePassAttribute(opt_level=0))(ReplaceScalarWithTensorArgPass) -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceScalarTensorWithFullPass(ExportPass): - """ - aten.scalar_tensor can be replaced by aten.full with a shape of [1]. - scalar_tensor is not supported, so this is an opt_level=0 pass. - """ - - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.aten.scalar_tensor.default, - torch.ops.aten.scalar_tensor.default, - }: - return super().call_operator(op, args, kwargs, meta) - - return super().call_operator( - exir_ops.edge.aten.full.default, - ( - [1], - args[0], - ), - {"dtype": torch.float32}, - meta, - ) +register_cadence_pass(CadencePassAttribute(opt_level=0))( + ReplaceScalarTensorWithFullPass +) @register_cadence_pass(CadencePassAttribute(opt_level=0)) diff --git a/backends/transforms/replace_scalar_tensor_with_full.py b/backends/transforms/replace_scalar_tensor_with_full.py new file mode 100644 index 00000000000..13cea5cc20a --- /dev/null +++ b/backends/transforms/replace_scalar_tensor_with_full.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Tuple + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from torch.fx.node import Argument + + +class ReplaceScalarTensorWithFullPass(ExportPass): + """ + aten.scalar_tensor can be replaced by aten.full with a shape of [1]. + """ + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in { + exir_ops.edge.aten.scalar_tensor.default, + torch.ops.aten.scalar_tensor.default, + }: + return super().call_operator(op, args, kwargs, meta) + + return super().call_operator( + exir_ops.edge.aten.full.default, + ( + [1], + args[0], + ), + {"dtype": kwargs["dtype"]}, + meta, + ) From f52c63cebad2d8a3ea705f85e306b0e186653d57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Thu, 6 Mar 2025 12:06:07 +0100 Subject: [PATCH 2/3] Update conformer model test --- backends/arm/test/models/test_conformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index d9bc4e363c1..685e768ecdb 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -38,7 +38,6 @@ class TestConformer(unittest.TestCase): "executorch_exir_dialects_edge__ops_aten_any_dim": 2, "torch.ops.aten._assert_scalar.default": 10, "torch.ops.aten._local_scalar_dense.default": 1, - "torch.ops.aten.scalar_tensor.default": 2, "torch.ops.higher_order.executorch_call_delegate": 4, } From edc1470e82d243bdcfbad48beda79a0610b21e8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Thu, 6 Mar 2025 13:48:31 +0100 Subject: [PATCH 3/3] Fix formatting --- backends/arm/_passes/arm_pass_manager.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index a4729f9b29c..42ee636598e 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -80,13 +80,12 @@ ) from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform + from executorch.backends.transforms.replace_scalar_tensor_with_full import ( ReplaceScalarTensorWithFullPass, ) -from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform - - from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, )