diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 50781eade4d..8470184d808 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -6,6 +6,7 @@ import ctypes import hashlib +import logging from typing import cast, Dict, List, Optional, Tuple @@ -592,8 +593,16 @@ def get_serialized_buffer_index( xnn_graph.constant_data.append( ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key) ) + + external_tag = tensor.meta.get("delegate_constant_tag", None) + logging.info( + f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store" + ) self._named_data_store.add_named_data( - named_key, bytes(array), alignment=CONSTANT_TENSOR_ALIGNMENT + named_key, + bytes(array), + alignment=CONSTANT_TENSOR_ALIGNMENT, + external_tag=external_tag, ) return buffer_idx diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 0b187d05df0..1d8c6e3fdbc 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -204,8 +204,9 @@ const uint8_t* getConstantDataPtr( if (!buffer.ok()) { ET_LOG( Error, - "Failed to get constant data for key %s", - data_name.c_str()); + "Failed to get constant data for key %s from named_data_map. Error code: %u", + data_name.c_str(), + static_cast(buffer.error())); return nullptr; } const uint8_t* data_ptr = diff --git a/exir/passes/external_constants_pass.py b/exir/passes/external_constants_pass.py index e024fdcbcd2..d9bba4635ff 100644 --- a/exir/passes/external_constants_pass.py +++ b/exir/passes/external_constants_pass.py @@ -6,13 +6,25 @@ # pyre-strict +from typing import Callable, Optional + import torch from executorch.exir.pass_base import PassResult from executorch.exir.tensor import TensorSpec + +from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param from torch.export.exported_program import ExportedProgram, OutputKind from torch.fx import GraphModule +def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: + return ( + is_param(exp_prog, node) + or is_buffer(exp_prog, node) + or is_lifted_tensor_constant(exp_prog, node) + ) + + def external_constants_pass( gm: GraphModule, ) -> PassResult: @@ -74,3 +86,33 @@ def external_mutable_weights_pass( node.meta["constant_tag"] = "_default_external_constant" mutated = True return PassResult(gm, mutated) + + +def delegate_external_constants_pass( + gm: GraphModule, + ep: ExportedProgram, + gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None, +) -> PassResult: + """ + Tag external constants before to_backend. + + Note: this pass must be run after run_decompositions(), as tags on + constants are removed then. + + Args: + gm: GraphModule to tag. + ep: ExportedProgram, to distinguish if a node is a constant. + gen_tag_fn: node -> str callable indicating the tag for the node. + Returns: + PassResult: The resulting gm, and if it was mutated or not. + """ + mutated = False + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.op == "placeholder" and is_param_node(ep, node): + if gen_tag_fn is not None: + node.meta["delegate_constant_tag"] = gen_tag_fn(node) + mutated = True + return PassResult(gm, mutated) diff --git a/test/models/export_delegated_program.py b/test/models/export_delegated_program.py index 25818a721a6..f7b2f354373 100644 --- a/test/models/export_delegated_program.py +++ b/test/models/export_delegated_program.py @@ -10,6 +10,8 @@ import inspect import os import sys + +from functools import partial from typing import Dict, final, Optional, Sequence, Type import executorch.exir as exir @@ -21,6 +23,9 @@ from executorch.exir.backend.test.backend_with_compiler_demo import ( BackendWithCompilerDemo, ) +from executorch.exir.passes.external_constants_pass import ( + delegate_external_constants_pass, +) from executorch.exir.program import ExecutorchProgramManager from torch import nn from torch.export import export @@ -129,6 +134,7 @@ def export_module_to_program( constant_tensor_alignment: Optional[int] = None, delegate_alignment: Optional[int] = None, method_name: str = "forward", + external_constants: bool = False, ) -> ExecutorchProgramManager: eager_module = module_class().eval() inputs = () @@ -158,8 +164,17 @@ def forward(self, *args, **kwargs): XnnpackPartitioner, ) + transform_passes = [] + if external_constants: + partial_function = partial( + delegate_external_constants_pass, + ep=exported_program, + gen_tag_fn=lambda x: module_class.__name__, + ) + transform_passes.append(partial_function) executorch_program = to_edge_transform_and_lower( exported_program, + transform_passes=transform_passes, compile_config=edge_config, partitioner=[XnnpackPartitioner()], ).to_executorch(config=et_config) @@ -221,6 +236,11 @@ def main() -> None: parser.add_argument( "--delegate_alignment", type=int, default=None, help="Delegate alignment." ) + parser.add_argument( + "--external_constants", + action="store_true", + help="Export the model with all constants saved to an external file.", + ) parser.add_argument( "--outdir", type=str, @@ -247,16 +267,22 @@ def main() -> None: suffix += "-nosegments" if args.delegate_alignment is not None: suffix += f"-da{args.delegate_alignment}" + if args.external_constants: + suffix += "-e" outfile = os.path.join(args.outdir, f"{module_name}{suffix}.pte") executorch_program = export_module_to_program( module_class, backend_id=args.backend_id, extract_delegate_segments=not args.inline_delegate_segments, delegate_alignment=args.delegate_alignment, + external_constants=args.external_constants, ) with open(outfile, "wb") as fp: fp.write(executorch_program.buffer) print(f"Exported {module_name} and wrote program data to {outfile}") + if args.external_constants: + print(f"Saving external constants to {module_name}.ptd") + executorch_program.write_tensor_data_to_file(args.outdir) if __name__ == "__main__": diff --git a/test/models/targets.bzl b/test/models/targets.bzl index ab5fcc8a51d..0e3b881b706 100644 --- a/test/models/targets.bzl +++ b/test/models/targets.bzl @@ -206,3 +206,22 @@ def define_common_targets(): ], env = {"PYTORCH_DISABLE_JUSTKNOBS": "1",}, ) + + runtime.genrule( + name = "exported_xnnpack_program_and_data", + cmd = "$(exe :export_delegated_program)" + + " --modules ModuleLinear" + + " --backend_id XnnpackBackend" + + " --external_constants" + + " --outdir $OUT", + + outs = { + "ModuleLinear-e.pte": ["ModuleLinear-e.pte"], + "ModuleLinear.ptd": ["ModuleLinear.ptd"], + }, + default_outs = ["."], + visibility = [ + "//executorch/runtime/executor/test/...", + "//executorch/test/...", + ], + )