Skip to content

Commit 40d6a4a

Browse files
committed
Update
[ghstack-poisoned]
2 parents acd2079 + da36d8a commit 40d6a4a

File tree

101 files changed

+3344
-613
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

101 files changed

+3344
-613
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ It supports a wide range of models including LLMs (Large Language Models), CV (C
1919
Platform Support:
2020
- Operating Systems:
2121
- iOS
22-
- Mac
22+
- MacOS (ARM64)
2323
- Android
2424
- Linux
2525
- Microcontrollers

backends/arm/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if(NOT EXECUTORCH_ROOT)
1212
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
1313
endif()
1414

15+
add_compile_options("-Wall" "-Werror")
16+
1517
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
1618

1719
set(_common_include_directories ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10)

backends/arm/_passes/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66

77
from . import arm_pass_utils # noqa
8+
from .arm_pass import ArmPass # noqa # usort: skip
9+
from .add_bias_pass import AddBiasPass # noqa
810
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
911
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
10-
from .arm_pass import ArmPass # noqa
1112
from .broadcast_args_pass import BroadcastArgsPass # noqa
13+
from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa
1214
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1315
from .cast_to_int32_pass import CastToInt32Pass # noqa
1416
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
@@ -20,10 +22,12 @@
2022
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
2123
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2224
from .convert_to_clamp import ConvertToClampPass # noqa
25+
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
2326
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2427
from .decompose_div_pass import DecomposeDivPass # noqa
2528
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
2629
from .decompose_gelu_pass import DecomposeGeluPass # noqa
30+
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
2731
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
2832
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
2933
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
@@ -32,6 +36,7 @@
3236
from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa
3337
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
3438
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
39+
from .decompose_round_pass import DecomposeRoundPass # noqa
3540
from .decompose_select import DecomposeSelectPass # noqa
3641
from .decompose_silu_pass import DecomposeSiluPass # noqa
3742
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes import ArmPass
8+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
9+
from executorch.backends.transforms.utils import create_constant_placeholder
10+
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import PassResult
13+
from torch.export.graph_signature import InputKind
14+
15+
16+
class AddBiasPass(ArmPass):
17+
"""TOSA requires convolution nodes to have a bias input.
18+
This pass adds a bias input to convolution nodes that do not have one.
19+
The bias is set to zero.
20+
"""
21+
22+
targeted_ops = (exir_ops.edge.aten.convolution.default,)
23+
24+
def call(self, graph_module):
25+
modified = False
26+
for node in graph_module.graph.nodes:
27+
if node.op != "call_function":
28+
continue
29+
if node.target not in self.targeted_ops:
30+
continue
31+
32+
if len(node.all_input_nodes) < 3:
33+
modified = True
34+
# bias is missing
35+
weight_node = node.all_input_nodes[1]
36+
output_channels = get_first_fake_tensor(weight_node).shape[0]
37+
# add a node containging zeros
38+
# if quantized, use int32, otherwise use float32
39+
if (
40+
"output_qparams" in node.meta
41+
and len(node.meta["output_qparams"]) > 0
42+
):
43+
bias_data = torch.zeros(size=(output_channels,), dtype=torch.int32)
44+
else:
45+
bias_data = torch.zeros(
46+
size=(output_channels,), dtype=torch.float32
47+
)
48+
49+
with graph_module.graph.inserting_after(weight_node):
50+
bias_node = create_constant_placeholder(
51+
self.exported_program,
52+
graph=graph_module.graph,
53+
kind=InputKind.PARAMETER,
54+
data=bias_data,
55+
persistent_buffer=True,
56+
name=f"{node.name}_bias",
57+
)
58+
node.update_arg(2, bias_node)
59+
60+
if modified:
61+
graph_module = super().call(graph_module).graph_module
62+
return PassResult(graph_module, modified)

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,18 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
203203
- 1D/2D tensors
204204
"""
205205
for node in graph_module.graph.nodes:
206-
if node.op != "call_function":
206+
# call_function and placeholder allowed due to
207+
# index.Tensor being able to come in as both
208+
if node.op not in ["call_function", "placeholder"]:
207209
continue
208210

209-
elif node.target == exir_ops.edge.aten.view_copy.default:
211+
elif node.target in (
212+
exir_ops.edge.aten.view_copy.default,
213+
exir_ops.edge.aten.index.Tensor,
214+
):
215+
# For index.Tensor:
216+
# If we want to support 4D indexing tensors this logic
217+
# should be updated.
210218
input_node = node.args[0]
211219
input_shape = input_node.meta["val"].shape
212220
output_shape = node.meta["val"].shape

backends/arm/_passes/arm_pass_manager.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77

88
# pyre-unsafe
99
from executorch.backends.arm._passes import (
10+
AddBiasPass,
1011
AnnotateChannelsLastDimOrder,
1112
AnnotateDecomposedMatmulPass,
1213
BroadcastArgsPass,
14+
CastBoolToInt8Pass,
1315
CastInt64BuffersToInt32Pass,
1416
CastToInt32Pass,
1517
ComputeConstantOpsAOT,
@@ -23,10 +25,12 @@
2325
ConvertSplitToSlicePass,
2426
ConvertSqueezesToViewPass,
2527
ConvertToClampPass,
28+
DecomposeAvgPool2d,
2629
DecomposeCosineSimilarityPass,
2730
DecomposeDivPass,
2831
DecomposeEmbeddingPass,
2932
DecomposeGeluPass,
33+
DecomposeGroupedConv,
3034
DecomposeGroupNormPass,
3135
DecomposeLayerNormPass,
3236
DecomposeLeakyReLUPass,
@@ -35,6 +39,7 @@
3539
DecomposeMaxPool2DPass,
3640
DecomposeMeanDimPass,
3741
DecomposeNotEqualPass,
42+
DecomposeRoundPass,
3843
DecomposeSelectPass,
3944
DecomposeSiluPass,
4045
DecomposeSoftmaxPass,
@@ -63,7 +68,6 @@
6368
UnsqueezeBeforeRepeatPass,
6469
UnsqueezeScalarPlaceholdersPass,
6570
)
66-
6771
from executorch.backends.arm.tosa_specification import (
6872
TosaLoweringContext,
6973
TosaSpecification,
@@ -105,6 +109,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
105109
if self.tosa_spec.is_U55_subset:
106110
self.add_pass(CastToInt32Pass())
107111

112+
self.add_pass(CastBoolToInt8Pass())
108113
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
109114
self.add_pass(AnnotateDecomposedMatmulPass())
110115
self.add_pass(QuantizeOperatorArguments())
@@ -115,8 +120,10 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
115120
if self.tosa_spec.is_U55_subset:
116121
self.add_pass(BroadcastArgsPass())
117122
self.add_pass(DecomposeLinearPass())
123+
self.add_pass(DecomposeAvgPool2d())
118124
self.add_pass(ComputeConstantOpsAOT(exported_program))
119125

126+
self.add_pass(DecomposeGroupedConv())
120127
self.add_pass(RemoveClonePass())
121128
self.add_pass(SizeAdjustConv2DPass())
122129
self.add_pass(ConvertExpandCopyToRepeatPass())
@@ -130,6 +137,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
130137

131138
self.add_pass(FuseViewCopyTransform())
132139
self.add_pass(FuseConstantArgsPass(exported_program))
140+
self.add_pass(AddBiasPass(exported_program))
133141

134142
self.add_pass(InsertTableOpsPass(exported_program))
135143
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
@@ -139,8 +147,10 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
139147
return self._transform(exported_program.graph_module)
140148

141149
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
150+
self.add_pass(DecomposeRoundPass())
142151
self.add_pass(DecomposeSqrtPass())
143152
self.add_pass(ConvertIntPowToMuls())
153+
self.add_pass(CastBoolToInt8Pass())
144154
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
145155
self.add_pass(DecomposeEmbeddingPass())
146156
self.add_pass(FuseQuantizedActivationPass())
@@ -172,8 +182,10 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
172182
self.add_pass(RetraceFoldedDtypesPass())
173183
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
174184
self.add_pass(MatchArgRanksPass(exported_program))
185+
self.add_pass(DecomposeAvgPool2d())
175186
self.add_pass(ComputeConstantOpsAOT(exported_program))
176187

188+
self.add_pass(DecomposeGroupedConv())
177189
self.add_pass(RemoveClonePass())
178190
self.add_pass(SizeAdjustConv2DPass())
179191
self.add_pass(ConvertExpandCopyToRepeatPass())
@@ -187,6 +199,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
187199

188200
self.add_pass(FuseViewCopyTransform())
189201
self.add_pass(FuseConstantArgsPass(exported_program))
202+
self.add_pass(AddBiasPass(exported_program))
190203
self.add_pass(InsertTableOpsPass(exported_program))
191204
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
192205
self.add_pass(AnnotateChannelsLastDimOrder())
@@ -219,6 +232,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
219232
self.add_pass(InsertCastForOpsWithInt64InputPass())
220233
self.add_pass(DecomposeEmbeddingPass())
221234
self.add_pass(DecomposeScaledDotProductAttention())
235+
self.add_pass(DecomposeRoundPass())
236+
self.add_pass(CastBoolToInt8Pass())
222237
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
223238
self.add_pass(ScalarsToAttributePass())
224239
self.add_pass(DecomposeGroupNormPass())
@@ -232,6 +247,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
232247
self.add_pass(DecomposeLinearVectorNormPass())
233248
self.add_pass(DecomposeSqrtPass())
234249
self.add_pass(DecomposeSiluPass())
250+
self.add_pass(DecomposeAvgPool2d())
235251

236252
if self.tosa_spec.is_U55_subset:
237253
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool as input
7+
# If input/output is bool lest add a cast/conversion pass before/after to/from int8.
8+
9+
import torch
10+
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass
13+
14+
15+
class CastBoolToInt8Pass(ExportPass):
16+
"""Casts the input to int8 if it is not already and casts back the output to the original input dtype."""
17+
18+
targeted_ops = {
19+
exir_ops.edge.aten.bitwise_and.Tensor,
20+
exir_ops.edge.aten.bitwise_or.Tensor,
21+
exir_ops.edge.aten.bitwise_xor.Tensor,
22+
}
23+
24+
def call_operator(self, op, args, kwargs, meta):
25+
if op not in self.targeted_ops:
26+
return super().call_operator(op, args, kwargs, meta)
27+
28+
new_args: list = []
29+
did_cast = False
30+
for arg in args:
31+
if arg.data.dtype == torch.bool:
32+
new_args.append(
33+
super().call_operator(
34+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
35+
(arg,),
36+
{"dtype": torch.int8},
37+
meta,
38+
)
39+
)
40+
did_cast = True
41+
else:
42+
new_args.append(arg)
43+
44+
output = super().call_operator(
45+
op,
46+
tuple(new_args),
47+
{},
48+
meta,
49+
)
50+
51+
if did_cast:
52+
output = super().call_operator(
53+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
54+
(output,),
55+
{"dtype": args[0].data.dtype},
56+
meta,
57+
)
58+
return output

0 commit comments

Comments
 (0)