diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index bc1c2ef3d66..dd7f3d02518 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -21,6 +21,7 @@ from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa from .decompose_div_pass import DecomposeDivPass # noqa from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa +from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa from .decompose_linear_pass import DecomposeLinearPass # noqa from .decompose_meandim_pass import DecomposeMeanDimPass # noqa from .decompose_select import DecomposeSelectPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index c085e3def1b..703c6ff214c 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -26,6 +26,7 @@ DecomposeBatchNormPass, DecomposeDivPass, DecomposeLayerNormPass, + DecomposeLeakyReLUPass, DecomposeLinearPass, DecomposeMeanDimPass, DecomposeSelectPass, @@ -121,6 +122,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(FuseBatchnorm2DPass(exported_program)) self.add_pass(ConvertMmToBmmPass()) self.add_pass(DecomposeLinearPass()) + self.add_pass(DecomposeLeakyReLUPass()) self.add_pass(DecomposeBatchNormPass()) self.add_pass(DecomposeLayerNormPass()) self.add_pass(DecomposeVarPass()) @@ -178,6 +180,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeVarPass()) self.add_pass(DecomposeMeanDimPass()) self.add_pass(DecomposeDivPass()) + self.add_pass(DecomposeLeakyReLUPass()) if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: # Numerically stable softmax uses amax which is not supported on Ethos-U55 diff --git a/backends/arm/_passes/decompose_leaky_relu_pass.py b/backends/arm/_passes/decompose_leaky_relu_pass.py new file mode 100644 index 00000000000..e896cc584be --- /dev/null +++ b/backends/arm/_passes/decompose_leaky_relu_pass.py @@ -0,0 +1,71 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops + +edge_ops = (exir_ops.edge.aten.leaky_relu.default,) +torch_ops = (torch.ops.aten.leaky_relu.default,) + + +def _get_leaky_relu_ops(op) -> tuple: + if op in edge_ops: + return ( + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.add.Tensor, + ) + elif op in torch_ops: + return ( + torch.ops.aten.clamp.default, + torch.ops.aten.full.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.add.Tensor, + ) + else: + raise RuntimeError(f"Can't get decomposition ops for op {op}") + + +class DecomposeLeakyReLUPass(ArmPass): + """ + This pass decomposes Leaky ReLU into primitive operations. + LeakyReLU(x,slope) = max(0,x) + slope * min(0,x) + + Example: + %op1 = clamp(x,0,None) (equivalent to max(0,x)) + %op2 = clamp(x,None,0) (equivalent to min(0,x)) + %op3 = full(x.shape,slope) + %op4 = mul(%op3,%op2) + %op5 = add(%op1,%op4) + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in (edge_ops + torch_ops): + return super().call_operator(op, args, kwargs, meta) + + x = args[0] + slope = args[1] if len(args) > 1 else 0.01 + dtype = x.node.meta["val"].dtype + clamp, full, mul, add = _get_leaky_relu_ops(op) + op1 = super().call_operator( + op=clamp, args=(x, 0, None), kwargs=kwargs, meta=meta + ) + op2 = super().call_operator( + op=clamp, args=(x, None, 0), kwargs=kwargs, meta=meta + ) + op3 = super().call_operator( + op=full, + args=(x.node.meta["val"].shape, slope), + kwargs={"dtype": dtype}, + meta=meta, + ) + op4 = super().call_operator(op=mul, args=(op3, op2), kwargs=kwargs, meta=meta) + op5 = super().call_operator(op=add, args=(op1, op4), kwargs=kwargs, meta=meta) + return op5 diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 0e5d7ecc958..4932a0cf45f 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -192,6 +192,7 @@ def is_node_supported( exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.leaky_relu.default, exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.select_copy.int, @@ -257,6 +258,7 @@ def is_node_supported( exir_ops.edge.aten.sub.Scalar, exir_ops.edge.aten.mul.Scalar, exir_ops.edge.aten.div.Scalar, + exir_ops.edge.aten.leaky_relu.default, ] if needs_decomp: self.reporter.report_reject(node, "Needs to be decomposed.") diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index baca13029a3..603dab8b9e8 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -217,6 +217,8 @@ def _match_pattern( torch.ops.aten.pad.default, torch.ops.aten.amax.default, torch.ops.aten.amin.default, + torch.ops.aten.clamp.default, + torch.ops.aten.clamp.Tensor, ] # Operators that can inherit the quantization specs from its parent node @@ -236,8 +238,6 @@ def _match_pattern( torch.ops.aten.flatten.using_ints, torch.ops.aten.dropout.default, torch.ops.aten.dropout_.default, - torch.ops.aten.clamp.default, - torch.ops.aten.clamp.Tensor, torch.ops.aten.where, operator.getitem, ] diff --git a/backends/arm/test/ops/test_leaky_relu.py b/backends/arm/test/ops/test_leaky_relu.py new file mode 100644 index 00000000000..b9f0c3a8d1a --- /dev/null +++ b/backends/arm/test/ops/test_leaky_relu.py @@ -0,0 +1,88 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.leaky_relu.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_leaky_relu_default" +input_t1 = Tuple[torch.Tensor] # Input x + + +class LeakyReLU(torch.nn.Module): + def __init__(self, slope: float = 0.01): + super().__init__() + self.activation = torch.nn.LeakyReLU(slope) + + def forward(self, x: torch.Tensor): + return self.activation(x) + + test_data: dict[str, input_t1] = { + "zeros": ((torch.zeros(1, 1, 5, 5),), 0.01), + "ones": ((torch.ones(1, 32, 112, 112),), 0.01), + "rand": ((torch.rand(1, 96, 56, 56),), 0.2), + "3Dtensor": ((torch.rand(5, 5, 5),), 0.001), + "negative_slope": ((torch.rand(1, 16, 128, 128),), -0.002), + } + + +@common.parametrize("test_data", LeakyReLU.test_data) +def test_leaky_relu_tosa_MI(test_data): + data, slope = test_data + pipeline = TosaPipelineMI[input_t1]( + LeakyReLU(slope), data, [], use_to_edge_transform_and_lower=True + ) + pipeline.add_stage_after( + "to_edge_transform_and_lower", pipeline.tester.check_not, [exir_op] + ) + pipeline.run() + + +@common.parametrize("test_data", LeakyReLU.test_data) +def test_leaky_relu_tosa_BI(test_data): + data, slope = test_data + pipeline = TosaPipelineBI[input_t1]( + LeakyReLU(slope), data, [], use_to_edge_transform_and_lower=True + ) + pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op]) + pipeline.run() + + +@common.parametrize("test_data", LeakyReLU.test_data) +@common.XfailIfNoCorstone300 +def test_leaky_relu_u55_BI(test_data): + data, slope = test_data + pipeline = EthosU55PipelineBI[input_t1]( + LeakyReLU(slope), + data, + [], + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op]) + pipeline.run() + + +@common.parametrize("test_data", LeakyReLU.test_data) +@common.XfailIfNoCorstone320 +def test_leaky_relu_u85_BI(test_data): + data, slope = test_data + pipeline = EthosU85PipelineBI[input_t1]( + LeakyReLU(slope), + data, + [], + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op]) + pipeline.run()