Skip to content

Commit 409c396

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo] Record the pre-graph bytecode using fast record function event (pytorch#154769)
![image](https://github.com/user-attachments/assets/1d06618b-1c14-4ed5-ab7b-dcfecbb4d632) Adds another event in the profiler traces. This can help us find models where pre-graph bytecode is very expensive. Pull Request resolved: pytorch#154769 Approved by: https://github.com/zou3519, https://github.com/williamwen42, https://github.com/StrongerXi, https://github.com/jansel
1 parent f6b83d4 commit 409c396

File tree

4 files changed

+48
-2
lines changed

4 files changed

+48
-2
lines changed

test/inductor/test_compiled_autograd.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,15 @@ def reset():
103103

104104
class TestCompiledAutograd(TestCase):
105105
def setUp(self) -> None:
106+
self.exit_stack = contextlib.ExitStack()
107+
self.exit_stack.enter_context(
108+
config.patch("record_pre_graph_bytecode_in_traces", False)
109+
)
106110
super().setUp()
107111
reset()
108112

109113
def tearDown(self) -> None:
114+
self.exit_stack.close()
110115
super().tearDown()
111116
reset()
112117

torch/_dynamo/codegen.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch.nn
2424
from torch.utils._ordered_set import OrderedSet
2525

26-
from . import graph_break_hints, utils
26+
from . import config, graph_break_hints, utils
2727
from .bytecode_transformation import (
2828
add_push_null,
2929
add_push_null_call_function_ex,
@@ -613,6 +613,18 @@ def collect_temp_source(source):
613613
if arg.source is not None:
614614
collect_temp_source(arg.source)
615615

616+
cm_var = None
617+
if config.record_pre_graph_bytecode_in_traces:
618+
# Record the pregraph bytecode start
619+
self.add_push_null(
620+
lambda: self.load_import_from(
621+
utils.__name__, "record_pregraph_bytecode_enter"
622+
)
623+
)
624+
self.extend_output(create_call_function(0, False))
625+
cm_var = self.new_var()
626+
self.store(cm_var)
627+
616628
for arg in graphargs:
617629
if arg.pass_arg_as_tensor:
618630
self.add_push_null(
@@ -628,6 +640,18 @@ def collect_temp_source(source):
628640
else:
629641
self.call_reconstruct(arg)
630642

643+
if config.record_pre_graph_bytecode_in_traces:
644+
# Record the pregraph bytecode end
645+
self.add_push_null(
646+
lambda: self.load_import_from(
647+
utils.__name__, "record_pregraph_bytecode_exit"
648+
)
649+
)
650+
assert cm_var is not None
651+
self.extend_output([self.create_load(cm_var)])
652+
self.extend_output(create_call_function(1, False))
653+
self.pop_top()
654+
631655
self.extend_output(create_call_function(len(graphargs), False))
632656

633657
def load_import_from(self, module_name, object_name) -> None:

torch/_dynamo/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,9 @@ def default_debug_dir_root():
615615
# wrapper. This ensures that nn.module hooks are also compiled in the same frame.
616616
wrap_top_frame = False
617617

618+
# record pre-graph bytecode in profile traces
619+
record_pre_graph_bytecode_in_traces = True
620+
618621
# HACK: this is for testing custom ops profiling only
619622
_custom_ops_profile: Optional[Any] = None
620623

torch/_dynamo/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
import warnings
4848
import weakref
4949
from collections import Counter, OrderedDict
50-
from contextlib import contextmanager
50+
from contextlib import AbstractContextManager, contextmanager
5151
from dataclasses import is_dataclass
5252
from functools import lru_cache
5353
from types import MethodWrapperType
@@ -4670,3 +4670,17 @@ def maybe_disable_inference_mode_for_fake_prop() -> Generator[None, None, None]:
46704670

46714671
def is_node_meta_valid(node: Optional[torch.fx.Node]) -> bool:
46724672
return node is None or "example_value" in node.meta or "val" in node.meta
4673+
4674+
4675+
def record_pregraph_bytecode_enter() -> AbstractContextManager[None]:
4676+
cm: AbstractContextManager[None] = (
4677+
torch._C._profiler._RecordFunctionFast("Pregraph bytecode")
4678+
if torch.autograd.profiler._is_profiler_enabled
4679+
else contextlib.nullcontext()
4680+
)
4681+
cm.__enter__()
4682+
return cm
4683+
4684+
4685+
def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None:
4686+
cm.__exit__(None, None, None)

0 commit comments

Comments
 (0)