Skip to content

Commit b64182d

Browse files
dulinrileyfacebook-github-bot
authored andcommitted
Fix memory.view insertion except for output nodes (#3602)
Summary: Pull Request resolved: #3602 The previous implementation of ignoring `view_copy` on outputs was incorrect in that it only checked `node.next` instead of all users of the node. `node.next` just selects the next node in topological order, which may or may not be the output if there is more than one output. In the case of more than one output, the next node may not be related at all! Check if any of the users of the node are an output instead. Reviewed By: metascroy, mcremon-meta Differential Revision: D57299853 fbshipit-source-id: 6a373181f6bdd58444e0c859fce320d576b7f749
1 parent 01ce72c commit b64182d

File tree

3 files changed

+24
-12
lines changed

3 files changed

+24
-12
lines changed

exir/passes/replace_view_copy_with_view_pass.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
275275
for node in module.graph.nodes:
276276
# Note: We only replace view_copy nodes that are not output, since
277277
# the output pointer could be modified at runtime (T187925929)
278-
if _is_view_copy(node) and node.next.op != "output":
278+
if _is_view_copy(node) and all(u.op != "output" for u in node.users):
279279
base, _ = node.args
280280
node.target = _VIEW_OP
281281

@@ -302,7 +302,9 @@ def ensures(self, graph_module: torch.fx.GraphModule) -> None:
302302
for node in module.graph.nodes:
303303
# Note: We only replace view_copy nodes that are not output, since
304304
# the output pointer could be modified at runtime (T187925929)
305-
assert not (_is_view_copy(node) and node.next.op != "output")
305+
assert not (
306+
_is_view_copy(node) and all(u.op != "output" for u in node.users)
307+
)
306308
if node.op == "call_function" and node.target == _VIEW_OP:
307309
assert isinstance(node.meta["spec"], _ViewSpec)
308310

@@ -317,6 +319,6 @@ def requires(self, graph_module: torch.fx.GraphModule) -> None:
317319
for node in module.graph.nodes:
318320
# Note: We only replace view_copy nodes that are not output, since
319321
# the output pointer could be modified at runtime (T187925929)
320-
if _is_view_copy(node) and node.next.op != "output":
322+
if _is_view_copy(node) and all(u.op != "output" for u in node.users):
321323
base, size = node.args
322324
assert not _is_view_copy(base)

exir/tests/test_passes.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1602,7 +1602,9 @@ def __init__(self):
16021602
def forward(self, x):
16031603
o1 = torch.ops.aten.view_copy.default(x, [1])
16041604
o2 = torch.ops.aten.view_copy.default(self.parameter, [1])
1605-
return o1, o2
1605+
# view_copys at the end of a function are not replaced, so add
1606+
# a computation before the end of the graph.
1607+
return torch.ops.aten.add.Tensor(o1, o2)
16061608

16071609
ep = torch.export.export(
16081610
TestViewCopies(),
@@ -1631,10 +1633,9 @@ def forward(self, x):
16311633
gm = gm_res.graph_module
16321634

16331635
# Check after transformation
1634-
# Note: one view copy is not replaced, because it's the output of the graph
16351636
FileCheck().check_count(
1636-
"torch.ops.aten.view_copy.default", 1, exactly=True
1637+
"torch.ops.aten.view_copy.default", 0, exactly=True
16371638
).run(gm.code)
1638-
FileCheck().check_count("executorch_exir_memory_view", 1, exactly=True).run(
1639+
FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run(
16391640
gm.code
16401641
)

exir/tests/test_remove_view_copy.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def forward(self, x):
3232
) # removed, lifetime of mul.Tensor will be extended
3333
v4 = torch.ops.aten.mul.Tensor(v3, self.parameter2)
3434
v5 = v4.view(6, 5) # not removed, output of the graph
35-
return v5
35+
v6 = v4.view(2, 15) # not removed, output of the graph
36+
return v5, v6
3637

3738
def get_example_inputs(self):
3839
return (torch.rand(5, 6),)
@@ -87,10 +88,15 @@ def test_output_matches(self) -> None:
8788
),
8889
)
8990

90-
out_remove = etpm_remove.exported_program().module()(*example_inputs)
91-
out_no_remove = etpm_no_remove.exported_program().module()(*example_inputs)
91+
out_remove_v5, out_remove_v6 = etpm_remove.exported_program().module()(
92+
*example_inputs
93+
)
94+
out_no_remove_v5, out_no_remove_v6 = etpm_no_remove.exported_program().module()(
95+
*example_inputs
96+
)
9297

93-
self.assertTrue(torch.allclose(out_remove, out_no_remove))
98+
self.assertTrue(torch.allclose(out_remove_v5, out_no_remove_v5))
99+
self.assertTrue(torch.allclose(out_remove_v6, out_no_remove_v6))
94100

95101
def test_spec(self) -> None:
96102
model = TestModel1()
@@ -196,7 +202,7 @@ def test_spec(self) -> None:
196202
self.assertEqual(plan.operators[2].name, "aten::view_copy")
197203

198204
instructions = plan.chains[0].instructions
199-
self.assertEqual(len(instructions), 6)
205+
self.assertEqual(len(instructions), 7)
200206

201207
self.assertEqual(
202208
instructions[0].instr_args.op_index, 0 # pyre-ignore
@@ -216,3 +222,6 @@ def test_spec(self) -> None:
216222
self.assertEqual(
217223
instructions[5].instr_args.op_index, 2 # pyre-ignore
218224
) # aten:view_copy @ idx11
225+
self.assertEqual(
226+
instructions[6].instr_args.op_index, 2 # pyre-ignore
227+
) # aten:view_copy @ idx11

0 commit comments

Comments
 (0)