6060 FuncExpr , MDEF , FuncBase , Decorator , SetExpr , TypeVarExpr , NewTypeExpr ,
6161 StrExpr , BytesExpr , PrintStmt , ConditionalExpr , PromoteExpr ,
6262 ComparisonExpr , StarExpr , ARG_POS , ARG_NAMED , MroError , type_aliases ,
63- YieldFromExpr , NamedTupleExpr , NonlocalDecl , SymbolNode ,
63+ YieldFromExpr , NamedTupleExpr , TypedDictExpr , NonlocalDecl , SymbolNode ,
6464 SetComprehension , DictionaryComprehension , TYPE_ALIAS , TypeAliasExpr ,
6565 YieldExpr , ExecStmt , Argument , BackquoteExpr , ImportBase , AwaitExpr ,
6666 IntExpr , FloatExpr , UnicodeExpr , EllipsisExpr ,
@@ -1127,6 +1127,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
11271127 self .process_newtype_declaration (s )
11281128 self .process_typevar_declaration (s )
11291129 self .process_namedtuple_definition (s )
1130+ self .process_typeddict_definition (s )
11301131
11311132 if (len (s .lvalues ) == 1 and isinstance (s .lvalues [0 ], NameExpr ) and
11321133 s .lvalues [0 ].name == '__all__' and s .lvalues [0 ].kind == GDEF and
@@ -1498,9 +1499,9 @@ def get_typevar_declaration(self, s: AssignmentStmt) -> Optional[CallExpr]:
14981499 if not isinstance (s .rvalue , CallExpr ):
14991500 return None
15001501 call = s .rvalue
1501- if not isinstance (call .callee , RefExpr ):
1502- return None
15031502 callee = call .callee
1503+ if not isinstance (callee , RefExpr ):
1504+ return None
15041505 if callee .fullname != 'typing.TypeVar' :
15051506 return None
15061507 return call
@@ -1579,10 +1580,9 @@ def process_namedtuple_definition(self, s: AssignmentStmt) -> None:
15791580 # Yes, it's a valid namedtuple definition. Add it to the symbol table.
15801581 node = self .lookup (name , s )
15811582 node .kind = GDEF # TODO locally defined namedtuple
1582- # TODO call.analyzed
15831583 node .node = named_tuple
15841584
1585- def check_namedtuple (self , node : Expression , var_name : str = None ) -> TypeInfo :
1585+ def check_namedtuple (self , node : Expression , var_name : str = None ) -> Optional [ TypeInfo ] :
15861586 """Check if a call defines a namedtuple.
15871587
15881588 The optional var_name argument is the name of the variable to
@@ -1596,9 +1596,9 @@ def check_namedtuple(self, node: Expression, var_name: str = None) -> TypeInfo:
15961596 if not isinstance (node , CallExpr ):
15971597 return None
15981598 call = node
1599- if not isinstance (call .callee , RefExpr ):
1600- return None
16011599 callee = call .callee
1600+ if not isinstance (callee , RefExpr ):
1601+ return None
16021602 fullname = callee .fullname
16031603 if fullname not in ('collections.namedtuple' , 'typing.NamedTuple' ):
16041604 return None
@@ -1607,9 +1607,9 @@ def check_namedtuple(self, node: Expression, var_name: str = None) -> TypeInfo:
16071607 # Error. Construct dummy return value.
16081608 return self .build_namedtuple_typeinfo ('namedtuple' , [], [])
16091609 else :
1610- # Give it a unique name derived from the line number.
16111610 name = cast (StrExpr , call .args [0 ]).value
16121611 if name != var_name :
1612+ # Give it a unique name derived from the line number.
16131613 name += '@' + str (call .line )
16141614 info = self .build_namedtuple_typeinfo (name , items , types )
16151615 # Store it as a global just in case it would remain anonymous.
@@ -1620,7 +1620,7 @@ def check_namedtuple(self, node: Expression, var_name: str = None) -> TypeInfo:
16201620
16211621 def parse_namedtuple_args (self , call : CallExpr ,
16221622 fullname : str ) -> Tuple [List [str ], List [Type ], bool ]:
1623- # TODO Share code with check_argument_count in checkexpr.py?
1623+ # TODO: Share code with check_argument_count in checkexpr.py?
16241624 args = call .args
16251625 if len (args ) < 2 :
16261626 return self .fail_namedtuple_arg ("Too few arguments for namedtuple()" , call )
@@ -1777,6 +1777,114 @@ def analyze_types(self, items: List[Expression]) -> List[Type]:
17771777 result .append (AnyType ())
17781778 return result
17791779
1780+ def process_typeddict_definition (self , s : AssignmentStmt ) -> None :
1781+ """Check if s defines a TypedDict; if yes, store the definition in symbol table."""
1782+ if len (s .lvalues ) != 1 or not isinstance (s .lvalues [0 ], NameExpr ):
1783+ return
1784+ lvalue = s .lvalues [0 ]
1785+ name = lvalue .name
1786+ typed_dict = self .check_typeddict (s .rvalue , name )
1787+ if typed_dict is None :
1788+ return
1789+ # Yes, it's a valid TypedDict definition. Add it to the symbol table.
1790+ node = self .lookup (name , s )
1791+ node .kind = GDEF # TODO locally defined TypedDict
1792+ node .node = typed_dict
1793+
1794+ def check_typeddict (self , node : Expression , var_name : str = None ) -> Optional [TypeInfo ]:
1795+ """Check if a call defines a TypedDict.
1796+
1797+ The optional var_name argument is the name of the variable to
1798+ which this is assigned, if any.
1799+
1800+ If it does, return the corresponding TypeInfo. Return None otherwise.
1801+
1802+ If the definition is invalid but looks like a TypedDict,
1803+ report errors but return (some) TypeInfo.
1804+ """
1805+ if not isinstance (node , CallExpr ):
1806+ return None
1807+ call = node
1808+ callee = call .callee
1809+ if not isinstance (callee , RefExpr ):
1810+ return None
1811+ fullname = callee .fullname
1812+ if fullname != 'mypy_extensions.TypedDict' :
1813+ return None
1814+ items , types , ok = self .parse_typeddict_args (call , fullname )
1815+ if not ok :
1816+ # Error. Construct dummy return value.
1817+ return self .build_typeddict_typeinfo ('TypedDict' , [], [])
1818+ else :
1819+ name = cast (StrExpr , call .args [0 ]).value
1820+ if name != var_name :
1821+ # Give it a unique name derived from the line number.
1822+ name += '@' + str (call .line )
1823+ info = self .build_typeddict_typeinfo (name , items , types )
1824+ # Store it as a global just in case it would remain anonymous.
1825+ self .globals [name ] = SymbolTableNode (GDEF , info , self .cur_mod_id )
1826+ call .analyzed = TypedDictExpr (info )
1827+ call .analyzed .set_line (call .line , call .column )
1828+ return info
1829+
1830+ def parse_typeddict_args (self , call : CallExpr ,
1831+ fullname : str ) -> Tuple [List [str ], List [Type ], bool ]:
1832+ # TODO: Share code with check_argument_count in checkexpr.py?
1833+ args = call .args
1834+ if len (args ) < 2 :
1835+ return self .fail_typeddict_arg ("Too few arguments for TypedDict()" , call )
1836+ if len (args ) > 2 :
1837+ return self .fail_typeddict_arg ("Too many arguments for TypedDict()" , call )
1838+ # TODO: Support keyword arguments
1839+ if call .arg_kinds != [ARG_POS , ARG_POS ]:
1840+ return self .fail_typeddict_arg ("Unexpected arguments to TypedDict()" , call )
1841+ if not isinstance (args [0 ], (StrExpr , BytesExpr , UnicodeExpr )):
1842+ return self .fail_typeddict_arg (
1843+ "TypedDict() expects a string literal as the first argument" , call )
1844+ if not isinstance (args [1 ], DictExpr ):
1845+ return self .fail_typeddict_arg (
1846+ "TypedDict() expects a dictionary literal as the second argument" , call )
1847+ dictexpr = args [1 ]
1848+ items , types , ok = self .parse_typeddict_fields_with_types (dictexpr .items , call )
1849+ return items , types , ok
1850+
1851+ def parse_typeddict_fields_with_types (self , dict_items : List [Tuple [Expression , Expression ]],
1852+ context : Context ) -> Tuple [List [str ], List [Type ], bool ]:
1853+ items = [] # type: List[str]
1854+ types = [] # type: List[Type]
1855+ for (field_name_expr , field_type_expr ) in dict_items :
1856+ if isinstance (field_name_expr , (StrExpr , BytesExpr , UnicodeExpr )):
1857+ items .append (field_name_expr .value )
1858+ else :
1859+ return self .fail_typeddict_arg ("Invalid TypedDict() field name" , field_name_expr )
1860+ try :
1861+ type = expr_to_unanalyzed_type (field_type_expr )
1862+ except TypeTranslationError :
1863+ return self .fail_typeddict_arg ('Invalid field type' , field_type_expr )
1864+ types .append (self .anal_type (type ))
1865+ return items , types , True
1866+
1867+ def fail_typeddict_arg (self , message : str ,
1868+ context : Context ) -> Tuple [List [str ], List [Type ], bool ]:
1869+ self .fail (message , context )
1870+ return [], [], False
1871+
1872+ def build_typeddict_typeinfo (self , name : str , items : List [str ],
1873+ types : List [Type ]) -> TypeInfo :
1874+ strtype = self .named_type ('__builtins__.str' ) # type: Type
1875+ dictype = (self .named_type_or_none ('builtins.dict' , [strtype , AnyType ()])
1876+ or self .object_type ())
1877+ fallback = dictype
1878+
1879+ info = self .basic_new_typeinfo (name , fallback )
1880+ info .is_typed_dict = True
1881+
1882+ # (TODO: Store {items, types} inside "info" somewhere for use later.
1883+ # Probably inside a new "info.keys" field which
1884+ # would be analogous to "info.names".)
1885+
1886+ return info
1887+
17801888 def visit_decorator (self , dec : Decorator ) -> None :
17811889 for d in dec .decorators :
17821890 d .accept (self )
0 commit comments