Skip to content

Commit d865b78

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
Support unbacked whitelist (pytorch#154295)
Pull Request resolved: pytorch#154295 Approved by: https://github.com/angelayi
1 parent ef4d573 commit d865b78

File tree

3 files changed

+98
-8
lines changed

3 files changed

+98
-8
lines changed

test/dynamo/test_misc.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7921,6 +7921,34 @@ def fn(x):
79217921

79227922
self.assertEqual(counter.frame_count, 1)
79237923

7924+
@torch.compiler.config.patch(unbacked_sources="L['x']")
7925+
def test_unbacked_sources_tensor(self):
7926+
counter = CompileCounter()
7927+
7928+
@torch.compile(backend=counter)
7929+
def fn(x):
7930+
return x * x
7931+
7932+
fn(torch.randn(0))
7933+
fn(torch.randn(1))
7934+
fn(torch.randn(2))
7935+
7936+
self.assertEqual(counter.frame_count, 1)
7937+
7938+
@torch.compiler.config.patch(unbacked_sources="L['x']")
7939+
def test_unbacked_sources_scalar(self):
7940+
counter = CompileCounter()
7941+
7942+
@torch.compile(backend=counter)
7943+
def fn(x):
7944+
return x * x
7945+
7946+
fn(0)
7947+
fn(1)
7948+
fn(2)
7949+
7950+
self.assertEqual(counter.frame_count, 1)
7951+
79247952
@torch.compiler.config.patch(dynamic_sources="L['x']")
79257953
def test_dynamic_sources_graph_break(self):
79267954
counter = CompileCounter()

torch/_dynamo/variables/builder.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,11 @@ def build_key_value(i, k, v):
14541454
if value.dynamism is None or value.dynamism.type == _DimHintType.STATIC:
14551455
return self.wrap_symint(value.val)
14561456
elif value.dynamism.type == _DimHintType.DYNAMIC:
1457+
log.debug(
1458+
"%s marked %s via IntWrapper",
1459+
self.source.name(),
1460+
DimDynamic.DYNAMIC,
1461+
)
14571462
return self.wrap_symint(
14581463
value.val,
14591464
dynamism=DimDynamic.DYNAMIC,
@@ -1462,6 +1467,11 @@ def build_key_value(i, k, v):
14621467
),
14631468
)
14641469
elif value.dynamism.type == _DimHintType.AUTO:
1470+
log.debug(
1471+
"%s marked %s via IntWrapper",
1472+
self.source.name(),
1473+
DimDynamic.DYNAMIC,
1474+
)
14651475
return self.wrap_symint(value.val, dynamism=DimDynamic.DYNAMIC)
14661476
else:
14671477
raise RuntimeError(f"Undefined dynamism {value.dynamism}")
@@ -1767,7 +1777,12 @@ def wrap_literal(self, value):
17671777
if type(value) is int:
17681778
# allowlist has higher precedence over specialization control.
17691779
if is_dynamic_source(self.source.name()):
1770-
return self.wrap_symint(value, True)
1780+
log.debug("%s marked dynamic via source whitelist", self.source.name())
1781+
return self.wrap_symint(value, dynamism=DimDynamic.DYNAMIC)
1782+
1783+
if is_unbacked_source(self.source.name()):
1784+
log.debug("%s marked unbacked via source whitelist", self.source.name())
1785+
return self.wrap_symint(value, dynamism=DimDynamic.SIZE_LIKE_UNBACKED)
17711786

17721787
if not config.specialize_int:
17731788
# unspecializing int by default, but still
@@ -2117,7 +2132,6 @@ def wrap_numpy_ndarray(self, value):
21172132
def wrap_symint(
21182133
self,
21192134
value,
2120-
is_forced_allow_list_dynamic=False,
21212135
dynamism: Optional[DimDynamic] = None,
21222136
context: Optional[SymIntSymbolicContext] = None,
21232137
):
@@ -2165,12 +2179,8 @@ def wrap_symint(
21652179
if isinstance(base_source, ChainedSource):
21662180
base_source = base_source.get_base()
21672181

2168-
if dynamism == DimDynamic.DYNAMIC:
2169-
log.debug("%s marked %s via IntWrapper", self.source.name(), dynamism)
2182+
if dynamism is not None:
21702183
dynamic_dim = dynamism
2171-
elif is_forced_allow_list_dynamic:
2172-
log.debug("%s marked dynamic via source whitelist", self.source.name())
2173-
dynamic_dim = DimDynamic.DYNAMIC
21742184
elif (
21752185
config.automatic_dynamic_shapes
21762186
and frame_state_entry.scalar is auto_dynamic
@@ -2963,6 +2973,43 @@ def record_automatic_dynamic(
29632973
)
29642974

29652975

2976+
_UNBACKED_SOURCES: Optional[set[str]] = None
2977+
_UNBACKED_SOURCES_CONFIG_HASH: Optional[int] = None
2978+
2979+
2980+
def get_unbacked_sources() -> set[str]:
2981+
global _UNBACKED_SOURCES, _UNBACKED_SOURCES_CONFIG_HASH
2982+
2983+
current_hash = hash(torch.compiler.config.unbacked_sources)
2984+
2985+
# If we have already calculated the sources and the config hasn't changed, return cached result
2986+
if _UNBACKED_SOURCES is not None and _UNBACKED_SOURCES_CONFIG_HASH == current_hash:
2987+
return _UNBACKED_SOURCES
2988+
2989+
# Config has changed or first time, (re)calculate the sources
2990+
_UNBACKED_SOURCES = {
2991+
s
2992+
for s in torch.compiler.config.unbacked_sources.replace(" ", "").split(",")
2993+
if s
2994+
}
2995+
_UNBACKED_SOURCES_CONFIG_HASH = current_hash
2996+
2997+
return _UNBACKED_SOURCES
2998+
2999+
3000+
def is_unbacked_source(source_name: str) -> bool:
3001+
unbacked_sources = get_unbacked_sources()
3002+
for pattern in unbacked_sources:
3003+
if pattern == source_name or re.match(pattern, source_name):
3004+
log.debug(
3005+
"%s was marked unbacked due to unbacked source allowlist pattern: %s",
3006+
source_name,
3007+
pattern,
3008+
)
3009+
return True
3010+
return False
3011+
3012+
29663013
# Performs automatic dynamic dim determination.
29673014
# Returns a SymbolicContext
29683015
def _automatic_dynamic(
@@ -3135,6 +3182,11 @@ def update_dim2constraint(dim, constraint_range, name):
31353182
automatic_dynamic_size = True
31363183
automatic_dynamic_stride = True
31373184

3185+
if is_unbacked_source(name):
3186+
log.debug("%s marked unbacked via source whitelist", name)
3187+
automatic_dynamic_size = True
3188+
automatic_dynamic_stride = True
3189+
31383190
automatic_dynamic = automatic_dynamic_size or automatic_dynamic_stride
31393191

31403192
# We will process constraints first, as they will imply that we
@@ -3185,7 +3237,7 @@ def update_dim2constraint(dim, constraint_range, name):
31853237
constraint_sizes.append(constraint_size)
31863238
constraint_strides.append(constraint_stride)
31873239

3188-
if marked_unbacked:
3240+
if marked_unbacked or is_unbacked_source(name):
31893241
dynamic_size = DimDynamic.SIZE_LIKE_UNBACKED
31903242
elif (
31913243
constraint_size is not None

torch/compiler/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,15 @@
8383
for models that are not identical, but are similar enough to share PGO profiles.
8484
"""
8585

86+
unbacked_sources: str = Config(
87+
env_name_default="TORCH_COMPILE_UNBACKED_SOURCES", default=""
88+
)
89+
"""
90+
Comma delimited list of sources that should be marked as unbacked. Primarily useful for large
91+
models with graph breaks where you need intermediate tensors marked unbacked.
92+
93+
This whitelist is dominant over all other flags dynamic=False, force_nn_module_property_static_shapes
94+
and force_parameter_static_shapes.
95+
"""
8696

8797
install_config_module(sys.modules[__name__])

0 commit comments

Comments
 (0)