Skip to content

Commit c210472

Browse files
pytorchbotlucylq
andauthored
Add pass to tag external constants for delegates (#10422)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #10328 by @lucylq ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/63/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/63/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/63/orig @diff-train-skip-merge Co-authored-by: lucylq <[email protected]>
1 parent 954f2cb commit c210472

File tree

5 files changed

+100
-3
lines changed

5 files changed

+100
-3
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import ctypes
88
import hashlib
9+
import logging
910

1011
from typing import cast, Dict, List, Optional, Tuple
1112

@@ -592,8 +593,16 @@ def get_serialized_buffer_index(
592593
xnn_graph.constant_data.append(
593594
ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)
594595
)
596+
597+
external_tag = tensor.meta.get("delegate_constant_tag", None)
598+
logging.info(
599+
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
600+
)
595601
self._named_data_store.add_named_data(
596-
named_key, bytes(array), alignment=CONSTANT_TENSOR_ALIGNMENT
602+
named_key,
603+
bytes(array),
604+
alignment=CONSTANT_TENSOR_ALIGNMENT,
605+
external_tag=external_tag,
597606
)
598607

599608
return buffer_idx

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,9 @@ const uint8_t* getConstantDataPtr(
204204
if (!buffer.ok()) {
205205
ET_LOG(
206206
Error,
207-
"Failed to get constant data for key %s",
208-
data_name.c_str());
207+
"Failed to get constant data for key %s from named_data_map. Error code: %u",
208+
data_name.c_str(),
209+
static_cast<uint32_t>(buffer.error()));
209210
return nullptr;
210211
}
211212
const uint8_t* data_ptr =

exir/passes/external_constants_pass.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,25 @@
66

77
# pyre-strict
88

9+
from typing import Callable, Optional
10+
911
import torch
1012
from executorch.exir.pass_base import PassResult
1113
from executorch.exir.tensor import TensorSpec
14+
15+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
1216
from torch.export.exported_program import ExportedProgram, OutputKind
1317
from torch.fx import GraphModule
1418

1519

20+
def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
21+
return (
22+
is_param(exp_prog, node)
23+
or is_buffer(exp_prog, node)
24+
or is_lifted_tensor_constant(exp_prog, node)
25+
)
26+
27+
1628
def external_constants_pass(
1729
gm: GraphModule,
1830
) -> PassResult:
@@ -74,3 +86,33 @@ def external_mutable_weights_pass(
7486
node.meta["constant_tag"] = "_default_external_constant"
7587
mutated = True
7688
return PassResult(gm, mutated)
89+
90+
91+
def delegate_external_constants_pass(
92+
gm: GraphModule,
93+
ep: ExportedProgram,
94+
gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None,
95+
) -> PassResult:
96+
"""
97+
Tag external constants before to_backend.
98+
99+
Note: this pass must be run after run_decompositions(), as tags on
100+
constants are removed then.
101+
102+
Args:
103+
gm: GraphModule to tag.
104+
ep: ExportedProgram, to distinguish if a node is a constant.
105+
gen_tag_fn: node -> str callable indicating the tag for the node.
106+
Returns:
107+
PassResult: The resulting gm, and if it was mutated or not.
108+
"""
109+
mutated = False
110+
for module in gm.modules():
111+
if not isinstance(module, torch.fx.GraphModule):
112+
continue
113+
for node in module.graph.nodes:
114+
if node.op == "placeholder" and is_param_node(ep, node):
115+
if gen_tag_fn is not None:
116+
node.meta["delegate_constant_tag"] = gen_tag_fn(node)
117+
mutated = True
118+
return PassResult(gm, mutated)

test/models/export_delegated_program.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import inspect
1111
import os
1212
import sys
13+
14+
from functools import partial
1315
from typing import Dict, final, Optional, Sequence, Type
1416

1517
import executorch.exir as exir
@@ -21,6 +23,9 @@
2123
from executorch.exir.backend.test.backend_with_compiler_demo import (
2224
BackendWithCompilerDemo,
2325
)
26+
from executorch.exir.passes.external_constants_pass import (
27+
delegate_external_constants_pass,
28+
)
2429
from executorch.exir.program import ExecutorchProgramManager
2530
from torch import nn
2631
from torch.export import export
@@ -129,6 +134,7 @@ def export_module_to_program(
129134
constant_tensor_alignment: Optional[int] = None,
130135
delegate_alignment: Optional[int] = None,
131136
method_name: str = "forward",
137+
external_constants: bool = False,
132138
) -> ExecutorchProgramManager:
133139
eager_module = module_class().eval()
134140
inputs = ()
@@ -158,8 +164,17 @@ def forward(self, *args, **kwargs):
158164
XnnpackPartitioner,
159165
)
160166

167+
transform_passes = []
168+
if external_constants:
169+
partial_function = partial(
170+
delegate_external_constants_pass,
171+
ep=exported_program,
172+
gen_tag_fn=lambda x: module_class.__name__,
173+
)
174+
transform_passes.append(partial_function)
161175
executorch_program = to_edge_transform_and_lower(
162176
exported_program,
177+
transform_passes=transform_passes,
163178
compile_config=edge_config,
164179
partitioner=[XnnpackPartitioner()],
165180
).to_executorch(config=et_config)
@@ -221,6 +236,11 @@ def main() -> None:
221236
parser.add_argument(
222237
"--delegate_alignment", type=int, default=None, help="Delegate alignment."
223238
)
239+
parser.add_argument(
240+
"--external_constants",
241+
action="store_true",
242+
help="Export the model with all constants saved to an external file.",
243+
)
224244
parser.add_argument(
225245
"--outdir",
226246
type=str,
@@ -247,16 +267,22 @@ def main() -> None:
247267
suffix += "-nosegments"
248268
if args.delegate_alignment is not None:
249269
suffix += f"-da{args.delegate_alignment}"
270+
if args.external_constants:
271+
suffix += "-e"
250272
outfile = os.path.join(args.outdir, f"{module_name}{suffix}.pte")
251273
executorch_program = export_module_to_program(
252274
module_class,
253275
backend_id=args.backend_id,
254276
extract_delegate_segments=not args.inline_delegate_segments,
255277
delegate_alignment=args.delegate_alignment,
278+
external_constants=args.external_constants,
256279
)
257280
with open(outfile, "wb") as fp:
258281
fp.write(executorch_program.buffer)
259282
print(f"Exported {module_name} and wrote program data to {outfile}")
283+
if args.external_constants:
284+
print(f"Saving external constants to {module_name}.ptd")
285+
executorch_program.write_tensor_data_to_file(args.outdir)
260286

261287

262288
if __name__ == "__main__":

test/models/targets.bzl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,22 @@ def define_common_targets():
206206
],
207207
env = {"PYTORCH_DISABLE_JUSTKNOBS": "1",},
208208
)
209+
210+
runtime.genrule(
211+
name = "exported_xnnpack_program_and_data",
212+
cmd = "$(exe :export_delegated_program)" +
213+
" --modules ModuleLinear" +
214+
" --backend_id XnnpackBackend" +
215+
" --external_constants" +
216+
" --outdir $OUT",
217+
218+
outs = {
219+
"ModuleLinear-e.pte": ["ModuleLinear-e.pte"],
220+
"ModuleLinear.ptd": ["ModuleLinear.ptd"],
221+
},
222+
default_outs = ["."],
223+
visibility = [
224+
"//executorch/runtime/executor/test/...",
225+
"//executorch/test/...",
226+
],
227+
)

0 commit comments

Comments
 (0)