diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 7ec04ea0844..faab48ce2be 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -44,8 +44,10 @@ from executorch.backends.arm._passes.decompose_select import ( # type: ignore[import-not-found] DecomposeSelectPass, ) -from executorch.backends.arm._passes.decompose_softmax_pass import DecomposeSoftmaxPass -from executorch.backends.arm._passes.decompose_softmax_unstable_pass import ( +from executorch.backends.arm._passes.decompose_softmax_pass import ( # type: ignore[import-not-found] + DecomposeSoftmaxPass, +) +from executorch.backends.arm._passes.decompose_softmax_unstable_pass import ( # type: ignore[import-not-found] DecomposeSoftmaxUnstablePass, ) from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass @@ -85,6 +87,10 @@ from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform +from executorch.backends.transforms.replace_scalar_tensor_with_full import ( # type: ignore[import-not-found] + ReplaceScalarTensorWithFullPass, +) + from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, ) @@ -143,6 +149,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul return self._transform(exported_program.graph_module) def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: + self.add_pass(ReplaceScalarTensorWithFullPass()) self.add_pass(ReplaceScalarWithTensorArgPass()) self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) @@ -213,4 +220,5 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeSoftmaxPass()) self.add_pass(ConvertMinMaxPass()) + self.add_pass(ReplaceScalarTensorWithFullPass()) return self._transform(graph_module) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index f86bc21009c..6d7f528c2e7 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -200,6 +200,7 @@ def is_node_supported( exir_ops.edge.aten.constant_pad_nd.default, exir_ops.edge.aten.amax.default, exir_ops.edge.aten.amin.default, + torch.ops.aten.scalar_tensor.default, ] return supported diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index 4ed203a964e..6ba53d9844e 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -36,7 +36,6 @@ class TestConformer(unittest.TestCase): "executorch_exir_dialects_edge__ops_aten_where_self": 4, "torch.ops.aten._assert_scalar.default": 10, "torch.ops.aten._local_scalar_dense.default": 1, - "torch.ops.aten.scalar_tensor.default": 2, "torch.ops.higher_order.executorch_call_delegate": 6, } diff --git a/backends/arm/test/ops/test_scalar_tensor.py b/backends/arm/test/ops/test_scalar_tensor.py new file mode 100644 index 00000000000..770ebc1b4ec --- /dev/null +++ b/backends/arm/test/ops/test_scalar_tensor.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.xnnpack.test.tester.tester import Quantize +from parameterized import parameterized + + +float_test_data_suite = [ + # (test_name, scalar input, scalar input type,) + ( + "scalar_tensor_float_1", + 3.7, + torch.float32, + ), + ( + "scalar_tensor_float_2", + 66, + torch.float32, + ), +] + +int_test_data_suite = [ + # (test_name, scalar input, scalar input type,) + ( + "scalar_tensor_int32", + 33, + torch.int32, + ), + ( + "scalar_tensor_int8", + 8, + torch.int8, + ), + ( + "scalar_tensor_int16", + 16 * 16 * 16, + torch.int16, + ), +] + + +class ScalarTensor(torch.nn.Module): + def __init__(self, scalar, dtype=torch.float32): + super().__init__() + self.scalar = scalar + self.dtype = dtype + + def forward(self): + return torch.scalar_tensor(self.scalar, dtype=self.dtype) + + +class TestScalarTensor(unittest.TestCase): + + def _test_scalar_tensor_tosa_MI_pipeline( + self, module: torch.nn.Module, expected_output + ): + test_outputs = [] + in_data = () + + ( + ArmTester( + module, + example_inputs=in_data, + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80+MI", + ), + ) + .export() + .check_count({"torch.ops.aten.scalar_tensor.default": 1}) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_get_output(test_outputs, inputs=in_data) + ) + self._verify_output(test_outputs, expected_output) + + def _test_scalar_tensor_tosa_BI_pipeline( + self, module: torch.nn.Module, expected_output + ): + test_outputs = [] + in_data = () + tosa_spec = TosaSpecification.create_from_string("TOSA-0.80+BI") + compile_spec = common.get_tosa_compile_spec(tosa_spec) + quantizer = TOSAQuantizer(tosa_spec).set_io(get_symmetric_quantization_config()) + + ( + ArmTester( + module, + example_inputs=in_data, + compile_spec=compile_spec, + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.full.default": 1}) # Already replaced + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_get_output(test_outputs, inputs=in_data) + ) + self._verify_output(test_outputs, expected_output) + + def _verify_output(self, test_outputs, expected_output): + out_data = torch.squeeze(test_outputs[0][0]) + assert out_data == expected_output + assert out_data.dtype == expected_output.dtype + + @parameterized.expand(int_test_data_suite + float_test_data_suite) + def test_scalar_tensor_tosa_MI( # Note TOSA MI supports all types + self, test_name: str, scalar_value, scalar_type + ): + scalar = scalar_value + dtype = scalar_type + self._test_scalar_tensor_tosa_MI_pipeline( + ScalarTensor(scalar, dtype), torch.scalar_tensor(scalar, dtype=dtype) + ) + + @parameterized.expand(float_test_data_suite) + def test_scalar_tensor_tosa_BI(self, test_name: str, scalar_value, scalar_type): + scalar = scalar_value + dtype = scalar_type + self._test_scalar_tensor_tosa_BI_pipeline( + ScalarTensor(scalar, dtype), torch.scalar_tensor(scalar, dtype=dtype) + ) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 7b74603cfb2..9ab8d5232b8 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -372,6 +372,60 @@ def serialize( def is_quantized(self) -> bool: return self.stages[self.stage_name(tester.Quantize)] is not None + def run_method_and_get_output( + self, + test_outputs: List, + inputs: Optional[Tuple[torch.Tensor]] = None, + stage: Optional[str] = None, + num_runs=1, + ): + """ + Returns the run_artifact output of 'stage'. This output is returned as parameter of type List. + Returns self to allow the function to be run in a test chain. + + Args: + stage: (Optional[str]): The name of the stage to compare. + The default is the latest run stage. + test_output: All output results. + inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data. + The default is random data. + """ + edge_stage = self.stages[self.stage_name(tester.ToEdge)] + if edge_stage is None: + edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)] + assert ( + edge_stage is not None + ), "To get outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run." + + stage = stage or self.cur + test_stage = self.stages[stage] + + exported_program = self.stages[self.stage_name(tester.Export)].artifact + output_nodes = get_output_nodes(exported_program) + output_qparams = get_output_quantization_params(output_nodes) + + quantization_scales = [] + for node in output_qparams: + quantization_scales.append(getattr(output_qparams[node], "scale", None)) + + # Loop inputs and get outputs of the test stage. + for run_iteration in range(num_runs): + reference_input = inputs if inputs else next(self.generate_random_inputs()) + + input_shapes = [ + generated_input.shape if hasattr(generated_input, "shape") else (1,) + for generated_input in reference_input + ] + input_shape_str = ", ".join([str(list(i)) for i in input_shapes]) + logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}") + + test_output, _ = pytree.tree_flatten( + test_stage.run_artifact(reference_input) + ) + test_outputs.append(test_output) + + return self + def run_method_and_compare_outputs( self, inputs: Optional[Tuple[torch.Tensor]] = None, diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index aa53750b64f..d2b4347116d 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -38,6 +38,9 @@ ) from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass from executorch.backends.cadence.aot.utils import get_edge_overload_packet +from executorch.backends.transforms.replace_scalar_tensor_with_full import ( + ReplaceScalarTensorWithFullPass, +) from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, ) @@ -1723,35 +1726,9 @@ def call_operator(self, op, args, kwargs, meta): register_cadence_pass(CadencePassAttribute(opt_level=0))(ReplaceScalarWithTensorArgPass) -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceScalarTensorWithFullPass(ExportPass): - """ - aten.scalar_tensor can be replaced by aten.full with a shape of [1]. - scalar_tensor is not supported, so this is an opt_level=0 pass. - """ - - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.aten.scalar_tensor.default, - torch.ops.aten.scalar_tensor.default, - }: - return super().call_operator(op, args, kwargs, meta) - - return super().call_operator( - exir_ops.edge.aten.full.default, - ( - [1], - args[0], - ), - {"dtype": torch.float32}, - meta, - ) +register_cadence_pass(CadencePassAttribute(opt_level=0))( + ReplaceScalarTensorWithFullPass +) @register_cadence_pass(CadencePassAttribute(opt_level=0)) diff --git a/backends/transforms/replace_scalar_tensor_with_full.py b/backends/transforms/replace_scalar_tensor_with_full.py new file mode 100644 index 00000000000..13cea5cc20a --- /dev/null +++ b/backends/transforms/replace_scalar_tensor_with_full.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Tuple + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from torch.fx.node import Argument + + +class ReplaceScalarTensorWithFullPass(ExportPass): + """ + aten.scalar_tensor can be replaced by aten.full with a shape of [1]. + """ + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in { + exir_ops.edge.aten.scalar_tensor.default, + torch.ops.aten.scalar_tensor.default, + }: + return super().call_operator(op, args, kwargs, meta) + + return super().call_operator( + exir_ops.edge.aten.full.default, + ( + [1], + args[0], + ), + {"dtype": kwargs["dtype"]}, + meta, + )