Skip to content

Commit cfca108

Browse files
authored
Arm backend: Add support for Leaky ReLU (#9903)
Decompose LeakyReLU as operators supported by the Arm backend Signed-off-by: George Gekov <[email protected]>
1 parent d48d7a9 commit cfca108

File tree

6 files changed

+167
-2
lines changed

6 files changed

+167
-2
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
2222
from .decompose_div_pass import DecomposeDivPass # noqa
2323
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
24+
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
2425
from .decompose_linear_pass import DecomposeLinearPass # noqa
2526
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
2627
from .decompose_select import DecomposeSelectPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DecomposeBatchNormPass,
2727
DecomposeDivPass,
2828
DecomposeLayerNormPass,
29+
DecomposeLeakyReLUPass,
2930
DecomposeLinearPass,
3031
DecomposeMeanDimPass,
3132
DecomposeSelectPass,
@@ -121,6 +122,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
121122
self.add_pass(FuseBatchnorm2DPass(exported_program))
122123
self.add_pass(ConvertMmToBmmPass())
123124
self.add_pass(DecomposeLinearPass())
125+
self.add_pass(DecomposeLeakyReLUPass())
124126
self.add_pass(DecomposeBatchNormPass())
125127
self.add_pass(DecomposeLayerNormPass())
126128
self.add_pass(DecomposeVarPass())
@@ -178,6 +180,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
178180
self.add_pass(DecomposeVarPass())
179181
self.add_pass(DecomposeMeanDimPass())
180182
self.add_pass(DecomposeDivPass())
183+
self.add_pass(DecomposeLeakyReLUPass())
181184

182185
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
183186
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
from executorch.backends.arm._passes import ArmPass
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
13+
edge_ops = (exir_ops.edge.aten.leaky_relu.default,)
14+
torch_ops = (torch.ops.aten.leaky_relu.default,)
15+
16+
17+
def _get_leaky_relu_ops(op) -> tuple:
18+
if op in edge_ops:
19+
return (
20+
exir_ops.edge.aten.clamp.default,
21+
exir_ops.edge.aten.full.default,
22+
exir_ops.edge.aten.mul.Tensor,
23+
exir_ops.edge.aten.add.Tensor,
24+
)
25+
elif op in torch_ops:
26+
return (
27+
torch.ops.aten.clamp.default,
28+
torch.ops.aten.full.default,
29+
torch.ops.aten.mul.Tensor,
30+
torch.ops.aten.add.Tensor,
31+
)
32+
else:
33+
raise RuntimeError(f"Can't get decomposition ops for op {op}")
34+
35+
36+
class DecomposeLeakyReLUPass(ArmPass):
37+
"""
38+
This pass decomposes Leaky ReLU into primitive operations.
39+
LeakyReLU(x,slope) = max(0,x) + slope * min(0,x)
40+
41+
Example:
42+
%op1 = clamp(x,0,None) (equivalent to max(0,x))
43+
%op2 = clamp(x,None,0) (equivalent to min(0,x))
44+
%op3 = full(x.shape,slope)
45+
%op4 = mul(%op3,%op2)
46+
%op5 = add(%op1,%op4)
47+
"""
48+
49+
def call_operator(self, op, args, kwargs, meta):
50+
if op not in (edge_ops + torch_ops):
51+
return super().call_operator(op, args, kwargs, meta)
52+
53+
x = args[0]
54+
slope = args[1] if len(args) > 1 else 0.01
55+
dtype = x.node.meta["val"].dtype
56+
clamp, full, mul, add = _get_leaky_relu_ops(op)
57+
op1 = super().call_operator(
58+
op=clamp, args=(x, 0, None), kwargs=kwargs, meta=meta
59+
)
60+
op2 = super().call_operator(
61+
op=clamp, args=(x, None, 0), kwargs=kwargs, meta=meta
62+
)
63+
op3 = super().call_operator(
64+
op=full,
65+
args=(x.node.meta["val"].shape, slope),
66+
kwargs={"dtype": dtype},
67+
meta=meta,
68+
)
69+
op4 = super().call_operator(op=mul, args=(op3, op2), kwargs=kwargs, meta=meta)
70+
op5 = super().call_operator(op=add, args=(op1, op4), kwargs=kwargs, meta=meta)
71+
return op5

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def is_node_supported(
193193
exir_ops.edge.aten.repeat.default,
194194
exir_ops.edge.aten.reciprocal.default,
195195
exir_ops.edge.aten.relu.default,
196+
exir_ops.edge.aten.leaky_relu.default,
196197
exir_ops.edge.aten.rsqrt.default,
197198
exir_ops.edge.aten._softmax.default,
198199
exir_ops.edge.aten.select_copy.int,
@@ -258,6 +259,7 @@ def is_node_supported(
258259
exir_ops.edge.aten.sub.Scalar,
259260
exir_ops.edge.aten.mul.Scalar,
260261
exir_ops.edge.aten.div.Scalar,
262+
exir_ops.edge.aten.leaky_relu.default,
261263
]
262264
if needs_decomp:
263265
self.reporter.report_reject(node, "Needs to be decomposed.")

backends/arm/quantizer/quantization_annotator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ def _match_pattern(
218218
torch.ops.aten.pad.default,
219219
torch.ops.aten.amax.default,
220220
torch.ops.aten.amin.default,
221+
torch.ops.aten.clamp.default,
222+
torch.ops.aten.clamp.Tensor,
221223
]
222224

223225
# Operators that can inherit the quantization specs from its parent node
@@ -237,8 +239,6 @@ def _match_pattern(
237239
torch.ops.aten.flatten.using_ints,
238240
torch.ops.aten.dropout.default,
239241
torch.ops.aten.dropout_.default,
240-
torch.ops.aten.clamp.default,
241-
torch.ops.aten.clamp.Tensor,
242242
torch.ops.aten.where,
243243
operator.getitem,
244244
]
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
EthosU55PipelineBI,
12+
EthosU85PipelineBI,
13+
TosaPipelineBI,
14+
TosaPipelineMI,
15+
)
16+
17+
aten_op = "torch.ops.aten.leaky_relu.default"
18+
exir_op = "executorch_exir_dialects_edge__ops_aten_leaky_relu_default"
19+
input_t1 = Tuple[torch.Tensor] # Input x
20+
21+
22+
class LeakyReLU(torch.nn.Module):
23+
def __init__(self, slope: float = 0.01):
24+
super().__init__()
25+
self.activation = torch.nn.LeakyReLU(slope)
26+
27+
def forward(self, x: torch.Tensor):
28+
return self.activation(x)
29+
30+
test_data: dict[str, input_t1] = {
31+
"zeros": ((torch.zeros(1, 1, 5, 5),), 0.01),
32+
"ones": ((torch.ones(1, 32, 112, 112),), 0.01),
33+
"rand": ((torch.rand(1, 96, 56, 56),), 0.2),
34+
"3Dtensor": ((torch.rand(5, 5, 5),), 0.001),
35+
"negative_slope": ((torch.rand(1, 16, 128, 128),), -0.002),
36+
}
37+
38+
39+
@common.parametrize("test_data", LeakyReLU.test_data)
40+
def test_leaky_relu_tosa_MI(test_data):
41+
data, slope = test_data
42+
pipeline = TosaPipelineMI[input_t1](
43+
LeakyReLU(slope), data, [], use_to_edge_transform_and_lower=True
44+
)
45+
pipeline.add_stage_after(
46+
"to_edge_transform_and_lower", pipeline.tester.check_not, [exir_op]
47+
)
48+
pipeline.run()
49+
50+
51+
@common.parametrize("test_data", LeakyReLU.test_data)
52+
def test_leaky_relu_tosa_BI(test_data):
53+
data, slope = test_data
54+
pipeline = TosaPipelineBI[input_t1](
55+
LeakyReLU(slope), data, [], use_to_edge_transform_and_lower=True
56+
)
57+
pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op])
58+
pipeline.run()
59+
60+
61+
@common.parametrize("test_data", LeakyReLU.test_data)
62+
@common.XfailIfNoCorstone300
63+
def test_leaky_relu_u55_BI(test_data):
64+
data, slope = test_data
65+
pipeline = EthosU55PipelineBI[input_t1](
66+
LeakyReLU(slope),
67+
data,
68+
[],
69+
run_on_fvp=True,
70+
use_to_edge_transform_and_lower=True,
71+
)
72+
pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op])
73+
pipeline.run()
74+
75+
76+
@common.parametrize("test_data", LeakyReLU.test_data)
77+
@common.XfailIfNoCorstone320
78+
def test_leaky_relu_u85_BI(test_data):
79+
data, slope = test_data
80+
pipeline = EthosU85PipelineBI[input_t1](
81+
LeakyReLU(slope),
82+
data,
83+
[],
84+
run_on_fvp=True,
85+
use_to_edge_transform_and_lower=True,
86+
)
87+
pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op])
88+
pipeline.run()

0 commit comments

Comments
 (0)