2121
2222import dataclasses
2323import enum
24- from typing import Any , Optional , Union
24+ from typing import Any , Optional , TYPE_CHECKING , Union
2525
2626from torch ._guards import ChainedSource , GuardSource , Source
2727
2828from . import utils
2929from .bytecode_transformation import create_call_function , create_instruction
3030
3131
32+ if TYPE_CHECKING :
33+ from .codegen import PyCodegen
34+
3235# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
3336# so those cases are omitted intentionally
3437
@@ -120,7 +123,7 @@ class LocalSource(Source):
120123 # or `co_freevars`.
121124 is_derefed_cell_contents : bool = False
122125
123- def reconstruct (self , codegen ):
126+ def reconstruct (self , codegen : "PyCodegen" ):
124127 if self .is_derefed_cell_contents :
125128 codegen .load_deref (self .local_name )
126129 else :
@@ -137,7 +140,7 @@ def name(self):
137140class SyntheticLocalSource (Source ):
138141 local_name : str
139142
140- def reconstruct (self , codegen ):
143+ def reconstruct (self , codegen : "PyCodegen" ):
141144 codegen .append_output (codegen .create_load (self .local_name ))
142145
143146 def guard_source (self ):
@@ -154,7 +157,7 @@ class RandomValueSource(Source):
154157 def guard_source (self ):
155158 return GuardSource .RANDOM_VALUE
156159
157- def reconstruct (self , codegen ):
160+ def reconstruct (self , codegen : "PyCodegen" ):
158161 codegen .append_output (codegen .create_load (codegen .tx .output .random_values_var ))
159162 codegen .append_output (codegen .create_load_const (self .random_call_index ))
160163 codegen .append_output (create_instruction ("BINARY_SUBSCR" ))
@@ -167,7 +170,7 @@ def name(self):
167170class GlobalSource (Source ):
168171 global_name : str
169172
170- def reconstruct (self , codegen ):
173+ def reconstruct (self , codegen : "PyCodegen" ):
171174 codegen .append_output (codegen .create_load_global (self .global_name , add = True ))
172175
173176 def guard_source (self ):
@@ -181,7 +184,7 @@ def name(self):
181184class GlobalWeakRefSource (Source ):
182185 global_name : str
183186
184- def reconstruct (self , codegen ):
187+ def reconstruct (self , codegen : "PyCodegen" ):
185188 codegen .add_push_null (
186189 lambda : codegen .append_output (
187190 codegen .create_load_global (self .global_name , add = True )
@@ -198,7 +201,7 @@ def name(self):
198201
199202@dataclasses .dataclass (frozen = True )
200203class WeakRefCallSource (ChainedSource ):
201- def reconstruct (self , codegen ):
204+ def reconstruct (self , codegen : "PyCodegen" ):
202205 codegen .add_push_null (lambda : codegen (self .base ))
203206 codegen .extend_output (create_call_function (0 , False ))
204207
@@ -227,7 +230,7 @@ def __post_init__(self):
227230 )
228231 object .__setattr__ (self , "member" , member_parts [- 1 ])
229232
230- def reconstruct (self , codegen ):
233+ def reconstruct (self , codegen : "PyCodegen" ):
231234 codegen (self .base )
232235 codegen .extend_output (codegen .create_load_attrs (self .member ))
233236
@@ -249,7 +252,7 @@ class LocalCellSource(Source):
249252
250253 local_name : str
251254
252- def reconstruct (self , codegen ):
255+ def reconstruct (self , codegen : "PyCodegen" ):
253256 # Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics,
254257 # Dynamo's bytecode transformation differentiates them slightly, so we
255258 # always emit `LOAD_CLOSURE` here.
@@ -267,7 +270,7 @@ def reconstruct(self, codegen):
267270class GradSource (ChainedSource ):
268271 member : str = "grad"
269272
270- def reconstruct (self , codegen ):
273+ def reconstruct (self , codegen : "PyCodegen" ):
271274 codegen (self .base )
272275 codegen .extend_output (codegen .create_load_attrs (self .member ))
273276
@@ -342,7 +345,7 @@ def __post_init__(self):
342345 else :
343346 assert self .idx is not None
344347
345- def reconstruct (self , codegen ):
348+ def reconstruct (self , codegen : "PyCodegen" ):
346349 codegen .add_push_null (
347350 lambda : codegen .load_import_from (
348351 utils .__name__ , f"call_{ self .prop .method_name ()} "
@@ -378,7 +381,7 @@ class IndexedSource(ChainedSource):
378381 def __post_init__ (self ):
379382 assert self .base is not None
380383
381- def reconstruct (self , codegen ):
384+ def reconstruct (self , codegen : "PyCodegen" ):
382385 raise NotImplementedError
383386
384387 def guard_source (self ):
@@ -393,7 +396,7 @@ class NegateSource(ChainedSource):
393396 def __post_init__ (self ):
394397 assert self .base is not None
395398
396- def reconstruct (self , codegen ):
399+ def reconstruct (self , codegen : "PyCodegen" ):
397400 raise NotImplementedError
398401
399402 def guard_source (self ):
@@ -409,7 +412,7 @@ class ConvertIntSource(ChainedSource):
409412 def __post_init__ (self ):
410413 assert self .base is not None
411414
412- def reconstruct (self , codegen ):
415+ def reconstruct (self , codegen : "PyCodegen" ):
413416 codegen (self .base )
414417
415418 def guard_source (self ):
@@ -424,7 +427,7 @@ class FlattenScriptObjectSource(ChainedSource):
424427 def __post_init__ (self ):
425428 assert self .base is not None
426429
427- def reconstruct (self , codegen ):
430+ def reconstruct (self , codegen : "PyCodegen" ):
428431 codegen (self .base )
429432
430433 def guard_source (self ):
@@ -439,7 +442,7 @@ class ScriptObjectQualifiedNameSource(ChainedSource):
439442 def __post_init__ (self ):
440443 assert self .base is not None
441444
442- def reconstruct (self , codegen ):
445+ def reconstruct (self , codegen : "PyCodegen" ):
443446 codegen (self .base )
444447
445448 def guard_source (self ):
@@ -450,7 +453,7 @@ def name(self):
450453
451454
452455class AttrProxySource (ChainedSource ):
453- def reconstruct (self , codegen ):
456+ def reconstruct (self , codegen : "PyCodegen" ):
454457 codegen (self .base )
455458
456459 def guard_source (self ):
@@ -484,7 +487,7 @@ def __post_init__(self):
484487 self , "_name" , f"{ self .base .name ()} .{ self .field } [{ self .idx_key } ]"
485488 )
486489
487- def reconstruct (self , codegen ):
490+ def reconstruct (self , codegen : "PyCodegen" ):
488491 codegen (self .base )
489492 codegen .extend_output (codegen .create_load_attrs (self .field ))
490493 codegen .append_output (codegen .create_load_const (self .idx_key ))
@@ -509,7 +512,7 @@ def __post_init__(self):
509512 super ().__setattr__ ("index" , self .index .__reduce__ ())
510513 super ().__setattr__ ("index_is_slice" , True )
511514
512- def reconstruct (self , codegen ):
515+ def reconstruct (self , codegen : "PyCodegen" ):
513516 codegen (self .base )
514517 if self .index_is_slice :
515518 codegen .append_output (codegen .create_load_const (self .unpack_slice ()))
@@ -543,7 +546,7 @@ class ConstDictKeySource(ChainedSource):
543546 def guard_source (self ):
544547 return self .base .guard_source ()
545548
546- def reconstruct (self , codegen ):
549+ def reconstruct (self , codegen : "PyCodegen" ):
547550 codegen .add_push_null (
548551 lambda : codegen .load_import_from (utils .__name__ , "dict_keys_getitem" )
549552 )
@@ -577,7 +580,7 @@ def __post_init__(self):
577580 def guard_source (self ):
578581 return self .base .guard_source ()
579582
580- def reconstruct (self , codegen ):
583+ def reconstruct (self , codegen : "PyCodegen" ):
581584 # reconstruct dict.__getitem__(dct, key)
582585
583586 # Load dict.__getitem__
@@ -609,7 +612,7 @@ class ListGetItemSource(GetItemSource):
609612 Same as GetItemSource with reconstruct and name overridden to be list specific.
610613 """
611614
612- def reconstruct (self , codegen ):
615+ def reconstruct (self , codegen : "PyCodegen" ):
613616 # Reconstruct list.__getitem__(lst, index) to avoid any side effects
614617 # from possibly overridden __getitem__.
615618
@@ -646,7 +649,7 @@ def name(self):
646649
647650@dataclasses .dataclass (frozen = True )
648651class TupleIteratorGetItemSource (GetItemSource ):
649- def reconstruct (self , codegen ):
652+ def reconstruct (self , codegen : "PyCodegen" ):
650653 codegen .add_push_null (
651654 lambda : codegen .load_import_from (utils .__name__ , "tuple_iterator_getitem" )
652655 )
@@ -663,7 +666,7 @@ class TypeSource(ChainedSource):
663666 def __post_init__ (self ):
664667 assert self .base is not None
665668
666- def reconstruct (self , codegen ):
669+ def reconstruct (self , codegen : "PyCodegen" ):
667670 codegen .add_push_null (lambda : codegen .load_import_from ("builtins" , "type" ))
668671 codegen (self .base )
669672 codegen .extend_output (create_call_function (1 , False ))
@@ -677,7 +680,7 @@ def name(self):
677680
678681@dataclasses .dataclass (frozen = True )
679682class OptimizerSource (ChainedSource ):
680- def reconstruct (self , codegen ):
683+ def reconstruct (self , codegen : "PyCodegen" ):
681684 codegen (self .base )
682685
683686 def guard_source (self ):
@@ -689,7 +692,7 @@ def name(self):
689692
690693@dataclasses .dataclass (frozen = True )
691694class NNModuleSource (ChainedSource ):
692- def reconstruct (self , codegen ):
695+ def reconstruct (self , codegen : "PyCodegen" ):
693696 codegen (self .base )
694697
695698 def guard_source (self ):
@@ -738,7 +741,7 @@ def _get_index(self):
738741
739742 return TorchFunctionModeStackVariable .get_mode_index (self .ind )
740743
741- def reconstruct (self , codegen ):
744+ def reconstruct (self , codegen : "PyCodegen" ):
742745 codegen .add_push_null (
743746 lambda : codegen .load_import_from (
744747 utils .__name__ , "get_torch_function_mode_stack_at"
@@ -755,7 +758,7 @@ def guard_source(self):
755758class ConstantSource (Source ):
756759 source_name : str
757760
758- def reconstruct (self , codegen ):
761+ def reconstruct (self , codegen : "PyCodegen" ):
759762 codegen .append_output (codegen .create_load_global (self .source_name , add = False ))
760763
761764 def guard_source (self ):
@@ -776,7 +779,7 @@ def name(self) -> str:
776779 def guard_source (self ):
777780 return self .base .guard_source ()
778781
779- def reconstruct (self , codegen ):
782+ def reconstruct (self , codegen : "PyCodegen" ):
780783 codegen .add_push_null (lambda : codegen .load_import_from ("torch" , "as_tensor" ))
781784 codegen (self .base )
782785 codegen .extend_output (create_call_function (1 , False ))
0 commit comments