Skip to content

Add pass to tag external constants for delegates #10328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import ctypes
import hashlib
import logging

from typing import cast, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -592,8 +593,16 @@ def get_serialized_buffer_index(
xnn_graph.constant_data.append(
ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)
)

external_tag = tensor.meta.get("delegate_constant_tag", None)
logging.info(
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
)
self._named_data_store.add_named_data(
named_key, bytes(array), alignment=CONSTANT_TENSOR_ALIGNMENT
named_key,
bytes(array),
alignment=CONSTANT_TENSOR_ALIGNMENT,
external_tag=external_tag,
)

return buffer_idx
Expand Down
5 changes: 3 additions & 2 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,9 @@ const uint8_t* getConstantDataPtr(
if (!buffer.ok()) {
ET_LOG(
Error,
"Failed to get constant data for key %s",
data_name.c_str());
"Failed to get constant data for key %s from named_data_map. Error code: %u",
data_name.c_str(),
static_cast<uint32_t>(buffer.error()));
return nullptr;
}
const uint8_t* data_ptr =
Expand Down
42 changes: 42 additions & 0 deletions exir/passes/external_constants_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,25 @@

# pyre-strict

from typing import Callable, Optional

import torch
from executorch.exir.pass_base import PassResult
from executorch.exir.tensor import TensorSpec

from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
from torch.export.exported_program import ExportedProgram, OutputKind
from torch.fx import GraphModule


def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
return (
is_param(exp_prog, node)
or is_buffer(exp_prog, node)
or is_lifted_tensor_constant(exp_prog, node)
)


def external_constants_pass(
gm: GraphModule,
) -> PassResult:
Expand Down Expand Up @@ -74,3 +86,33 @@ def external_mutable_weights_pass(
node.meta["constant_tag"] = "_default_external_constant"
mutated = True
return PassResult(gm, mutated)


def delegate_external_constants_pass(
gm: GraphModule,
ep: ExportedProgram,
gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None,
) -> PassResult:
"""
Tag external constants before to_backend.

Note: this pass must be run after run_decompositions(), as tags on
constants are removed then.

Args:
gm: GraphModule to tag.
ep: ExportedProgram, to distinguish if a node is a constant.
gen_tag_fn: node -> str callable indicating the tag for the node.
Returns:
PassResult: The resulting gm, and if it was mutated or not.
"""
mutated = False
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if node.op == "placeholder" and is_param_node(ep, node):
if gen_tag_fn is not None:
node.meta["delegate_constant_tag"] = gen_tag_fn(node)
mutated = True
return PassResult(gm, mutated)
26 changes: 26 additions & 0 deletions test/models/export_delegated_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import inspect
import os
import sys

from functools import partial
from typing import Dict, final, Optional, Sequence, Type

import executorch.exir as exir
Expand All @@ -21,6 +23,9 @@
from executorch.exir.backend.test.backend_with_compiler_demo import (
BackendWithCompilerDemo,
)
from executorch.exir.passes.external_constants_pass import (
delegate_external_constants_pass,
)
from executorch.exir.program import ExecutorchProgramManager
from torch import nn
from torch.export import export
Expand Down Expand Up @@ -129,6 +134,7 @@ def export_module_to_program(
constant_tensor_alignment: Optional[int] = None,
delegate_alignment: Optional[int] = None,
method_name: str = "forward",
external_constants: bool = False,
) -> ExecutorchProgramManager:
eager_module = module_class().eval()
inputs = ()
Expand Down Expand Up @@ -158,8 +164,17 @@ def forward(self, *args, **kwargs):
XnnpackPartitioner,
)

transform_passes = []
if external_constants:
partial_function = partial(
delegate_external_constants_pass,
ep=exported_program,
gen_tag_fn=lambda x: module_class.__name__,
)
transform_passes.append(partial_function)
executorch_program = to_edge_transform_and_lower(
exported_program,
transform_passes=transform_passes,
compile_config=edge_config,
partitioner=[XnnpackPartitioner()],
).to_executorch(config=et_config)
Expand Down Expand Up @@ -221,6 +236,11 @@ def main() -> None:
parser.add_argument(
"--delegate_alignment", type=int, default=None, help="Delegate alignment."
)
parser.add_argument(
"--external_constants",
action="store_true",
help="Export the model with all constants saved to an external file.",
)
parser.add_argument(
"--outdir",
type=str,
Expand All @@ -247,16 +267,22 @@ def main() -> None:
suffix += "-nosegments"
if args.delegate_alignment is not None:
suffix += f"-da{args.delegate_alignment}"
if args.external_constants:
suffix += "-e"
outfile = os.path.join(args.outdir, f"{module_name}{suffix}.pte")
executorch_program = export_module_to_program(
module_class,
backend_id=args.backend_id,
extract_delegate_segments=not args.inline_delegate_segments,
delegate_alignment=args.delegate_alignment,
external_constants=args.external_constants,
)
with open(outfile, "wb") as fp:
fp.write(executorch_program.buffer)
print(f"Exported {module_name} and wrote program data to {outfile}")
if args.external_constants:
print(f"Saving external constants to {module_name}.ptd")
executorch_program.write_tensor_data_to_file(args.outdir)


if __name__ == "__main__":
Expand Down
19 changes: 19 additions & 0 deletions test/models/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,22 @@ def define_common_targets():
],
env = {"PYTORCH_DISABLE_JUSTKNOBS": "1",},
)

runtime.genrule(
name = "exported_xnnpack_program_and_data",
cmd = "$(exe :export_delegated_program)" +
" --modules ModuleLinear" +
" --backend_id XnnpackBackend" +
" --external_constants" +
" --outdir $OUT",

outs = {
"ModuleLinear-e.pte": ["ModuleLinear-e.pte"],
"ModuleLinear.ptd": ["ModuleLinear.ptd"],
},
default_outs = ["."],
visibility = [
"//executorch/runtime/executor/test/...",
"//executorch/test/...",
],
)
Loading