Skip to content

Commit a0ca486

Browse files
authored
Arm backend: Update node visitors to support TOSA 1.0 (#10390)
### Summary Updates node visitors to support TOSA 1.0 specification. ### Test plan Tested through public and internal CI. Signed-off-by: Per Åstrand <[email protected]>
1 parent 450d008 commit a0ca486

12 files changed

+602
-66
lines changed

backends/arm/operator_support/to_copy_support.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def is_node_tosa_supported(
7777
) -> bool:
7878
assert node.target in self.targets
7979

80-
assert tosa_spec.support_integer()
8180
supported_dtypes = (
8281
self.ALL_SUPPORTED_TYPES
8382
if tosa_spec.support_float()

backends/arm/operators/op_amax.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
from typing import List
5+
from typing import Any, List
66

7-
import tosa_tools.v0_80.serializer.tosa_serializer as ts
87
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
98
from executorch.backends.arm.operators.node_visitor import (
109
NodeVisitor,
@@ -15,19 +14,22 @@
1514

1615

1716
@register_node_visitor
18-
class MaxVisitor(NodeVisitor):
17+
class MaxVisitor_0_80(NodeVisitor):
1918
target = "aten.amax.default"
2019

20+
tosa_specs = NodeVisitor.tosa_specs_0_80
21+
2122
def __init__(self, *args):
2223
super().__init__(*args)
2324

2425
def define_node(
2526
self,
2627
node: Node,
27-
tosa_graph: ts.TosaSerializer,
28+
tosa_graph: Any,
2829
inputs: List[TosaArg],
2930
output: TosaArg,
3031
) -> None:
32+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3133

3234
input = inputs[0]
3335
dim = inputs[1].number
@@ -49,3 +51,42 @@ def define_node(
4951
tosa_graph.addOperator(
5052
ts.TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr
5153
)
54+
55+
56+
@register_node_visitor
57+
class MaxVisitor(NodeVisitor):
58+
target = "aten.amax.default"
59+
60+
tosa_specs = NodeVisitor.tosa_specs_1_00
61+
62+
def __init__(self, *args):
63+
super().__init__(*args)
64+
65+
def define_node(
66+
self,
67+
node: Node,
68+
tosa_graph: Any,
69+
inputs: List[TosaArg],
70+
output: TosaArg,
71+
) -> None:
72+
import serializer.tosa_serializer as ts
73+
74+
input = inputs[0]
75+
dim = inputs[1].number
76+
77+
if dim < 0:
78+
tensor = get_first_fake_tensor(node)
79+
rank = len(tensor.size())
80+
dim = rank + dim
81+
82+
keep_dims = inputs[2].number
83+
if not keep_dims:
84+
raise RuntimeError(
85+
"TOSA only supports keepdims == True; Did you run the convert_minmax pass?"
86+
)
87+
88+
attr = ts.TosaSerializerAttribute()
89+
attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=1)
90+
tosa_graph.addOperator(
91+
ts.TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr
92+
)

backends/arm/operators/op_amin.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
from typing import List
5+
from typing import Any, List
66

7-
import tosa_tools.v0_80.serializer.tosa_serializer as ts
87
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
98
from executorch.backends.arm.operators.node_visitor import (
109
NodeVisitor,
@@ -15,19 +14,22 @@
1514

1615

1716
@register_node_visitor
18-
class MinVisitor(NodeVisitor):
17+
class MinVisitor_0_80(NodeVisitor):
1918
target = "aten.amin.default"
2019

20+
tosa_specs = NodeVisitor.tosa_specs_0_80
21+
2122
def __init__(self, *args):
2223
super().__init__(*args)
2324

2425
def define_node(
2526
self,
2627
node: Node,
27-
tosa_graph: ts.TosaSerializer,
28+
tosa_graph: Any,
2829
inputs: List[TosaArg],
2930
output: TosaArg,
3031
) -> None:
32+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3133

3234
input = inputs[0]
3335
dim = inputs[1].number
@@ -49,3 +51,42 @@ def define_node(
4951
tosa_graph.addOperator(
5052
ts.TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr
5153
)
54+
55+
56+
@register_node_visitor
57+
class MinVisitor(NodeVisitor):
58+
target = "aten.amin.default"
59+
60+
tosa_specs = NodeVisitor.tosa_specs_1_00
61+
62+
def __init__(self, *args):
63+
super().__init__(*args)
64+
65+
def define_node(
66+
self,
67+
node: Node,
68+
tosa_graph: Any,
69+
inputs: List[TosaArg],
70+
output: TosaArg,
71+
) -> None:
72+
import serializer.tosa_serializer as ts
73+
74+
input = inputs[0]
75+
dim = inputs[1].number
76+
77+
if dim < 0:
78+
tensor = get_first_fake_tensor(node)
79+
rank = len(tensor.size())
80+
dim = rank + dim
81+
82+
keep_dims = inputs[2].number
83+
if not keep_dims:
84+
raise RuntimeError(
85+
"TOSA only supports keepdims == True; Did you run the convert_minmax pass?"
86+
)
87+
88+
attr = ts.TosaSerializerAttribute()
89+
attr.ReduceMinAttribute(axis=input.dim_order.index(dim), nan_mode=1)
90+
tosa_graph.addOperator(
91+
ts.TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr
92+
)

backends/arm/operators/op_clamp.py

Lines changed: 123 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
from typing import Any, List, Tuple
1010

11+
import numpy as np
1112
import torch
1213

13-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1414
from executorch.backends.arm.operators.node_visitor import (
1515
NodeVisitor,
1616
register_node_visitor,
@@ -34,14 +34,16 @@ def __init__(self, *args):
3434

3535
def _create_clamp_node(
3636
self,
37-
tosa_graph: ts.TosaSerializer,
37+
tosa_graph: Any,
3838
input_name: str,
3939
output_name: str,
4040
min_int: int,
4141
max_int: int,
4242
min_fp32: float,
4343
max_fp32: float,
4444
) -> None:
45+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
46+
4547
attr = ts.TosaSerializerAttribute()
4648
attr.ClampAttribute(
4749
tosa_graph.builder,
@@ -81,7 +83,7 @@ def cast_type(value: Any) -> int | float:
8183
def define_node(
8284
self,
8385
node: Node,
84-
tosa_graph: ts.TosaSerializer,
86+
tosa_graph: Any,
8587
inputs: List[TosaArg],
8688
output: TosaArg,
8789
) -> None:
@@ -122,10 +124,12 @@ def __init__(self, *args):
122124
def define_node(
123125
self,
124126
node: Node,
125-
tosa_graph: ts.TosaSerializer,
127+
tosa_graph: Any,
126128
inputs: List[TosaArg],
127129
output: TosaArg,
128130
) -> None:
131+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
132+
129133
if len(node.all_input_nodes) != 1:
130134
raise ValueError(
131135
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
@@ -150,3 +154,118 @@ def define_node(
150154
min_fp32,
151155
max_fp32,
152156
)
157+
158+
159+
@register_node_visitor
160+
class ClampVisitor_INT(NodeVisitor):
161+
target = "aten.clamp.default"
162+
163+
tosa_specs = [
164+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
165+
]
166+
167+
def __init__(self, *args):
168+
super().__init__(*args)
169+
170+
def _get_min_max_arguments(
171+
self, node: Node, dtype_min: int | float, dtype_max: int | float
172+
) -> Tuple[int | float, int | float]:
173+
174+
def cast_type(value: Any) -> int | float:
175+
if isinstance(value, int):
176+
return value
177+
else:
178+
# Attempt to cast to float
179+
return float(value)
180+
181+
if len(node.args) != 2 and len(node.args) != 3:
182+
raise ValueError(f"Expected len(node.args) to be 2 or 3, got {node.args}")
183+
184+
min_arg = dtype_min
185+
max_arg = dtype_max
186+
187+
if node.args[1] is not None:
188+
min_arg = cast_type(node.args[1])
189+
190+
if len(node.args) > 2:
191+
if node.args[2] is not None:
192+
max_arg = cast_type(node.args[2])
193+
194+
return min_arg, max_arg
195+
196+
def define_node(
197+
self,
198+
node: Node,
199+
tosa_graph: Any,
200+
inputs: List[TosaArg],
201+
output: TosaArg,
202+
) -> None:
203+
import serializer.tosa_serializer as ts # type: ignore
204+
205+
if len(node.all_input_nodes) != 1:
206+
raise ValueError(
207+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
208+
)
209+
210+
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
211+
min_int8, max_int8 = self._get_min_max_arguments(
212+
node,
213+
torch.iinfo(torch.int8).min,
214+
torch.iinfo(torch.int8).max,
215+
)
216+
217+
attr = ts.TosaSerializerAttribute()
218+
attr.ClampAttribute(
219+
tosa_graph.builder,
220+
np.int8(min_int8).tobytes(),
221+
np.int8(max_int8).tobytes(),
222+
nan_mode=1,
223+
)
224+
225+
tosa_graph.addOperator(
226+
ts.TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr
227+
)
228+
229+
230+
@register_node_visitor
231+
class ClampVisitor_FP(ClampVisitor_INT):
232+
# inheriting 'target' from INT class
233+
234+
tosa_specs = [
235+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
236+
]
237+
238+
def __init__(self, *args):
239+
super().__init__(*args)
240+
241+
def define_node(
242+
self,
243+
node: Node,
244+
tosa_graph: Any,
245+
inputs: List[TosaArg],
246+
output: TosaArg,
247+
) -> None:
248+
import serializer.tosa_serializer as ts # type: ignore
249+
250+
if len(node.all_input_nodes) != 1:
251+
raise ValueError(
252+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
253+
)
254+
255+
min_fp32, max_fp32 = self._get_min_max_arguments(
256+
node,
257+
torch.finfo(torch.float32).min,
258+
torch.finfo(torch.float32).max,
259+
)
260+
261+
attr = ts.TosaSerializerAttribute()
262+
attr.ClampAttribute(
263+
tosa_graph.builder,
264+
np.float32(min_fp32).tobytes(),
265+
np.float32(max_fp32).tobytes(),
266+
nan_mode=1,
267+
)
268+
269+
tosa_graph.addOperator(
270+
ts.TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr
271+
)

0 commit comments

Comments
 (0)