diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index 3cac7514fff..c5381d3d891 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -807,6 +807,72 @@ def remove_branched( user.replace_all_uses_with(node.args[0]) +class RemoveCatFromSliceCopyPass(ExportPass): + def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None: + slice_copy_nodes = [ + node + for node in graph_module.graph.nodes + if node.target == exir_ops.edge.aten.slice_copy.Tensor + ] + for slice_copy_node in slice_copy_nodes: + slice_dim, start_idx, end_idx, step = 0, 0, float("inf"), 1 + input_node, *other_args = slice_copy_node.args + if len(other_args) >= 1: + slice_dim = other_args[0] + if len(other_args) >= 2: + start_idx = other_args[1] + if len(other_args) >= 3: + end_idx = other_args[2] + if len(other_args) >= 4: + step = other_args[3] + if step != 1: + continue + slice_copy_dtype = slice_copy_node.meta["val"].dtype + if input_node.target != exir_ops.edge.aten.cat.default: + continue + cat_dtype = input_node.meta["val"].dtype + if slice_copy_dtype != cat_dtype: + continue + cat_dim = input_node.args[1:] + if len(cat_dim) == 0: + cat_dim = 0 + if cat_dim != slice_dim: + continue + cat_output_shape = input_node.meta["val"].shape + start_idx = ( + cat_output_shape[cat_dim] + start_idx if start_idx < 0 else start_idx + ) + end_idx = ( + cat_output_shape[cat_dim] + if end_idx > cat_output_shape[cat_dim] + else end_idx + ) + base_idx = 0 + cat_input_to_keep = None + for cat_input_node in input_node.args[0]: + cat_input_dtype = cat_input_node.meta["val"].dtype + if slice_copy_dtype != cat_input_dtype: + continue + cat_input_shape = cat_input_node.meta["val"].shape + + # check if the slice range overlaps with the cat range + if ( + base_idx <= start_idx + and end_idx <= list(cat_input_shape)[cat_dim] + base_idx + ): + cat_input_to_keep = cat_input_node + break + base_idx += list(cat_input_shape)[cat_dim] + if cat_input_to_keep is not None: + slice_copy_node.replace_input_with(input_node, cat_input_to_keep) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self._remove_unused_cat(graph_module) + graph_module.recompile() + graph_module.graph.eliminate_dead_code() + return super().call(graph_module) + + # The following class consolidates functions to remove ops that are redundant # in Jarvis. Currently, each function in this class iterates over each node of # the graph module once. In future, we could consolidate them into a monolithic diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 0c802f9cbf5..42f4b87bdcb 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -22,6 +22,7 @@ from executorch.backends.cadence.aot.remove_ops import ( RemoveAliasCopyOpPass, RemoveBranchedQuantDequant, + RemoveCatFromSliceCopyPass, RemoveCloneOpPass, RemoveContiguousOpPass, RemoveDetachCopyPass, @@ -741,3 +742,54 @@ def forward(self, x): }, ) ) + + def test_remove_cat_from_slice_copy_all_removal(self) -> None: + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + x1 = torch.cat((x, y), 0) # (2, 4) + return torch.slice_copy(x1, dim=0, start=0, end=1) + + inputs = tuple(torch.randn(2, 4) for _ in range(2)) + graph_module = export_to_edge(M(), inputs).exported_program().graph_module + p = RemoveCatFromSliceCopyPass() + graph_module = cast(PassResult, p(graph_module)).graph_module + + # Ensure both cat nodes were removed + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0) + + def test_remove_cat_from_slice_copy_no_removal(self) -> None: + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + x1 = torch.cat((x, y), 0) # (2, 4) + return torch.slice_copy(x1, dim=0, start=0, end=3) + + inputs = tuple(torch.randn(2, 4) for _ in range(2)) + graph_module = export_to_edge(M(), inputs).exported_program().graph_module + p = RemoveCatFromSliceCopyPass() + graph_module = cast(PassResult, p(graph_module)).graph_module + + # Ensure both cat nodes were removed + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1) + + def test_remove_cat_from_slice_copy_zero_range(self) -> None: + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + x1 = torch.cat((x, y), 0) # (2, 4) + return torch.slice_copy(x1, dim=0, start=0, end=0) + + inputs = tuple(torch.randn(2, 4) for _ in range(2)) + graph_module = export_to_edge(M(), inputs).exported_program().graph_module + p = RemoveCatFromSliceCopyPass() + graph_module = cast(PassResult, p(graph_module)).graph_module + + # Ensure both cat nodes were removed + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)