Skip to content

Commit 0fdc8df

Browse files
authored
Generalize view_copy fusion.
Differential Revision: D73443870 Pull Request resolved: #10356
1 parent ad0f610 commit 0fdc8df

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -526,34 +526,14 @@ class FuseCascadedViewOps(ExportPass):
526526
Fuse a cascaded chain of view ops
527527
"""
528528

529-
# Find a chain of view ops, and fuse them into a single permute op.
530-
531529
def fuse_cascaded_view_ops(self, graph_module: torch.fx.GraphModule):
532-
graph = graph_module.graph
533-
for node in graph.nodes:
534-
# We are only interested in view ops
535-
if node.target != exir_ops.edge.aten.view_copy.default:
536-
continue
537-
538-
# Get the cascaded chain of view ops starting at node
539-
cascaded_view_ops = get_cascaded_ops(
540-
[node], [exir_ops.edge.aten.view_copy.default]
541-
)
542-
# The chain must have more than 1 node
543-
if len(cascaded_view_ops) == 1:
530+
view_target = exir_ops.edge.aten.view_copy.default
531+
for view_node in graph_module.graph.find_nodes(op="call_function", target=view_target, sort=True):
532+
input_view = view_node.args[0]
533+
if input_view.op != "call_function" or input_view.target != view_target:
544534
continue
545535

546-
last_view_node = cascaded_view_ops[-1]
547-
with graph.inserting_before(last_view_node):
548-
new_view = graph.call_function(
549-
exir_ops.edge.aten.view_copy.default,
550-
args=(node.args[0], last_view_node.args[1]),
551-
)
552-
last_view_node.replace_all_uses_with(new_view)
553-
554-
# Now erase the chain
555-
for v in reversed(cascaded_view_ops):
556-
graph.erase_node(v)
536+
view_node.replace_input_with(input_view, input_view.args[0])
557537

558538
graph_module.recompile()
559539

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,26 @@ def forward(self, x):
222222
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 1
223223
)
224224

225+
def test_view_fusion_branched(self):
226+
class ViewFusion(torch.nn.Module):
227+
def forward(self, x):
228+
y = x.view([1, 8, 15])
229+
z = y.view([1, 1, 120])
230+
t = y.view([120, 1, 1])
231+
return z, t
232+
233+
x = torch.randn(8, 5, 3)
234+
graph_module = (
235+
compiler.export_to_cadence(ViewFusion(), (x,))
236+
.exported_program()
237+
.graph_module
238+
)
239+
graph_module.graph.eliminate_dead_code()
240+
# z and t should be fused and y should be eliminated.
241+
self.assertEqual(
242+
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 2
243+
)
244+
225245
def test_force_quant_dequant_fusion(self):
226246
class M(torch.nn.Module):
227247
def __init__(self):

0 commit comments

Comments
 (0)