Skip to content

Commit 3ca6cbb

Browse files
SaoirseARMper
andauthored
Arm backend: Update Rescale and affected nodes to support TOSA 1.0 (#10656)
### Summary Updates to rescale to support TOSA 1.0 specification and updates to nodevisitors affected. ### Test plan Tested through public and internal CI. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 --------- Signed-off-by: Per Åstrand <[email protected]> Co-authored-by: Per Åstrand <[email protected]>
1 parent d7030aa commit 3ca6cbb

18 files changed

+1811
-172
lines changed

backends/arm/operator_support/convolution_support.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
register_tosa_support_check,
1212
SupportedTOSAOperatorCheck,
1313
)
14-
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
14+
from executorch.backends.arm.tosa_specification import (
15+
Tosa_0_80,
16+
Tosa_1_00,
17+
TosaSpecification,
18+
)
1519
from executorch.exir.dialects._ops import ops as exir_ops
1620

1721

@@ -43,6 +47,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
4347

4448
# Hardware specific constraints
4549
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
50+
# TODO remove this once TOSA 1.0 support for u55 is added.
51+
if isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions:
52+
return False
4653
return True
4754
else:
4855
return self._is_node_supported_u55(node)

backends/arm/operators/op_abs.py

+129-5
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import executorch.backends.arm.tosa_quant_utils as tqutils
1010
import executorch.backends.arm.tosa_utils as tutils
1111

12-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1312
from executorch.backends.arm.operators.node_visitor import (
1413
NodeVisitor,
1514
register_node_visitor,
@@ -33,10 +32,13 @@ def __init__(self, *args):
3332
def define_node(
3433
self,
3534
node: Node,
36-
tosa_graph: ts.TosaSerializer,
35+
tosa_graph: Any,
3736
inputs: List[TosaArg],
3837
output: TosaArg,
3938
) -> None:
39+
40+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
41+
4042
# Specification (0.80) states that input and output types
4143
# should all be the same
4244
if not (inputs[0].dtype == output.dtype):
@@ -53,7 +55,7 @@ def define_node(
5355
if inputs[0].dtype == ts.DType.INT8:
5456
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
5557
tosa_graph, inputs, node
56-
)
58+
) # type: ignore[possibly-undefined]
5759
else:
5860
# input[0].dtype == ts.DType.INT32
5961
# Non quantized input, natively support by TOSA.abs
@@ -96,10 +98,13 @@ def __init__(self, *args):
9698
def define_node(
9799
self,
98100
node: Node,
99-
tosa_graph: ts.TosaSerializer,
101+
tosa_graph: Any,
100102
inputs: List[TosaArg],
101103
output: TosaArg,
102104
) -> None:
105+
106+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
107+
103108
# Specification (0.80) states that input and output types
104109
# should all be the same
105110
if not (inputs[0].dtype == output.dtype):
@@ -129,3 +134,122 @@ def define_node(
129134
[output.name],
130135
None,
131136
)
137+
138+
139+
@register_node_visitor
140+
class AbsVisitor_INT(NodeVisitor):
141+
target = "aten.abs.default"
142+
143+
tosa_specs = [
144+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
145+
]
146+
147+
def __init__(self, *args):
148+
super().__init__(*args)
149+
150+
def define_node(
151+
self,
152+
node: Node,
153+
tosa_graph: Any,
154+
inputs: List[TosaArg],
155+
output: TosaArg,
156+
) -> None:
157+
158+
import serializer.tosa_serializer as ts # type: ignore
159+
160+
# Specification (1.0) states that input and output types
161+
# should all be the same
162+
if not (inputs[0].dtype == output.dtype):
163+
raise ValueError(
164+
"All inputs and outputs need same dtype."
165+
f"Got {inputs[0].dtype=}, {output.dtype=}"
166+
)
167+
# Handle int8 (quantized) and int32
168+
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
169+
raise ValueError(
170+
"All inputs need to be INT8 or INT32." f"Got {inputs[0].dtype=}"
171+
)
172+
173+
scale_back = 1.0
174+
if inputs[0].dtype == ts.DType.INT8:
175+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
176+
tosa_graph, inputs, node, self.tosa_specs
177+
) # type: ignore[possibly-undefined]
178+
else:
179+
# input[0].dtype == ts.DType.INT32
180+
# Non quantized input, natively support by TOSA.abs
181+
rescaled_inputs = inputs
182+
183+
if output.dtype == ts.DType.INT8:
184+
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
185+
abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
186+
else:
187+
# output.dtype == ts.DType.INT32
188+
abs_output = output
189+
190+
# Do the INT32 Abs
191+
tosa_graph.addOperator(
192+
ts.TosaOp.Op().ABS,
193+
[
194+
rescaled_inputs[0].name,
195+
],
196+
[abs_output.name],
197+
None,
198+
)
199+
200+
if output.dtype == ts.DType.INT8:
201+
# Scale output back to 8 bit
202+
# pyre-ignore
203+
tqutils.insert_rescale_op_to_int8(
204+
tosa_graph, abs_output, scale_back, node, self.tosa_specs
205+
) # type: ignore[possibly-undefined]
206+
207+
208+
@register_node_visitor
209+
class AbsVisitor_FP(AbsVisitor_INT):
210+
# inheriting 'target' from BI class
211+
212+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
213+
214+
def __init__(self, *args):
215+
super().__init__(*args)
216+
217+
def define_node(
218+
self,
219+
node: Node,
220+
tosa_graph: Any,
221+
inputs: List[TosaArg],
222+
output: TosaArg,
223+
) -> None:
224+
225+
import serializer.tosa_serializer as ts # type: ignore
226+
227+
# Specification (1.0) states that input and output types
228+
# should all be the same
229+
if not (inputs[0].dtype == output.dtype):
230+
raise ValueError(
231+
"All inputs and output need same dtype."
232+
f"Got {inputs[0].dtype=}, {output.dtype=}"
233+
)
234+
235+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
236+
# Call the inherited define_node for handling integers
237+
super().define_node(node, tosa_graph, inputs, output)
238+
else:
239+
# FP32 Abs lowering
240+
241+
if not (inputs[0].dtype == ts.DType.FP32):
242+
raise ValueError(
243+
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
244+
)
245+
246+
if not (output.dtype == ts.DType.FP32):
247+
raise ValueError("All outputs need to be FP32." f"Got {output.dtype=}")
248+
249+
# MI lowering
250+
tosa_graph.addOperator(
251+
ts.TosaOp.Op().ABS,
252+
[inputs[0].name],
253+
[output.name],
254+
None,
255+
)

backends/arm/operators/op_add.py

+133-7
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import Any, List
99

1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111
import executorch.backends.arm.tosa_utils as tutils
1212

13-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1413
from executorch.backends.arm.operators.node_visitor import (
1514
NodeVisitor,
1615
register_node_visitor,
@@ -34,10 +33,13 @@ def __init__(self, *args):
3433
def define_node(
3534
self,
3635
node: Node,
37-
tosa_graph: ts.TosaSerializer,
36+
tosa_graph: Any,
3837
inputs: List[TosaArg],
3938
output: TosaArg,
4039
) -> None:
40+
41+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
42+
4143
# Specification (0.80) states that input and output types
4244
# should all be the same
4345
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -58,7 +60,7 @@ def define_node(
5860
if len(inputs[0].shape) > len(inputs[1].shape)
5961
else inputs[1].dim_order
6062
)
61-
63+
scale_back = 1.0
6264
if inputs[0].dtype == ts.DType.INT8:
6365
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
6466
tosa_graph, inputs, node
@@ -90,7 +92,9 @@ def define_node(
9092
if output.dtype == ts.DType.INT8:
9193
# Scale output back to 8 bit
9294
# pyre-ignore
93-
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node) # type: ignore[possibly-undefined]
95+
tqutils.insert_rescale_op_to_int8(
96+
tosa_graph, add_output, scale_back, node
97+
) # type: ignore[possibly-undefined]
9498

9599

96100
@register_node_visitor
@@ -107,10 +111,13 @@ def __init__(self, *args):
107111
def define_node(
108112
self,
109113
node: Node,
110-
tosa_graph: ts.TosaSerializer,
114+
tosa_graph: Any,
111115
inputs: List[TosaArg],
112116
output: TosaArg,
113117
) -> None:
118+
119+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
120+
114121
# Specification (0.80) states that input and output types
115122
# should all be the same
116123
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -130,7 +137,7 @@ def define_node(
130137
f"Expected IO data type to be FP32, got {inputs[0].dtype}"
131138
)
132139

133-
input1, input2 = tutils.reshape_for_broadcast(tosa_graph, inputs)
140+
input1, input2 = inputs
134141

135142
# MI lowering
136143
tosa_graph.addOperator(
@@ -139,3 +146,122 @@ def define_node(
139146
[output.name],
140147
None,
141148
)
149+
150+
151+
@register_node_visitor
152+
class AddVisitor_INT(NodeVisitor):
153+
target = "aten.add.Tensor"
154+
155+
tosa_specs = [
156+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
157+
]
158+
159+
def __init__(self, *args):
160+
super().__init__(*args)
161+
162+
def define_node(
163+
self,
164+
node: Node,
165+
tosa_graph: Any,
166+
inputs: List[TosaArg],
167+
output: TosaArg,
168+
) -> None:
169+
170+
import serializer.tosa_serializer as ts # type: ignore
171+
172+
# Specification (1.0) states that input and output types
173+
# should all be the same
174+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
175+
raise TypeError(
176+
f"All IO needs to have the same data type, got input 1: "
177+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
178+
f"{output.dtype}"
179+
)
180+
# Handle int8 (quantized) and int32
181+
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
182+
if inputs[0].dtype not in supported_dtypes:
183+
raise TypeError(
184+
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
185+
)
186+
scale_back = 1.0
187+
if inputs[0].dtype == ts.DType.INT8:
188+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
189+
tosa_graph, inputs, node, self.tosa_specs
190+
)
191+
else:
192+
# input[0].dtype == ts.DType.INT32
193+
# Non quantized input, natively support by TOSA.ADD
194+
rescaled_inputs = inputs
195+
196+
if output.dtype == ts.DType.INT8:
197+
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
198+
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
199+
else:
200+
# output.dtype == ts.DType.INT32
201+
add_output = output
202+
203+
input1, input2 = rescaled_inputs
204+
205+
# Do the INT32 Add
206+
tosa_graph.addOperator(
207+
ts.TosaOp.Op().ADD,
208+
[input1.name, input2.name],
209+
[add_output.name],
210+
None,
211+
)
212+
213+
if output.dtype == ts.DType.INT8:
214+
# Scale output back to 8 bit
215+
# pyre-ignore
216+
tqutils.insert_rescale_op_to_int8(
217+
tosa_graph, add_output, scale_back, node, self.tosa_specs
218+
) # type: ignore[possibly-undefined]
219+
220+
221+
@register_node_visitor
222+
class AddVisitor_FP(AddVisitor_INT):
223+
# inheriting 'target' from INT class
224+
225+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
226+
227+
def __init__(self, *args):
228+
super().__init__(*args)
229+
230+
def define_node(
231+
self,
232+
node: Node,
233+
tosa_graph: Any,
234+
inputs: List[TosaArg],
235+
output: TosaArg,
236+
) -> None:
237+
238+
import serializer.tosa_serializer as ts # type: ignore
239+
240+
# Specification (1.0) states that input and output types
241+
# should all be the same
242+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
243+
raise TypeError(
244+
f"All IO needs to have the same data type, got input 1: "
245+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
246+
f"{output.dtype}"
247+
)
248+
249+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
250+
# Call the inherited define_node for handling integers
251+
super().define_node(node, tosa_graph, inputs, output)
252+
else:
253+
# FP32 Add lowering
254+
if inputs[0].dtype != ts.DType.FP32:
255+
raise TypeError(
256+
f"Expected IO data type to be FP32, got {inputs[0].dtype}"
257+
)
258+
259+
input1, input2 = inputs
260+
261+
# FP lowering
262+
tosa_graph.addOperator(
263+
ts.TosaOp.Op().ADD,
264+
[input1.name, input2.name],
265+
[output.name],
266+
None,
267+
)

0 commit comments

Comments
 (0)