Skip to content

Commit 1b0a023

Browse files
shinkpytorchmergebot
authored andcommitted
[Dynamo][Misc] Apply typing hints for codegen (pytorch#150289)
Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#150289 Approved by: https://github.com/Skylion007, https://github.com/cyyever
1 parent 295b7e2 commit 1b0a023

File tree

17 files changed

+101
-80
lines changed

17 files changed

+101
-80
lines changed

torch/_dynamo/codegen.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import sys
1919
import types
2020
from collections import Counter
21-
from typing import Optional, Union
21+
from typing import Optional, TYPE_CHECKING, Union
2222

2323
import torch.nn
2424
from torch.utils._ordered_set import OrderedSet
@@ -54,6 +54,10 @@
5454
from .variables.torch_function import TensorWithTFOverrideVariable
5555

5656

57+
if TYPE_CHECKING:
58+
from .symbolic_convert import InstructionTranslatorBase
59+
60+
5761
@dataclasses.dataclass
5862
class GraphOutputEntry:
5963
index: int
@@ -67,7 +71,7 @@ class PyCodegen:
6771

6872
def __init__(
6973
self,
70-
tx=None,
74+
tx: "InstructionTranslatorBase",
7175
root: Optional[torch.nn.Module] = None,
7276
graph_output_var: Optional[str] = None,
7377
tempvars=None,

torch/_dynamo/output_graph.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def __init__(
390390
# and LOAD_ATTR for same python objects free.
391391
self.variable_tracker_cache = VariableTrackerCache()
392392
self.unique_var_id = itertools.count()
393-
self.code_options = dict(code_options)
393+
self.code_options: dict[str, Any] = dict(code_options)
394394
self.output_instructions: list[Instruction] = []
395395
# used to track nodes that are added between calls of copy_graphstate
396396
# and restore_graphstate
@@ -401,7 +401,7 @@ def __init__(
401401

402402
# Not checkpointed
403403
self.compiler_fn: Optional[CompilerFn] = compiler_fn
404-
self.global_scope = global_scope
404+
self.global_scope: Scope = global_scope
405405
self.local_scope = local_scope
406406
self.root_tx = root_tx
407407

@@ -462,7 +462,7 @@ def __init__(
462462
self.random_calls: list[
463463
tuple[Callable[..., object], tuple[object, ...], dict[str, object]]
464464
] = []
465-
self.random_values_var = None
465+
self.random_values_var: Any = None
466466

467467
# Bytecode to insert right before we call the graph
468468
self.pregraph_bytecode: list[Instruction] = []
@@ -888,7 +888,9 @@ def wrap_name(module_key):
888888
self.output.update_co_names(module_key)
889889
self.global_scope[module_key] = target
890890
return VariableTracker.build(
891-
self, target, ConstantSource(source_name=module_key)
891+
self, # type: ignore[arg-type]
892+
target,
893+
ConstantSource(source_name=module_key),
892894
)
893895

894896
for k, v in self.nn_modules.items():

torch/_dynamo/source.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@
2121

2222
import dataclasses
2323
import enum
24-
from typing import Any, Optional, Union
24+
from typing import Any, Optional, TYPE_CHECKING, Union
2525

2626
from torch._guards import ChainedSource, GuardSource, Source
2727

2828
from . import utils
2929
from .bytecode_transformation import create_call_function, create_instruction
3030

3131

32+
if TYPE_CHECKING:
33+
from .codegen import PyCodegen
34+
3235
# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
3336
# so those cases are omitted intentionally
3437

@@ -120,7 +123,7 @@ class LocalSource(Source):
120123
# or `co_freevars`.
121124
is_derefed_cell_contents: bool = False
122125

123-
def reconstruct(self, codegen):
126+
def reconstruct(self, codegen: "PyCodegen"):
124127
if self.is_derefed_cell_contents:
125128
codegen.load_deref(self.local_name)
126129
else:
@@ -137,7 +140,7 @@ def name(self):
137140
class SyntheticLocalSource(Source):
138141
local_name: str
139142

140-
def reconstruct(self, codegen):
143+
def reconstruct(self, codegen: "PyCodegen"):
141144
codegen.append_output(codegen.create_load(self.local_name))
142145

143146
def guard_source(self):
@@ -154,7 +157,7 @@ class RandomValueSource(Source):
154157
def guard_source(self):
155158
return GuardSource.RANDOM_VALUE
156159

157-
def reconstruct(self, codegen):
160+
def reconstruct(self, codegen: "PyCodegen"):
158161
codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var))
159162
codegen.append_output(codegen.create_load_const(self.random_call_index))
160163
codegen.append_output(create_instruction("BINARY_SUBSCR"))
@@ -167,7 +170,7 @@ def name(self):
167170
class GlobalSource(Source):
168171
global_name: str
169172

170-
def reconstruct(self, codegen):
173+
def reconstruct(self, codegen: "PyCodegen"):
171174
codegen.append_output(codegen.create_load_global(self.global_name, add=True))
172175

173176
def guard_source(self):
@@ -181,7 +184,7 @@ def name(self):
181184
class GlobalWeakRefSource(Source):
182185
global_name: str
183186

184-
def reconstruct(self, codegen):
187+
def reconstruct(self, codegen: "PyCodegen"):
185188
codegen.add_push_null(
186189
lambda: codegen.append_output(
187190
codegen.create_load_global(self.global_name, add=True)
@@ -198,7 +201,7 @@ def name(self):
198201

199202
@dataclasses.dataclass(frozen=True)
200203
class WeakRefCallSource(ChainedSource):
201-
def reconstruct(self, codegen):
204+
def reconstruct(self, codegen: "PyCodegen"):
202205
codegen.add_push_null(lambda: codegen(self.base))
203206
codegen.extend_output(create_call_function(0, False))
204207

@@ -227,7 +230,7 @@ def __post_init__(self):
227230
)
228231
object.__setattr__(self, "member", member_parts[-1])
229232

230-
def reconstruct(self, codegen):
233+
def reconstruct(self, codegen: "PyCodegen"):
231234
codegen(self.base)
232235
codegen.extend_output(codegen.create_load_attrs(self.member))
233236

@@ -249,7 +252,7 @@ class LocalCellSource(Source):
249252

250253
local_name: str
251254

252-
def reconstruct(self, codegen):
255+
def reconstruct(self, codegen: "PyCodegen"):
253256
# Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics,
254257
# Dynamo's bytecode transformation differentiates them slightly, so we
255258
# always emit `LOAD_CLOSURE` here.
@@ -267,7 +270,7 @@ def reconstruct(self, codegen):
267270
class GradSource(ChainedSource):
268271
member: str = "grad"
269272

270-
def reconstruct(self, codegen):
273+
def reconstruct(self, codegen: "PyCodegen"):
271274
codegen(self.base)
272275
codegen.extend_output(codegen.create_load_attrs(self.member))
273276

@@ -342,7 +345,7 @@ def __post_init__(self):
342345
else:
343346
assert self.idx is not None
344347

345-
def reconstruct(self, codegen):
348+
def reconstruct(self, codegen: "PyCodegen"):
346349
codegen.add_push_null(
347350
lambda: codegen.load_import_from(
348351
utils.__name__, f"call_{self.prop.method_name()}"
@@ -378,7 +381,7 @@ class IndexedSource(ChainedSource):
378381
def __post_init__(self):
379382
assert self.base is not None
380383

381-
def reconstruct(self, codegen):
384+
def reconstruct(self, codegen: "PyCodegen"):
382385
raise NotImplementedError
383386

384387
def guard_source(self):
@@ -393,7 +396,7 @@ class NegateSource(ChainedSource):
393396
def __post_init__(self):
394397
assert self.base is not None
395398

396-
def reconstruct(self, codegen):
399+
def reconstruct(self, codegen: "PyCodegen"):
397400
raise NotImplementedError
398401

399402
def guard_source(self):
@@ -409,7 +412,7 @@ class ConvertIntSource(ChainedSource):
409412
def __post_init__(self):
410413
assert self.base is not None
411414

412-
def reconstruct(self, codegen):
415+
def reconstruct(self, codegen: "PyCodegen"):
413416
codegen(self.base)
414417

415418
def guard_source(self):
@@ -424,7 +427,7 @@ class FlattenScriptObjectSource(ChainedSource):
424427
def __post_init__(self):
425428
assert self.base is not None
426429

427-
def reconstruct(self, codegen):
430+
def reconstruct(self, codegen: "PyCodegen"):
428431
codegen(self.base)
429432

430433
def guard_source(self):
@@ -439,7 +442,7 @@ class ScriptObjectQualifiedNameSource(ChainedSource):
439442
def __post_init__(self):
440443
assert self.base is not None
441444

442-
def reconstruct(self, codegen):
445+
def reconstruct(self, codegen: "PyCodegen"):
443446
codegen(self.base)
444447

445448
def guard_source(self):
@@ -450,7 +453,7 @@ def name(self):
450453

451454

452455
class AttrProxySource(ChainedSource):
453-
def reconstruct(self, codegen):
456+
def reconstruct(self, codegen: "PyCodegen"):
454457
codegen(self.base)
455458

456459
def guard_source(self):
@@ -484,7 +487,7 @@ def __post_init__(self):
484487
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
485488
)
486489

487-
def reconstruct(self, codegen):
490+
def reconstruct(self, codegen: "PyCodegen"):
488491
codegen(self.base)
489492
codegen.extend_output(codegen.create_load_attrs(self.field))
490493
codegen.append_output(codegen.create_load_const(self.idx_key))
@@ -509,7 +512,7 @@ def __post_init__(self):
509512
super().__setattr__("index", self.index.__reduce__())
510513
super().__setattr__("index_is_slice", True)
511514

512-
def reconstruct(self, codegen):
515+
def reconstruct(self, codegen: "PyCodegen"):
513516
codegen(self.base)
514517
if self.index_is_slice:
515518
codegen.append_output(codegen.create_load_const(self.unpack_slice()))
@@ -543,7 +546,7 @@ class ConstDictKeySource(ChainedSource):
543546
def guard_source(self):
544547
return self.base.guard_source()
545548

546-
def reconstruct(self, codegen):
549+
def reconstruct(self, codegen: "PyCodegen"):
547550
codegen.add_push_null(
548551
lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem")
549552
)
@@ -577,7 +580,7 @@ def __post_init__(self):
577580
def guard_source(self):
578581
return self.base.guard_source()
579582

580-
def reconstruct(self, codegen):
583+
def reconstruct(self, codegen: "PyCodegen"):
581584
# reconstruct dict.__getitem__(dct, key)
582585

583586
# Load dict.__getitem__
@@ -609,7 +612,7 @@ class ListGetItemSource(GetItemSource):
609612
Same as GetItemSource with reconstruct and name overridden to be list specific.
610613
"""
611614

612-
def reconstruct(self, codegen):
615+
def reconstruct(self, codegen: "PyCodegen"):
613616
# Reconstruct list.__getitem__(lst, index) to avoid any side effects
614617
# from possibly overridden __getitem__.
615618

@@ -646,7 +649,7 @@ def name(self):
646649

647650
@dataclasses.dataclass(frozen=True)
648651
class TupleIteratorGetItemSource(GetItemSource):
649-
def reconstruct(self, codegen):
652+
def reconstruct(self, codegen: "PyCodegen"):
650653
codegen.add_push_null(
651654
lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
652655
)
@@ -663,7 +666,7 @@ class TypeSource(ChainedSource):
663666
def __post_init__(self):
664667
assert self.base is not None
665668

666-
def reconstruct(self, codegen):
669+
def reconstruct(self, codegen: "PyCodegen"):
667670
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type"))
668671
codegen(self.base)
669672
codegen.extend_output(create_call_function(1, False))
@@ -677,7 +680,7 @@ def name(self):
677680

678681
@dataclasses.dataclass(frozen=True)
679682
class OptimizerSource(ChainedSource):
680-
def reconstruct(self, codegen):
683+
def reconstruct(self, codegen: "PyCodegen"):
681684
codegen(self.base)
682685

683686
def guard_source(self):
@@ -689,7 +692,7 @@ def name(self):
689692

690693
@dataclasses.dataclass(frozen=True)
691694
class NNModuleSource(ChainedSource):
692-
def reconstruct(self, codegen):
695+
def reconstruct(self, codegen: "PyCodegen"):
693696
codegen(self.base)
694697

695698
def guard_source(self):
@@ -738,7 +741,7 @@ def _get_index(self):
738741

739742
return TorchFunctionModeStackVariable.get_mode_index(self.ind)
740743

741-
def reconstruct(self, codegen):
744+
def reconstruct(self, codegen: "PyCodegen"):
742745
codegen.add_push_null(
743746
lambda: codegen.load_import_from(
744747
utils.__name__, "get_torch_function_mode_stack_at"
@@ -755,7 +758,7 @@ def guard_source(self):
755758
class ConstantSource(Source):
756759
source_name: str
757760

758-
def reconstruct(self, codegen):
761+
def reconstruct(self, codegen: "PyCodegen"):
759762
codegen.append_output(codegen.create_load_global(self.source_name, add=False))
760763

761764
def guard_source(self):
@@ -776,7 +779,7 @@ def name(self) -> str:
776779
def guard_source(self):
777780
return self.base.guard_source()
778781

779-
def reconstruct(self, codegen):
782+
def reconstruct(self, codegen: "PyCodegen"):
780783
codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor"))
781784
codegen(self.base)
782785
codegen.extend_output(create_call_function(1, False))

torch/_dynamo/variables/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929

3030

3131
if TYPE_CHECKING:
32-
from .symbolic_convert import InstructionTranslator, InstructionTranslatorBase
32+
from ..codegen import PyCodegen
33+
from ..symbolic_convert import InstructionTranslator, InstructionTranslatorBase
3334

3435

3536
class SourceType(Enum):
@@ -399,7 +400,7 @@ def maybe_fx_node(self):
399400
except NotImplementedError:
400401
return None
401402

402-
def reconstruct(self, codegen):
403+
def reconstruct(self, codegen: "PyCodegen"):
403404
raise NotImplementedError
404405

405406
def unpack_var_sequence(self, tx) -> list["VariableTracker"]:

torch/_dynamo/variables/builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@
276276

277277

278278
if TYPE_CHECKING:
279+
from torch._dynamo.codegen import PyCodegen
279280
from torch._dynamo.symbolic_convert import InstructionTranslator
280281

281282

@@ -348,7 +349,7 @@ def __post_init__(self):
348349
self._example = TensorWeakRef(self._example)
349350
assert is_fake(self.fake_tensor)
350351

351-
def reconstruct(self, codegen):
352+
def reconstruct(self, codegen: "PyCodegen"):
352353
codegen(self.source)
353354

354355
def erase(self):
@@ -369,7 +370,7 @@ def __init__(self) -> None:
369370
is_tensor=False,
370371
)
371372

372-
def reconstruct(self, codegen):
373+
def reconstruct(self, codegen: "PyCodegen"):
373374
assert codegen.tx.output.backward_state_var
374375
codegen.add_push_null(
375376
lambda: codegen.load_import_from(BackwardState.__module__, "BackwardState")

torch/_dynamo/variables/builtin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787

8888
if TYPE_CHECKING:
8989
# Cyclic dependency...
90+
from torch._dynamo.codegen import PyCodegen
9091
from torch._dynamo.symbolic_convert import InstructionTranslator
9192

9293
log = logging.getLogger(__name__)
@@ -730,7 +731,7 @@ def as_proxy(self):
730731
return DTYPE[self.fn]
731732
return super().as_proxy()
732733

733-
def reconstruct(self, codegen: "torch._dynamo.codegen.PyCodegen"):
734+
def reconstruct(self, codegen: "PyCodegen"):
734735
name = self.fn.__name__
735736
assert self.fn.__module__ == "builtins"
736737
assert name not in codegen.tx.f_globals, "shadowed global"

0 commit comments

Comments
 (0)