Skip to content

Commit 1322e5e

Browse files
committed
Update on "[ET-VK] Minor build graph change to improve model load time and memory."
A minor change in GraphBuilder to avoid creating a temp vector and reserve memory while building operator. Differential Revision: [D73864959](https://our.internmc.facebook.com/intern/diff/D73864959/) [ghstack-poisoned]
2 parents 06f5f25 + 4ed5218 commit 1322e5e

File tree

23 files changed

+426
-236
lines changed

23 files changed

+426
-236
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
)
6060

6161
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
62+
from executorch.backends.transforms.decompose_sdpa import (
63+
DecomposeScaledDotProductAttention,
64+
)
6265
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
6366
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
6467
from executorch.exir import ExportedProgram
@@ -194,6 +197,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
194197
)
195198

196199
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
200+
self.add_pass(DecomposeScaledDotProductAttention())
197201
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
198202
self.add_pass(ScalarsToAttributePass())
199203
self.add_pass(DecomposeLayerNormPass())

backends/arm/_passes/decompose_softmax_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
from executorch.exir.pass_base import ExportPass
99

1010
# For BI case
11-
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
11+
torch_softmax = (
12+
torch.ops.aten.softmax.int,
13+
torch.ops.aten._safe_softmax.default,
14+
torch.ops.aten.log_softmax.int,
15+
)
1216
# For MI case
1317
edge_softmax = (
1418
exir_ops.edge.aten._softmax.default,

backends/arm/test/models/test_conformer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def test_conformer_tosa_BI(self):
8383
)
8484
)
8585

86-
@unittest.expectedFailure # TODO(MLETORCH-635)
8786
def test_conformer_u55_BI(self):
8887
tester = (
8988
ArmTester(
@@ -97,13 +96,20 @@ def test_conformer_u55_BI(self):
9796
.to_executorch()
9897
.serialize()
9998
)
99+
100100
if conftest.is_option_enabled("corstone_fvp"):
101-
tester.run_method_and_compare_outputs(
102-
qtol=1.0,
103-
rtol=1.0,
104-
atol=5.0,
105-
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
106-
)
101+
try:
102+
tester.run_method_and_compare_outputs(
103+
qtol=1.0,
104+
rtol=1.0,
105+
atol=5.0,
106+
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
107+
)
108+
self.fail(
109+
"TODO(MLETORCH-635): Expected failure under FVP option, but test passed."
110+
)
111+
except Exception:
112+
pass
107113

108114
@unittest.expectedFailure # TODO(MLETORCH-635)
109115
def test_conformer_u85_BI(self):

backends/arm/test/ops/test_sdpa.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Tuple
8+
9+
import torch
10+
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
TosaPipelineBI,
13+
TosaPipelineMI,
14+
)
15+
16+
17+
class SDPA(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, query, key, value):
22+
return torch.nn.functional.scaled_dot_product_attention(
23+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
24+
)
25+
26+
27+
input_t = Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
28+
29+
30+
def test_sdpa_MI():
31+
test_input = tuple(torch.randn(1, 3, 197, 64) for x in range(3))
32+
pipeline = TosaPipelineMI[input_t](SDPA(), test_input, [], [])
33+
pipeline.pop_stage("check_count.exir")
34+
pipeline.run()
35+
36+
37+
def test_sdpa_BI():
38+
test_input = tuple(torch.randn(1, 3, 197, 64) for x in range(3))
39+
pipeline = TosaPipelineBI[input_t](SDPA(), test_input, [], [])
40+
pipeline.pop_stage("check.quant_nodes")
41+
pipeline.pop_stage("check_count.exir")
42+
pipeline.pop_stage(
43+
"run_method_and_compare_outputs"
44+
) # TODO: reference is not quantized
45+
pipeline.run()

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ python_unittest(
347347
":compiler",
348348
"//caffe2:torch",
349349
"//executorch/backends/cadence/aot:compiler",
350+
"//executorch/backends/cadence/aot:graph_builder",
350351
"//executorch/backends/cadence/aot:ops_registrations",
351352
"//executorch/backends/cadence/aot:pass_utils",
352353
"//executorch/backends/cadence/aot:remove_ops",

backends/cadence/aot/compiler.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def quantize_pt2(
151151
quantizer: Optional[CadenceQuantizer] = None,
152152
calibration_data: Optional[list[tuple[object, ...]]] = None,
153153
dump_graphs: bool = False,
154-
) -> torch.fx.GraphModule:
154+
) -> ExportedProgram:
155155
"""
156156
Trace, prepare, convert and fuse the model using the given quantizer.
157157
If calibration data is provided, it will be used to calibrate the model. If
@@ -178,7 +178,9 @@ def quantize_pt2(
178178
logging.info("Graph after quantization and fusion:")
179179
logging.info(fused_gm.graph.print_tabular())
180180

181-
return fused_gm
181+
program = torch.export.export(fused_gm, inputs, strict=True)
182+
183+
return program
182184

183185

184186
# Export the model and lower it to an ExportedProgram (in aten IR)
@@ -260,21 +262,43 @@ def quantize_and_export_to_edge(
260262
dump_graphs: bool = False,
261263
constant_methods: Optional[dict[str, object]] = None,
262264
) -> EdgeProgramManager:
265+
"""
266+
Trace, quantize and lower a model/inputs pair to edge IR.
267+
"""
263268
quantized_model = quantize_pt2(
264269
model,
265270
inputs,
266271
quantizer=quantizer,
267272
dump_graphs=dump_graphs,
268273
)
269274

270-
return export_to_edge(
275+
return lower_ep_to_edge(
271276
quantized_model,
272-
inputs,
273277
dump_graphs=dump_graphs,
274278
constant_methods=constant_methods,
275279
)
276280

277281

282+
def lower_ep_to_cadence(
283+
program: ExportedProgram,
284+
dump_graphs: bool = False,
285+
opt_level: int = 1,
286+
) -> EdgeProgramManager:
287+
"""
288+
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
289+
"""
290+
edge_prog_manager = lower_ep_to_edge(program, dump_graphs=dump_graphs)
291+
cadence_passes = get_cadence_passes(opt_level)
292+
293+
# Run a couple required passes for quant/dequant ops
294+
cadence_prog_manager = edge_prog_manager.transform(
295+
cast(
296+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
297+
)
298+
)
299+
return cadence_prog_manager
300+
301+
278302
def export_to_cadence(
279303
model: torch.nn.Module,
280304
inputs: tuple[object, ...],
@@ -299,11 +323,14 @@ def quantize_and_export_to_cadence(
299323
dump_graphs: bool = False,
300324
opt_level: int = 1,
301325
) -> EdgeProgramManager:
326+
"""
327+
Trace, quantize, lower a model/inputs pair to edge IR and apply frontend
328+
optimization passes.
329+
"""
302330
quantized_model = quantize_pt2(model, inputs)
303331

304-
return export_to_cadence(
332+
return lower_ep_to_cadence(
305333
quantized_model,
306-
inputs,
307334
opt_level=opt_level,
308335
dump_graphs=dump_graphs,
309336
)

0 commit comments

Comments
 (0)