Skip to content

Commit 48286d3

Browse files
Revert "Break graph on manual_seed. (pytorch#107594)"
This reverts commit 6ad5568. Reverted pytorch#107594 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it has an import issue that breaks internal code ([comment](pytorch#107594 (comment)))
1 parent e08577a commit 48286d3

File tree

6 files changed

+25
-34
lines changed

6 files changed

+25
-34
lines changed

test/dynamo/test_functions.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,15 +1146,6 @@ def test_numpy_random():
11461146
x = np.random.randn(2, 2)
11471147
return x - x
11481148

1149-
def test_manual_seed(self):
1150-
@torch.compile
1151-
def foo():
1152-
torch.manual_seed(3)
1153-
return torch.randint(0, 5, (5,))
1154-
1155-
self.assertEqual(foo(), foo())
1156-
self.assertEqual(foo(), foo())
1157-
11581149

11591150
def global_func_with_default_tensor_args(
11601151
x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))

test/dynamo/test_misc.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1813,6 +1813,14 @@ def fn(x, obj):
18131813
res = opt_fn(x, obj)
18141814
self.assertTrue(same(ref, res))
18151815

1816+
def test_manual_seed(self):
1817+
def fn(a, b):
1818+
x = a + b
1819+
torch.manual_seed(9000)
1820+
return x + 1
1821+
1822+
torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
1823+
18161824
def test_usr_cls_staticmethod(self):
18171825
class Foo:
18181826
@staticmethod
@@ -2259,17 +2267,13 @@ def fn(x):
22592267
torch.manual_seed(attention_seed)
22602268
return (x,)
22612269

2262-
x = torch.randn(10, requires_grad=True)
2270+
x = torch.randn(100, requires_grad=True)
22632271
ref = fn(x)
22642272

2265-
# Python code is needed here, since torch.manual_seed graph-breaks.
2266-
# Refs: https://github.com/pytorch/pytorch/issues/107187
2267-
opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
2273+
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
22682274
res = opt_fn(x)
22692275

22702276
self.assertTrue(same(ref, res))
2271-
self.assertEqual(cnts.op_count, 1)
2272-
self.assertEqual(cnts.frame_count, 1)
22732277

22742278
def test_is_tensor_like(self):
22752279
cnts = torch._dynamo.testing.CompileCounter()

test/inductor/test_torchinductor_opinfo.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def format_op(op):
262262
"nn.functional.instance_norm": {f16},
263263
"nn.functional.local_response_norm": {f16},
264264
"nn.functional.normalize": {f16},
265+
"nn.functional.rrelu": {f16, f32, f64},
265266
"nn.functional.soft_margin_loss": {f16},
266267
"nn.functional.softsign": {f16},
267268
"nn.functional.triplet_margin_loss": {f16},
@@ -277,6 +278,7 @@ def format_op(op):
277278
"sparse.sampled_addmm": {f32, f64},
278279
("std_mean", "unbiased"): {f16},
279280
"to_sparse": {f16, f32, f64},
281+
"uniform": {f16, f32, f64},
280282
}
281283

282284

@@ -340,13 +342,15 @@ def get_skips_and_xfails(from_dict, xfails=True):
340342
)
341343

342344

343-
def wrapper_noop_set_seed(op, *args, **kwargs):
345+
def wrapper_set_seed(op, *args, **kwargs):
346+
"""Wrapper to set seed manually for some functions like dropout
347+
See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
348+
"""
349+
torch.manual_seed(42)
344350
return op(*args, **kwargs)
345351

346352

347-
torch.testing._internal.common_methods_invocations.wrapper_set_seed = (
348-
wrapper_noop_set_seed
349-
)
353+
torch.testing._internal.common_methods_invocations.wrapper_set_seed = wrapper_set_seed
350354

351355
# This file does a global patch to `disable_global_flags()` - which we should not invoke in non testing cases.
352356
torch._dynamo.variables.torch.tensor_dunder_fns.append(

torch/__init__.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,17 +1303,6 @@ def _dtype(self):
13031303
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
13041304
_tensor_classes: Set[Type] = set()
13051305

1306-
################################################################################
1307-
# Import TorchDynamo's lazy APIs to avoid circular dependenices
1308-
################################################################################
1309-
1310-
# needs to be before from .functional import * to avoid circular dependencies
1311-
from ._compile import _disable_dynamo
1312-
1313-
################################################################################
1314-
# Import miscelaneous torch functions
1315-
################################################################################
1316-
13171306
# If you edit these imports, please update torch/__init__.py.in as well
13181307
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
13191308
from .serialization import save, load
@@ -1378,6 +1367,13 @@ def manager_path():
13781367

13791368

13801369

1370+
################################################################################
1371+
# Import TorchDynamo's lazy APIs to avoid circular dependenices
1372+
################################################################################
1373+
1374+
# needs to be before from .functional import * to avoid circular dependencies
1375+
from ._compile import _disable_dynamo
1376+
13811377
################################################################################
13821378
# Import interface functions defined in Python
13831379
################################################################################

torch/_dynamo/variables/torch.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,9 +453,6 @@ def call_function(
453453
elif self.value is torch.nn.Parameter:
454454
# https://github.com/pytorch/pytorch/issues/99569
455455
unimplemented("torch.nn.Parameter not supported")
456-
elif self.value is torch.manual_seed:
457-
# https://github.com/pytorch/pytorch/issues/107187
458-
unimplemented("torch.manual_seed not supported")
459456
if (
460457
self.value.__name__ == "get_state"
461458
and hasattr(self.value, "__self__")

torch/random.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def get_rng_state() -> torch.Tensor:
2323
return default_generator.get_state()
2424

2525

26-
@torch._disable_dynamo
2726
def manual_seed(seed) -> torch._C.Generator:
2827
r"""Sets the seed for generating random numbers. Returns a
2928
`torch.Generator` object.

0 commit comments

Comments
 (0)