diff --git a/exir/backend/test/test_partitioner.py b/exir/backend/test/test_partitioner.py index 74974d16231..d492c291f34 100644 --- a/exir/backend/test/test_partitioner.py +++ b/exir/backend/test/test_partitioner.py @@ -26,7 +26,7 @@ from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import ( ExecutorBackend, ) -from executorch.exir.backend.utils import get_delegates +from executorch.exir.backend.utils import get_delegates, tag_constant_data from executorch.exir.dialects._ops import ops as exir_ops @@ -523,3 +523,85 @@ def partition( "constant data node (b_const) is tagged with (tag0) but has user (aten_sub_tensor) which has tag (None)", str(error.exception), ) + + def test_not_delegate_mutable_buffers(self) -> None: + """ + A test case to check the mutated buffer is not delegated. We'll need to add a test case + to consider when the delegate can consume the mutable buffer. + """ + + class MutableStateModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("my_state", torch.zeros(1)) + + def forward(self, x): + y = x + self.my_state + self.my_state.add_(1) + return y + + edge = exir.to_edge( + torch.export.export( + MutableStateModule(), + (torch.zeros(1),), + ) + ) + self.assertGreater( + len(edge.exported_program().graph_signature.buffers_to_mutate), + 0, + "The test case should at leaset one mutable buffer", + ) + + class PartitionerTagData(Partitioner): + def __init__(self): + super().__init__() + self.delegation_spec = DelegationSpec( + ExecutorBackend.__name__, + [CompileSpec(key, value) for key, value in self.spec.items()], + ) + + def partition( + self, edge_exported_program: ExportedProgram + ) -> PartitionResult: + partition_tags = {} + for node in edge_exported_program.graph.nodes: + if node.op == "call_function" and node.target in [ + exir_ops.edge.aten.add.Tensor + ]: + delegation_tag = "tag0" + node.meta["delegation_tag"] = delegation_tag + partition_tags[delegation_tag] = self.delegation_spec + tag_constant_data(edge_exported_program) + return PartitionResult( + tagged_exported_program=edge_exported_program, + partition_tags=partition_tags, + ) + + # Check the edge program inital buffers_to_mutate + mutate_op = "aten_add_tensor_1" + self.assertEqual( + edge.exported_program().graph_signature.buffers_to_mutate[mutate_op], + "my_state", + ) + edge = edge.to_backend(PartitionerTagData()) + # After to_backend, add is delegated and is no longer in buffers_to_mutate. + self.assertNotIn( + mutate_op, + edge.exported_program().graph_signature.buffers_to_mutate, + ) + + mutate_op = "getitem_1" + # Ensure the mutated buffer is not delegated, and the new mutate node is getitem (from call_delegate) + self.assertEqual( + edge.exported_program().graph_signature.buffers_to_mutate[mutate_op], + "my_state", + ) + # Check the copy_ node is inserted + edge = edge.to_executorch() + copy_node = [ + node + for node in edge.exported_program().graph.nodes + if node.op == "call_function" + and node.target == torch.ops.aten.copy_.default + ] + self.assertEqual(len(copy_node), 1) diff --git a/exir/backend/utils.py b/exir/backend/utils.py index f4c1c28f8bd..b299ba4be8a 100644 --- a/exir/backend/utils.py +++ b/exir/backend/utils.py @@ -508,6 +508,20 @@ def tag_constant_data(edge_program: ExportedProgram) -> None: subgraph. Throw error when const/param/buffers is used across different partitions. That is the underlying data will be owned by multiple delegates. """ + mutated_buffer = set() + for node in edge_program.graph.nodes: + if node.op == "placeholder" and ( + is_param(edge_program, node) + or is_buffer(edge_program, node) + or is_lifted_tensor_constant(edge_program, node) + ): + for node_user in node.users: + if node_user.name in edge_program.graph_signature.buffers_to_mutate: + logging.info( + "The buffer node is a mutated buffer node, which is not constant." + ) + mutated_buffer.add(node) + for node in edge_program.graph.nodes: # go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition if node.op == "placeholder" and ( @@ -515,20 +529,21 @@ def tag_constant_data(edge_program: ExportedProgram) -> None: or is_buffer(edge_program, node) or is_lifted_tensor_constant(edge_program, node) ): - user_tags = set() - for user in node.users: - user_tag = user.meta.get("delegation_tag", None) - if user_tag is not None: - user_tags.add(user_tag) - if len(user_tags) > 1: - logging.info( - f"The data node is used across multiple partitions, including {user_tags}. " - "If the data is too large and it's not preferred to copy, please tag the " - "constant node like node.['no_copy'] = True and they won't be copied." - ) - # tag the data node with the same tag as the last user - if len(user_tags) > 0: - node.meta["delegation_tag"] = user_tags.pop() + if node not in mutated_buffer: + user_tags = set() + for user in node.users: + user_tag = user.meta.get("delegation_tag", None) + if user_tag is not None: + user_tags.add(user_tag) + if len(user_tags) > 1: + logging.info( + f"The data node is used across multiple partitions, including {user_tags}. " + "If the data is too large and it's not preferred to copy, please tag the " + "constant node like node.['no_copy'] = True and they won't be copied." + ) + # tag the data node with the same tag as the last user + if len(user_tags) > 0: + node.meta["delegation_tag"] = user_tags.pop() # TODO - style: use templated types