Skip to content

Commit 8bee6d6

Browse files
internal developeraostrowski-hbn
authored andcommitted
Return unsafe_view instead of view from matmul when folding occurs
When tensor folding occurs during matmul operation returned tensor is a view. This can cause issues when matmul is used inside a custom function and such view is then returned as output. Then it cannot be modified inplace and causes errors. It can be especially problematic when after such function inplace allreduce is performed. Issue is resolved when unsafe_view is returned from matmul instead. This solution aligns matmul decomposition with eager implementation in such a way that a non view tensor is returned. Pull request openned to pytorch pytorch#134568 Change-Id: I77484ff6f22d3e290352348b1acbffa267eb063b
1 parent d9ba83d commit 8bee6d6

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
3+
from torch.testing._internal.common_utils import TestCase, run_tests
4+
5+
class TestCustomFunction(TestCase):
6+
def test_autograd_function_with_matmul_folding_at_output(self):
7+
"""
8+
When tensor folding occurs during matmul operation returned tensor is a view.
9+
This can cause issues when matmul is used inside a custom function
10+
and such view is then returned as output. Then it cannot be modified inplace
11+
and causes errors.
12+
It can be especially problematic when after such function inplace allreduce
13+
is performed. This test recreates this behaviour.
14+
Issue is resolved when unsafe_view is returned from matmul instead.
15+
"""
16+
17+
class CustomFunction(torch.autograd.Function):
18+
19+
@staticmethod
20+
def forward(ctx, inp1, inp2) -> torch.Tensor:
21+
ctx.save_for_backward(inp2)
22+
ctx.output_shape = inp1.size()
23+
return torch.matmul(inp1, inp2)
24+
25+
@staticmethod
26+
def backward(ctx, grad_output) -> tuple[torch.Tensor, None]:
27+
output_shape = ctx.output_shape
28+
inp2, = ctx.saved_tensors
29+
return torch.mm(grad_output.squeeze(), inp2.t()).view(output_shape), None
30+
31+
32+
def outer_function(inp1, inp2) -> torch.Tensor:
33+
res = CustomFunction.apply(inp1, inp2)
34+
res.add_(1.0)
35+
return res.sum()
36+
37+
def usual_function(inp1, inp2) -> torch.Tensor:
38+
res = torch.matmul(inp1, inp2)
39+
res.add_(1.0)
40+
return res.sum()
41+
42+
43+
inp1_custom = torch.randn(4, 1, 2, requires_grad=True)
44+
inp1_usual = inp1_custom.detach().clone().requires_grad_(True)
45+
46+
inp2 = torch.randn(2, 4)
47+
c_custom_func = torch.compile(outer_function)
48+
c_usual_func = torch.compile(usual_function)
49+
50+
result_custom = c_custom_func(inp1_custom, inp2)
51+
result_custom.backward()
52+
result_usual = c_usual_func(inp1_usual, inp2)
53+
result_usual.backward()
54+
55+
torch.allclose(inp1_custom.grad, inp1_usual.grad)
56+
57+
58+
if __name__ == "__main__":
59+
run_tests()

torch/_decomp/decompositions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4361,10 +4361,10 @@ def matmul(tensor1, tensor2, *, is_out=False):
43614361
if t2_is_matrix:
43624362
# This copies if we perform a 2D @ 3D and the first tensor requires_grad
43634363
# See should_fold native/LinearAlgebra.cpp for why.
4364-
output = t1_folded.mm(t2).view(output_shape)
4364+
output = torch.ops.aten._unsafe_view(t1_folded.mm(t2), output_shape)
43654365
return output.mT.contiguous() if transpose else output
43664366
else:
4367-
return t1_folded.mv(t2).view(output_shape)
4367+
return torch.ops.aten._unsafe_view(t1_folded.mv(t2), output_shape)
43684368

43694369
elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
43704370
# We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);

0 commit comments

Comments
 (0)