Skip to content

Commit 6747bf9

Browse files
internal developermryszt
authored andcommitted
Return unsafe_view instead of view from matmul when folding occurs
When tensor folding occurs during matmul operation returned tensor is a view. This can cause issues when matmul is used inside a custom function and such view is then returned as output. Then it cannot be modified inplace and causes errors. It can be especially problematic when after such function inplace allreduce is performed. Issue is resolved when unsafe_view is returned from matmul instead. This solution aligns matmul decomposition with eager implementation in such a way that a non view tensor is returned. Pull request openned to pytorch pytorch#134568 Change-Id: I77484ff6f22d3e290352348b1acbffa267eb063b
1 parent 2735488 commit 6747bf9

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
3+
from torch.testing._internal.common_utils import TestCase, run_tests
4+
5+
class TestCustomFunction(TestCase):
6+
def test_autograd_function_with_matmul_folding_at_output(self):
7+
"""
8+
When tensor folding occurs during matmul operation returned tensor is a view.
9+
This can cause issues when matmul is used inside a custom function
10+
and such view is then returned as output. Then it cannot be modified inplace
11+
and causes errors.
12+
It can be especially problematic when after such function inplace allreduce
13+
is performed. This test recreates this behaviour.
14+
Issue is resolved when unsafe_view is returned from matmul instead.
15+
"""
16+
17+
class CustomFunction(torch.autograd.Function):
18+
19+
@staticmethod
20+
def forward(ctx, inp1, inp2) -> torch.Tensor:
21+
ctx.save_for_backward(inp2)
22+
ctx.output_shape = inp1.size()
23+
return torch.matmul(inp1, inp2)
24+
25+
@staticmethod
26+
def backward(ctx, grad_output) -> tuple[torch.Tensor, None]:
27+
output_shape = ctx.output_shape
28+
inp2, = ctx.saved_tensors
29+
return torch.mm(grad_output.squeeze(), inp2.t()).view(output_shape), None
30+
31+
32+
def outer_function(inp1, inp2) -> torch.Tensor:
33+
res = CustomFunction.apply(inp1, inp2)
34+
res.add_(1.0)
35+
return res.sum()
36+
37+
def usual_function(inp1, inp2) -> torch.Tensor:
38+
res = torch.matmul(inp1, inp2)
39+
res.add_(1.0)
40+
return res.sum()
41+
42+
43+
inp1_custom = torch.randn(4, 1, 2, requires_grad=True)
44+
inp1_usual = inp1_custom.detach().clone().requires_grad_(True)
45+
46+
inp2 = torch.randn(2, 4)
47+
c_custom_func = torch.compile(outer_function)
48+
c_usual_func = torch.compile(usual_function)
49+
50+
result_custom = c_custom_func(inp1_custom, inp2)
51+
result_custom.backward()
52+
result_usual = c_usual_func(inp1_usual, inp2)
53+
result_usual.backward()
54+
55+
torch.allclose(inp1_custom.grad, inp1_usual.grad)
56+
57+
58+
if __name__ == "__main__":
59+
run_tests()

torch/_decomp/decompositions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4217,10 +4217,10 @@ def matmul(tensor1, tensor2, *, is_out=False):
42174217
if t2_is_matrix:
42184218
# This copies if we perform a 2D @ 3D and the first tensor requires_grad
42194219
# See should_fold native/LinearAlgebra.cpp for why.
4220-
output = t1_folded.mm(t2).view(output_shape)
4220+
output = torch.ops.aten._unsafe_view(t1_folded.mm(t2), output_shape)
42214221
return output.mT.contiguous() if transpose else output
42224222
else:
4223-
return t1_folded.mv(t2).view(output_shape)
4223+
return torch.ops.aten._unsafe_view(t1_folded.mv(t2), output_shape)
42244224

42254225
elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
42264226
# We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);

0 commit comments

Comments
 (0)