Skip to content

Commit 742aa69

Browse files
committed
Fix dynamo tracing into AOTAutogradCache results
ghstack-source-id: 57ec57c Pull Request resolved: #155251
1 parent d2a2bfc commit 742aa69

File tree

3 files changed

+60
-2
lines changed

3 files changed

+60
-2
lines changed

test/dynamo/test_aot_autograd_cache.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: dynamo"]
22

3+
import copy
34
import os
45
import shutil
56
import unittest
@@ -822,6 +823,44 @@ def fn(a, b):
822823
self.assertEqual(a.grad, a2.grad)
823824
self.assertEqual(b.grad, b2.grad)
824825

826+
@inductor_config.patch("fx_graph_remote_cache", False)
827+
@inductor_config.patch({"fx_graph_cache": True})
828+
@functorch_config.patch({"enable_autograd_cache": True})
829+
@functorch_config.patch({"strict_autograd_cache": True})
830+
def test_autograd_no_dynamo_trace_backward(self):
831+
"""
832+
Test that dynamo does not trace into the backward compiled function,
833+
even on cache hit.
834+
"""
835+
torch._dynamo.eval_frame.clear_dynamo_tls()
836+
837+
@torch.compile
838+
def fn(x):
839+
# Calls x.sum().backward() during forward execution of fn
840+
(x_grad,) = torch.autograd.grad(x.sum(), x)
841+
return x_grad
842+
843+
a = torch.randn(10, 10, requires_grad=True, device="cpu")
844+
result = fn(a)
845+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
846+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
847+
# Backward of `sum` will run during execution of graph break
848+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
849+
traced_frame_infos = copy.deepcopy(
850+
torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos
851+
)
852+
853+
torch._dynamo.reset()
854+
torch._dynamo.eval_frame.clear_dynamo_tls()
855+
result2 = fn(a)
856+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
857+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
858+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
859+
new_traced_frame_infos = torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos
860+
self.assertEqual(result, result2)
861+
# Dynamo should trace exactly the same frames on cache hit
862+
self.assertEqual(traced_frame_infos, new_traced_frame_infos)
863+
825864
@inductor_config.patch("fx_graph_remote_cache", False)
826865
@inductor_config.patch("fx_graph_cache", True)
827866
@functorch_config.patch({"enable_autograd_cache": True})

torch/_dynamo/backends/common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs):
6868

6969
def wrap_bw_compiler(bw_compiler_fn):
7070
def _wrapped_bw_compiler(*args, **kwargs):
71-
# stop TorchDynamo from trying to compile our generated backwards pass
71+
# Note [Wrapping bw_compiler in disable]
72+
# The two disables here:
73+
# - stop TorchDynamo from trying to compile the bw_compiler function itself
74+
# - stop TorchDynamo from trying to compile our the generated backwards pass bw_compiler produces
7275
return disable(
7376
disable(
7477
bw_compiler_fn, reason="do not trace backward compiler function"

torch/_functorch/_aot_autograd/autograd_cache.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,15 @@ class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoa
589589
def _is_backward(self) -> bool:
590590
return True
591591

592+
def post_compile(
593+
self, result: CompiledFxGraph, fx_config: _CompileFxKwargs
594+
) -> CompiledFxGraph:
595+
compiled_bw = super().post_compile(result, fx_config)
596+
# See note [Wrapping bw_compiler in disable]
597+
# This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py
598+
# But since on cache hit we do not call the bw_compiler, we need to reapply the disable
599+
return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value]
600+
592601

593602
# Forward types don't have any extra parameters, so this is just a TypeAlias, in essence
594603
class BundledCompiledForward(CompiledFxGraphLoadable):
@@ -599,7 +608,14 @@ class BundledCompiledForward(CompiledFxGraphLoadable):
599608
class BundledCompiledBackward(
600609
GenericCompiledBackward[CompiledFxGraph], CompiledFxGraphLoadable
601610
):
602-
pass
611+
def post_compile(
612+
self, result: CompiledFxGraph, fx_config: _CompileFxKwargs
613+
) -> CompiledFxGraph:
614+
compiled_bw = super().post_compile(result, fx_config)
615+
# See note [Wrapping bw_compiler in disable]
616+
# This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py
617+
# But since on cache hit we do not call the bw_compiler, we need to reapply the disable
618+
return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value]
603619

604620

605621
TForward = TypeVar("TForward", bound=InductorOutput)

0 commit comments

Comments
 (0)