@@ -786,25 +786,20 @@ def visit_func_def(self, o: FuncDef) -> None:
786
786
elif o .name in KNOWN_MAGIC_METHODS_RETURN_TYPES :
787
787
retname = KNOWN_MAGIC_METHODS_RETURN_TYPES [o .name ]
788
788
elif has_yield_expression (o ) or has_yield_from_expression (o ):
789
- self .add_typing_import ("Generator" )
789
+ generator_name = self .add_typing_import ("Generator" )
790
790
yield_name = "None"
791
791
send_name = "None"
792
792
return_name = "None"
793
793
if has_yield_from_expression (o ):
794
- self .add_typing_import ("Incomplete" )
795
- yield_name = send_name = self .typing_name ("Incomplete" )
794
+ yield_name = send_name = self .add_typing_import ("Incomplete" )
796
795
else :
797
796
for expr , in_assignment in all_yield_expressions (o ):
798
797
if expr .expr is not None and not self .is_none_expr (expr .expr ):
799
- self .add_typing_import ("Incomplete" )
800
- yield_name = self .typing_name ("Incomplete" )
798
+ yield_name = self .add_typing_import ("Incomplete" )
801
799
if in_assignment :
802
- self .add_typing_import ("Incomplete" )
803
- send_name = self .typing_name ("Incomplete" )
800
+ send_name = self .add_typing_import ("Incomplete" )
804
801
if has_return_statement (o ):
805
- self .add_typing_import ("Incomplete" )
806
- return_name = self .typing_name ("Incomplete" )
807
- generator_name = self .typing_name ("Generator" )
802
+ return_name = self .add_typing_import ("Incomplete" )
808
803
retname = f"{ generator_name } [{ yield_name } , { send_name } , { return_name } ]"
809
804
elif not has_return_statement (o ) and o .abstract_status == NOT_ABSTRACT :
810
805
retname = "None"
@@ -965,21 +960,19 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
965
960
nt_fields = self ._get_namedtuple_fields (base )
966
961
assert isinstance (base .args [0 ], StrExpr )
967
962
typename = base .args [0 ].value
968
- if nt_fields is not None :
969
- fields_str = ", " .join (f"({ f !r} , { t } )" for f , t in nt_fields )
970
- namedtuple_name = self .typing_name ("NamedTuple" )
971
- base_types .append (f"{ namedtuple_name } ({ typename !r} , [{ fields_str } ])" )
972
- self .add_typing_import ("NamedTuple" )
973
- else :
963
+ if nt_fields is None :
974
964
# Invalid namedtuple() call, cannot determine fields
975
- base_types .append (self .typing_name ("Incomplete" ))
965
+ base_types .append (self .add_typing_import ("Incomplete" ))
966
+ continue
967
+ fields_str = ", " .join (f"({ f !r} , { t } )" for f , t in nt_fields )
968
+ namedtuple_name = self .add_typing_import ("NamedTuple" )
969
+ base_types .append (f"{ namedtuple_name } ({ typename !r} , [{ fields_str } ])" )
976
970
elif self .is_typed_namedtuple (base ):
977
971
base_types .append (base .accept (p ))
978
972
else :
979
973
# At this point, we don't know what the base class is, so we
980
974
# just use Incomplete as the base class.
981
- base_types .append (self .typing_name ("Incomplete" ))
982
- self .add_typing_import ("Incomplete" )
975
+ base_types .append (self .add_typing_import ("Incomplete" ))
983
976
for name , value in cdef .keywords .items ():
984
977
if name == "metaclass" :
985
978
continue # handled separately
@@ -1059,9 +1052,9 @@ def _get_namedtuple_fields(self, call: CallExpr) -> list[tuple[str, str]] | None
1059
1052
field_names .append (field .value )
1060
1053
else :
1061
1054
return None # Invalid namedtuple fields type
1062
- if field_names :
1063
- self . add_typing_import ( "Incomplete" )
1064
- incomplete = self .typing_name ("Incomplete" )
1055
+ if not field_names :
1056
+ return []
1057
+ incomplete = self .add_typing_import ("Incomplete" )
1065
1058
return [(field_name , incomplete ) for field_name in field_names ]
1066
1059
elif self .is_typed_namedtuple (call ):
1067
1060
fields_arg = call .args [1 ]
@@ -1092,8 +1085,7 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
1092
1085
if fields is None :
1093
1086
self .annotate_as_incomplete (lvalue )
1094
1087
return
1095
- self .add_typing_import ("NamedTuple" )
1096
- bases = self .typing_name ("NamedTuple" )
1088
+ bases = self .add_typing_import ("NamedTuple" )
1097
1089
# TODO: Add support for generic NamedTuples. Requires `Generic` as base class.
1098
1090
class_def = f"{ self ._indent } class { lvalue .name } ({ bases } ):"
1099
1091
if len (fields ) == 0 :
@@ -1143,14 +1135,13 @@ def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
1143
1135
total = arg
1144
1136
else :
1145
1137
items .append ((arg_name , arg ))
1146
- self .add_typing_import ("TypedDict" )
1138
+ bases = self .add_typing_import ("TypedDict" )
1147
1139
p = AliasPrinter (self )
1148
1140
if any (not key .isidentifier () or keyword .iskeyword (key ) for key , _ in items ):
1149
1141
# Keep the call syntax if there are non-identifier or reserved keyword keys.
1150
1142
self .add (f"{ self ._indent } { lvalue .name } = { rvalue .accept (p )} \n " )
1151
1143
self ._state = VAR
1152
1144
else :
1153
- bases = self .typing_name ("TypedDict" )
1154
1145
# TODO: Add support for generic TypedDicts. Requires `Generic` as base class.
1155
1146
if total is not None :
1156
1147
bases += f", total={ total .accept (p )} "
@@ -1167,8 +1158,7 @@ def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
1167
1158
self ._state = CLASS
1168
1159
1169
1160
def annotate_as_incomplete (self , lvalue : NameExpr ) -> None :
1170
- self .add_typing_import ("Incomplete" )
1171
- self .add (f"{ self ._indent } { lvalue .name } : { self .typing_name ('Incomplete' )} \n " )
1161
+ self .add (f"{ self ._indent } { lvalue .name } : { self .add_typing_import ('Incomplete' )} \n " )
1172
1162
self ._state = VAR
1173
1163
1174
1164
def is_alias_expression (self , expr : Expression , top_level : bool = True ) -> bool :
@@ -1384,13 +1374,14 @@ def typing_name(self, name: str) -> str:
1384
1374
else :
1385
1375
return name
1386
1376
1387
- def add_typing_import (self , name : str ) -> None :
1377
+ def add_typing_import (self , name : str ) -> str :
1388
1378
"""Add a name to be imported for typing, unless it's imported already.
1389
1379
1390
1380
The import will be internal to the stub.
1391
1381
"""
1392
1382
name = self .typing_name (name )
1393
1383
self .import_tracker .require_name (name )
1384
+ return name
1394
1385
1395
1386
def add_import_line (self , line : str ) -> None :
1396
1387
"""Add a line of text to the import section, unless it's already there."""
@@ -1448,11 +1439,9 @@ def get_str_type_of_node(
1448
1439
if isinstance (rvalue , NameExpr ) and rvalue .name in ("True" , "False" ):
1449
1440
return "bool"
1450
1441
if can_infer_optional and isinstance (rvalue , NameExpr ) and rvalue .name == "None" :
1451
- self .add_typing_import ("Incomplete" )
1452
- return f"{ self .typing_name ('Incomplete' )} | None"
1442
+ return f"{ self .add_typing_import ('Incomplete' )} | None"
1453
1443
if can_be_any :
1454
- self .add_typing_import ("Incomplete" )
1455
- return self .typing_name ("Incomplete" )
1444
+ return self .add_typing_import ("Incomplete" )
1456
1445
else :
1457
1446
return ""
1458
1447
0 commit comments