Skip to content

Commit 0efdb91

Browse files
authored
Arm backend: Add ERF operator (#9836)
- Register ERF node visitor for MI case - BI handled by table op pass - Add tests Signed-off-by: Madeleine Dunn <[email protected]>
1 parent a5967fd commit 0efdb91

File tree

6 files changed

+111
-0
lines changed

6 files changed

+111
-0
lines changed

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class TableOps:
4141
# Targets that follow a straigtforward one-to-one mapping to their table op
4242
unary_table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
4343
exir_ops.edge.aten.ceil.default: torch.ceil,
44+
exir_ops.edge.aten.erf.default: torch.erf,
4445
exir_ops.edge.aten.exp.default: torch.exp,
4546
exir_ops.edge.aten.floor.default: torch.floor,
4647
exir_ops.edge.aten.log.default: torch.log,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def is_node_supported(
166166
exir_ops.edge.aten.div.Tensor,
167167
exir_ops.edge.aten.eq.Tensor,
168168
exir_ops.edge.aten.eq.Scalar,
169+
exir_ops.edge.aten.erf.default,
169170
exir_ops.edge.aten.exp.default,
170171
exir_ops.edge.aten.log.default,
171172
exir_ops.edge.aten.linear.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
op_constant_pad_nd,
2020
op_conv2d,
2121
op_eq,
22+
op_erf,
2223
op_exp,
2324
op_full,
2425
op_ge,

backends/arm/operators/op_erf.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
# pyre-unsafe
6+
from typing import List
7+
8+
import serializer.tosa_serializer as ts # type: ignore
9+
import torch.fx
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_specification import TosaSpecification
16+
from serializer.tosa_serializer import TosaOp
17+
18+
19+
@register_node_visitor
20+
class ERFVisitor_080_MI(NodeVisitor):
21+
target = "aten.erf.default"
22+
23+
# BI case handled by op_table
24+
tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")]
25+
26+
def __init__(self, *args):
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
tosa_graph: ts.TosaSerializer,
33+
inputs: List[TosaArg],
34+
output: TosaArg,
35+
) -> None:
36+
if not (inputs[0].dtype == output.dtype):
37+
raise ValueError(
38+
"All inputs and output need same dtype."
39+
f"Got {inputs[0].dtype=}, {output.dtype=}"
40+
)
41+
if not (inputs[0].dtype == ts.DType.FP32):
42+
raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}")
43+
# MI lowering
44+
tosa_graph.addOperator(TosaOp.Op().ERF, [inputs[0].name], [output.name])

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def _match_pattern(
164164
_one_to_one = [
165165
torch.ops.aten.abs.default,
166166
torch.ops.aten.ceil.default,
167+
torch.ops.aten.erf.default,
167168
torch.ops.aten.exp.default,
168169
torch.ops.aten.floor.default,
169170
torch.ops.aten.log.default,

backends/arm/test/ops/test_erf.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
EthosU55PipelineBI,
12+
EthosU85PipelineBI,
13+
TosaPipelineBI,
14+
TosaPipelineMI,
15+
)
16+
17+
aten_op = "torch.ops.aten.erf.default"
18+
exir_op = "executorch_exir_dialects_edge__ops_aten_erf_default"
19+
input_t1 = Tuple[torch.Tensor] # Input x
20+
21+
22+
class Erf(torch.nn.Module):
23+
def forward(self, x: torch.Tensor):
24+
return torch.erf(x)
25+
26+
test_data: dict[str, input_t1] = {
27+
"zeros": (torch.zeros(1, 10, 10, 10),),
28+
"ones": (torch.ones(10, 10, 10),),
29+
"rand": ((torch.rand(10, 10) - 0.5),),
30+
"randn_pos": ((torch.randn(1, 4, 4, 4) + 10),),
31+
"randn_neg": ((torch.randn(1, 4, 4, 4) - 10),),
32+
"ramp": (torch.arange(-16, 16, 0.2),),
33+
}
34+
35+
36+
@common.parametrize("test_data", Erf.test_data)
37+
def test_erf_tosa_MI(test_data: input_t1):
38+
pipeline = TosaPipelineMI[input_t1](Erf(), test_data, aten_op, exir_op)
39+
pipeline.run()
40+
41+
42+
@common.parametrize("test_data", Erf.test_data)
43+
def test_erf_tosa_BI(test_data: input_t1):
44+
pipeline = TosaPipelineBI[input_t1](Erf(), test_data, aten_op, exir_op)
45+
pipeline.run()
46+
47+
48+
@common.parametrize("test_data", Erf.test_data)
49+
@common.XfailIfNoCorstone300
50+
def test_erf_u55_BI(test_data: input_t1):
51+
pipeline = EthosU55PipelineBI[input_t1](
52+
Erf(), test_data, aten_op, exir_op, run_on_fvp=True
53+
)
54+
pipeline.run()
55+
56+
57+
@common.parametrize("test_data", Erf.test_data)
58+
@common.XfailIfNoCorstone320
59+
def test_erf_u85_BI(test_data: input_t1):
60+
pipeline = EthosU85PipelineBI[input_t1](
61+
Erf(), test_data, aten_op, exir_op, run_on_fvp=True
62+
)
63+
pipeline.run()

0 commit comments

Comments
 (0)