diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index 310e5ea9379..838156498c4 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -204,12 +204,16 @@ def _insert_lowered_submodule( owning_graph_module = call_submodule_node.graph.owning_module # call delegate args should only use user_inputs call_delegate_args = [] - # Preserve input order as user_inputs - for inp_name in submodule_program.graph_signature.user_inputs: - for inp_node in call_submodule_node.all_input_nodes: - if inp_node.name == inp_name: - call_delegate_args.append(inp_node) - break + # names of input_specs to delete + input_specs_to_delete = toplevel_input_specs_to_delete + # Delete owned constants from the call_submodule_node args + for call_sm_input in call_submodule_node.args: + if ( + isinstance(call_sm_input, torch.fx.Node) + and call_sm_input.name in input_specs_to_delete.keys() + ): + continue + call_delegate_args.append(call_sm_input) def generate_debug_handle(ep: ExportedProgram) -> int: """ @@ -324,6 +328,7 @@ def _partition_and_lower_one_graph_module( toplevel_input_specs_to_delete, toplevel_output_specs_to_delete, ) + owning_program._validate() return tagged_graph_module @@ -742,6 +747,7 @@ def to_backend( for method_name in method_to_edge_program.keys(): if method_name in method_to_tagged_exported_program: tagged_exported_program = method_to_tagged_exported_program[method_name] + tagged_exported_program._validate() partitioned_and_lowered_exported_programs[method_name] = ExportedProgram( root=tagged_exported_program.graph_module, graph=tagged_exported_program.graph_module.graph, diff --git a/exir/backend/test/backend_with_preprocess_all_demo.py b/exir/backend/test/backend_with_preprocess_all_demo.py index ae9a8174be5..11941b703a0 100644 --- a/exir/backend/test/backend_with_preprocess_all_demo.py +++ b/exir/backend/test/backend_with_preprocess_all_demo.py @@ -21,10 +21,30 @@ ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.graph_module import get_control_flow_submodules +from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param from torch.export.exported_program import ExportedProgram from torch.fx.passes.operator_support import any_chain, OperatorSupportBase +def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: + return ( + is_param(exp_prog, node) + or is_buffer(exp_prog, node) + or is_lifted_tensor_constant(exp_prog, node) + ) + + +def get_total_num_ops_in_ep(edge_programs, supported_ops): + total_number_of_ops = 0 + for edge_program in edge_programs.values(): + for partitioned_program in edge_program: + for node in partitioned_program.graph.nodes: + if node.op == "call_function": + if node.target in supported_ops: + total_number_of_ops += 1 + return total_number_of_ops + + def _preprocess_multimethod( edge_programs: Dict[str, List[ExportedProgram]], compile_specs: Dict[str, List[List[CompileSpec]]], @@ -37,13 +57,7 @@ def _preprocess_multimethod( in testing for a partitioner which tags different partitions for different backends to be lowered to """ - total_number_of_ops = 0 - for edge_program in edge_programs.values(): - for partitioned_program in edge_program: - for node in partitioned_program.graph.nodes: - if node.op == "call_function": - if node.target in supported_ops: - total_number_of_ops += 1 + total_number_of_ops = get_total_num_ops_in_ep(edge_programs, supported_ops) all_processed_results = {key: [] for key in edge_programs.keys()} for method_name, partitioned_programs in edge_programs.items(): @@ -67,6 +81,8 @@ def _preprocess_multimethod( raise RuntimeError( f"{node.op} {node.target.__name__} is not supported in backend {backend_name}" ) + if is_param_node(partitioned_program, node): + processed_bytes += f"CONST{node.name}:" processed_bytes += "#" for cs in compile_spec_for_partition: @@ -171,14 +187,30 @@ def preprocess_multimethod( class AddSinOperatorSupport(OperatorSupportBase): + def __init__(self, original_program): + self.original_program = original_program + super().__init__() + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: - return node.op == "call_function" and node.target in [ + supported_targets = [ exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.sin.default, ] + if node.op == "call_function" and node.target in supported_targets: + return True + + if node.op == "placeholder" and is_param_node(self.original_program, node): + for user in node.users.keys(): + if user.target in supported_targets: + return True + return False class SubCosOperatorSupport(OperatorSupportBase): + def __init__(self, original_program): + self.original_program = original_program + super().__init__() + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: return node.op == "call_function" and node.target in [ exir_ops.edge.aten.sub.Tensor, @@ -199,11 +231,8 @@ class BackendWithPreprocessAllPartitioner(Partitioner): """ def __init__(self) -> None: - self.add_sin_support = any_chain(AddSinOperatorSupport()) - self.add_sin_backend_id = FirstBackendWithPreprocessAll.__name__ - - self.sub_cos_support = any_chain(SubCosOperatorSupport()) self.sub_cos_backend_id = SecondBackendWithPreprocessAll.__name__ + self.add_sin_backend_id = FirstBackendWithPreprocessAll.__name__ def _partition_graph_module( self, @@ -260,6 +289,8 @@ def _partition_graph_module( return partition_tags, start_idx_for_submodules def partition(self, exported_program: ExportedProgram) -> PartitionResult: + self.add_sin_support = any_chain(AddSinOperatorSupport(exported_program)) + self.sub_cos_support = any_chain(SubCosOperatorSupport(exported_program)) partition_tags, _ = self._partition_graph_module(exported_program.graph_module) return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags diff --git a/exir/backend/test/test_to_backend_multi_method.py b/exir/backend/test/test_to_backend_multi_method.py index d4f8fccb8f2..045de253e0f 100644 --- a/exir/backend/test/test_to_backend_multi_method.py +++ b/exir/backend/test/test_to_backend_multi_method.py @@ -392,6 +392,77 @@ def forward(self, x): } self._test(test_set) + def test_multi_method_to_backend_sequential_delegates(self): + class SequentialBackendModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + # delegate one + x = x - x + y = y - y + z = z - z + # graph break + a = x * y * z + # delegate two uses outputs from delegate one and the + # output from the graph break + b = x + a + b = b + z + a + b = b + y + a + return b + + module = SequentialBackendModule() + example_inputs = (torch.ones(1), torch.ones(1), torch.ones(1)) + seq_edgeir_m = to_edge(torch.export.export(module, example_inputs)) + + test_set = { + "seq_edgeir": ( + seq_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "SecondBackendWithPreprocessAll#3#aten.sub.Tensor:aten.sub.Tensor:aten.sub.Tensor:#sub:b'\\x02';sub:b'\\x02';sub:b'\\x02';", + "FirstBackendWithPreprocessAll#5#aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:#add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';", + ], + ), + } + self._test(test_set) + + def test_multi_method_to_backend_constants(self): + class SequentialBackendModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.zeros(1) + + def forward(self, x, y, z): + # delegate one + x = x - x + y = y - y + z = z - z + # graph break + a = x * y * z * self.const + # delegate two uses outputs from delegate one and the + # output from the graph break + b = x + self.const + a + b = z + a + b + b = y + a + b + return b + + module = SequentialBackendModule() + example_inputs = (torch.ones(1), torch.ones(1), torch.ones(1)) + seq_const_m = to_edge(torch.export.export(module, example_inputs)) + + test_set = { + "seq_const": ( + seq_const_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "SecondBackendWithPreprocessAll#3#aten.sub.Tensor:aten.sub.Tensor:aten.sub.Tensor:#sub:b'\\x02';sub:b'\\x02';sub:b'\\x02';", + "FirstBackendWithPreprocessAll#6#CONSTc_const_copy_0:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:#add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';", + ], + ), + } + self._test(test_set) + def test_multi_method_to_backend_not_found(self): class SinModule(torch.nn.Module): def __init__(self): diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 78b031a238e..6792626d4ac 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -381,7 +381,7 @@ def _fixup_output_node(gm: torch.fx.GraphModule) -> None: def arrange_graph_placeholders( - gm: torch.fx.GraphModule, owning_program: ExportedProgram + gm: torch.fx.GraphModule, owning_program: ExportedProgram, tag ) -> torch.fx.GraphModule: """ Modifies the graph of the given graphmodule with one that contains the same nodes as the original, @@ -411,9 +411,15 @@ def arrange_graph_placeholders( if node.op != "placeholder": continue - if node.name in graph_sign.inputs_to_parameters: + if ( + node.name in graph_sign.inputs_to_parameters + and node.meta.get("delegation_tag", None) == tag + ): param_nodes.append(node) - elif node.name in graph_sign.inputs_to_buffers: + elif ( + node.name in graph_sign.inputs_to_buffers + and node.meta.get("delegation_tag", None) == tag + ): buffer_nodes.append(node) else: input_nodes.append(node) @@ -694,7 +700,7 @@ def create_exported_program_from_submodule( removed from the toplevel ExportedProgram. """ # Arrange the submodule's placeholders in order - submodule = arrange_graph_placeholders(submodule, owning_program) + submodule = arrange_graph_placeholders(submodule, owning_program, tag) # TODO: we probably need to arrange the outputs wrt buffer mutations. @@ -958,5 +964,3 @@ def _unsafe_adjust_original_program( # noqa: C901 if user_idx > idx: user.args = (user.args[0], user_idx - (len(getitem_idxs) - i)) break - - original_program._validate()