Skip to content

Commit a929e11

Browse files
pianpwkpytorchmergebot
authored andcommitted
[dynamic shapes][export] ignore when real-tensor fallback fails (pytorch#147779)
Summary: uninspired solution to pytorch#147402 Test Plan: test_draft_export Differential Revision: D70132269 Pull Request resolved: pytorch#147779 Approved by: https://github.com/bobrenjc93
1 parent 0929181 commit a929e11

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

test/export/test_draft_export.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,21 @@ def forward(self, x):
386386
self.assertTrue(report.successful())
387387
self.assertEqual(inp[0], torch.ones(3, 3))
388388

389+
def test_masked_linear(self):
390+
class M(torch.nn.Module):
391+
def forward(self, x, mask, weight, bias):
392+
masked = x[mask != 0, :, :]
393+
return torch.nn.functional.linear(masked, weight, bias)
394+
395+
x = torch.zeros(10)
396+
inp = (torch.randn(10, 8, 7), x, torch.randn(25, 7), torch.randn(25))
397+
draft_ep = draft_export(M(), inp)
398+
ep = export(M(), inp)
399+
self.assertEqual(draft_ep.module()(*inp), ep.module()(*inp))
400+
x[2] += 1
401+
x[3] += 1
402+
self.assertEqual(draft_ep.module()(*inp), ep.module()(*inp))
403+
389404
def test_torchbind(self):
390405
class Model(torch.nn.Module):
391406
def __init__(self) -> None:

torch/_subclasses/fake_tensor.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2201,7 +2201,17 @@ def maybe_to_real_tensor(
22012201
func, real_flat_args, args_spec
22022202
)
22032203

2204-
real_out = func(*real_args, **real_kwargs)
2204+
try:
2205+
real_out = func(*real_args, **real_kwargs)
2206+
except ZeroDivisionError as exc:
2207+
# we shouldn't broadly catch all errors here;
2208+
# some come from real-kernel mutation/aliasing checks we want to run.
2209+
# add more exception types as needed.
2210+
log.debug(
2211+
"real-tensor fallback failed for %s: %s; silently ignoring",
2212+
func,
2213+
exc,
2214+
)
22052215

22062216
if not is_builtin:
22072217
mutation_checker.check() # type: ignore[possibly-undefined]

0 commit comments

Comments
 (0)