Skip to content

Commit a5b6391

Browse files
authored
Merge branch 'main' into large-whispers
2 parents a2da0d5 + 3dbc15b commit a5b6391

File tree

96 files changed

+2178
-350
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+2178
-350
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
4361747abfc55e40e929396ed986efe775d745f9
1+
d03e90c2cd9048e6d9a75285c0355f033cd016fc
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
556fc09a9f67f24ca5591ec049c5d0c347c5f62a
1+
b31bad1b8f1331bf43d47f46602cf6141db56844

backends/arm/CMakeLists.txt

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,44 @@ endif()
4848

4949
# VGF backend builds
5050
if(EXECUTORCH_BUILD_VGF)
51-
52-
# include libvgf
53-
set(LIBVGF_PATH
54-
"${EXECUTORCH_ROOT}/examples/arm/ethos-u-scratch/ml-sdk-for-vulkan-manifest/sw/vgf-lib/"
55-
)
56-
5751
set(VULKAN_THIRD_PARTY_PATH ${EXECUTORCH_ROOT}/backends/vulkan/third-party)
5852
set(VULKAN_HEADERS_PATH ${VULKAN_THIRD_PARTY_PATH}/Vulkan-Headers/include)
5953
set(VOLK_HEADERS_PATH ${VULKAN_THIRD_PARTY_PATH}/volk)
6054

61-
set(LIBVGF_STATIC "${LIBVGF_PATH}/build/src/libvgf.a")
55+
if(APPLE
56+
OR CMAKE_SYSTEM_PROCESSOR MATCHES "^(arm64|aarch64)$"
57+
OR EXISTS
58+
"${EXECUTORCH_ROOT}/examples/arm/ethos-u-scratch/ml-sdk-for-vulkan-manifest/"
59+
)
60+
message(STATUS "libvgf sourced from local scratch tree")
61+
62+
# Legacy layout: libvgf sourced from local scratch tree
63+
set(LIBVGF_PATH
64+
"${EXECUTORCH_ROOT}/examples/arm/ethos-u-scratch/ml-sdk-for-vulkan-manifest/sw/vgf-lib/"
65+
)
66+
set(LIBVGF_STATIC "${LIBVGF_PATH}/build/src/libvgf.a")
67+
else()
68+
message(STATUS "libvgf installed from pip package")
69+
70+
set(Python3_FIND_VIRTUALENV FIRST)
71+
if(EXECUTORCH_ROOT AND EXISTS "${EXECUTORCH_ROOT}/env")
72+
set(Python3_EXECUTABLE "${EXECUTORCH_ROOT}/env/bin/python3")
73+
endif()
74+
75+
find_package(Python3 REQUIRED COMPONENTS Interpreter)
76+
77+
# Prefer arch-specific site-packages if present, else pure
78+
set(_vgf_site_arch "${Python3_SITEARCH}/vgf_lib/binaries")
79+
set(_vgf_site_pure "${Python3_SITELIB}/vgf_lib/binaries")
80+
if(EXISTS "${_vgf_site_arch}")
81+
set(LIBVGF_PATH "${_vgf_site_arch}")
82+
else()
83+
set(LIBVGF_PATH "${_vgf_site_pure}")
84+
endif()
85+
86+
set(LIBVGF_STATIC "${LIBVGF_PATH}/lib/libvgf.a")
87+
endif()
88+
6289
set(LIBVGF_INCLUDE "${LIBVGF_PATH}/include/")
6390

6491
add_library(vgf STATIC IMPORTED)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,12 @@ def _transform(self, graph_module: GraphModule):
159159
def _tosa_pipeline(
160160
self, exported_program: ExportedProgram, graph_module: GraphModule
161161
) -> GraphModule:
162+
# Preprocessing passes
163+
162164
self.add_pass(AnnotateOutputDimOrderPass())
165+
166+
# Node transformation passes (pre q/dq folding)
167+
163168
self.add_pass(FuseQuantizedActivationPass())
164169
self.add_pass(RemoveGetItemPass())
165170
self.add_pass(ConvertToClampPass())
@@ -174,8 +179,19 @@ def _tosa_pipeline(
174179
self.add_pass(ConvertELUParamsPass())
175180
self.add_pass(ConvertSplitToSlicePass())
176181
self.add_pass(QuantizeOperatorArguments())
182+
183+
# Fold Q/DQ nodes, insert INT8/INT32 rescales.
184+
177185
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
178186
self.add_pass(FuseDuplicateUsersPass())
187+
# TODO: DecomposeLinearPass should run after InsertRescaleInt32Pass or
188+
# before FoldAndAnnotateQParamsPass but is unable to at the moment.
189+
# Ticket: MLETORCH-1539
190+
self.add_pass(DecomposeLinearPass())
191+
self.add_pass(InsertRescaleInt32Pass())
192+
193+
# Node transformation passes (post q/dq folding)
194+
179195
self.add_pass(DecomposeExpm1Pass())
180196
self.add_pass(DecomposeLogitPass())
181197
self.add_pass(DecomposeMaskedFill())
@@ -196,57 +212,67 @@ def _tosa_pipeline(
196212
self.add_pass(DecomposeSignPass())
197213
self.add_pass(DecomposeFloorDividePass())
198214
self.add_pass(DecomposeDivTensorModePass())
215+
self.add_pass(DecomposeGeluPass())
216+
self.add_pass(DecomposeAddSubAlphaPass())
217+
self.add_pass(DecomposeGroupedConv())
218+
self.add_pass(Conv1dUnsqueezePass())
219+
220+
# Scalars -> tensors, match tensor dtypes and ranks.
221+
199222
self.add_pass(ReplaceScalarWithTensorByProfilePass())
223+
self.add_pass(ConvertFullLikeToFullPass())
224+
self.add_pass(MatchArgDtypePass())
225+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
226+
# TODO: Move DecomposeNotEqualPass to before or after this block of
227+
# passes. Ticket: MLETORCH-1540
228+
self.add_pass(DecomposeNotEqualPass())
229+
self.add_pass(MatchArgRanksPass(exported_program))
230+
self.add_pass(FuseConstantArgsPass(exported_program))
231+
232+
# Node transformation passes (post scalar-removal)
233+
200234
self.add_pass(DecomposeRemainderPass())
201235
self.add_pass(DecomposeDivTensorModePass())
202236
self.add_pass(DecomposeEmbeddingPass())
203237
self.add_pass(FuseBatchnorm2DPass(exported_program))
204238
self.add_pass(ConvertMmToBmmPass())
205239
self.add_pass(DecomposeGluPass())
206-
self.add_pass(DecomposeLinearPass())
207240
self.add_pass(DecomposeLeakyReLUPass())
208-
self.add_pass(DecomposeNotEqualPass())
209241
self.add_pass(DecomposeDivPass())
210-
self.add_pass(DecomposeAddSubAlphaPass())
211242
self.add_pass(DecomposeSoftmaxPass())
212-
self.add_pass(DecomposeGeluPass())
213-
self.add_pass(ConvertFullLikeToFullPass())
214243
self.add_pass(ConvertMinMaxPass())
215244
self.add_pass(ConvertAnyDefaultDimDimsPass())
216-
self.add_pass(MatchArgDtypePass())
217-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
218-
self.add_pass(MatchArgRanksPass(exported_program))
219245
self.add_pass(DecomposeAdaptiveAvgPool2dPass())
220246
self.add_pass(DecomposeAvgPool2d())
221247
self.add_pass(
222248
DecorateFp32toInt32CastingPass()
223249
) # Require that no new fp32->int32 is introduced after this pass
224250
self.add_pass(ComputeConstantOpsAOT(exported_program))
225-
226-
self.add_pass(DecomposeGroupedConv())
227251
self.add_pass(ConvertExpandCopyToRepeatPass())
228252
self.add_pass(UnsqueezeBeforeRepeatPass())
229253
self.add_pass(DecomposeCumsumPass(exported_program))
230-
self.add_pass(Conv1dUnsqueezePass())
231254
self.add_pass(DecomposeMaxPool2DPass())
232255
self.add_pass(SizeAdjustInputPass())
233256
self.add_pass(DecomposeSelectPass())
234257
self.add_pass(ConvertSqueezesToViewPass())
235258
self.add_pass(CastToInt32Pass())
236259
self.add_pass(BroadcastArgsPass())
237-
238260
self.add_pass(ConvertPermuteSingletonToViewPass())
239261
self.add_pass(FuseViewCopyTransform())
240-
self.add_pass(FuseConstantArgsPass(exported_program))
241262
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
242-
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
263+
self.add_pass(DecomposeSumPass())
243264
self.add_pass(InsertTableOpsPass(exported_program))
265+
266+
# Aten -> TOSA transformation passes
267+
244268
self.add_pass(RewriteUpsamplePass())
245269
self.add_pass(RewriteConv2dPass(exported_program))
246270
self.add_pass(RewriteMatmulPass())
271+
272+
# Postprocessing/cleanup passes
273+
274+
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
247275
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
248-
self.add_pass(InsertRescaleInt32Pass())
249-
self.add_pass(DecomposeSumPass())
250276
self.add_pass(ToTosaMemoryFormatPass(exported_program))
251277
self.add_pass(RemoveNoopPass())
252278
self.add_pass(InsertRescalePass())

backends/arm/_passes/arm_pass_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,25 @@
3131
from torch.export.graph_signature import InputKind
3232

3333

34+
def is_submodule_node(node: torch.fx.Node):
35+
if node.op not in ("get_attr", "placeholder"):
36+
return False
37+
try:
38+
node.graph.owning_module.get_submodule(node.target)
39+
except AttributeError:
40+
return False
41+
return True
42+
43+
3444
def is_get_attr_node(node: torch.fx.Node) -> bool:
3545
"""
36-
Returns true if the given node is a get attr node for a tensor of the model
46+
Returns true if the given node is a get attr node for a tensor of the model.
3747
"""
38-
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
48+
return (
49+
isinstance(node, torch.fx.Node)
50+
and node.op == "get_attr"
51+
and not is_submodule_node(node)
52+
)
3953

4054

4155
def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:

backends/arm/_passes/cast_int64_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def _to_int32(self, graph_module: torch.fx.GraphModule):
4141
for node in graph_module.graph.nodes:
4242
if len(node.users) == 0:
4343
continue
44+
if "val" not in node.meta:
45+
continue
4446
fake_tensor = node.meta["val"]
4547
if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
4648
continue

backends/arm/_passes/decompose_linear_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
create_node,
1313
get_first_fake_tensor,
1414
)
15+
from executorch.backends.arm._passes.insert_rescales_pass import InsertRescaleInt32Pass
1516
from executorch.exir.dialects._ops import ops as exir_ops
1617
from executorch.exir.pass_base import ExportPass, PassResult
1718

@@ -26,7 +27,7 @@ class DecomposeLinearPass(ArmPass):
2627
output = view(conv2d)
2728
"""
2829

29-
_passes_required_after: Set[Type[ExportPass]] = set()
30+
_passes_required_after: Set[Type[ExportPass]] = {InsertRescaleInt32Pass}
3031

3132
def call(self, graph_module):
3233
for node in graph_module.graph.nodes:

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(self, exported_program: ExportedProgram) -> None:
5757
exir_ops.edge.aten.lt.Tensor,
5858
exir_ops.edge.aten.le.Tensor,
5959
exir_ops.edge.aten.pow.Tensor_Tensor,
60+
exir_ops.edge.aten.remainder.Tensor,
6061
exir_ops.edge.aten.where.self,
6162
exir_ops.edge.aten.bitwise_and.Tensor,
6263
exir_ops.edge.aten.bitwise_xor.Tensor,

backends/arm/_passes/rewrite_upsample.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
create_node,
1212
get_first_fake_tensor,
1313
)
14+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1415
from executorch.backends.arm.tosa.utils import get_resize_parameters
1516
from executorch.exir.dialects._ops import ops as exir_ops
1617
from executorch.exir.pass_base import ExportPass, PassResult
@@ -52,7 +53,9 @@ def call(self, graph_module):
5253
node.replace_all_uses_with(tosa_resize_node)
5354
graph_module.graph.erase_node(node)
5455
input_dtype = get_first_fake_tensor(x).dtype
55-
if input_dtype == torch.int8 and resize_mode == "bilinear":
56+
if (
57+
input_dtype == torch.int8 or input_dtype == torch.int16
58+
) and resize_mode == "bilinear":
5659
input_size = get_first_fake_tensor(x).shape
5760
input_size_xy = input_size[2:]
5861
output_size = get_first_fake_tensor(node).shape
@@ -71,6 +74,11 @@ def call(self, graph_module):
7174
exir_ops.backend.tosa.RESCALE.default,
7275
)
7376
tosa_resize_node.replace_all_uses_with(rescale_node)
77+
if input_dtype == torch.int16:
78+
tosa_resize_node.meta[TosaSpecialDtype.meta_key()] = (
79+
TosaSpecialDtype.INT48
80+
)
81+
7482
rescale_node.args = (
7583
tosa_resize_node,
7684
output_dtype,

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,8 @@ def remove_dim_order_kwargs(
299299

300300
def call(self, graph_module: torch.fx.GraphModule):
301301
for node in graph_module.graph.nodes:
302+
if "val" not in node.meta:
303+
continue
302304
node_data = get_first_fake_tensor(node).data
303305

304306
self.remove_dim_order_kwargs(graph_module, node)

0 commit comments

Comments
 (0)