Skip to content

Commit 0135b99

Browse files
Merge branch 'main' into sigmoid_flaky_fix
2 parents 5137ab7 + e88aafc commit 0135b99

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2390
-410
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
)
6060

6161
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
62+
from executorch.backends.transforms.decompose_sdpa import (
63+
DecomposeScaledDotProductAttention,
64+
)
6265
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
6366
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
6467
from executorch.exir import ExportedProgram
@@ -194,6 +197,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
194197
)
195198

196199
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
200+
self.add_pass(DecomposeScaledDotProductAttention())
197201
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
198202
self.add_pass(ScalarsToAttributePass())
199203
self.add_pass(DecomposeLayerNormPass())

backends/arm/_passes/decompose_softmax_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
from executorch.exir.pass_base import ExportPass
99

1010
# For BI case
11-
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
11+
torch_softmax = (
12+
torch.ops.aten.softmax.int,
13+
torch.ops.aten._safe_softmax.default,
14+
torch.ops.aten.log_softmax.int,
15+
)
1216
# For MI case
1317
edge_softmax = (
1418
exir_ops.edge.aten._softmax.default,

backends/arm/operator_support/convolution_support.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
register_tosa_support_check,
1212
SupportedTOSAOperatorCheck,
1313
)
14-
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
14+
from executorch.backends.arm.tosa_specification import (
15+
Tosa_0_80,
16+
Tosa_1_00,
17+
TosaSpecification,
18+
)
1519
from executorch.exir.dialects._ops import ops as exir_ops
1620

1721

@@ -43,6 +47,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
4347

4448
# Hardware specific constraints
4549
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
50+
# TODO remove this once TOSA 1.0 support for u55 is added.
51+
if isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions:
52+
return False
4653
return True
4754
else:
4855
return self._is_node_supported_u55(node)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def is_node_supported(
194194
exir_ops.edge.aten.mul.Tensor,
195195
exir_ops.edge.aten.ne.Tensor,
196196
exir_ops.edge.aten.ne.Scalar,
197+
exir_ops.edge.aten.neg.default,
197198
exir_ops.edge.aten.add.Scalar,
198199
exir_ops.edge.aten.sub.Scalar,
199200
exir_ops.edge.aten.mul.Scalar,
@@ -311,6 +312,7 @@ class CheckProperQuantization(OperatorSupportBase):
311312
exir_ops.edge.aten.max_pool2d_with_indices.default,
312313
exir_ops.edge.aten.mm.default,
313314
exir_ops.edge.aten.mul.Tensor,
315+
exir_ops.edge.aten.neg.default,
314316
exir_ops.edge.aten.relu.default,
315317
exir_ops.edge.aten.sub.Tensor,
316318
exir_ops.edge.aten.upsample_bilinear2d.vec,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
op_maximum,
3232
op_minimum,
3333
op_mul,
34+
op_neg,
3435
op_permute,
3536
op_pow,
3637
op_reciprocal,

backends/arm/operators/op_abs.py

Lines changed: 129 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import executorch.backends.arm.tosa_quant_utils as tqutils
1010
import executorch.backends.arm.tosa_utils as tutils
1111

12-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1312
from executorch.backends.arm.operators.node_visitor import (
1413
NodeVisitor,
1514
register_node_visitor,
@@ -33,10 +32,13 @@ def __init__(self, *args):
3332
def define_node(
3433
self,
3534
node: Node,
36-
tosa_graph: ts.TosaSerializer,
35+
tosa_graph: Any,
3736
inputs: List[TosaArg],
3837
output: TosaArg,
3938
) -> None:
39+
40+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
41+
4042
# Specification (0.80) states that input and output types
4143
# should all be the same
4244
if not (inputs[0].dtype == output.dtype):
@@ -53,7 +55,7 @@ def define_node(
5355
if inputs[0].dtype == ts.DType.INT8:
5456
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
5557
tosa_graph, inputs, node
56-
)
58+
) # type: ignore[possibly-undefined]
5759
else:
5860
# input[0].dtype == ts.DType.INT32
5961
# Non quantized input, natively support by TOSA.abs
@@ -96,10 +98,13 @@ def __init__(self, *args):
9698
def define_node(
9799
self,
98100
node: Node,
99-
tosa_graph: ts.TosaSerializer,
101+
tosa_graph: Any,
100102
inputs: List[TosaArg],
101103
output: TosaArg,
102104
) -> None:
105+
106+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
107+
103108
# Specification (0.80) states that input and output types
104109
# should all be the same
105110
if not (inputs[0].dtype == output.dtype):
@@ -129,3 +134,122 @@ def define_node(
129134
[output.name],
130135
None,
131136
)
137+
138+
139+
@register_node_visitor
140+
class AbsVisitor_INT(NodeVisitor):
141+
target = "aten.abs.default"
142+
143+
tosa_specs = [
144+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
145+
]
146+
147+
def __init__(self, *args):
148+
super().__init__(*args)
149+
150+
def define_node(
151+
self,
152+
node: Node,
153+
tosa_graph: Any,
154+
inputs: List[TosaArg],
155+
output: TosaArg,
156+
) -> None:
157+
158+
import serializer.tosa_serializer as ts # type: ignore
159+
160+
# Specification (1.0) states that input and output types
161+
# should all be the same
162+
if not (inputs[0].dtype == output.dtype):
163+
raise ValueError(
164+
"All inputs and outputs need same dtype."
165+
f"Got {inputs[0].dtype=}, {output.dtype=}"
166+
)
167+
# Handle int8 (quantized) and int32
168+
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
169+
raise ValueError(
170+
"All inputs need to be INT8 or INT32." f"Got {inputs[0].dtype=}"
171+
)
172+
173+
scale_back = 1.0
174+
if inputs[0].dtype == ts.DType.INT8:
175+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
176+
tosa_graph, inputs, node, self.tosa_specs
177+
) # type: ignore[possibly-undefined]
178+
else:
179+
# input[0].dtype == ts.DType.INT32
180+
# Non quantized input, natively support by TOSA.abs
181+
rescaled_inputs = inputs
182+
183+
if output.dtype == ts.DType.INT8:
184+
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
185+
abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
186+
else:
187+
# output.dtype == ts.DType.INT32
188+
abs_output = output
189+
190+
# Do the INT32 Abs
191+
tosa_graph.addOperator(
192+
ts.TosaOp.Op().ABS,
193+
[
194+
rescaled_inputs[0].name,
195+
],
196+
[abs_output.name],
197+
None,
198+
)
199+
200+
if output.dtype == ts.DType.INT8:
201+
# Scale output back to 8 bit
202+
# pyre-ignore
203+
tqutils.insert_rescale_op_to_int8(
204+
tosa_graph, abs_output, scale_back, node, self.tosa_specs
205+
) # type: ignore[possibly-undefined]
206+
207+
208+
@register_node_visitor
209+
class AbsVisitor_FP(AbsVisitor_INT):
210+
# inheriting 'target' from BI class
211+
212+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
213+
214+
def __init__(self, *args):
215+
super().__init__(*args)
216+
217+
def define_node(
218+
self,
219+
node: Node,
220+
tosa_graph: Any,
221+
inputs: List[TosaArg],
222+
output: TosaArg,
223+
) -> None:
224+
225+
import serializer.tosa_serializer as ts # type: ignore
226+
227+
# Specification (1.0) states that input and output types
228+
# should all be the same
229+
if not (inputs[0].dtype == output.dtype):
230+
raise ValueError(
231+
"All inputs and output need same dtype."
232+
f"Got {inputs[0].dtype=}, {output.dtype=}"
233+
)
234+
235+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
236+
# Call the inherited define_node for handling integers
237+
super().define_node(node, tosa_graph, inputs, output)
238+
else:
239+
# FP32 Abs lowering
240+
241+
if not (inputs[0].dtype == ts.DType.FP32):
242+
raise ValueError(
243+
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
244+
)
245+
246+
if not (output.dtype == ts.DType.FP32):
247+
raise ValueError("All outputs need to be FP32." f"Got {output.dtype=}")
248+
249+
# MI lowering
250+
tosa_graph.addOperator(
251+
ts.TosaOp.Op().ABS,
252+
[inputs[0].name],
253+
[output.name],
254+
None,
255+
)

0 commit comments

Comments
 (0)