@@ -564,6 +564,8 @@ def check_function_signature(self, fdef: FuncItem) -> None:
564
564
565
565
def visit_class_def (self , defn : ClassDef ) -> None :
566
566
self .clean_up_bases_and_infer_type_variables (defn )
567
+ if self .analyze_typeddict_classdef (defn ):
568
+ return
567
569
if self .analyze_namedtuple_classdef (defn ):
568
570
return
569
571
self .setup_class_def_analysis (defn )
@@ -1009,6 +1011,101 @@ def bind_class_type_variables_in_symbol_table(
1009
1011
nodes .append (node )
1010
1012
return nodes
1011
1013
1014
+ def is_typeddict (self , expr : Expression ) -> bool :
1015
+ return (isinstance (expr , RefExpr ) and isinstance (expr .node , TypeInfo ) and
1016
+ expr .node .typeddict_type is not None )
1017
+
1018
+ def analyze_typeddict_classdef (self , defn : ClassDef ) -> bool :
1019
+ # special case for TypedDict
1020
+ possible = False
1021
+ for base_expr in defn .base_type_exprs :
1022
+ if isinstance (base_expr , RefExpr ):
1023
+ base_expr .accept (self )
1024
+ if (base_expr .fullname == 'mypy_extensions.TypedDict' or
1025
+ self .is_typeddict (base_expr )):
1026
+ possible = True
1027
+ if possible :
1028
+ node = self .lookup (defn .name , defn )
1029
+ if node is not None :
1030
+ node .kind = GDEF # TODO in process_namedtuple_definition also applies here
1031
+ if (len (defn .base_type_exprs ) == 1 and
1032
+ isinstance (defn .base_type_exprs [0 ], RefExpr ) and
1033
+ defn .base_type_exprs [0 ].fullname == 'mypy_extensions.TypedDict' ):
1034
+ # Building a new TypedDict
1035
+ fields , types = self .check_typeddict_classdef (defn )
1036
+ node .node = self .build_typeddict_typeinfo (defn .name , fields , types )
1037
+ return True
1038
+ # Extending/merging existing TypedDicts
1039
+ if any (not isinstance (expr , RefExpr ) or
1040
+ expr .fullname != 'mypy_extensions.TypedDict' and
1041
+ not self .is_typeddict (expr ) for expr in defn .base_type_exprs ):
1042
+ self .fail ("All bases of a new TypedDict must be TypedDict types" , defn )
1043
+ typeddict_bases = list (filter (self .is_typeddict , defn .base_type_exprs ))
1044
+ newfields = [] # type: List[str]
1045
+ newtypes = [] # type: List[Type]
1046
+ tpdict = None # type: OrderedDict[str, Type]
1047
+ for base in typeddict_bases :
1048
+ assert isinstance (base , RefExpr )
1049
+ assert isinstance (base .node , TypeInfo )
1050
+ assert isinstance (base .node .typeddict_type , TypedDictType )
1051
+ tpdict = base .node .typeddict_type .items
1052
+ newdict = tpdict .copy ()
1053
+ for key in tpdict :
1054
+ if key in newfields :
1055
+ self .fail ('Cannot overwrite TypedDict field "{}" while merging'
1056
+ .format (key ), defn )
1057
+ newdict .pop (key )
1058
+ newfields .extend (newdict .keys ())
1059
+ newtypes .extend (newdict .values ())
1060
+ fields , types = self .check_typeddict_classdef (defn , newfields )
1061
+ newfields .extend (fields )
1062
+ newtypes .extend (types )
1063
+ node .node = self .build_typeddict_typeinfo (defn .name , newfields , newtypes )
1064
+ return True
1065
+ return False
1066
+
1067
+ def check_typeddict_classdef (self , defn : ClassDef ,
1068
+ oldfields : List [str ] = None ) -> Tuple [List [str ], List [Type ]]:
1069
+ TPDICT_CLASS_ERROR = ('Invalid statement in TypedDict definition; '
1070
+ 'expected "field_name: field_type"' )
1071
+ if self .options .python_version < (3 , 6 ):
1072
+ self .fail ('TypedDict class syntax is only supported in Python 3.6' , defn )
1073
+ return [], []
1074
+ fields = [] # type: List[str]
1075
+ types = [] # type: List[Type]
1076
+ for stmt in defn .defs .body :
1077
+ if not isinstance (stmt , AssignmentStmt ):
1078
+ # Still allow pass or ... (for empty TypedDict's).
1079
+ if (not isinstance (stmt , PassStmt ) and
1080
+ not (isinstance (stmt , ExpressionStmt ) and
1081
+ isinstance (stmt .expr , EllipsisExpr ))):
1082
+ self .fail (TPDICT_CLASS_ERROR , stmt )
1083
+ elif len (stmt .lvalues ) > 1 or not isinstance (stmt .lvalues [0 ], NameExpr ):
1084
+ # An assignment, but an invalid one.
1085
+ self .fail (TPDICT_CLASS_ERROR , stmt )
1086
+ else :
1087
+ name = stmt .lvalues [0 ].name
1088
+ if name in (oldfields or []):
1089
+ self .fail ('Cannot overwrite TypedDict field "{}" while extending'
1090
+ .format (name ), stmt )
1091
+ continue
1092
+ if name in fields :
1093
+ self .fail ('Duplicate TypedDict field "{}"' .format (name ), stmt )
1094
+ continue
1095
+ # Append name and type in this case...
1096
+ fields .append (name )
1097
+ types .append (AnyType () if stmt .type is None else self .anal_type (stmt .type ))
1098
+ # ...despite possible minor failures that allow further analyzis.
1099
+ if name .startswith ('_' ):
1100
+ self .fail ('TypedDict field name cannot start with an underscore: {}'
1101
+ .format (name ), stmt )
1102
+ if stmt .type is None or hasattr (stmt , 'new_syntax' ) and not stmt .new_syntax :
1103
+ self .fail (TPDICT_CLASS_ERROR , stmt )
1104
+ elif not isinstance (stmt .rvalue , TempNode ):
1105
+ # x: int assigns rvalue to TempNode(AnyType())
1106
+ self .fail ('Right hand side values are not supported in TypedDict' , stmt )
1107
+ return fields , types
1108
+
1012
1109
def visit_import (self , i : Import ) -> None :
1013
1110
for id , as_id in i .ids :
1014
1111
if as_id is not None :
0 commit comments