Skip to content

Commit cc96feb

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo] Mark a vt unspecialized nn module variable source earlier (pytorch#154780)
I am working on providing some skip guard helper functions to allow users to reduce guard overhead. This is a refactor to allow that. Pull Request resolved: pytorch#154780 Approved by: https://github.com/StrongerXi, https://github.com/jansel
1 parent ea7b233 commit cc96feb

File tree

8 files changed

+39
-19
lines changed

8 files changed

+39
-19
lines changed

benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ hf_Bert_large,pass,0
138138

139139

140140

141-
hf_BigBird,pass,18
141+
hf_BigBird,pass,24
142142

143143

144144

benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ hf_Bert_large,pass,0
122122

123123

124124

125-
hf_BigBird,pass,18
125+
hf_BigBird,pass,24
126126

127127

128128

test/dynamo/test_modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,7 @@ def test_unsupportedmodule(self):
12991299
self.assertTrue(torch._dynamo.testing.same(r, m(i)))
13001300
self.assertEqual(cnt.op_count, 6)
13011301

1302+
@patch.object(torch._dynamo.config, "allow_unspec_int_on_nn_module", True)
13021303
def test_self_mutating1(self):
13031304
m1 = torch.nn.Linear(10, 10)
13041305
m2 = SelfMutatingModule(m1)

test/functorch/test_control_flow.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7454,14 +7454,13 @@ def forward(self, a, b):
74547454
self.assertExpectedInline(
74557455
backend.graphs[0].code.strip(),
74567456
"""\
7457-
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt):
7457+
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
74587458
l_a_ = L_a_
74597459
l_b_ = L_b_
7460-
l_self_num = L_self_num
74617460
tensor = torch.tensor([True])
74627461
cond_true_0 = self.cond_true_0
74637462
cond_false_0 = self.cond_false_0
7464-
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s97 = None
7463+
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = s97 = None
74657464
getitem = cond[0]; cond = None
74667465
return (getitem,)""", # noqa: B950
74677466
)

torch/_dynamo/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2402,6 +2402,10 @@ def is_int_specialization_case(value, source):
24022402
source.guard_source().is_unspecialized_builtin_nn_module()
24032403
and not config.allow_unspec_int_on_nn_module
24042404
)
2405+
or (
2406+
source.guard_source().is_unspecialized_nn_module()
2407+
and not config.allow_unspec_int_on_nn_module
2408+
)
24052409
or is_from_defaults(source)
24062410
# TODO: Delete this condition when rollout is done. NB: this
24072411
# condition never evaluates True in open source

torch/_dynamo/variables/builder.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@
115115
Source,
116116
SubclassAttrListSource,
117117
TupleIteratorGetItemSource,
118+
UnspecializedBuiltinNNModuleSource,
119+
UnspecializedNNModuleSource,
118120
)
119121
from ..utils import (
120122
_extract_tensor_dict,
@@ -434,7 +436,10 @@ def __call__(self, value):
434436
return cached_vt
435437

436438
vt = self._wrap(value)
437-
vt.source = self.source
439+
440+
if vt.source is None:
441+
vt.source = self.source
442+
438443
if (
439444
self._can_lift_attrs_to_inputs(vt)
440445
and value not in self.tx.output.side_effects
@@ -1714,7 +1719,6 @@ def wrap_module(self, value: torch.nn.Module):
17141719
value = value.get_base()
17151720
self.source = AttrProxySource(self.source)
17161721

1717-
self.install_guards(GuardBuilder.TYPE_MATCH)
17181722
if torch._dynamo.config.inline_inbuilt_nn_modules:
17191723
freezing = is_parameter_freezing()
17201724

@@ -1749,12 +1753,23 @@ def wrap_module(self, value: torch.nn.Module):
17491753
# this will get cleaned up once compile ends
17501754
self.tx.output.nn_modules[self.name] = value
17511755

1752-
if value.__module__.startswith(("torch.nn.", "torch.ao.")) or getattr(
1753-
value.__class__, "_dynamo_marked_static", False
1754-
):
1755-
result = UnspecializedBuiltinNNModuleVariable(value, source=self.source)
1756+
if (
1757+
value.__module__.startswith(("torch.nn.modules", "torch.ao."))
1758+
and not value.__module__.startswith("torch.nn.modules.container")
1759+
) or getattr(value.__class__, "_dynamo_marked_static", False):
1760+
new_source = self.source
1761+
if config.inline_inbuilt_nn_modules:
1762+
# Export corner case - look at test_repros.py test_inlining_cornercase
1763+
new_source = UnspecializedBuiltinNNModuleSource(self.source)
1764+
result = UnspecializedBuiltinNNModuleVariable(value, source=new_source)
1765+
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
17561766
else:
1757-
result = UnspecializedNNModuleVariable(value, source=self.source)
1767+
new_source = self.source
1768+
if config.inline_inbuilt_nn_modules:
1769+
# Export corner case - look at test_repros.py test_inlining_cornercase
1770+
new_source = UnspecializedNNModuleSource(self.source)
1771+
result = UnspecializedNNModuleVariable(value, source=new_source)
1772+
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
17581773

17591774
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
17601775
# don't allow STORE_ATTR mutation with custom __setattr__
@@ -2127,6 +2142,10 @@ def wrap_numpy_ndarray(self, value):
21272142
)
21282143
proxy.node.meta["grapharg"] = grapharg
21292144

2145+
# TODO - Why do we need to set the source of the np ndarray vt back to
2146+
# original source. Many tests fails.
2147+
numpy_ndarray_variable.source = self.source
2148+
21302149
return numpy_ndarray_variable
21312150

21322151
def wrap_symint(

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2658,8 +2658,8 @@ def call_function(
26582658

26592659
class FlexAttentionBackwardHighOrderVariable(TorchHigherOrderOperatorVariable):
26602660
def proxy_submod(self, tx, arg):
2661-
assert isinstance(arg.source, DictGetItemSource)
2662-
submod_name = tx.output.install_subgraph(arg.source.index, arg.value)
2661+
assert isinstance(arg.source.base, DictGetItemSource)
2662+
submod_name = tx.output.install_subgraph(arg.source.base.index, arg.value)
26632663
p_submod = make_attr(tx, submod_name)
26642664
set_example_value(p_submod.node, arg.value)
26652665
return p_submod

torch/_dynamo/variables/nn_module.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
FSDPNNModuleSource,
4949
GetItemSource,
5050
NNModuleSource,
51-
UnspecializedBuiltinNNModuleSource,
5251
UnspecializedNNModuleSource,
5352
)
5453
from ..utils import (
@@ -891,8 +890,7 @@ def __init__(self, value, **kwargs) -> None:
891890
self.nn_module_stack_source = self.source
892891

893892
def _wrap_source(self, attr_source):
894-
if not isinstance(attr_source, UnspecializedNNModuleSource):
895-
return UnspecializedNNModuleSource(attr_source)
893+
# the vt is already wrapped with UnspecializedNNModuleSource
896894
return attr_source
897895

898896
def get_nn_module_stack_source(self):
@@ -1193,8 +1191,7 @@ class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable):
11931191
"""
11941192

11951193
def _wrap_source(self, attr_source):
1196-
if not isinstance(attr_source, UnspecializedBuiltinNNModuleSource):
1197-
return UnspecializedBuiltinNNModuleSource(attr_source)
1194+
# vt is already wrapped with the UnspecializedBuiltinNNModuleSource
11981195
return attr_source
11991196

12001197

0 commit comments

Comments
 (0)