Skip to content

Commit df4f8a0

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
add pass to remove cat from slice pass (#8857)
Summary: only keep the cat before slice iff the slice range overlaps with *both* tensors in cat. TODO: trace to 1+ level for more cat in a china TODO: support >2 tensors in a cat Differential Revision: D70425971
1 parent e673f7c commit df4f8a0

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

backends/cadence/aot/remove_ops.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,63 @@ def remove_branched(
807807
user.replace_all_uses_with(node.args[0])
808808

809809

810+
class RemoveCatFromSliceCopyPass(ExportPass):
811+
def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
812+
slice_copy_nodes = [
813+
node
814+
for node in graph_module.graph.nodes
815+
if node.target == exir_ops.edge.aten.slice_copy.Tensor
816+
]
817+
for slice_copy_node in slice_copy_nodes:
818+
slice_dim, start_idx, end_idx, step = 0, 0, float("inf"), 1
819+
input_node, *other_args = slice_copy_node.args
820+
if len(other_args) >= 1:
821+
slice_dim = other_args[0]
822+
if len(other_args) >= 2:
823+
start_idx = other_args[1]
824+
if len(other_args) >= 3:
825+
end_idx = other_args[2]
826+
if len(other_args) >= 4:
827+
step = other_args[3]
828+
if step != 1:
829+
continue
830+
slice_copy_dtype = slice_copy_node.meta["val"].dtype
831+
if input_node.target != exir_ops.edge.aten.cat.default:
832+
continue
833+
cat_dtype = input_node.meta["val"].dtype
834+
if slice_copy_dtype != cat_dtype:
835+
continue
836+
cat_dim = input_node.args[1:]
837+
if len(cat_dim) == 0:
838+
cat_dim = 0
839+
if cat_dim != slice_dim:
840+
continue
841+
cat_output_shape = input_node.meta["val"].shape
842+
start_idx = cat_output_shape[cat_dim] + start_idx if start_idx < 0 else start_idx
843+
end_idx = cat_output_shape[cat_dim] if end_idx > cat_output_shape[cat_dim] else end_idx
844+
base_idx = 0
845+
cat_input_to_keep = None
846+
for cat_input_node in input_node.args[0]:
847+
cat_input_dtype = cat_input_node.meta["val"].dtype
848+
if slice_copy_dtype != cat_input_dtype:
849+
continue
850+
cat_input_shape = cat_input_node.meta["val"].shape
851+
852+
# check if the slice range overlaps with the cat range
853+
if base_idx <= start_idx and end_idx <= list(cat_input_shape)[cat_dim] + base_idx:
854+
cat_input_to_keep = cat_input_node
855+
break
856+
base_idx += list(cat_input_shape)[cat_dim]
857+
if cat_input_to_keep is not None:
858+
slice_copy_node.replace_input_with(input_node, cat_input_to_keep)
859+
860+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
861+
self._remove_unused_cat(graph_module)
862+
graph_module.recompile()
863+
graph_module.graph.eliminate_dead_code()
864+
return super().call(graph_module)
865+
866+
810867
# The following class consolidates functions to remove ops that are redundant
811868
# in Jarvis. Currently, each function in this class iterates over each node of
812869
# the graph module once. In future, we could consolidate them into a monolithic

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from executorch.backends.cadence.aot.remove_ops import (
2323
RemoveAliasCopyOpPass,
2424
RemoveBranchedQuantDequant,
25+
RemoveCatFromSliceCopyPass,
2526
RemoveCloneOpPass,
2627
RemoveContiguousOpPass,
2728
RemoveDetachCopyPass,
@@ -741,3 +742,54 @@ def forward(self, x):
741742
},
742743
)
743744
)
745+
746+
def test_remove_cat_from_slice_copy_all_removal(self) -> None:
747+
class M(torch.nn.Module):
748+
def __init__(self):
749+
super().__init__()
750+
751+
def forward(self, x, y):
752+
x1 = torch.cat((x, y), 0) # (2, 4)
753+
return torch.slice_copy(x1, dim=0, start=0, end=1)
754+
755+
inputs = tuple(torch.randn(2, 4) for _ in range(2))
756+
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
757+
p = RemoveCatFromSliceCopyPass()
758+
graph_module = cast(PassResult, p(graph_module)).graph_module
759+
760+
# Ensure both cat nodes were removed
761+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
762+
763+
def test_remove_cat_from_slice_copy_no_removal(self) -> None:
764+
class M(torch.nn.Module):
765+
def __init__(self):
766+
super().__init__()
767+
768+
def forward(self, x, y):
769+
x1 = torch.cat((x, y), 0) # (2, 4)
770+
return torch.slice_copy(x1, dim=0, start=0, end=3)
771+
772+
inputs = tuple(torch.randn(2, 4) for _ in range(2))
773+
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
774+
p = RemoveCatFromSliceCopyPass()
775+
graph_module = cast(PassResult, p(graph_module)).graph_module
776+
777+
# Ensure both cat nodes were removed
778+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
779+
780+
def test_remove_cat_from_slice_copy_zero_range(self) -> None:
781+
class M(torch.nn.Module):
782+
def __init__(self):
783+
super().__init__()
784+
785+
def forward(self, x, y):
786+
x1 = torch.cat((x, y), 0) # (2, 4)
787+
return torch.slice_copy(x1, dim=0, start=0, end=0)
788+
789+
inputs = tuple(torch.randn(2, 4) for _ in range(2))
790+
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
791+
p = RemoveCatFromSliceCopyPass()
792+
graph_module = cast(PassResult, p(graph_module)).graph_module
793+
794+
# Ensure both cat nodes were removed
795+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)

0 commit comments

Comments
 (0)