From a84fb9625ef31240faa7a838dddc00a6a2d8b314 Mon Sep 17 00:00:00 2001 From: Rock Neurotiko Date: Mon, 28 Jul 2014 03:06:49 +0200 Subject: [PATCH 01/12] Yield from statements and expressions support. Files changed: nodes.py -> - Added YieldFromStmt, extends from YieldStmt traverser.py -> - Add import YieldFromStmt - Added visit_yield_from function visitor.py -> - Added visit_yield_from function parse.py -> - Add import YieldFromStmt - Modified parse_yield_stmt to allow yield from statements (here comes when there is not assigntment, just a yield from wait) - Added parse_yield_from_expr to allow a yield from assigned to a var, return the callExpr of the yield from - Modified parse_expresssion, changed the self.current() to t (there is no reason to call it when saved it before) and added a clausule that checks that a assignment expression can go through yield from TODO: tests (I made "by hand" but it is not automated) --- mypy/nodes.py | 407 +++++++++++++++++++++++----------------------- mypy/parse.py | 314 ++++++++++++++++++----------------- mypy/traverser.py | 78 ++++----- mypy/visitor.py | 28 ++-- 4 files changed, 426 insertions(+), 401 deletions(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index eef1d5319f32..0a6f67810a3a 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -33,7 +33,7 @@ def get_line(self) -> int: pass MDEF = 2 # type: int MODULE_REF = 3 # type: int # Type variable declared using typevar(...) has kind UNBOUND_TVAR. It's not -# valid as a type. A type variable is valid as a type (kind TVAR) within +# valid as a type. A type variable is valid as a type (kind TVAR) within # (1) a generic class that uses the type variable as a type argument or # (2) a generic function that refers to the type variable in its signature. UNBOUND_TVAR = 4 # type: 'int' @@ -69,7 +69,7 @@ def get_line(self) -> int: pass class Node(Context): """Common base class for all non-type parse tree nodes.""" - + line = -1 # Textual representation repr = None # type: Any @@ -87,7 +87,7 @@ def __str__(self) -> str: def set_line(self, tok: Token) -> 'Node': self.line = tok.line return self - + @overload def set_line(self, line: int) -> 'Node': self.line = line @@ -96,16 +96,16 @@ def set_line(self, line: int) -> 'Node': def get_line(self) -> int: # TODO this should be just 'line' return self.line - + def accept(self, visitor: NodeVisitor[T]) -> T: raise RuntimeError('Not implemented') class SymbolNode(Node): # Nodes that can be stored in a symbol table. - + # TODO do not use methods for these - + @abstractmethod def name(self) -> str: pass @@ -115,7 +115,7 @@ def fullname(self) -> str: pass class MypyFile(SymbolNode): """The abstract syntax tree of a single source file.""" - + _name = None # type: str # Module name ('__main__' for initial file) _fullname = None # type: str # Qualified module name path = '' # Path to the file (None if not known) @@ -123,7 +123,7 @@ class MypyFile(SymbolNode): is_bom = False # Is there a UTF-8 BOM at the start? names = Undefined('SymbolTable') imports = Undefined(List['ImportBase']) # All import nodes within the file - + def __init__(self, defs: List[Node], imports: List['ImportBase'], is_bom: bool = False) -> None: self.defs = defs @@ -136,7 +136,7 @@ def name(self) -> str: def fullname(self) -> str: return self._fullname - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_mypy_file(self) @@ -144,46 +144,46 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class ImportBase(Node): """Base class for all import statements.""" is_unreachable = False - + class Import(ImportBase): """import m [as n]""" - + ids = Undefined(List[Tuple[str, str]]) # (module id, as id) - + def __init__(self, ids: List[Tuple[str, str]]) -> None: self.ids = ids - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_import(self) class ImportFrom(ImportBase): """from m import x, ...""" - + names = Undefined(List[Tuple[str, str]]) # Tuples (name, as name) - + def __init__(self, id: str, names: List[Tuple[str, str]]) -> None: self.id = id self.names = names - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_import_from(self) class ImportAll(ImportBase): """from m import *""" - + def __init__(self, id: str) -> None: self.id = id - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_import_all(self) class FuncBase(SymbolNode): """Abstract base class for function-like nodes""" - + # Type signature (Callable or Overloaded) type = None # type: mypy.types.Type # If method, reference to TypeInfo @@ -191,10 +191,10 @@ class FuncBase(SymbolNode): @abstractmethod def name(self) -> str: pass - + def fullname(self) -> str: return self.name() - + def is_method(self) -> bool: return bool(self.info) @@ -205,20 +205,20 @@ class OverloadedFuncDef(FuncBase): This node has no explicit representation in the source program. Overloaded variants must be consecutive in the source file. """ - + items = Undefined(List['Decorator']) _fullname = None # type: str - + def __init__(self, items: List['Decorator']) -> None: self.items = items self.set_line(items[0].line) - + def name(self) -> str: return self.items[1].func.name() def fullname(self) -> str: return self._fullname - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_overloaded_func_def(self) @@ -226,7 +226,7 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class FuncItem(FuncBase): args = Undefined(List['Var']) # Argument names arg_kinds = Undefined(List[int]) # Kinds of arguments (ARG_*) - + # Initialization expessions for fixed args; None if no initialiser init = Undefined(List['AssignmentStmt']) min_args = 0 # Minimum number of arguments @@ -241,7 +241,7 @@ class FuncItem(FuncBase): is_class = False # Uses @classmethod? expanded = Undefined(List['FuncItem']) # Variants of function with type # variables with values expanded - + def __init__(self, args: List['Var'], arg_kinds: List[int], init: List[Node], body: 'Block', typ: 'mypy.types.Type' = None) -> None: @@ -251,7 +251,7 @@ def __init__(self, args: List['Var'], arg_kinds: List[int], self.body = body self.type = typ self.expanded = [] - + i2 = List[AssignmentStmt]() self.min_args = 0 for i in range(len(init)): @@ -266,24 +266,24 @@ def __init__(self, args: List['Var'], arg_kinds: List[int], if i < self.max_fixed_argc(): self.min_args = i + 1 self.init = i2 - + def max_fixed_argc(self) -> int: return self.max_pos - + @overload def set_line(self, tok: Token) -> Node: super().set_line(tok) for n in self.args: n.line = self.line return self - + @overload def set_line(self, tok: int) -> Node: super().set_line(tok) for n in self.args: n.line = self.line return self - + def init_expressions(self) -> List[Node]: res = List[Node]() for i in self.init: @@ -297,16 +297,16 @@ def init_expressions(self) -> List[Node]: class FuncDef(FuncItem): """Function definition. - This is a non-lambda function defined using 'def'. + This is a non-lambda function defined using 'def'. """ - + _fullname = None # type: str # Name with module prefix is_decorated = False is_conditional = False # Defined conditionally (within block)? is_abstract = False is_property = False original_def = None # type: FuncDef # Original conditional definition - + def __init__(self, name: str, # Function name args: List['Var'], # Argument names @@ -319,13 +319,13 @@ def __init__(self, def name(self) -> str: return self._name - + def fullname(self) -> str: return self._fullname def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_func_def(self) - + def is_constructor(self) -> bool: return self.info is not None and self._name == '__init__' @@ -339,12 +339,12 @@ class Decorator(SymbolNode): A single Decorator object can include any number of function decorators. """ - + func = Undefined(FuncDef) # Decorated function decorators = Undefined(List[Node]) # Decorators, at least one var = Undefined('Var') # Represents the decorated function obj is_overload = False - + def __init__(self, func: FuncDef, decorators: List[Node], var: 'Var') -> None: self.func = func @@ -367,7 +367,7 @@ class Var(SymbolNode): It can refer to global/local variable or a data attribute. """ - + _name = None # type: str # Name without module prefix _fullname = None # type: str # Name with module prefix info = Undefined('TypeInfo') # Defining class (for member variables) @@ -380,7 +380,7 @@ class Var(SymbolNode): is_staticmethod = False is_classmethod = False is_property = False - + def __init__(self, name: str, type: 'mypy.types.Type' = None) -> None: self._name = name self.type = type @@ -393,14 +393,14 @@ def name(self) -> str: def fullname(self) -> str: return self._fullname - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_var(self) class ClassDef(Node): """Class definition""" - + name = Undefined(str) # Name of the class without module prefix fullname = None # type: str # Fully qualified name of the class defs = Undefined('Block') @@ -412,7 +412,7 @@ class ClassDef(Node): decorators = Undefined(List[Node]) # Built-in/extension class? (single implementation inheritance only) is_builtinclass = False - + def __init__(self, name: str, defs: 'Block', type_vars: List['mypy.types.TypeVarDef'] = None, base_types: List['mypy.types.Type'] = None, @@ -425,58 +425,58 @@ def __init__(self, name: str, defs: 'Block', self.base_types = base_types self.metaclass = metaclass self.decorators = [] - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_class_def(self) - + def is_generic(self) -> bool: return self.info.is_generic() class VarDef(Node): """Variable definition with explicit types""" - + items = Undefined(List[Var]) kind = None # type: int # LDEF/GDEF/MDEF/... init = Undefined(Node) # Expression or None is_top_level = False # Is the definition at the top level (not within # a function or a type)? - + def __init__(self, items: List[Var], is_top_level: bool, init: Node = None) -> None: self.items = items self.is_top_level = is_top_level self.init = init - + def info(self) -> 'TypeInfo': return self.items[0].info - + @overload def set_line(self, tok: Token) -> Node: super().set_line(tok) for n in self.items: n.line = self.line return self - + @overload def set_line(self, tok: int) -> Node: super().set_line(tok) for n in self.items: n.line = self.line return self - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_var_def(self) class GlobalDecl(Node): """Declaration global x, y, ...""" - + names = Undefined(List[str]) - + def __init__(self, names: List[str]) -> None: self.names = names - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_global_decl(self) @@ -487,10 +487,10 @@ class Block(Node): # this applies to blocks that are protected by something like "if PY3:" # when using Python 2. is_unreachable = False - + def __init__(self, body: List[Node]) -> None: self.body = body - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_block(self) @@ -501,10 +501,10 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class ExpressionStmt(Node): """An expression as a statament, such as print(s).""" expr = Undefined(Node) - + def __init__(self, expr: Node) -> None: self.expr = expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_expression_stmt(self) @@ -519,34 +519,34 @@ class AssignmentStmt(Node): An lvalue can be NameExpr, TupleExpr, ListExpr, MemberExpr, IndexExpr or ParenExpr. """ - + lvalues = Undefined(List[Node]) rvalue = Undefined(Node) type = None # type: mypy.types.Type # Declared type in a comment, # may be None. - + def __init__(self, lvalues: List[Node], rvalue: Node, type: 'mypy.types.Type' = None) -> None: self.lvalues = lvalues self.rvalue = rvalue self.type = type - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_assignment_stmt(self) class OperatorAssignmentStmt(Node): """Operator assignment statement such as x += 1""" - + op = '' lvalue = Undefined(Node) rvalue = Undefined(Node) - + def __init__(self, op: str, lvalue: Node, rvalue: Node) -> None: self.op = op self.lvalue = lvalue self.rvalue = rvalue - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_operator_assignment_stmt(self) @@ -555,12 +555,12 @@ class WhileStmt(Node): expr = Undefined(Node) body = Undefined(Block) else_body = Undefined(Block) - + def __init__(self, expr: Node, body: Block, else_body: Block) -> None: self.expr = expr self.body = body self.else_body = else_body - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_while_stmt(self) @@ -574,7 +574,7 @@ class ForStmt(Node): expr = Undefined(Node) body = Undefined(Block) else_body = Undefined(Block) - + def __init__(self, index: List['NameExpr'], expr: Node, body: Block, else_body: Block, types: List['mypy.types.Type'] = None) -> None: @@ -583,10 +583,10 @@ def __init__(self, index: List['NameExpr'], expr: Node, body: Block, self.body = body self.else_body = else_body self.types = types - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_for_stmt(self) - + def is_annotated(self) -> bool: ann = False for t in self.types: @@ -597,40 +597,45 @@ def is_annotated(self) -> bool: class ReturnStmt(Node): expr = Undefined(Node) # Expression or None - + def __init__(self, expr: Node) -> None: self.expr = expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_return_stmt(self) class AssertStmt(Node): expr = Undefined(Node) - + def __init__(self, expr: Node) -> None: self.expr = expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_assert_stmt(self) class YieldStmt(Node): expr = Undefined(Node) - + def __init__(self, expr: Node) -> None: self.expr = expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_yield_stmt(self) +class YieldFromStmt(YieldStmt): + def accept(self, visitor: NodeVisitor[T]) -> T: + return visitor.visit_yield_from_stmt(self) + + class DelStmt(Node): expr = Undefined(Node) - + def __init__(self, expr: Node) -> None: self.expr = expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_del_stmt(self) @@ -654,13 +659,13 @@ class IfStmt(Node): expr = Undefined(List[Node]) body = Undefined(List[Block]) else_body = Undefined(Block) - + def __init__(self, expr: List[Node], body: List[Block], else_body: Block) -> None: self.expr = expr self.body = body self.else_body = else_body - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_if_stmt(self) @@ -668,11 +673,11 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class RaiseStmt(Node): expr = Undefined(Node) from_expr = Undefined(Node) - + def __init__(self, expr: Node, from_expr: Node = None) -> None: self.expr = expr self.from_expr = from_expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_raise_stmt(self) @@ -684,7 +689,7 @@ class TryStmt(Node): handlers = Undefined(List[Block]) # Except bodies else_body = Undefined(Block) finally_body = Undefined(Block) - + def __init__(self, body: Block, vars: List['NameExpr'], types: List[Node], handlers: List[Block], else_body: Block, finally_body: Block) -> None: @@ -694,7 +699,7 @@ def __init__(self, body: Block, vars: List['NameExpr'], types: List[Node], self.handlers = handlers self.else_body = else_body self.finally_body = finally_body - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_try_stmt(self) @@ -703,27 +708,27 @@ class WithStmt(Node): expr = Undefined(List[Node]) name = Undefined(List['NameExpr']) body = Undefined(Block) - + def __init__(self, expr: List[Node], name: List['NameExpr'], body: Block) -> None: self.expr = expr self.name = name self.body = body - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_with_stmt(self) class PrintStmt(Node): """Python 2 print statement""" - + args = Undefined(List[Node]) newline = False def __init__(self, args: List[Node], newline: bool) -> None: self.args = args self.newline = newline - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_print_stmt(self) @@ -733,95 +738,95 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class IntExpr(Node): """Integer literal""" - + value = 0 literal = LITERAL_YES - + def __init__(self, value: int) -> None: self.value = value self.literal_hash = value - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_int_expr(self) class StrExpr(Node): """String literal""" - + value = '' literal = LITERAL_YES def __init__(self, value: str) -> None: self.value = value self.literal_hash = value - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_str_expr(self) class BytesExpr(Node): """Bytes literal""" - + value = '' # TODO use bytes literal = LITERAL_YES def __init__(self, value: str) -> None: self.value = value self.literal_hash = value - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_bytes_expr(self) class UnicodeExpr(Node): """Unicode literal (Python 2.x)""" - + value = '' # TODO use bytes literal = LITERAL_YES def __init__(self, value: str) -> None: self.value = value self.literal_hash = value - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_unicode_expr(self) class FloatExpr(Node): """Float literal""" - + value = 0.0 literal = LITERAL_YES - + def __init__(self, value: float) -> None: self.value = value self.literal_hash = value - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_float_expr(self) class ParenExpr(Node): """Parenthesised expression""" - + expr = Undefined(Node) - + def __init__(self, expr: Node) -> None: self.expr = expr self.literal = self.expr.literal self.literal_hash = ('Paren', expr.literal_hash,) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_paren_expr(self) class RefExpr(Node): """Abstract base class for name-like constructs""" - + kind = None # type: int # LDEF/GDEF/MDEF/... (None if not available) node = Undefined(Node) # Var, FuncDef or TypeInfo that describes this fullname = None # type: str # Fully qualified name (or name if not global) - + # Does this define a new name with inferred type? # # For members, after semantic analysis, this does not take base @@ -834,33 +839,33 @@ class NameExpr(RefExpr): This refers to a local name, global name or a module. """ - + name = None # type: str # Name referred to (may be qualified) info = Undefined('TypeInfo') # TypeInfo of class surrounding expression # (may be None) literal = LITERAL_TYPE - + def __init__(self, name: str) -> None: self.name = name self.literal_hash = ('Var', name,) - + def type_node(self): return cast('TypeInfo', self.node) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_name_expr(self) class MemberExpr(RefExpr): """Member access expression x.y""" - + expr = Undefined(Node) name = None # type: str # The variable node related to a definition. def_var = None # type: Var # Is this direct assignment to a data member (bypassing accessors)? direct = False - + def __init__(self, expr: Node, name: str, direct: bool = False) -> None: self.expr = expr self.name = name @@ -892,7 +897,7 @@ class CallExpr(Node): This can also represent several special forms that are syntactically calls such as cast(...) and Undefined(...). """ - + callee = Undefined(Node) args = Undefined(List[Node]) arg_kinds = Undefined(List[int]) # ARG_ constants @@ -901,7 +906,7 @@ class CallExpr(Node): analyzed = Undefined(Node) # If not None, the node that represents # the meaning of the CallExpr. For # cast(...) this is a CastExpr. - + def __init__(self, callee: Node, args: List[Node], arg_kinds: List[int], arg_names: List[str] = None, analyzed: Node = None) -> None: if not arg_names: @@ -911,7 +916,7 @@ def __init__(self, callee: Node, args: List[Node], arg_kinds: List[int], self.arg_kinds = arg_kinds self.arg_names = arg_names self.analyzed = analyzed - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_call_expr(self) @@ -921,7 +926,7 @@ class IndexExpr(Node): Also wraps type application as a special form. """ - + base = Undefined(Node) index = Undefined(Node) # Inferred __getitem__ method type @@ -929,7 +934,7 @@ class IndexExpr(Node): # If not None, this is actually semantically a type application # Class[type, ...]. analyzed = Undefined('TypeApplication') - + def __init__(self, base: Node, index: Node) -> None: self.base = base self.index = index @@ -938,25 +943,25 @@ def __init__(self, base: Node, index: Node) -> None: self.literal = self.base.literal self.literal_hash = ('Member', base.literal_hash, index.literal_hash) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_index_expr(self) class UnaryExpr(Node): """Unary operation""" - + op = '' expr = Undefined(Node) # Inferred operator method type method_type = None # type: mypy.types.Type - + def __init__(self, op: str, expr: Node) -> None: self.op = op self.expr = expr self.literal = self.expr.literal self.literal_hash = ('Unary', op, expr.literal_hash) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_unary_expr(self) @@ -1017,21 +1022,21 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class OpExpr(Node): """Binary operation (other than . or [], which have specific nodes).""" - + op = '' left = Undefined(Node) right = Undefined(Node) # Inferred type for the operator method type (when relevant; None for # 'is'). method_type = None # type: mypy.types.Type - + def __init__(self, op: str, left: Node, right: Node) -> None: self.op = op self.left = left self.right = right self.literal = min(self.left.literal, self.right.literal) self.literal_hash = ('Binary', op, left.literal_hash, right.literal_hash) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_op_expr(self) @@ -1041,44 +1046,44 @@ class SliceExpr(Node): This is only valid as index in index expressions. """ - + begin_index = Undefined(Node) # May be None end_index = Undefined(Node) # May be None stride = Undefined(Node) # May be None - + def __init__(self, begin_index: Node, end_index: Node, stride: Node) -> None: self.begin_index = begin_index self.end_index = end_index self.stride = stride - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_slice_expr(self) class CastExpr(Node): """Cast expression cast(type, expr).""" - + expr = Undefined(Node) type = Undefined('mypy.types.Type') - + def __init__(self, expr: Node, typ: 'mypy.types.Type') -> None: self.expr = expr self.type = typ - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_cast_expr(self) class SuperExpr(Node): """Expression super().name""" - + name = '' info = Undefined('TypeInfo') # Type that contains this super expression - + def __init__(self, name: str) -> None: self.name = name - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_super_expr(self) @@ -1088,85 +1093,85 @@ class FuncExpr(FuncItem): def name(self) -> str: return '' - + def expr(self) -> Node: """Return the expression (the body) of the lambda.""" ret = cast(ReturnStmt, self.body.body[0]) return ret.expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_func_expr(self) class ListExpr(Node): """List literal expression [...].""" - + items = Undefined(List[Node] ) - + def __init__(self, items: List[Node]) -> None: self.items = items if all(x.literal == LITERAL_YES for x in items): self.literal = LITERAL_YES self.literal_hash = ('List',) + tuple(x.literal_hash for x in items) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_list_expr(self) class DictExpr(Node): """Dictionary literal expression {key: value, ...}.""" - + items = Undefined(List[Tuple[Node, Node]]) - + def __init__(self, items: List[Tuple[Node, Node]]) -> None: self.items = items if all(x[0].literal == LITERAL_YES and x[1].literal == LITERAL_YES for x in items): self.literal = LITERAL_YES self.literal_hash = ('Dict',) + tuple((x[0].literal_hash, x[1].literal_hash) for x in items) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_dict_expr(self) class TupleExpr(Node): """Tuple literal expression (..., ...)""" - + items = Undefined(List[Node]) - + def __init__(self, items: List[Node]) -> None: self.items = items if all(x.literal == LITERAL_YES for x in items): self.literal = LITERAL_YES self.literal_hash = ('Tuple',) + tuple(x.literal_hash for x in items) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_tuple_expr(self) class SetExpr(Node): """Set literal expression {value, ...}.""" - + items = Undefined(List[Node]) - + def __init__(self, items: List[Node]) -> None: self.items = items if all(x.literal == LITERAL_YES for x in items): self.literal = LITERAL_YES self.literal_hash = ('Set',) + tuple(x.literal_hash for x in items) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_set_expr(self) class GeneratorExpr(Node): """Generator expression ... for ... in ... [ for ... in ... ] [ if ... ].""" - + left_expr = Undefined(Node) sequences_expr = Undefined(List[Node]) condition = Undefined(Node) # May be None indices = Undefined(List[List[NameExpr]]) types = Undefined(List[List['mypy.types.Type']]) - + def __init__(self, left_expr: Node, indices: List[List[NameExpr]], types: List[List['mypy.types.Type']], sequences: List[Node], condition: Node) -> None: @@ -1175,35 +1180,35 @@ def __init__(self, left_expr: Node, indices: List[List[NameExpr]], self.condition = condition self.indices = indices self.types = types - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_generator_expr(self) class ListComprehension(Node): """List comprehension (e.g. [x + 1 for x in a])""" - + generator = Undefined(GeneratorExpr) - + def __init__(self, generator: GeneratorExpr) -> None: self.generator = generator - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_list_comprehension(self) class ConditionalExpr(Node): """Conditional expression (e.g. x if y else z)""" - + cond = Undefined(Node) if_expr = Undefined(Node) else_expr = Undefined(Node) - + def __init__(self, cond: Node, if_expr: Node, else_expr: Node) -> None: self.cond = cond self.if_expr = if_expr self.else_expr = else_expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_conditional_expr(self) @@ -1216,7 +1221,7 @@ class UndefinedExpr(Node): x = Undefined(List[int]) """ - + def __init__(self, type: 'mypy.types.Type') -> None: self.type = type @@ -1226,14 +1231,14 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class TypeApplication(Node): """Type application expr[type, ...]""" - + expr = Undefined(Node) types = Undefined(List['mypy.types.Type']) - + def __init__(self, expr: Node, types: List['mypy.types.Type']) -> None: self.expr = expr self.types = types - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_type_application(self) @@ -1258,7 +1263,7 @@ def name(self) -> str: def fullname(self) -> str: return self._fullname - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_type_var_expr(self) @@ -1293,12 +1298,12 @@ class CoerceExpr(Node): This is used only when compiling/transforming. These are inserted after type checking. """ - + expr = Undefined(Node) target_type = Undefined('mypy.types.Type') source_type = Undefined('mypy.types.Type') is_wrapper_class = False - + def __init__(self, expr: Node, target_type: 'mypy.types.Type', source_type: 'mypy.types.Type', is_wrapper_class: bool) -> None: @@ -1306,7 +1311,7 @@ def __init__(self, expr: Node, target_type: 'mypy.types.Type', self.target_type = target_type self.source_type = source_type self.is_wrapper_class = is_wrapper_class - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_coerce_expr(self) @@ -1314,12 +1319,12 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class JavaCast(Node): # TODO obsolete; remove expr = Undefined(Node) - target = Undefined('mypy.types.Type') - + target = Undefined('mypy.types.Type') + def __init__(self, expr: Node, target: 'mypy.types.Type') -> None: self.expr = expr self.target = target - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_java_cast(self) @@ -1330,12 +1335,12 @@ class TypeExpr(Node): This is used only for runtime type checking. This node is always generated only after type checking. """ - - type = Undefined('mypy.types.Type') - + + type = Undefined('mypy.types.Type') + def __init__(self, typ: 'mypy.types.Type') -> None: self.type = typ - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_type_expr(self) @@ -1347,12 +1352,12 @@ class TempNode(Node): of the type checker implementation. It only represents an opaque node with some fixed type. """ - + type = Undefined('mypy.types.Type') - + def __init__(self, typ: 'mypy.types.Type') -> None: self.type = typ - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_temp_node(self) @@ -1363,7 +1368,7 @@ class TypeInfo(SymbolNode): The corresponding ClassDef instance represents the parse tree of the class. """ - + _fullname = None # type: str # Fully qualified name defn = Undefined(ClassDef) # Corresponding ClassDef # Method Resolution Order: the order of looking up attributes. The first @@ -1378,18 +1383,18 @@ class TypeInfo(SymbolNode): # Targets of disjointclass declarations present in this class only (for # generating error messages). disjointclass_decls = Undefined(List['TypeInfo']) - + # Information related to type annotations. - + # Generic type variable names type_vars = Undefined(List[str]) - + # Direct base classes. bases = Undefined(List['mypy.types.Instance']) # Duck type compatibility (ducktype decorator) ducktype = None # type: mypy.types.Type - + def __init__(self, names: 'SymbolTable', defn: ClassDef) -> None: """Initialize a TypeInfo.""" self.names = names @@ -1406,18 +1411,18 @@ def __init__(self, names: 'SymbolTable', defn: ClassDef) -> None: if defn.type_vars: for vd in defn.type_vars: self.type_vars.append(vd.name) - + def name(self) -> str: """Short name.""" return self.defn.name def fullname(self) -> str: return self._fullname - + def is_generic(self) -> bool: """Is the type generic (i.e. does it have type variables)?""" return self.type_vars is not None and len(self.type_vars) > 0 - + def get(self, name: str) -> 'SymbolTableNode': for cls in self.mro: n = cls.names.get(name) @@ -1434,23 +1439,23 @@ def __getitem__(self, name: str) -> 'SymbolTableNode': def __repr__(self) -> str: return '' % self.fullname() - - + + # IDEA: Refactor the has* methods to be more consistent and document # them. - + def has_readable_member(self, name: str) -> bool: return self.get(name) is not None - + def has_writable_member(self, name: str) -> bool: return self.has_var(name) - + def has_var(self, name: str) -> bool: return self.get_var(name) is not None - + def has_method(self, name: str) -> bool: return self.get_method(name) is not None - + def get_var(self, name: str) -> Var: for cls in self.mro: if name in cls.names: @@ -1460,15 +1465,15 @@ def get_var(self, name: str) -> Var: else: return None return None - + def get_var_or_getter(self, name: str) -> SymbolNode: # TODO getter return self.get_var(name) - + def get_var_or_setter(self, name: str) -> SymbolNode: # TODO setter return self.get_var(name) - + def get_method(self, name: str) -> FuncBase: for cls in self.mro: if name in cls.names: @@ -1485,7 +1490,7 @@ def calculate_mro(self) -> None: Raise MroError if cannot determine mro. """ self.mro = linearize_hierarchy(self) - + def has_base(self, fullname: str) -> bool: """Return True if type has a base type with the specified name. @@ -1495,7 +1500,7 @@ def has_base(self, fullname: str) -> bool: if cls.fullname() == fullname: return True return False - + def all_subtypes(self) -> 'Set[TypeInfo]': """Return TypeInfos of all subtypes, including this type, as a set.""" subtypes = set([self]) @@ -1503,18 +1508,18 @@ def all_subtypes(self) -> 'Set[TypeInfo]': for t in subt.all_subtypes(): subtypes.add(t) return subtypes - + def all_base_classes(self) -> 'List[TypeInfo]': """Return a list of base classes, including indirect bases.""" assert False - + def direct_base_classes(self) -> 'List[TypeInfo]': """Return a direct base classes. Omit base classes of other base classes. """ return [base.type for base in self.bases] - + def __str__(self) -> str: """Return a string representation of the type. @@ -1540,9 +1545,9 @@ class SymbolTableNode: tvar_id = 0 # Module id (e.g. "foo.bar") or None mod_id = '' - # If None, fall back to type of node + # If None, fall back to type of node type_override = Undefined('mypy.types.Type') - + def __init__(self, kind: int, node: SymbolNode, mod_id: str = None, typ: 'mypy.types.Type' = None, tvar_id: int = 0) -> None: self.kind = kind @@ -1571,7 +1576,7 @@ def type(self) -> 'mypy.types.Type': return (cast(Decorator, node)).var.type else: return None - + def __str__(self) -> str: s = '{}/{}'.format(node_kinds[self.kind], short_type(self.node)) if self.mod_id is not None: @@ -1603,7 +1608,7 @@ def __str__(self) -> str: def clean_up(s: str) -> str: # TODO remove return re.sub('.*::', '', s) - + def function_type(func: FuncBase) -> 'mypy.types.FunctionLike': if func.type: @@ -1612,7 +1617,7 @@ def function_type(func: FuncBase) -> 'mypy.types.FunctionLike': # Implicit type signature with dynamic types. # Overloaded functions always have a signature, so func must be an # ordinary function. - fdef = cast(FuncDef, func) + fdef = cast(FuncDef, func) name = func.name() if name: name = '"{}"'.format(name) diff --git a/mypy/parse.py b/mypy/parse.py index 7f222a270ca1..f591dfb1b736 100644 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -23,7 +23,7 @@ TupleExpr, GeneratorExpr, ListComprehension, ListExpr, ConditionalExpr, DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, - UnaryExpr, FuncExpr, TypeApplication, PrintStmt, ImportBase + UnaryExpr, FuncExpr, TypeApplication, PrintStmt, ImportBase, YieldFromStmt ) from mypy import nodes from mypy import noderepr @@ -82,14 +82,14 @@ class Parser: ind = 0 errors = Undefined(Errors) raise_on_error = False - + # Are we currently parsing the body of a class definition? is_class_body = False # All import nodes encountered so far in this parse unit. imports = Undefined(List[ImportBase]) # Names imported from __future__. future_options = Undefined(List[str]) - + def __init__(self, fnam: str, errors: Errors, pyversion: int, custom_typing_module : str = None) -> None: self.raise_on_error = errors is None @@ -103,7 +103,7 @@ def __init__(self, fnam: str, errors: Errors, pyversion: int, self.errors.set_file(fnam) else: self.errors.set_file('') - + def parse(self, s: str) -> MypyFile: self.tok = lex.lex(s) self.ind = 0 @@ -113,7 +113,7 @@ def parse(self, s: str) -> MypyFile: if self.raise_on_error and self.errors.is_errors(): self.errors.raise_error() return file - + def parse_file(self) -> MypyFile: """Parse a mypy source file.""" is_bom = self.parse_bom() @@ -122,9 +122,9 @@ def parse_file(self) -> MypyFile: node = MypyFile(defs, self.imports, is_bom) self.set_repr(node, noderepr.MypyFileRepr(eof)) return node - + # Parse the initial part - + def parse_bom(self) -> bool: """Parse the optional byte order mark at the beginning of a file.""" if isinstance(self.current(), Bom): @@ -134,7 +134,7 @@ def parse_bom(self) -> bool: return True else: return False - + def parse_import(self) -> Import: import_tok = self.expect('import') ids = List[Tuple[str, str]]() @@ -164,7 +164,7 @@ def parse_import(self) -> Import: self.set_repr(node, noderepr.ImportRepr(import_tok, id_toks, as_names, commas, br)) return node - + def parse_import_from(self) -> Node: from_tok = self.expect('from') name, components = self.parse_qualified_name() @@ -212,7 +212,7 @@ def parse_import_from(self) -> Node: if name == '__future__': self.future_options.extend(target[0] for target in targets) return node - + def parse_import_name(self) -> Tuple[str, str, List[Token]]: tok = self.expect_type(Name) name = tok.string @@ -224,7 +224,7 @@ def parse_import_name(self) -> Tuple[str, str, List[Token]]: return name, as_name.string, tokens else: return name, name, tokens - + def parse_qualified_name(self) -> Tuple[str, List[Token]]: """Parse a name with an optional module qualifier. @@ -241,9 +241,9 @@ def parse_qualified_name(self) -> Tuple[str, List[Token]]: n += '.' + tok.string components.append(tok) return n, components - + # Parsing global definitions - + def parse_defs(self) -> List[Node]: defs = List[Node]() while not self.eof(): @@ -255,24 +255,24 @@ def parse_defs(self) -> List[Node]: except ParseError: pass return defs - + def parse_class_def(self) -> ClassDef: old_is_class_body = self.is_class_body self.is_class_body = True - + type_tok = self.expect('class') lparen = none rparen = none metaclass = None # type: str - + try: commas, base_types = List[Token](), List[Type]() try: name_tok = self.expect_type(Name) name = name_tok.string - + self.errors.push_type(name) - + if self.current_str() == '(': lparen = self.skip() while True: @@ -286,9 +286,9 @@ def parse_class_def(self) -> ClassDef: rparen = self.expect(')') except ParseError: pass - + defs, _ = self.parse_block() - + node = ClassDef(name, defs, None, base_types, metaclass=metaclass) self.set_repr(node, noderepr.TypeDefRepr(type_tok, name_tok, lparen, commas, rparen)) @@ -296,7 +296,7 @@ def parse_class_def(self) -> ClassDef: finally: self.errors.pop_type() self.is_class_body = old_is_class_body - + def parse_super_type(self) -> Type: if (isinstance(self.current(), Name) and self.current_str() != 'void'): return self.parse_type() @@ -307,7 +307,7 @@ def parse_metaclass(self) -> str: self.expect('metaclass') self.expect('=') return self.parse_qualified_name()[0] - + def parse_decorated_function_or_class(self) -> Node: ats = List[Token]() brs = List[Token]() @@ -330,7 +330,7 @@ def parse_decorated_function_or_class(self) -> Node: cls = self.parse_class_def() cls.decorators = decorators return cls - + def parse_function(self) -> FuncDef: def_tok = self.expect('def') is_method = self.is_class_body @@ -338,7 +338,7 @@ def parse_function(self) -> FuncDef: try: (name, args, init, kinds, typ, is_error, toks) = self.parse_function_header() - + body, comment_type = self.parse_block(allow_type=True) if comment_type: # The function has a # type: ... signature. @@ -364,12 +364,12 @@ def parse_function(self) -> FuncDef: [arg.name() for arg in args], sig.ret_type, False) - + # If there was a serious error, we really cannot build a parse tree # node. if is_error: return None - + node = FuncDef(name, args, kinds, init, body, typ) name_tok, arg_reprs = toks node.set_line(name_tok) @@ -396,7 +396,7 @@ def check_argument_kinds(self, funckinds: List[int], sigkinds: List[int], self.fail( "Inconsistent use of '{}' in function " "signature".format(token), line) - + def parse_function_header(self) -> Tuple[str, List[Var], List[Node], List[int], Type, bool, Tuple[Token, Any]]: @@ -410,15 +410,15 @@ def parse_function_header(self) -> Tuple[str, List[Var], List[Node], signature (annotation) error flag (True if error) (name token, representation of arguments) - """ + """ name_tok = none - + try: name_tok = self.expect_type(Name) name = name_tok.string - + self.errors.push_function(name) - + (args, init, kinds, typ, arg_repr) = self.parse_args() except ParseError: if not isinstance(self.current(), Break): @@ -427,20 +427,20 @@ def parse_function_header(self) -> Tuple[str, List[Var], List[Node], if isinstance(self.tok[self.ind - 1], Colon): self.ind -= 1 return (name, [], [], [], None, True, (name_tok, None)) - + return (name, args, init, kinds, typ, False, (name_tok, arg_repr)) - + def parse_args(self) -> Tuple[List[Var], List[Node], List[int], Type, noderepr.FuncArgsRepr]: """Parse a function signature (...) [-> t].""" lparen = self.expect('(') - + # Parse the argument list (everything within '(' and ')'). (args, init, kinds, has_inits, arg_names, commas, asterisk, assigns, arg_types) = self.parse_arg_list() - + rparen = self.expect(')') if self.current_str() == '-': @@ -451,20 +451,20 @@ def parse_args(self) -> Tuple[List[Var], List[Node], List[int], Type, ret_type = None self.verify_argument_kinds(kinds, lparen.line) - + names = [] # type: List[str] for arg in args: names.append(arg.name()) - + annotation = self.build_func_annotation( ret_type, arg_types, kinds, names, lparen.line) - + return (args, init, kinds, annotation, noderepr.FuncArgsRepr(lparen, rparen, arg_names, commas, assigns, asterisk)) - + def build_func_annotation(self, ret_type: Type, arg_types: List[Type], - kinds: List[int], names: List[str], + kinds: List[int], names: List[str], line: int, is_default_ret: bool = False) -> Type: # Are there any type annotations? if ((ret_type and not is_default_ret) @@ -474,7 +474,7 @@ def build_func_annotation(self, ret_type: Type, arg_types: List[Type], ret_type, line) else: return None - + def parse_arg_list( self, allow_signature: bool = True) -> Tuple[List[Var], List[Node], List[int], bool, @@ -494,16 +494,16 @@ def parse_arg_list( names = [] # type: List[str] init = [] # type: List[Node] has_inits = False - arg_types = [] # type: List[Type] - + arg_types = [] # type: List[Type] + arg_names = [] # type: List[Token] commas = [] # type: List[Token] asterisk = [] # type: List[Token] assigns = [] # type: List[Token] - + require_named = False bare_asterisk_before = -1 - + if self.current_str() != ')' and self.current_str() != ':': while self.current_str() != ')': if self.current_str() == '*' and self.peek().string == ',': @@ -539,7 +539,7 @@ def parse_arg_list( arg_names.append(name) args.append(Var(name.string)) arg_types.append(self.parse_arg_type(allow_signature)) - + if self.current_str() == '=': assigns.append(self.expect('=')) init.append(self.parse_expression(precedence[','])) @@ -556,11 +556,11 @@ def parse_arg_list( kinds.append(nodes.ARG_NAMED) else: kinds.append(nodes.ARG_POS) - + if self.current().string != ',': break commas.append(self.expect(',')) - + return (args, init, kinds, has_inits, arg_names, commas, asterisk, assigns, arg_types) @@ -583,7 +583,7 @@ def verify_argument_kinds(self, kinds: List[int], line: int) -> None: elif kind == nodes.ARG_STAR2 and i != len(kinds) - 1: self.fail('Invalid argument list', line) found.add(kind) - + def construct_function_type(self, arg_types: List[Type], kinds: List[int], names: List[str], ret_type: Type, line: int) -> Callable: @@ -596,9 +596,9 @@ def construct_function_type(self, arg_types: List[Type], kinds: List[int], ret_type = AnyType() return Callable(arg_types, kinds, names, ret_type, False, None, None, [], line, None) - + # Parsing statements - + def parse_block(self, allow_type: bool = False) -> Tuple[Block, Type]: colon = self.expect(':') if not isinstance(self.current(), Break): @@ -624,7 +624,7 @@ def parse_block(self, allow_type: bool = False) -> Tuple[Block, Type]: dedent = none if isinstance(self.current(), Dedent): dedent = self.skip() - node = Block(stmt).set_line(colon) + node = Block(stmt).set_line(colon) self.set_repr(node, noderepr.BlockRepr(colon, br, indent, dedent)) return cast(Block, node), type @@ -641,7 +641,7 @@ def try_combine_overloads(self, s: Node, stmt: List[Node]) -> bool: (cast(OverloadedFuncDef, stmt[-1])).items.append(fdef) return True return False - + def parse_statement(self) -> Node: stmt = Undefined # type: Node t = self.current() @@ -684,7 +684,7 @@ def parse_statement(self) -> Node: stmt = self.parse_with_stmt() elif ts == '@': stmt = self.parse_decorated_function_or_class() - elif ts == 'print' and (self.pyversion == 2 and + elif ts == 'print' and (self.pyversion == 2 and 'print_function' not in self.future_options): stmt = self.parse_print_stmt() else: @@ -692,7 +692,7 @@ def parse_statement(self) -> Node: if stmt is not None: stmt.set_line(t) return stmt - + def parse_expression_or_assignment(self) -> Node: e = self.parse_expression() if self.current_str() == '=': @@ -713,7 +713,7 @@ def parse_expression_or_assignment(self) -> Node: expr = ExpressionStmt(e) self.set_repr(expr, noderepr.ExpressionStmtRepr(br)) return expr - + def parse_assignment(self, lv: Any) -> Node: """Parse an assignment statement. @@ -722,7 +722,7 @@ def parse_assignment(self, lv: Any) -> Node: """ assigns = [self.expect('=')] lvalues = [lv] - + e = self.parse_expression() while self.current_str() == '=': lvalues.append(e) @@ -734,7 +734,7 @@ def parse_assignment(self, lv: Any) -> Node: assignment = AssignmentStmt(lvalues, e, type) self.set_repr(assignment, noderepr.AssignmentStmtRepr(assigns, br)) return assignment - + def parse_return_stmt(self) -> ReturnStmt: return_tok = self.expect('return') expr = None # type: Node @@ -744,7 +744,7 @@ def parse_return_stmt(self) -> ReturnStmt: node = ReturnStmt(expr) self.set_repr(node, noderepr.SimpleStmtRepr(return_tok, br)) return node - + def parse_raise_stmt(self) -> RaiseStmt: raise_tok = self.expect('raise') expr = None # type: Node @@ -759,7 +759,7 @@ def parse_raise_stmt(self) -> RaiseStmt: node = RaiseStmt(expr, from_expr) self.set_repr(node, noderepr.RaiseStmtRepr(raise_tok, from_tok, br)) return node - + def parse_assert_stmt(self) -> AssertStmt: assert_tok = self.expect('assert') expr = self.parse_expression() @@ -767,17 +767,29 @@ def parse_assert_stmt(self) -> AssertStmt: node = AssertStmt(expr) self.set_repr(node, noderepr.SimpleStmtRepr(assert_tok, br)) return node - + def parse_yield_stmt(self) -> YieldStmt: yield_tok = self.expect('yield') expr = None # type: Node + node = YieldStmt(expr) if not isinstance(self.current(), Break): - expr = self.parse_expression() + if isinstance(self.current(), Keyword): + from_tok = self.expect("from") + expr = self.parse_expression() # Here comes when yield from is not assigned + node = YieldFromStmt(expr) + else: + expr = self.parse_expression() + node = YieldStmt(expr) br = self.expect_break() - node = YieldStmt(expr) self.set_repr(node, noderepr.SimpleStmtRepr(yield_tok, br)) return node - + + def parse_yield_from_expr(self) -> CallExpr: + y_tok = self.expect("yield") + f_tok = self.expect("from") + tok = self.parse_expression() # Here comes when yield from is assigned to a variable + return tok + def parse_del_stmt(self) -> DelStmt: del_tok = self.expect('del') expr = self.parse_expression() @@ -785,28 +797,28 @@ def parse_del_stmt(self) -> DelStmt: node = DelStmt(expr) self.set_repr(node, noderepr.SimpleStmtRepr(del_tok, br)) return node - + def parse_break_stmt(self) -> BreakStmt: break_tok = self.expect('break') br = self.expect_break() node = BreakStmt() self.set_repr(node, noderepr.SimpleStmtRepr(break_tok, br)) return node - + def parse_continue_stmt(self) -> ContinueStmt: continue_tok = self.expect('continue') br = self.expect_break() node = ContinueStmt() self.set_repr(node, noderepr.SimpleStmtRepr(continue_tok, br)) return node - + def parse_pass_stmt(self) -> PassStmt: pass_tok = self.expect('pass') br = self.expect_break() node = PassStmt() self.set_repr(node, noderepr.SimpleStmtRepr(pass_tok, br)) return node - + def parse_global_decl(self) -> GlobalDecl: global_tok = self.expect('global') names = List[str]() @@ -824,7 +836,7 @@ def parse_global_decl(self) -> GlobalDecl: self.set_repr(node, noderepr.GlobalDeclRepr(global_tok, name_toks, commas, br)) return node - + def parse_while_stmt(self) -> WhileStmt: is_error = False while_tok = self.expect('while') @@ -845,38 +857,38 @@ def parse_while_stmt(self) -> WhileStmt: return node else: return None - + def parse_for_stmt(self) -> ForStmt: for_tok = self.expect('for') index, types, commas = self.parse_for_index_variables() in_tok = self.expect('in') expr = self.parse_expression() - + body, _ = self.parse_block() - + if self.current_str() == 'else': else_tok = self.expect('else') else_body, _ = self.parse_block() else: else_body = None else_tok = none - + node = ForStmt(index, expr, body, else_body, types) self.set_repr(node, noderepr.ForStmtRepr(for_tok, commas, in_tok, else_tok)) return node - + def parse_for_index_variables(self) -> Tuple[List[NameExpr], List[Type], List[Token]]: # Parse index variables of a 'for' statement. index = List[NameExpr]() types = List[Type]() commas = List[Token]() - + is_paren = self.current_str() == '(' if is_paren: self.skip() - + while True: v = self.parse_name_expr() index.append(v) @@ -885,24 +897,24 @@ def parse_for_index_variables(self) -> Tuple[List[NameExpr], List[Type], commas.append(none) break commas.append(self.skip()) - + if is_paren: self.expect(')') - + return index, types, commas - + def parse_if_stmt(self) -> IfStmt: is_error = False - + if_tok = self.expect('if') expr = List[Node]() try: expr.append(self.parse_expression()) except ParseError: is_error = True - + body = [self.parse_block()[0]] - + elif_toks = List[Token]() while self.current_str() == 'elif': elif_toks.append(self.expect('elif')) @@ -911,14 +923,14 @@ def parse_if_stmt(self) -> IfStmt: except ParseError: is_error = True body.append(self.parse_block()[0]) - + if self.current_str() == 'else': else_tok = self.expect('else') else_body, _ = self.parse_block() else: else_tok = none else_body = None - + if not is_error: node = IfStmt(expr, body, else_body) self.set_repr(node, noderepr.IfStmtRepr(if_tok, elif_toks, @@ -926,7 +938,7 @@ def parse_if_stmt(self) -> IfStmt: return node else: return None - + def parse_try_stmt(self) -> Node: try_tok = self.expect('try') body, _ = self.parse_block() @@ -979,7 +991,7 @@ def parse_try_stmt(self) -> Node: return node else: return None - + def parse_with_stmt(self) -> WithStmt: with_tok = self.expect('with') as_toks = List[Token]() @@ -1017,14 +1029,14 @@ def parse_print_stmt(self) -> PrintStmt: break self.expect_break() return PrintStmt(args, newline=not comma) - + # Parsing expressions - + def parse_expression(self, prec: int = 0) -> Node: """Parse a subexpression within a specific precedence context.""" expr = Undefined # type: Node t = self.current() # Remember token for setting the line number. - + # Parse a "value" expression or unary operator expression and store # that in expr. s = self.current_str() @@ -1054,16 +1066,18 @@ def parse_expression(self, prec: int = 0) -> Node: expr = self.parse_unicode_literal() elif isinstance(self.current(), FloatLit): expr = self.parse_float_expr() + elif isinstance(t, Keyword) and s == "yield": #maybe check that next is from + expr = self.parse_yield_from_expr() # The expression yield from to assign else: # Invalid expression. self.parse_error() - + # Set the line of the expression node, if not specified. This # simplifies recording the line number as not every node type needs to # deal with it separately. if expr.line < 0: expr.set_line(t) - + # Parse operations that require a left argument (stored in expr). while True: t = self.current() @@ -1092,7 +1106,7 @@ def parse_expression(self, prec: int = 0) -> Node: # comprehension if needed elsewhere. expr = self.parse_generator_expr(expr) else: - break + break elif s == 'if': # Conditional expression. if precedence[''] > prec: @@ -1118,15 +1132,15 @@ def parse_expression(self, prec: int = 0) -> Node: # Not an operation that accepts a left argument; let the # caller handle the rest. break - + # Set the line of the expression node, if not specified. This # simplifies recording the line number as not every node type # needs to deal with it separately. if expr.line < 0: expr.set_line(t) - + return expr - + def parse_parentheses(self) -> Node: lparen = self.skip() if self.current_str() == ')': @@ -1139,13 +1153,13 @@ def parse_parentheses(self) -> Node: expr = ParenExpr(expr) self.set_repr(expr, noderepr.ParenExprRepr(lparen, rparen)) return expr - + def parse_empty_tuple_expr(self, lparen: Any) -> TupleExpr: rparen = self.expect(')') node = TupleExpr([]) self.set_repr(node, noderepr.TupleExprRepr(lparen, [], rparen)) return node - + def parse_list_expr(self) -> Node: """Parse list literal or list comprehension.""" items = List[Node]() @@ -1157,7 +1171,7 @@ def parse_list_expr(self) -> Node: break commas.append(self.expect(',')) if self.current_str() == 'for' and len(items) == 1: - items[0] = self.parse_generator_expr(items[0]) + items[0] = self.parse_generator_expr(items[0]) rbracket = self.expect(']') if len(items) == 1 and isinstance(items[0], GeneratorExpr): list_comp = ListComprehension(cast(GeneratorExpr, items[0])) @@ -1169,7 +1183,7 @@ def parse_list_expr(self) -> Node: self.set_repr(expr, noderepr.ListSetExprRepr(lbracket, commas, rbracket, none, none)) return expr - + def parse_generator_expr(self, left_expr: Node) -> GeneratorExpr: indices = List[List[NameExpr]]() sequences = List[Node]() @@ -1193,7 +1207,7 @@ def parse_generator_expr(self, left_expr: Node) -> GeneratorExpr: self.set_repr(gen, noderepr.GeneratorExprRepr(for_tok, commas, in_tok, if_tok)) return gen - + def parse_expression_list(self) -> Node: prec = precedence[''] expr = self.parse_expression(prec) @@ -1202,14 +1216,14 @@ def parse_expression_list(self) -> Node: else: t = self.current() return self.parse_tuple_expr(expr, prec).set_line(t) - + def parse_conditional_expr(self, left_expr: Node) -> ConditionalExpr: self.expect('if') cond = self.parse_expression(precedence['']) self.expect('else') else_expr = self.parse_expression(precedence['']) return ConditionalExpr(cond, left_expr, else_expr) - + def parse_dict_or_set_expr(self) -> Node: items = List[Tuple[Node, Node]]() lbrace = self.expect('{') @@ -1232,7 +1246,7 @@ def parse_dict_or_set_expr(self) -> Node: self.set_repr(node, noderepr.DictExprRepr(lbrace, colons, commas, rbrace, none, none, none)) return node - + def parse_set_expr(self, first: Node, lbrace: Token) -> SetExpr: items = [first] commas = List[Token]() @@ -1246,7 +1260,7 @@ def parse_set_expr(self, first: Node, lbrace: Token) -> SetExpr: self.set_repr(expr, noderepr.ListSetExprRepr(lbrace, commas, rbrace, none, none)) return expr - + def parse_tuple_expr(self, expr: Node, prec: int = precedence[',']) -> TupleExpr: items = [expr] @@ -1261,14 +1275,14 @@ def parse_tuple_expr(self, expr: Node, node = TupleExpr(items) self.set_repr(node, noderepr.TupleExprRepr(none, commas, none)) return node - + def parse_name_expr(self) -> NameExpr: tok = self.expect_type(Name) node = NameExpr(tok.string) node.set_line(tok) self.set_repr(node, noderepr.NameExprRepr(tok)) return node - + def parse_int_expr(self) -> IntExpr: tok = self.expect_type(IntLit) s = tok.string @@ -1282,7 +1296,7 @@ def parse_int_expr(self) -> IntExpr: node = IntExpr(v) self.set_repr(node, noderepr.IntExprRepr(tok)) return node - + def parse_str_expr(self) -> Node: # XXX \uxxxx literals tok = [self.expect_type(StrLit)] @@ -1298,7 +1312,7 @@ def parse_str_expr(self) -> Node: node = StrExpr(value) self.set_repr(node, noderepr.StrExprRepr(tok)) return node - + def parse_bytes_literal(self) -> Node: # XXX \uxxxx literals tok = [self.expect_type(BytesLit)] @@ -1313,7 +1327,7 @@ def parse_bytes_literal(self) -> Node: node = StrExpr(value) self.set_repr(node, noderepr.StrExprRepr(tok)) return node - + def parse_unicode_literal(self) -> Node: # XXX \uxxxx literals tok = [self.expect_type(UnicodeLit)] @@ -1329,13 +1343,13 @@ def parse_unicode_literal(self) -> Node: node = UnicodeExpr(value) self.set_repr(node, noderepr.StrExprRepr(tok)) return node - + def parse_float_expr(self) -> FloatExpr: tok = self.expect_type(FloatLit) node = FloatExpr(float(tok.string)) self.set_repr(node, noderepr.FloatExprRepr(tok)) return node - + def parse_call_expr(self, callee: Any) -> CallExpr: lparen = self.expect('(') (args, kinds, names, @@ -1345,7 +1359,7 @@ def parse_call_expr(self, callee: Any) -> CallExpr: self.set_repr(node, noderepr.CallExprRepr(lparen, commas, star, star2, assigns, rparen)) return node - + def parse_arg_expr(self) -> Tuple[List[Node], List[int], List[str], List[Token], Token, Token, List[List[Token]]]: @@ -1359,7 +1373,7 @@ def parse_arg_expr(self) -> Tuple[List[Node], List[int], List[str], * token (if any) ** token (if any) (assignment, name) tokens - """ + """ args = [] # type: List[Node] kinds = [] # type: List[int] names = [] # type: List[str] @@ -1403,7 +1417,7 @@ def parse_arg_expr(self) -> Tuple[List[Node], List[int], List[str], break commas.append(self.expect(',')) return args, kinds, names, commas, star, star2, keywords - + def parse_member_expr(self, expr: Any) -> Node: dot = self.expect('.') name = self.expect_type(Name) @@ -1420,7 +1434,7 @@ def parse_member_expr(self, expr: Any) -> Node: node = MemberExpr(expr, name.string) self.set_repr(node, noderepr.MemberExprRepr(dot, name)) return node - + def parse_index_expr(self, base: Any) -> IndexExpr: lbracket = self.expect('[') if self.current_str() != ':': @@ -1446,7 +1460,7 @@ def parse_index_expr(self, base: Any) -> IndexExpr: node = IndexExpr(base, index) self.set_repr(node, noderepr.IndexExprRepr(lbracket, rbracket)) return node - + def parse_bin_op_expr(self, left: Node, prec: int) -> OpExpr: op = self.expect_type(Op) op2 = none @@ -1467,7 +1481,7 @@ def parse_bin_op_expr(self, left: Node, prec: int) -> OpExpr: node = OpExpr(op_str, left, right) self.set_repr(node, noderepr.OpExprRepr(op, op2)) return node - + def parse_unary_expr(self) -> UnaryExpr: op_tok = self.skip() op = op_tok.string @@ -1479,11 +1493,11 @@ def parse_unary_expr(self) -> UnaryExpr: node = UnaryExpr(op, expr) self.set_repr(node, noderepr.UnaryExprRepr(op_tok)) return node - + def parse_lambda_expr(self) -> FuncExpr: is_error = False lambda_tok = self.expect('lambda') - + (args, init, kinds, has_inits, arg_names, commas, asterisk, assigns, arg_types) = self.parse_arg_list(allow_signature=False) @@ -1498,14 +1512,14 @@ def parse_lambda_expr(self) -> FuncExpr: ret_type = UnboundType('__builtins__.object') typ = self.build_func_annotation(ret_type, arg_types, kinds, names, lambda_tok.line, is_default_ret=True) - + colon = self.expect(':') - + expr = self.parse_expression(precedence[',']) - + body = Block([ReturnStmt(expr).set_line(lambda_tok)]) body.set_line(colon) - + node = FuncExpr(args, kinds, init, body, typ) self.set_repr(node, noderepr.FuncExprRepr( @@ -1513,59 +1527,59 @@ def parse_lambda_expr(self) -> FuncExpr: noderepr.FuncArgsRepr(none, none, arg_names, commas, assigns, asterisk))) return node - + # Helper methods - + def skip(self) -> Token: self.ind += 1 return self.tok[self.ind - 1] - + def expect(self, string: str) -> Token: if self.current_str() == string: self.ind += 1 return self.tok[self.ind - 1] else: self.parse_error() - + def expect_indent(self) -> Token: if isinstance(self.current(), Indent): return self.expect_type(Indent) else: self.fail('Expected an indented block', self.current().line) return none - + def fail(self, msg: str, line: int) -> None: self.errors.report(line, msg) - + def expect_type(self, typ: type) -> Token: if isinstance(self.current(), typ): self.ind += 1 return self.tok[self.ind - 1] else: self.parse_error() - + def expect_colon_and_break(self) -> Tuple[Token, Token]: return self.expect_type(Colon), self.expect_type(Break) - + def expect_break(self) -> Token: return self.expect_type(Break) - + def expect_end(self) -> Tuple[Token, Token]: return self.expect('end'), self.expect_type(Break) - + def current(self) -> Token: return self.tok[self.ind] - + def current_str(self) -> str: return self.current().string - + def peek(self) -> Token: return self.tok[self.ind + 1] - + def parse_error(self) -> None: self.parse_error_at(self.current()) raise ParseError() - + def parse_error_at(self, tok: Token, skip: bool = True) -> None: msg = '' if isinstance(tok, LexError): @@ -1575,12 +1589,12 @@ def parse_error_at(self, tok: Token, skip: bool = True) -> None: msg = 'Inconsistent indentation' else: msg = 'Parse error before {}'.format(token_repr(tok)) - + self.errors.report(tok.line, msg) - + if skip: self.skip_until_next_line() - + def skip_until_break(self) -> None: n = 0 while (not isinstance(self.current(), Break) @@ -1589,20 +1603,20 @@ def skip_until_break(self) -> None: n += 1 if isinstance(self.tok[self.ind - 1], Colon) and n > 1: self.ind -= 1 - + def skip_until_next_line(self) -> None: self.skip_until_break() if isinstance(self.current(), Break): self.skip() - + def eol(self) -> bool: return isinstance(self.current(), Break) or self.eof() - + def eof(self) -> bool: return isinstance(self.current(), Eof) - + # Type annotation related functionality - + def parse_type(self) -> Type: line = self.current().line try: @@ -1641,16 +1655,16 @@ def parse_type_comment(self, token: Token, signature: bool) -> Type: return None return type else: - return None - + return None + # Representation management - + def set_repr(self, node: Node, repr: Any) -> None: node.repr = repr - + def repr(self, node: Node) -> Any: return node.repr - + def paren_repr(self, e: Node) -> Tuple[List[Token], List[Token]]: """If e is a ParenExpr, return an array of left-paren tokens (more that one if nested parens) and an array of corresponding diff --git a/mypy/traverser.py b/mypy/traverser.py index b27125729181..6d9cde91be57 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -10,7 +10,7 @@ TryStmt, WithStmt, ParenExpr, MemberExpr, OpExpr, SliceExpr, CastExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr, GeneratorExpr, ListComprehension, ConditionalExpr, TypeApplication, - FuncExpr, OverloadedFuncDef + FuncExpr, OverloadedFuncDef, YieldFromStmt ) @@ -27,7 +27,7 @@ class TraverserVisitor(NodeVisitor[T], Generic[T]): """ # Visit methods - + def visit_mypy_file(self, o: MypyFile) -> T: for d in o.defs: d.accept(self) @@ -35,7 +35,7 @@ def visit_mypy_file(self, o: MypyFile) -> T: def visit_block(self, block: Block) -> T: for s in block.body: s.accept(self) - + def visit_func(self, o: FuncItem) -> T: for i in o.init: if i is not None: @@ -43,47 +43,47 @@ def visit_func(self, o: FuncItem) -> T: for v in o.args: self.visit_var(v) o.body.accept(self) - + def visit_func_def(self, o: FuncDef) -> T: self.visit_func(o) def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> T: for item in o.items: item.accept(self) - + def visit_class_def(self, o: ClassDef) -> T: o.defs.accept(self) - + def visit_decorator(self, o: Decorator) -> T: o.func.accept(self) o.var.accept(self) for decorator in o.decorators: decorator.accept(self) - + def visit_var_def(self, o: VarDef) -> T: if o.init is not None: o.init.accept(self) for v in o.items: self.visit_var(v) - + def visit_expression_stmt(self, o: ExpressionStmt) -> T: o.expr.accept(self) - + def visit_assignment_stmt(self, o: AssignmentStmt) -> T: o.rvalue.accept(self) for l in o.lvalues: l.accept(self) - + def visit_operator_assignment_stmt(self, o: OperatorAssignmentStmt) -> T: o.rvalue.accept(self) o.lvalue.accept(self) - + def visit_while_stmt(self, o: WhileStmt) -> T: o.expr.accept(self) o.body.accept(self) if o.else_body: o.else_body.accept(self) - + def visit_for_stmt(self, o: ForStmt) -> T: for ind in o.index: ind.accept(self) @@ -91,23 +91,27 @@ def visit_for_stmt(self, o: ForStmt) -> T: o.body.accept(self) if o.else_body: o.else_body.accept(self) - + def visit_return_stmt(self, o: ReturnStmt) -> T: if o.expr is not None: o.expr.accept(self) - + def visit_assert_stmt(self, o: AssertStmt) -> T: if o.expr is not None: o.expr.accept(self) - + def visit_yield_stmt(self, o: YieldStmt) -> T: if o.expr is not None: o.expr.accept(self) - + + def visit_yield_from_stmt(self, o: YieldFromStmt) -> T: + if o.expr is not None: + o.expr.accept(self) + def visit_del_stmt(self, o: DelStmt) -> T: if o.expr is not None: o.expr.accept(self) - + def visit_if_stmt(self, o: IfStmt) -> T: for e in o.expr: e.accept(self) @@ -115,13 +119,13 @@ def visit_if_stmt(self, o: IfStmt) -> T: b.accept(self) if o.else_body: o.else_body.accept(self) - + def visit_raise_stmt(self, o: RaiseStmt) -> T: if o.expr is not None: o.expr.accept(self) if o.from_expr is not None: o.from_expr.accept(self) - + def visit_try_stmt(self, o: TryStmt) -> T: o.body.accept(self) for i in range(len(o.types)): @@ -132,31 +136,31 @@ def visit_try_stmt(self, o: TryStmt) -> T: o.else_body.accept(self) if o.finally_body is not None: o.finally_body.accept(self) - + def visit_with_stmt(self, o: WithStmt) -> T: for i in range(len(o.expr)): o.expr[i].accept(self) if o.name[i] is not None: o.name[i].accept(self) o.body.accept(self) - + def visit_paren_expr(self, o: ParenExpr) -> T: o.expr.accept(self) - + def visit_member_expr(self, o: MemberExpr) -> T: o.expr.accept(self) - + def visit_call_expr(self, o: CallExpr) -> T: for a in o.args: a.accept(self) o.callee.accept(self) if o.analyzed: o.analyzed.accept(self) - + def visit_op_expr(self, o: OpExpr) -> T: o.left.accept(self) o.right.accept(self) - + def visit_slice_expr(self, o: SliceExpr) -> T: if o.begin_index is not None: o.begin_index.accept(self) @@ -164,36 +168,36 @@ def visit_slice_expr(self, o: SliceExpr) -> T: o.end_index.accept(self) if o.stride is not None: o.stride.accept(self) - + def visit_cast_expr(self, o: CastExpr) -> T: o.expr.accept(self) - + def visit_unary_expr(self, o: UnaryExpr) -> T: o.expr.accept(self) - + def visit_list_expr(self, o: ListExpr) -> T: for item in o.items: item.accept(self) - + def visit_tuple_expr(self, o: TupleExpr) -> T: for item in o.items: item.accept(self) - + def visit_dict_expr(self, o: DictExpr) -> T: for k, v in o.items: k.accept(self) v.accept(self) - + def visit_set_expr(self, o: SetExpr) -> T: for item in o.items: item.accept(self) - + def visit_index_expr(self, o: IndexExpr) -> T: o.base.accept(self) o.index.accept(self) if o.analyzed: o.analyzed.accept(self) - + def visit_generator_expr(self, o: GeneratorExpr) -> T: for index, sequence in zip(o.indices, o.sequences): sequence.accept(self) @@ -202,17 +206,17 @@ def visit_generator_expr(self, o: GeneratorExpr) -> T: o.left_expr.accept(self) if o.condition is not None: o.condition.accept(self) - + def visit_list_comprehension(self, o: ListComprehension) -> T: o.generator.accept(self) - + def visit_conditional_expr(self, o: ConditionalExpr) -> T: o.cond.accept(self) o.if_expr.accept(self) o.else_expr.accept(self) - + def visit_type_application(self, o: TypeApplication) -> T: o.expr.accept(self) - + def visit_func_expr(self, o: FuncExpr) -> T: self.visit_func(o) diff --git a/mypy/visitor.py b/mypy/visitor.py index 051141a60524..50bc20cc6ff3 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -17,21 +17,21 @@ class NodeVisitor(Generic[T]): TODO make the default return value explicit """ - + # Module structure - + def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T: pass - + def visit_import(self, o: 'mypy.nodes.Import') -> T: pass def visit_import_from(self, o: 'mypy.nodes.ImportFrom') -> T: pass def visit_import_all(self, o: 'mypy.nodes.ImportAll') -> T: pass - + # Definitions - + def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T: pass def visit_overloaded_func_def(self, @@ -45,15 +45,15 @@ def visit_global_decl(self, o: 'mypy.nodes.GlobalDecl') -> T: pass def visit_decorator(self, o: 'mypy.nodes.Decorator') -> T: pass - + def visit_var(self, o: 'mypy.nodes.Var') -> T: pass - + # Statements - + def visit_block(self, o: 'mypy.nodes.Block') -> T: pass - + def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> T: pass def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T: @@ -71,6 +71,8 @@ def visit_assert_stmt(self, o: 'mypy.nodes.AssertStmt') -> T: pass def visit_yield_stmt(self, o: 'mypy.nodes.YieldStmt') -> T: pass + def visit_yield_from_stmt(self, o: 'mypy.nodes.YieldFromStmt') -> T: + pass def visit_del_stmt(self, o: 'mypy.nodes.DelStmt') -> T: pass def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T: @@ -89,9 +91,9 @@ def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T: pass def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> T: pass - + # Expressions - + def visit_int_expr(self, o: 'mypy.nodes.IntExpr') -> T: pass def visit_str_expr(self, o: 'mypy.nodes.StrExpr') -> T: @@ -148,13 +150,13 @@ def visit_ducktype_expr(self, o: 'mypy.nodes.DucktypeExpr') -> T: pass def visit_disjointclass_expr(self, o: 'mypy.nodes.DisjointclassExpr') -> T: pass - + def visit_coerce_expr(self, o: 'mypy.nodes.CoerceExpr') -> T: pass def visit_type_expr(self, o: 'mypy.nodes.TypeExpr') -> T: pass def visit_java_cast(self, o: 'mypy.nodes.JavaCast') -> T: pass - + def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> T: pass From 581491e9749af3d526bd532bb59b89ad6d692fae Mon Sep 17 00:00:00 2001 From: Rock Neurotiko Date: Mon, 28 Jul 2014 04:23:18 +0200 Subject: [PATCH 02/12] First steps to the yield from expr --- mypy/nodes.py | 7 +- mypy/output.py | 185 +++++++++++++------------- mypy/semanal.py | 217 ++++++++++++++++--------------- mypy/strconv.py | 145 +++++++++++---------- mypy/test/data/parse-errors.test | 13 +- mypy/test/data/parse.test | 56 ++++++-- mypy/treetransform.py | 133 ++++++++++--------- 7 files changed, 406 insertions(+), 350 deletions(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index 0a6f67810a3a..640fea44c7bb 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -625,7 +625,12 @@ def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_yield_stmt(self) -class YieldFromStmt(YieldStmt): +class YieldFromStmt(Node): + expr = Undefined(Node) + + def __init__(self, expr: Node) -> None: + self.expr = expr + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_yield_from_stmt(self) diff --git a/mypy/output.py b/mypy/output.py index da30434d2765..522eae758c2b 100644 --- a/mypy/output.py +++ b/mypy/output.py @@ -26,15 +26,15 @@ def __init__(self): # break self.extra_indent = 0 self.block_depth = 0 - + def output(self): """Return a string representation of the output.""" return ''.join(self.result) - + def visit_mypy_file(self, o): self.nodes(o.defs) self.token(o.repr.eof) - + def visit_import(self, o): r = o.repr self.token(r.import_tok) @@ -45,13 +45,13 @@ def visit_import(self, o): if i < len(r.commas): self.token(r.commas[i]) self.token(r.br) - + def visit_import_from(self, o): self.output_import_from_or_all(o) - + def visit_import_all(self, o): self.output_import_from_or_all(o) - + def output_import_from_or_all(self, o): r = o.repr self.token(r.from_tok) @@ -63,7 +63,7 @@ def output_import_from_or_all(self, o): self.token(comma) self.token(r.rparen) self.token(r.br) - + def visit_class_def(self, o): r = o.repr self.tokens([r.class_tok, r.name]) @@ -76,7 +76,7 @@ def visit_class_def(self, o): self.token(r.commas[i]) self.token(r.rparen) self.node(o.defs) - + def type_vars(self, v): # IDEA: Combine this with type_vars in TypeOutputVisitor. if v and v.repr: @@ -91,38 +91,38 @@ def type_vars(self, v): if i < len(r.commas): self.token(r.commas[i]) self.token(r.rangle) - + def visit_func_def(self, o): r = o.repr - + if r.def_tok: self.token(r.def_tok) else: self.type(o.type.items()[0].ret_type) - + self.token(r.name) - + self.function_header(o, r.args, o.arg_kinds) - + self.node(o.body) - + def visit_overloaded_func_def(self, o): for f in o.items: f.accept(self) - + def function_header(self, o, arg_repr, arg_kinds, pre_args_func=None, erase_type=False, strip_space_before_first_arg=False): r = o.repr - + t = None if o.type and not erase_type: t = o.type - + init = o.init - + if t: self.type_vars(t.variables) - + self.token(arg_repr.lseparator) if pre_args_func: pre_args_func() @@ -151,7 +151,7 @@ def function_header(self, o, arg_repr, arg_kinds, pre_args_func=None, if i < len(arg_repr.commas): self.token(arg_repr.commas[i]) self.token(arg_repr.rseparator) - + def visit_var_def(self, o): r = o.repr if r: @@ -161,21 +161,21 @@ def visit_var_def(self, o): self.token(r.assign) self.node(o.init) self.token(r.br) - + def visit_var(self, o): r = o.repr self.token(r.name) self.token(r.comma) - + def visit_decorator(self, o): for at, br, dec in zip(o.repr.ats, o.repr.brs, o.decorators): self.token(at) self.node(dec) self.token(br) self.node(o.func) - + # Statements - + def visit_block(self, o): r = o.repr self.tokens([r.colon, r.br, r.indent]) @@ -186,7 +186,7 @@ def visit_block(self, o): self.token(r.dedent) self.indent = old_indent self.block_depth -= 1 - + def visit_global_decl(self, o): r = o.repr self.token(r.global_tok) @@ -195,11 +195,11 @@ def visit_global_decl(self, o): if i < len(r.commas): self.token(r.commas[i]) self.token(r.br) - + def visit_expression_stmt(self, o): self.node(o.expr) self.token(o.repr.br) - + def visit_assignment_stmt(self, o): r = o.repr i = 0 @@ -209,40 +209,43 @@ def visit_assignment_stmt(self, o): i += 1 self.node(o.rvalue) self.token(r.br) - + def visit_operator_assignment_stmt(self, o): r = o.repr self.node(o.lvalue) self.token(r.assign) self.node(o.rvalue) self.token(r.br) - + def visit_return_stmt(self, o): self.simple_stmt(o, o.expr) - + def visit_assert_stmt(self, o): self.simple_stmt(o, o.expr) - + def visit_yield_stmt(self, o): self.simple_stmt(o, o.expr) - + + def visit_yield_from_stmt(self, o): + self.simple_stmt(o, o.expr) + def visit_del_stmt(self, o): self.simple_stmt(o, o.expr) - + def visit_break_stmt(self, o): self.simple_stmt(o) - + def visit_continue_stmt(self, o): self.simple_stmt(o) - + def visit_pass_stmt(self, o): self.simple_stmt(o) - + def simple_stmt(self, o, expr=None): self.token(o.repr.keyword) self.node(expr) self.token(o.repr.br) - + def visit_raise_stmt(self, o): self.token(o.repr.raise_tok) self.node(o.expr) @@ -250,7 +253,7 @@ def visit_raise_stmt(self, o): self.token(o.repr.from_tok) self.node(o.from_expr) self.token(o.repr.br) - + def visit_while_stmt(self, o): self.token(o.repr.while_tok) self.node(o.expr) @@ -258,7 +261,7 @@ def visit_while_stmt(self, o): if o.else_body: self.token(o.repr.else_tok) self.node(o.else_body) - + def visit_for_stmt(self, o): r = o.repr self.token(r.for_tok) @@ -268,12 +271,12 @@ def visit_for_stmt(self, o): self.token(r.commas[i]) self.token(r.in_tok) self.node(o.expr) - + self.node(o.body) if o.else_body: self.token(r.else_tok) self.node(o.else_body) - + def visit_if_stmt(self, o): r = o.repr self.token(r.if_tok) @@ -286,7 +289,7 @@ def visit_if_stmt(self, o): self.token(r.else_tok) if o.else_body: self.node(o.else_body) - + def visit_try_stmt(self, o): r = o.repr self.token(r.try_tok) @@ -303,7 +306,7 @@ def visit_try_stmt(self, o): if o.finally_body: self.token(r.finally_tok) self.node(o.finally_body) - + def visit_with_stmt(self, o): self.token(o.repr.with_tok) for i in range(len(o.expr)): @@ -313,49 +316,49 @@ def visit_with_stmt(self, o): if i < len(o.repr.commas): self.token(o.repr.commas[i]) self.node(o.body) - + # Expressions - + def visit_int_expr(self, o): self.token(o.repr.int) - + def visit_str_expr(self, o): self.tokens(o.repr.string) - + def visit_bytes_expr(self, o): self.tokens(o.repr.string) - + def visit_float_expr(self, o): self.token(o.repr.float) - + def visit_paren_expr(self, o): self.token(o.repr.lparen) self.node(o.expr) self.token(o.repr.rparen) - + def visit_name_expr(self, o): # Supertype references may not have a representation. if o.repr: self.token(o.repr.id) - + def visit_member_expr(self, o): self.node(o.expr) self.token(o.repr.dot) self.token(o.repr.name) - + def visit_index_expr(self, o): self.node(o.base) self.token(o.repr.lbracket) self.node(o.index) self.token(o.repr.rbracket) - + def visit_slice_expr(self, o): self.node(o.begin_index) self.token(o.repr.colon) self.node(o.end_index) self.token(o.repr.colon2) - self.node(o.stride) - + self.node(o.stride) + def visit_call_expr(self, o): r = o.repr self.node(o.callee) @@ -374,41 +377,41 @@ def visit_call_expr(self, o): if i < len(r.commas): self.token(r.commas[i]) self.token(r.rparen) - + def visit_op_expr(self, o): self.node(o.left) self.tokens([o.repr.op, o.repr.op2]) self.node(o.right) - + def visit_cast_expr(self, o): self.token(o.repr.lparen) self.type(o.type) self.token(o.repr.rparen) self.node(o.expr) - + def visit_super_expr(self, o): r = o.repr self.tokens([r.super_tok, r.lparen, r.rparen, r.dot, r.name]) - + def visit_unary_expr(self, o): self.token(o.repr.op) self.node(o.expr) - + def visit_list_expr(self, o): r = o.repr self.token(r.lbracket) self.comma_list(o.items, r.commas) self.token(r.rbracket) - + def visit_set_expr(self, o): self.visit_list_expr(o) - + def visit_tuple_expr(self, o): r = o.repr self.token(r.lparen) self.comma_list(o.items, r.commas) self.token(r.rparen) - + def visit_dict_expr(self, o): r = o.repr self.token(r.lbrace) @@ -421,14 +424,14 @@ def visit_dict_expr(self, o): self.token(r.commas[i]) i += 1 self.token(r.rbrace) - + def visit_func_expr(self, o): r = o.repr self.token(r.lambda_tok) self.function_header(o, r.args, o.arg_kinds) self.token(r.colon) self.node(o.body.body[0].expr) - + def visit_type_application(self, o): self.node(o.expr) self.token(o.repr.langle) @@ -455,12 +458,12 @@ def visit_list_comprehension(self, o): self.token(o.repr.lbracket) self.node(o.generator) self.token(o.repr.rbracket) - + # Helpers - + def line(self): return self.line_number - + def string(self, s): """Output a string.""" if self.omit_next_space: @@ -471,44 +474,44 @@ def string(self, s): if s != '': s = s.replace('\n', '\n' + ' ' * self.extra_indent) self.result.append(s) - + def token(self, t): """Output a token.""" self.string(t.rep()) - + def tokens(self, a): """Output an array of tokens.""" for t in a: self.token(t) - + def node(self, n): """Output a node.""" if n: n.accept(self) - + def nodes(self, a): """Output an array of nodes.""" for n in a: self.node(n) - + def comma_list(self, items, commas): for i in range(len(items)): self.node(items[i]) if i < len(commas): self.token(commas[i]) - + def type_list(self, items, commas): for i in range(len(items)): self.type(items[i]) if i < len(commas): self.token(commas[i]) - + def type(self, t): """Output a type.""" if t: v = TypeOutputVisitor() t.accept(v) self.string(v.output()) - + def last_output_char(self): if self.result and self.result[-1]: return self.result[-1][-1] @@ -521,22 +524,22 @@ class TypeOutputVisitor: """Type visitor that outputs source code.""" def __init__(self): self.result = [] # strings - + def output(self): """Return a string representation of the output.""" return ''.join(self.result) - + def visit_unbound_type(self, t): self.visit_instance(t) - + def visit_any(self, t): if t.repr: self.token(t.repr.any_tok) - + def visit_void(self, t): if t.repr: self.token(t.repr.void) - + def visit_instance(self, t): r = t.repr if isinstance(r, CommonTypeRepr): @@ -549,17 +552,17 @@ def visit_instance(self, t): assert len(t.args) == 1 self.comma_list(t.args, []) self.tokens([r.lbracket, r.rbracket]) - + def visit_type_var(self, t): self.token(t.repr.name) - + def visit_tuple_type(self, t): r = t.repr self.tokens(r.components) self.token(r.langle) self.comma_list(t.items, r.commas) self.token(r.rangle) - + def visit_callable(self, t): r = t.repr self.tokens([r.func, r.langle]) @@ -567,7 +570,7 @@ def visit_callable(self, t): self.token(r.lparen) self.comma_list(t.arg_types, r.commas) self.tokens([r.rparen, r.rangle]) - + def type_vars(self, v): if v and v.repr: r = v.repr @@ -581,26 +584,26 @@ def type_vars(self, v): if i < len(r.commas): self.token(r.commas[i]) self.token(r.rangle) - + # Helpers - + def string(self, s): """Output a string.""" self.result.append(s) - + def token(self, t): """Output a token.""" self.result.append(t.rep()) - + def tokens(self, a): """Output an array of tokens.""" for t in a: self.token(t) - + def type(self, n): """Output a type.""" if n: n.accept(self) - + def comma_list(self, items, commas): for i in range(len(items)): self.type(items[i]) diff --git a/mypy/semanal.py b/mypy/semanal.py index e278f2bfacf3..2b99abd36ee3 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -21,11 +21,11 @@ cyclic references between modules, such as module 'a' that imports module 'b' and used names defined in b *and* vice versa. The first pass can be performed before dependent modules have been processed. - + * SemanticAnalyzer is the second pass. It does the bulk of the work. It assumes that dependent modules have been semantically analyzed, up to the second pass, unless there is a import cycle. - + * ThirdPass checks that type argument counts are valid; for example, it will reject Dict[int]. We don't do this in the second pass, since we infer the type argument counts of classes during this @@ -56,7 +56,7 @@ SymbolTableNode, TVAR, UNBOUND_TVAR, ListComprehension, GeneratorExpr, FuncExpr, MDEF, FuncBase, Decorator, SetExpr, UndefinedExpr, TypeVarExpr, StrExpr, PrintStmt, ConditionalExpr, DucktypeExpr, DisjointclassExpr, - ARG_POS, ARG_NAMED, MroError, type_aliases + ARG_POS, ARG_NAMED, MroError, type_aliases, YieldFromStmt ) from mypy.visitor import NodeVisitor from mypy.traverser import TraverserVisitor @@ -93,7 +93,7 @@ class SemanticAnalyzer(NodeVisitor): This is the second phase of semantic analysis. """ - + # Library search paths lib_path = Undefined(List[str]) # Module name space @@ -117,7 +117,7 @@ class SemanticAnalyzer(NodeVisitor): cur_mod_id = '' # Current module id (or None) (phase 2) imports = Undefined(Set[str]) # Imported modules (during phase 2 analysis) errors = Undefined(Errors) # Keep track of generated errors - + def __init__(self, lib_path: List[str], errors: Errors, pyversion: int = 3) -> None: """Construct semantic analyzer. @@ -137,28 +137,28 @@ def __init__(self, lib_path: List[str], errors: Errors, self.modules = {} self.pyversion = pyversion self.stored_vars = Dict[Node, Type]() - + def visit_file(self, file_node: MypyFile, fnam: str) -> None: self.errors.set_file(fnam) self.globals = file_node.names self.cur_mod_id = file_node.fullname() - + if 'builtins' in self.modules: self.globals['__builtins__'] = SymbolTableNode( MODULE_REF, self.modules['builtins'], self.cur_mod_id) - + defs = file_node.defs for d in defs: d.accept(self) if self.cur_mod_id == 'builtins': remove_imported_names_from_symtable(self.globals, 'builtins') - + def visit_func_def(self, defn: FuncDef) -> None: self.errors.push_function(defn.name()) self.update_function_type_variables(defn) self.errors.pop_function() - + if self.is_class_scope(): # Method definition defn.is_conditional = self.block_depth[-1] > 0 @@ -186,7 +186,7 @@ def visit_func_def(self, defn: FuncDef) -> None: not defn.is_overload): self.add_local_func(defn, defn) defn._fullname = defn.name() - + self.errors.push_function(defn.name()) self.analyse_function(defn) self.errors.pop_function() @@ -249,7 +249,7 @@ def find_type_variables_in_type( def is_defined_type_var(self, tvar: str, context: Node) -> bool: return self.lookup_qualified(tvar, context).kind == TVAR - + def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: t = List[Callable]() for item in defn.items: @@ -261,17 +261,17 @@ def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: if not [dec for dec in item.decorators if refers_to_fullname(dec, 'typing.overload')]: self.fail("'overload' decorator expected", item) - + defn.type = Overloaded(t) defn.type.line = defn.line - + if self.is_class_scope(): self.type.names[defn.name()] = SymbolTableNode(MDEF, defn, typ=defn.type) defn.info = self.type elif self.is_func_scope(): self.add_local_func(defn, defn) - + def analyse_function(self, defn: FuncItem) -> None: is_method = self.is_class_scope() tvarnodes = self.add_func_type_variables_to_symbol_table(defn) @@ -293,18 +293,18 @@ def analyse_function(self, defn: FuncItem) -> None: for init_ in defn.init: if init_: init_.lvalues[0].accept(self) - + # The first argument of a non-static, non-class method is like 'self' # (though the name could be different), having the enclosing class's # instance type. if is_method and not defn.is_static and not defn.is_class and defn.args: defn.args[0].is_self = True - + defn.body.accept(self) disable_typevars(tvarnodes) self.leave() self.function_stack.pop() - + def add_func_type_variables_to_symbol_table( self, defn: FuncItem) -> List[SymbolTableNode]: nodes = List[SymbolTableNode]() @@ -320,13 +320,13 @@ def add_func_type_variables_to_symbol_table( nodes.append(node) names.add(name) return nodes - + def type_var_names(self) -> Set[str]: if not self.type: return set() else: return set(self.type.type_vars) - + def add_type_var(self, fullname: str, id: int, context: Context) -> SymbolTableNode: node = self.lookup_qualified(fullname, context) @@ -340,7 +340,7 @@ def check_function_signature(self, fdef: FuncItem) -> None: self.fail('Type signature has too few arguments', fdef) elif len(sig.arg_types) > len(fdef.args): self.fail('Type signature has too many arguments', fdef) - + def visit_class_def(self, defn: ClassDef) -> None: self.clean_up_bases_and_infer_type_variables(defn) self.setup_class_def_analysis(defn) @@ -355,7 +355,7 @@ def visit_class_def(self, defn: ClassDef) -> None: self.calculate_abstract_status(defn.info) self.setup_ducktyping(defn) - + # Restore analyzer state. self.block_depth.pop() self.locals.pop() @@ -369,7 +369,7 @@ def analyze_class_decorator(self, defn: ClassDef, decorator: Node) -> None: decorator.accept(self) if refers_to_fullname(decorator, 'typing.builtinclass'): defn.is_builtinclass = True - + def calculate_abstract_status(self, typ: TypeInfo) -> None: """Calculate abstract status of a class. @@ -419,7 +419,7 @@ def clean_up_bases_and_infer_type_variables(self, defn: ClassDef) -> None: For example, consider this class: . class Foo(Bar, Generic[t]): ... - + Now we will remove Generic[t] from bases of Foo and infer that the type variable 't' is a type argument of Foo. """ @@ -493,7 +493,7 @@ def setup_class_def_analysis(self, defn: ClassDef) -> None: self.type = defn.info def analyze_base_classes(self, defn: ClassDef) -> None: - """Analyze and set up base classes.""" + """Analyze and set up base classes.""" bases = List[Instance]() for i in range(len(defn.base_types)): base = self.anal_type(defn.base_types[i]) @@ -565,10 +565,10 @@ def object_type(self) -> Instance: def named_type(self, qualified_name: str) -> Instance: sym = self.lookup_qualified(qualified_name, None) return Instance(cast(TypeInfo, sym.node), []) - + def is_instance_type(self, t: Type) -> bool: return isinstance(t, Instance) - + def add_class_type_variables_to_symbol_table( self, info: TypeInfo) -> List[SymbolTableNode]: vars = info.type_vars @@ -578,7 +578,7 @@ def add_class_type_variables_to_symbol_table( node = self.add_type_var(vars[i], i + 1, info) nodes.append(node) return nodes - + def visit_import(self, i: Import) -> None: for id, as_id in i.ids: if as_id != id: @@ -593,7 +593,7 @@ def add_module_symbol(self, id: str, as_id: str, context: Context) -> None: self.add_symbol(as_id, SymbolTableNode(MODULE_REF, m, self.cur_mod_id), context) else: self.add_unknown_symbol(as_id, context) - + def visit_import_from(self, i: ImportFrom) -> None: if i.id in self.modules: m = self.modules[i.id] @@ -617,7 +617,7 @@ def normalize_type_alias(self, node: SymbolTableNode, # Node refers to an aliased type such as typing.List; normalize. node = self.lookup_qualified(type_aliases[node.fullname], ctx) return node - + def visit_import_all(self, i: ImportAll) -> None: if i.id in self.modules: m = self.modules[i.id] @@ -636,11 +636,11 @@ def add_unknown_symbol(self, name: str, context: Context) -> None: var.is_ready = True var.type = AnyType() self.add_symbol(name, SymbolTableNode(GDEF, var, self.cur_mod_id), context) - + # # Statements # - + def visit_block(self, b: Block) -> None: if b.is_unreachable: return @@ -648,15 +648,15 @@ def visit_block(self, b: Block) -> None: for s in b.body: s.accept(self) self.block_depth[-1] -= 1 - + def visit_block_maybe(self, b: Block) -> None: if b: self.visit_block(b) - + def visit_var_def(self, defn: VarDef) -> None: for i in range(len(defn.items)): defn.items[i].type = self.anal_type(defn.items[i].type) - + for v in defn.items: if self.is_func_scope(): defn.kind = LDEF @@ -669,17 +669,17 @@ def visit_var_def(self, defn: VarDef) -> None: elif v.name not in self.globals: defn.kind = GDEF self.add_var(v, defn) - + if defn.init: defn.init.accept(self) - + def anal_type(self, t: Type) -> Type: if t: a = TypeAnalyser(self.lookup_qualified, self.stored_vars, self.fail) return t.accept(a) else: return None - + def visit_assignment_stmt(self, s: AssignmentStmt) -> None: for lval in s.lvalues: self.analyse_lvalue(lval, explicit_type=s.type is not None) @@ -722,7 +722,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> None: # that refers to a type, rather than making this # just an alias for the type. self.globals[lvalue.name].node = node - + def analyse_lvalue(self, lval: Node, nested: bool = False, add_global: bool = False, explicit_type: bool = False) -> None: @@ -798,7 +798,7 @@ def analyse_lvalue(self, lval: Node, nested: bool = False, explicit_type = explicit_type) else: self.fail('Invalid assignment target', lval) - + def analyse_member_lvalue(self, lval: MemberExpr) -> None: lval.accept(self) if (self.is_self_member_ref(lval) and @@ -848,7 +848,7 @@ def store_declared_types(self, lvalue: Node, typ: Type) -> None: self.store_declared_types(item, itemtype) else: self.fail('Tuple type expected for multiple variables', - lvalue) + lvalue) elif isinstance(lvalue, ParenExpr): self.store_declared_types(lvalue.expr, typ) else: @@ -903,7 +903,7 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> None: else: self.fail('The values argument must be a tuple literal', s) return - else: + else: self.fail('The values argument must be in parentheses (...)', s) return @@ -976,20 +976,20 @@ def check_decorated_function_is_method(self, decorator: str, context: Context) -> None: if not self.type or self.is_func_scope(): self.fail("'%s' used with a non-method" % decorator, context) - + def visit_expression_stmt(self, s: ExpressionStmt) -> None: s.expr.accept(self) - + def visit_return_stmt(self, s: ReturnStmt) -> None: if not self.is_func_scope(): self.fail("'return' outside function", s) if s.expr: s.expr.accept(self) - + def visit_raise_stmt(self, s: RaiseStmt) -> None: if s.expr: s.expr.accept(self) - + def visit_yield_stmt(self, s: YieldStmt) -> None: if not self.is_func_scope(): self.fail("'yield' outside function", s) @@ -997,30 +997,37 @@ def visit_yield_stmt(self, s: YieldStmt) -> None: self.function_stack[-1].is_generator = True if s.expr: s.expr.accept(self) - + + def visit_yield_from_stmt(self, s: YieldFromStmt) -> None: + if not self.is_func_scope(): + self.fail("'yield from' outside function", s) + #Check coroutine?? + if s.expr: + s.expr.accept(self) + def visit_assert_stmt(self, s: AssertStmt) -> None: if s.expr: s.expr.accept(self) - + def visit_operator_assignment_stmt(self, s: OperatorAssignmentStmt) -> None: s.lvalue.accept(self) s.rvalue.accept(self) - + def visit_while_stmt(self, s: WhileStmt) -> None: s.expr.accept(self) self.loop_depth += 1 s.body.accept(self) self.loop_depth -= 1 self.visit_block_maybe(s.else_body) - + def visit_for_stmt(self, s: ForStmt) -> None: s.expr.accept(self) - + # Bind index variables and check if they define new names. for n in s.index: self.analyse_lvalue(n) - + # Analyze index variable types. for i in range(len(s.types)): t = s.types[i] @@ -1029,32 +1036,32 @@ def visit_for_stmt(self, s: ForStmt) -> None: v = cast(Var, s.index[i].node) # TODO check if redefinition v.type = s.types[i] - + # Report error if only some of the loop variables have annotations. if s.types != [None] * len(s.types) and None in s.types: self.fail('Cannot mix unannotated and annotated loop variables', s) - + self.loop_depth += 1 self.visit_block(s.body) self.loop_depth -= 1 - + self.visit_block_maybe(s.else_body) - + def visit_break_stmt(self, s: BreakStmt) -> None: if self.loop_depth == 0: self.fail("'break' outside loop", s) - + def visit_continue_stmt(self, s: ContinueStmt) -> None: if self.loop_depth == 0: self.fail("'continue' outside loop", s) - + def visit_if_stmt(self, s: IfStmt) -> None: infer_reachability_of_if_statement(s, pyversion=self.pyversion) for i in range(len(s.expr)): s.expr[i].accept(self) self.visit_block(s.body[i]) self.visit_block_maybe(s.else_body) - + def visit_try_stmt(self, s: TryStmt) -> None: self.analyze_try_stmt(s, self) @@ -1071,7 +1078,7 @@ def analyze_try_stmt(self, s: TryStmt, visitor: NodeVisitor, s.else_body.accept(visitor) if s.finally_body: s.finally_body.accept(visitor) - + def visit_with_stmt(self, s: WithStmt) -> None: for e in s.expr: e.accept(self) @@ -1079,12 +1086,12 @@ def visit_with_stmt(self, s: WithStmt) -> None: if n: self.analyse_lvalue(n) self.visit_block(s.body) - + def visit_del_stmt(self, s: DelStmt) -> None: s.expr.accept(self) if not isinstance(s.expr, (IndexExpr, NameExpr, MemberExpr)): self.fail('Invalid delete target', s) - + def visit_global_decl(self, g: GlobalDecl) -> None: for n in g.names: self.global_decls[-1].add(n) @@ -1092,11 +1099,11 @@ def visit_global_decl(self, g: GlobalDecl) -> None: def visit_print_stmt(self, s: PrintStmt) -> None: for arg in s.args: arg.accept(self) - + # # Expressions # - + def visit_name_expr(self, expr: NameExpr) -> None: n = self.lookup(expr.name, expr) if n: @@ -1107,33 +1114,33 @@ def visit_name_expr(self, expr: NameExpr) -> None: expr.kind = n.kind expr.node = (cast(Node, n.node)) expr.fullname = n.fullname - + def visit_super_expr(self, expr: SuperExpr) -> None: if not self.type: self.fail('"super" used outside class', expr) - return + return expr.info = self.type - + def visit_tuple_expr(self, expr: TupleExpr) -> None: for item in expr.items: item.accept(self) - + def visit_list_expr(self, expr: ListExpr) -> None: for item in expr.items: item.accept(self) - + def visit_set_expr(self, expr: SetExpr) -> None: for item in expr.items: item.accept(self) - + def visit_dict_expr(self, expr: DictExpr) -> None: for key, value in expr.items: key.accept(self) value.accept(self) - + def visit_paren_expr(self, expr: ParenExpr) -> None: expr.expr.accept(self) - + def visit_call_expr(self, expr: CallExpr) -> None: """Analyze a call expression. @@ -1159,7 +1166,7 @@ def visit_call_expr(self, expr: CallExpr) -> None: elif refers_to_fullname(expr.callee, 'typing.Any'): # Special form Any(...). if not self.check_fixed_args(expr, 1, 'Any'): - return + return expr.analyzed = CastExpr(expr.args[0], AnyType()) expr.analyzed.line = expr.line expr.analyzed.accept(self) @@ -1222,7 +1229,7 @@ def check_fixed_args(self, expr: CallExpr, numargs: int, (name, numargs, s), expr) return False return True - + def visit_member_expr(self, expr: MemberExpr) -> None: base = expr.expr base.accept(self) @@ -1238,14 +1245,14 @@ def visit_member_expr(self, expr: MemberExpr) -> None: expr.kind = n.kind expr.fullname = n.fullname expr.node = n.node - + def visit_op_expr(self, expr: OpExpr) -> None: expr.left.accept(self) expr.right.accept(self) - + def visit_unary_expr(self, expr: UnaryExpr) -> None: expr.expr.accept(self) - + def visit_index_expr(self, expr: IndexExpr) -> None: expr.base.accept(self) if refers_to_class_or_function(expr.base): @@ -1276,14 +1283,14 @@ def visit_slice_expr(self, expr: SliceExpr) -> None: expr.end_index.accept(self) if expr.stride: expr.stride.accept(self) - + def visit_cast_expr(self, expr: CastExpr) -> None: expr.expr.accept(self) expr.type = self.anal_type(expr.type) def visit_undefined_expr(self, expr: UndefinedExpr) -> None: expr.type = self.anal_type(expr.type) - + def visit_type_application(self, expr: TypeApplication) -> None: expr.expr.accept(self) for i in range(len(expr.types)): @@ -1321,11 +1328,11 @@ def visit_ducktype_expr(self, expr: DucktypeExpr) -> None: def visit_disjointclass_expr(self, expr: DisjointclassExpr) -> None: expr.cls.accept(self) - + # # Helpers # - + def lookup(self, name: str, ctx: Context) -> SymbolTableNode: """Look up an unqualified name in all active namespaces.""" # 1. Name declared using 'global x' takes precedence @@ -1358,7 +1365,7 @@ def lookup(self, name: str, ctx: Context) -> SymbolTableNode: # Give up. self.name_not_defined(name, ctx) return None - + def lookup_qualified(self, name: str, ctx: Context) -> SymbolTableNode: if '.' not in name: return self.lookup(name, ctx) @@ -1376,14 +1383,14 @@ def lookup_qualified(self, name: str, ctx: Context) -> SymbolTableNode: if n: n = self.normalize_type_alias(n, ctx) return n - + def qualified_name(self, n: str) -> str: return self.cur_mod_id + '.' + n - + def enter(self) -> None: self.locals.append(SymbolTable()) self.global_decls.append(set()) - + def leave(self) -> None: self.locals.pop() self.global_decls.pop() @@ -1412,14 +1419,14 @@ def add_symbol(self, name: str, node: SymbolTableNode, # of multiple submodules of a package (e.g. a.x and a.y). self.name_already_defined(name, context) self.globals[name] = node - + def add_var(self, v: Var, ctx: Context) -> None: if self.is_func_scope(): self.add_local(v, ctx) else: self.globals[v.name()] = SymbolTableNode(GDEF, v, self.cur_mod_id) v._fullname = self.qualified_name(v.name()) - + def add_local(self, v: Var, ctx: Context) -> None: if v.name() in self.locals[-1]: self.name_already_defined(v.name(), ctx) @@ -1431,7 +1438,7 @@ def add_local_func(self, defn: FuncBase, ctx: Context) -> None: if defn.name() in self.locals[-1]: self.name_already_defined(defn.name(), ctx) self.locals[-1][defn.name()] = SymbolTableNode(LDEF, defn) - + def check_no_global(self, n: str, ctx: Context, is_func: bool = False) -> None: if n in self.globals: @@ -1440,20 +1447,20 @@ def check_no_global(self, n: str, ctx: Context, "must be next to each other)").format(n), ctx) else: self.name_already_defined(n, ctx) - + def name_not_defined(self, name: str, ctx: Context) -> None: self.fail("Name '{}' is not defined".format(name), ctx) - + def name_already_defined(self, name: str, ctx: Context) -> None: self.fail("Name '{}' already defined".format(name), ctx) - + def fail(self, msg: str, ctx: Context) -> None: self.errors.report(ctx.get_line(), msg) class FirstPass(NodeVisitor): """First phase of semantic analysis""" - + def __init__(self, sem: SemanticAnalyzer) -> None: self.sem = sem self.pyversion = sem.pyversion @@ -1475,15 +1482,15 @@ def analyze(self, file: MypyFile, fnam: str, mod_id: str) -> None: sem.block_depth = [0] defs = file.defs - + # Add implicit definitions of module '__name__' etc. for n in implicit_module_attrs: name_def = VarDef([Var(n, AnyType())], True) defs.insert(0, name_def) - + for d in defs: d.accept(self) - + # Add implicit definition of 'None' to builtins, as we cannot define a # variable with a None type explicitly. if mod_id == 'builtins': @@ -1498,12 +1505,12 @@ def visit_block(self, b: Block) -> None: for node in b.body: node.accept(self) self.sem.block_depth[-1] -= 1 - + def visit_assignment_stmt(self, s: AssignmentStmt) -> None: for lval in s.lvalues: self.sem.analyse_lvalue(lval, add_global=True, explicit_type=s.type is not None) - + def visit_func_def(self, d: FuncDef) -> None: sem = self.sem d.is_conditional = sem.block_depth[-1] > 0 @@ -1516,13 +1523,13 @@ def visit_func_def(self, d: FuncDef) -> None: sem.check_no_global(d.name(), d, True) d._fullname = sem.qualified_name(d.name()) sem.globals[d.name()] = SymbolTableNode(GDEF, d, sem.cur_mod_id) - + def visit_overloaded_func_def(self, d: OverloadedFuncDef) -> None: self.sem.check_no_global(d.name(), d) d._fullname = self.sem.qualified_name(d.name()) self.sem.globals[d.name()] = SymbolTableNode(GDEF, d, self.sem.cur_mod_id) - + def visit_class_def(self, d: ClassDef) -> None: self.sem.check_no_global(d.name, d) d.fullname = self.sem.qualified_name(d.name) @@ -1531,7 +1538,7 @@ def visit_class_def(self, d: ClassDef) -> None: d.info = info self.sem.globals[d.name] = SymbolTableNode(GDEF, info, self.sem.cur_mod_id) - + def visit_var_def(self, d: VarDef) -> None: for v in d.items: self.sem.check_no_global(v.name(), d) @@ -1569,14 +1576,14 @@ class ThirdPass(TraverserVisitor[None]): Check type argument counts and values of generic types. Also update TypeInfo disjointclass information. """ - + def __init__(self, errors: Errors) -> None: self.errors = errors - + def visit_file(self, file_node: MypyFile, fnam: str) -> None: self.errors.set_file(fnam) file_node.accept(self) - + def visit_func_def(self, fdef: FuncDef) -> None: self.errors.push_function(fdef.name()) self.analyze(fdef.type) @@ -1617,7 +1624,7 @@ def analyze(self, type: Type) -> None: if type: analyzer = TypeAnalyserPass3(self.fail) type.accept(analyzer) - + def fail(self, msg: str, ctx: Context) -> None: self.errors.report(ctx.get_line(), msg) @@ -1822,7 +1829,7 @@ def mark_block_unreachable(block: Block) -> None: class MarkImportsUnreachableVisitor(TraverserVisitor): """Visitor that flags all imports nested within a node as unreachable.""" - + def visit_import(self, node: Import) -> None: node.is_unreachable = True diff --git a/mypy/strconv.py b/mypy/strconv.py index 37b85408e187..c55862892f3f 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -12,10 +12,10 @@ class StrConv(NodeVisitor[str]): """Visitor for converting a Node to a human-readable string. - + For example, an MypyFile node from program '1' is converted into something like this: - + MypyFile:1( fnam ExpressionStmt:1( @@ -29,7 +29,7 @@ def dump(self, nodes, obj): argument. """ return dump_tagged(nodes, short_type(obj) + ':' + str(obj.line)) - + def func_helper(self, o): """Return a list in a format suitable for dump() that represents the arguments and the body of a function. The caller can then decorate the @@ -60,10 +60,10 @@ def func_helper(self, o): a.append('Generator') a.extend(extra) a.append(o.body) - return a - + return a + # Top-level structures - + def visit_mypy_file(self, o): # Skip implicit definitions. defs = o.defs @@ -81,24 +81,24 @@ def visit_mypy_file(self, o): # case# output in all platforms. a.insert(0, o.path.replace(os.sep, '/')) return self.dump(a, o) - + def visit_import(self, o): a = [] for id, as_id in o.ids: a.append('{} : {}'.format(id, as_id)) return 'Import:{}({})'.format(o.line, ', '.join(a)) - + def visit_import_from(self, o): a = [] for name, as_name in o.names: a.append('{} : {}'.format(name, as_name)) return 'ImportFrom:{}({}, [{}])'.format(o.line, o.id, ', '.join(a)) - + def visit_import_all(self, o): return 'ImportAll:{}({})'.format(o.line, o.id) - + # Definitions - + def visit_func_def(self, o): a = self.func_helper(o) a.insert(0, o.name()) @@ -113,13 +113,13 @@ def visit_func_def(self, o): if o.is_property: a.insert(-1, 'Property') return self.dump(a, o) - + def visit_overloaded_func_def(self, o): a = o.items[:] if o.type: a.insert(0, o.type) return self.dump(a, o) - + def visit_class_def(self, o): a = [o.name, o.defs.body] # Display base types unless they are implicitly just builtins.object @@ -141,7 +141,7 @@ def visit_class_def(self, o): a.insert(1, ('Disjointclasses', [info.fullname() for info in o.info.disjoint_classes])) return self.dump(a, o) - + def visit_var_def(self, o): a = [] for n in o.items: @@ -150,7 +150,7 @@ def visit_var_def(self, o): if o.init: a.append(o.init) return self.dump(a, o) - + def visit_var(self, o): l = '' # Add :nil line number tag if no line number is specified to remain @@ -158,24 +158,24 @@ def visit_var(self, o): if o.line < 0: l = ':nil' return 'Var' + l + '(' + o.name() + ')' - + def visit_global_decl(self, o): return self.dump([o.names], o) - + def visit_decorator(self, o): return self.dump([o.var, o.decorators, o.func], o) - + def visit_annotation(self, o): return 'Type:{}({})'.format(o.line, o.type) - + # Statements - + def visit_block(self, o): return self.dump(o.body, o) - + def visit_expression_stmt(self, o): return self.dump([o.expr], o) - + def visit_assignment_stmt(self, o): if len(o.lvalues) > 1: a = [('Lvalues', o.lvalues)] @@ -185,16 +185,16 @@ def visit_assignment_stmt(self, o): if o.type: a.append(o.type) return self.dump(a, o) - + def visit_operator_assignment_stmt(self, o): return self.dump([o.op, o.lvalue, o.rvalue], o) - + def visit_while_stmt(self, o): a = [o.expr, o.body] if o.else_body: a.append(('Else', o.else_body.body)) return self.dump(a, o) - + def visit_for_stmt(self, o): a = [o.index] if o.types != [None] * len(o.types): @@ -203,58 +203,61 @@ def visit_for_stmt(self, o): if o.else_body: a.append(('Else', o.else_body.body)) return self.dump(a, o) - + def visit_return_stmt(self, o): return self.dump([o.expr], o) - + def visit_if_stmt(self, o): a = [] for i in range(len(o.expr)): a.append(('If', [o.expr[i]])) a.append(('Then', o.body[i].body)) - + if not o.else_body: return self.dump(a, o) else: return self.dump([a, ('Else', o.else_body.body)], o) - + def visit_break_stmt(self, o): return self.dump([], o) - + def visit_continue_stmt(self, o): return self.dump([], o) - + def visit_pass_stmt(self, o): return self.dump([], o) - + def visit_raise_stmt(self, o): return self.dump([o.expr, o.from_expr], o) - + def visit_assert_stmt(self, o): return self.dump([o.expr], o) - + def visit_yield_stmt(self, o): return self.dump([o.expr], o) - + + def visit_yield_from_stmt(self, o): + return self.dump([o.expr], o) + def visit_del_stmt(self, o): return self.dump([o.expr], o) - + def visit_try_stmt(self, o): a = [o.body] - + for i in range(len(o.vars)): a.append(o.types[i]) if o.vars[i]: a.append(o.vars[i]) a.append(o.handlers[i]) - + if o.else_body: a.append(('Else', o.else_body.body)) if o.finally_body: a.append(('Finally', o.finally_body.body)) - + return self.dump(a, o) - + def visit_with_stmt(self, o): a = [] for i in range(len(o.expr)): @@ -268,39 +271,39 @@ def visit_print_stmt(self, o): if o.newline: a.append('Newline') return self.dump(a, o) - + # Expressions - + # Simple expressions - + def visit_int_expr(self, o): return 'IntExpr({})'.format(o.value) - + def visit_str_expr(self, o): return 'StrExpr({})'.format(self.str_repr(o.value)) - + def visit_bytes_expr(self, o): return 'BytesExpr({})'.format(self.str_repr(o.value)) - + def visit_unicode_expr(self, o): return 'UnicodeExpr({})'.format(self.str_repr(o.value)) - + def str_repr(self, s): s = re.sub(r'\\u[0-9a-fA-F]{4}', lambda m: '\\' + m.group(0), s) return re.sub('[^\\x20-\\x7e]', lambda m: r'\u%.4x' % ord(m.group(0)), s) - + def visit_float_expr(self, o): return 'FloatExpr({})'.format(o.value) - + def visit_paren_expr(self, o): return self.dump([o.expr], o) - + def visit_name_expr(self, o): return (short_type(o) + '(' + self.pretty_name(o.name, o.kind, o.fullname, o.is_def) + ')') - + def pretty_name(self, name, kind, fullname, is_def): n = name if is_def: @@ -316,11 +319,11 @@ def pretty_name(self, name, kind, fullname, is_def): # Add tag to signify a member reference. n += ' [m]' return n - + def visit_member_expr(self, o): return self.dump([o.expr, self.pretty_name(o.name, o.kind, o.fullname, o.is_def)], o) - + def visit_call_expr(self, o): if o.analyzed: return o.analyzed.accept(self) @@ -339,39 +342,39 @@ def visit_call_expr(self, o): raise RuntimeError('unknown kind %d' % kind) return self.dump([o.callee, ('Args', args)] + extra, o) - + def visit_op_expr(self, o): return self.dump([o.op, o.left, o.right], o) - + def visit_cast_expr(self, o): return self.dump([o.expr, o.type], o) - + def visit_unary_expr(self, o): return self.dump([o.op, o.expr], o) - + def visit_list_expr(self, o): return self.dump(o.items, o) - + def visit_dict_expr(self, o): return self.dump([[k, v] for k, v in o.items], o) - + def visit_set_expr(self, o): return self.dump(o.items, o) - + def visit_tuple_expr(self, o): return self.dump(o.items, o) - + def visit_index_expr(self, o): if o.analyzed: return o.analyzed.accept(self) return self.dump([o.base, o.index], o) - + def visit_super_expr(self, o): return self.dump([o.name], o) def visit_undefined_expr(self, o): return 'UndefinedExpr:{}({})'.format(o.line, o.type) - + def visit_type_application(self, o): return self.dump([o.expr, ('Types', o.types)], o) @@ -386,21 +389,21 @@ def visit_ducktype_expr(self, o): def visit_disjointclass_expr(self, o): return 'DisjointclassExpr:{}({})'.format(o.line, o.cls.fullname) - + def visit_func_expr(self, o): a = self.func_helper(o) return self.dump(a, o) - + def visit_generator_expr(self, o): # FIX types return self.dump([o.left_expr, o.indices, o.sequences, o.condition], o) - + def visit_list_comprehension(self, o): return self.dump([o.generator], o) - + def visit_conditional_expr(self, o): return self.dump([('Condition', [o.cond]), o.if_expr, o.else_expr], o) - + def visit_slice_expr(self, o): a = [o.begin_index, o.end_index, o.stride] if not a[0]: @@ -408,14 +411,14 @@ def visit_slice_expr(self, o): if not a[1]: a[1] = '' return self.dump(a, o) - + def visit_coerce_expr(self, o): return self.dump([o.expr, ('Types', [o.target_type, o.source_type])], o) - + def visit_type_expr(self, o): return self.dump([str(o.type)], o) - + def visit_filter_node(self, o): # These are for convenience. These node types are not defined in the # parser module. diff --git a/mypy/test/data/parse-errors.test b/mypy/test/data/parse-errors.test index a354e947da22..b9df789efac7 100644 --- a/mypy/test/data/parse-errors.test +++ b/mypy/test/data/parse-errors.test @@ -156,14 +156,14 @@ file: In function "f": file, line 1: Invalid argument list [case testInvalidFuncDefArgs4] -def f(**x, y=x): +def f(**x, y=x): pass [out] file: In function "f": file, line 1: Invalid argument list [case testInvalidStringLiteralType] -def f(x: +def f(x: 'A[' ) -> None: pass [out] @@ -172,7 +172,7 @@ file, line 2: Parse error before end of line file, line 3: Parse error before end of line [case testInvalidStringLiteralType2] -def f(x: +def f(x: 'A B' ) -> None: pass [out] @@ -340,3 +340,10 @@ file, line 1: Parse error before "for" 1 if x else for y in z [out] file, line 1: Parse error before "for" + +[case testYieldFromNotRightParameter] +def f(): + yield from +[out] +file: In function "f": +file, line 2: Parse error before end of line \ No newline at end of file diff --git a/mypy/test/data/parse.test b/mypy/test/data/parse.test index c48a60d0a574..4dce5a5cb526 100644 --- a/mypy/test/data/parse.test +++ b/mypy/test/data/parse.test @@ -339,7 +339,7 @@ MypyFile:1( Block:1( ReturnStmt:2( IntExpr(1))))) - + [case testReturnWithoutValue] def f(): @@ -689,7 +689,7 @@ MypyFile:1( AssignmentStmt:1( NameExpr(x) IntExpr(1))) - + [case testInvalidAnnotation] x=1 ##type: int y=1 #.type: int @@ -870,7 +870,7 @@ MypyFile:1( PassStmt:2()) Block:3( RaiseStmt:4()))) - + [case testRaiseFrom] raise e from x [out] @@ -993,7 +993,7 @@ MypyFile:1( Import:2(y.z.foo : y.z.foo, __foo__.bar : __foo__.bar)) [case testVariableTypeWithQualifiedName] -x = None # type: x.y +x = None # type: x.y [out] MypyFile:1( AssignmentStmt:1( @@ -1172,7 +1172,7 @@ MypyFile:1( NameExpr(x) NameExpr(y) NameExpr(z))))) - + [case testComplexListComprehension] x=[(x, y) for y, z in 1, 2] [out] @@ -1295,6 +1295,34 @@ MypyFile:1( NameExpr(x) IntExpr(1)))))) +[case testYieldFrom] +def f(): + yield from h() +[out] +MypyFile:1( + FuncDef:1( + f + Block:1( + YieldFromStmt:2( + CallExpr:2( + NameExpr(h) + Args()))))) + +[case testYieldFromAssignment] +def f(): + a = yield from h() +[out] +MypyFile:1( + FuncDef:1( + f + Block:1( + AssignmentStmt:2( + NameExpr(a) + CallExpr:2( + NameExpr(h) + Args()))))) + + [case testDel] del x del x[0], y[1] @@ -1758,7 +1786,7 @@ MypyFile:1( Var(a)) Block:1( PassStmt:1()))) - + [case testDictVarArgs] def f(x, **a): pass [out] @@ -1795,7 +1823,7 @@ MypyFile:1( Var(b)) Block:2( PassStmt:2()))) - + [case testDictVarArgsWithType] def f(x: X, **a: A) -> None: pass [out] @@ -1923,7 +1951,7 @@ MypyFile:1( SetExpr:2( IntExpr(1) IntExpr(2)))) - + [case testSetLiteralWithExtraComma] {x,} [out] @@ -2196,7 +2224,7 @@ f = None # type: Function[[], None] MypyFile:1( AssignmentStmt:1( NameExpr(f) - NameExpr(None) + NameExpr(None) Function?[, None?])) [case testFunctionTypeWithArgument] @@ -2205,7 +2233,7 @@ f = None # type: Function[[str], int] MypyFile:1( AssignmentStmt:1( NameExpr(f) - NameExpr(None) + NameExpr(None) Function?[, int?])) [case testFunctionTypeWithTwoArguments] @@ -2214,7 +2242,7 @@ f = None # type: Function[[a[b], x.y], List[int]] MypyFile:1( AssignmentStmt:1( NameExpr(f) - NameExpr(None) + NameExpr(None) Function?[, List?[int?]])) [case testFunctionTypeWithExtraComma] @@ -2407,7 +2435,7 @@ MypyFile:1( Var(y)) Block:6( PassStmt:6()))))) - + [case testDecoratorsThatAreNotOverloads] @foo def f() -> x: pass @@ -2545,9 +2573,9 @@ MypyFile:1( [case testCommentFunctionAnnotationOnSeparateLine2] def f(x): - + # type: (X) -> Y # bar - + pass [out] MypyFile:1( diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 3c11004d68d8..f6cde34880f9 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -16,7 +16,7 @@ UnicodeExpr, FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, UnaryExpr, FuncExpr, TypeApplication, PrintStmt, SymbolTable, RefExpr, UndefinedExpr, TypeVarExpr, DucktypeExpr, - DisjointclassExpr, CoerceExpr, TypeExpr, JavaCast, TempNode + DisjointclassExpr, CoerceExpr, TypeExpr, JavaCast, TempNode, YieldFromStmt ) from mypy.types import Type from mypy.visitor import NodeVisitor @@ -45,7 +45,7 @@ def __init__(self) -> None: # There may be multiple references to a Var node. Keep track of # Var translations using a dictionary. self.var_map = Dict[Var, Var]() - + def visit_mypy_file(self, node: MypyFile) -> Node: # NOTE: The 'names' and 'imports' instance variables will be empty! new = MypyFile(self.nodes(node.defs), [], node.is_bom) @@ -54,16 +54,16 @@ def visit_mypy_file(self, node: MypyFile) -> Node: new.path = node.path new.names = SymbolTable() return new - + def visit_import(self, node: Import) -> Node: return Import(node.ids[:]) - + def visit_import_from(self, node: ImportFrom) -> Node: return ImportFrom(node.id, node.names[:]) - + def visit_import_all(self, node: ImportAll) -> Node: return ImportAll(node.id) - + def visit_func_def(self, node: FuncDef) -> FuncDef: # Note that a FuncDef must be transformed to a FuncDef. new = FuncDef(node.name(), @@ -74,7 +74,7 @@ def visit_func_def(self, node: FuncDef) -> FuncDef: self.optional_type(node.type)) self.copy_function_attributes(new, node) - + new._fullname = node._fullname new.is_decorated = node.is_decorated new.is_conditional = node.is_conditional @@ -84,7 +84,7 @@ def visit_func_def(self, node: FuncDef) -> FuncDef: new.is_property = node.is_property new.original_def = node.original_def return new - + def visit_func_expr(self, node: FuncExpr) -> Node: new = FuncExpr([self.visit_var(var) for var in node.args], node.arg_kinds[:], @@ -113,7 +113,7 @@ def duplicate_inits(self, else: result.append(None) return result - + def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> Node: items = [self.visit_decorator(decorator) for decorator in node.items] @@ -124,7 +124,7 @@ def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> Node: new.type = self.type(node.type) new.info = node.info return new - + def visit_class_def(self, node: ClassDef) -> Node: new = ClassDef(node.name, self.block(node.defs), @@ -137,20 +137,20 @@ def visit_class_def(self, node: ClassDef) -> Node: for decorator in node.decorators] new.is_builtinclass = node.is_builtinclass return new - + def visit_var_def(self, node: VarDef) -> Node: new = VarDef([self.visit_var(var) for var in node.items], node.is_top_level, self.optional_node(node.init)) new.kind = node.kind return new - + def visit_global_decl(self, node: GlobalDecl) -> Node: return GlobalDecl(node.names[:]) - + def visit_block(self, node: Block) -> Block: return Block(self.nodes(node.body)) - + def visit_decorator(self, node: Decorator) -> Decorator: # Note that a Decorator must be transformed to a Decorator. func = self.visit_func_def(node.func) @@ -159,7 +159,7 @@ def visit_decorator(self, node: Decorator) -> Decorator: self.visit_var(node.var)) new.is_overload = node.is_overload return new - + def visit_var(self, node: Var) -> Var: # Note that a Var must be transformed to a Var. if node in self.var_map: @@ -177,68 +177,71 @@ def visit_var(self, node: Var) -> Var: new.set_line(node.line) self.var_map[node] = new return new - + def visit_expression_stmt(self, node: ExpressionStmt) -> Node: return ExpressionStmt(self.node(node.expr)) - + def visit_assignment_stmt(self, node: AssignmentStmt) -> Node: return self.duplicate_assignment(node) - + def duplicate_assignment(self, node: AssignmentStmt) -> AssignmentStmt: new = AssignmentStmt(self.nodes(node.lvalues), self.node(node.rvalue), self.optional_type(node.type)) new.line = node.line return new - + def visit_operator_assignment_stmt(self, node: OperatorAssignmentStmt) -> Node: return OperatorAssignmentStmt(node.op, self.node(node.lvalue), self.node(node.rvalue)) - + def visit_while_stmt(self, node: WhileStmt) -> Node: return WhileStmt(self.node(node.expr), self.block(node.body), self.optional_block(node.else_body)) - + def visit_for_stmt(self, node: ForStmt) -> Node: return ForStmt(self.names(node.index), self.node(node.expr), self.block(node.body), self.optional_block(node.else_body), self.optional_types(node.types)) - + def visit_return_stmt(self, node: ReturnStmt) -> Node: return ReturnStmt(self.optional_node(node.expr)) - + def visit_assert_stmt(self, node: AssertStmt) -> Node: return AssertStmt(self.node(node.expr)) - + def visit_yield_stmt(self, node: YieldStmt) -> Node: return YieldStmt(self.node(node.expr)) - + + def visit_yield_from_stmt(self, node: YieldFromStmt) -> Node: + return YieldFromStmt(self.node(node.expr)) + def visit_del_stmt(self, node: DelStmt) -> Node: return DelStmt(self.node(node.expr)) - + def visit_if_stmt(self, node: IfStmt) -> Node: return IfStmt(self.nodes(node.expr), self.blocks(node.body), self.optional_block(node.else_body)) - + def visit_break_stmt(self, node: BreakStmt) -> Node: return BreakStmt() - + def visit_continue_stmt(self, node: ContinueStmt) -> Node: return ContinueStmt() - + def visit_pass_stmt(self, node: PassStmt) -> Node: return PassStmt() - + def visit_raise_stmt(self, node: RaiseStmt) -> Node: return RaiseStmt(self.optional_node(node.expr), self.optional_node(node.from_expr)) - + def visit_try_stmt(self, node: TryStmt) -> Node: return TryStmt(self.block(node.body), self.optional_names(node.vars), @@ -246,34 +249,34 @@ def visit_try_stmt(self, node: TryStmt) -> Node: self.blocks(node.handlers), self.optional_block(node.else_body), self.optional_block(node.finally_body)) - + def visit_with_stmt(self, node: WithStmt) -> Node: return WithStmt(self.nodes(node.expr), self.optional_names(node.name), self.block(node.body)) - + def visit_print_stmt(self, node: PrintStmt) -> Node: return PrintStmt(self.nodes(node.args), node.newline) - + def visit_int_expr(self, node: IntExpr) -> Node: return IntExpr(node.value) - + def visit_str_expr(self, node: StrExpr) -> Node: return StrExpr(node.value) - + def visit_bytes_expr(self, node: BytesExpr) -> Node: return BytesExpr(node.value) - + def visit_unicode_expr(self, node: UnicodeExpr) -> Node: return UnicodeExpr(node.value) - + def visit_float_expr(self, node: FloatExpr) -> Node: return FloatExpr(node.value) - + def visit_paren_expr(self, node: ParenExpr) -> Node: return ParenExpr(self.node(node.expr)) - + def visit_name_expr(self, node: NameExpr) -> Node: return self.duplicate_name(node) @@ -284,7 +287,7 @@ def duplicate_name(self, node: NameExpr) -> NameExpr: new.info = node.info self.copy_ref(new, node) return new - + def visit_member_expr(self, node: MemberExpr) -> Node: member = MemberExpr(self.node(node.expr), node.name) @@ -300,47 +303,47 @@ def copy_ref(self, new: RefExpr, original: RefExpr) -> None: if isinstance(target, Var): target = self.visit_var(target) new.node = target - new.is_def = original.is_def - + new.is_def = original.is_def + def visit_call_expr(self, node: CallExpr) -> Node: return CallExpr(self.node(node.callee), self.nodes(node.args), node.arg_kinds[:], node.arg_names[:], self.optional_node(node.analyzed)) - + def visit_op_expr(self, node: OpExpr) -> Node: new = OpExpr(node.op, self.node(node.left), self.node(node.right)) new.method_type = self.optional_type(node.method_type) return new - + def visit_cast_expr(self, node: CastExpr) -> Node: return CastExpr(self.node(node.expr), self.type(node.type)) - + def visit_super_expr(self, node: SuperExpr) -> Node: new = SuperExpr(node.name) new.info = node.info return new - + def visit_unary_expr(self, node: UnaryExpr) -> Node: new = UnaryExpr(node.op, self.node(node.expr)) new.method_type = self.optional_type(node.method_type) return new - + def visit_list_expr(self, node: ListExpr) -> Node: return ListExpr(self.nodes(node.items)) - + def visit_dict_expr(self, node: DictExpr) -> Node: return DictExpr([(self.node(key), self.node(value)) for key, value in node.items]) - + def visit_tuple_expr(self, node: TupleExpr) -> Node: return TupleExpr(self.nodes(node.items)) - + def visit_set_expr(self, node: SetExpr) -> Node: return SetExpr(self.nodes(node.items)) - + def visit_index_expr(self, node: IndexExpr) -> Node: new = IndexExpr(self.node(node.base), self.node(node.index)) if node.method_type: @@ -349,19 +352,19 @@ def visit_index_expr(self, node: IndexExpr) -> Node: new.analyzed = self.visit_type_application(node.analyzed) new.analyzed.set_line(node.analyzed.line) return new - + def visit_undefined_expr(self, node: UndefinedExpr) -> Node: return UndefinedExpr(self.type(node.type)) - + def visit_type_application(self, node: TypeApplication) -> TypeApplication: return TypeApplication(self.node(node.expr), self.types(node.types)) - + def visit_list_comprehension(self, node: ListComprehension) -> Node: generator = self.duplicate_generator(node.generator) generator.set_line(node.generator.line) return ListComprehension(generator) - + def visit_generator_expr(self, node: GeneratorExpr) -> Node: return self.duplicate_generator(node) @@ -371,17 +374,17 @@ def duplicate_generator(self, node: GeneratorExpr) -> GeneratorExpr: [self.optional_types(t) for t in node.types], [self.node(s) for s in node.sequences], self.optional_node(node.condition)) - + def visit_slice_expr(self, node: SliceExpr) -> Node: return SliceExpr(self.optional_node(node.begin_index), self.optional_node(node.end_index), self.optional_node(node.stride)) - + def visit_conditional_expr(self, node: ConditionalExpr) -> Node: return ConditionalExpr(self.node(node.cond), self.node(node.if_expr), self.node(node.else_expr)) - + def visit_type_var_expr(self, node: TypeVarExpr) -> Node: return TypeVarExpr(node.name(), node.fullname(), self.types(node.values)) @@ -391,23 +394,23 @@ def visit_ducktype_expr(self, node: DucktypeExpr) -> Node: def visit_disjointclass_expr(self, node: DisjointclassExpr) -> Node: return DisjointclassExpr(node.cls) - + def visit_coerce_expr(self, node: CoerceExpr) -> Node: raise RuntimeError('Not supported') - + def visit_type_expr(self, node: TypeExpr) -> Node: raise RuntimeError('Not supported') - + def visit_java_cast(self, node: JavaCast) -> Node: raise RuntimeError('Not supported') - + def visit_temp_node(self, node: TempNode) -> Node: return TempNode(self.type(node.type)) def node(self, node: Node) -> Node: new = node.accept(self) new.set_line(node.line) - return new + return new # Helpers # @@ -450,7 +453,7 @@ def optional_names(self, names: List[NameExpr]) -> List[NameExpr]: else: result.append(None) return result - + def type(self, type: Type) -> Type: # Override this method to transform types. return type From fa64f905b4b1c4e081a783399df6952c931ebbc0 Mon Sep 17 00:00:00 2001 From: Rock Neurotiko Date: Tue, 29 Jul 2014 02:46:27 +0200 Subject: [PATCH 03/12] Add YieldFromExpr class and the visitors. Change return to check if it's not yield from --- mypy/checker.py | 221 ++++++++++++++++--------------- mypy/checkexpr.py | 155 +++++++++++----------- mypy/icode.py | 37 +++--- mypy/nodes.py | 9 ++ mypy/output.py | 3 + mypy/parse.py | 7 +- mypy/pprinter.py | 65 ++++----- mypy/semanal.py | 5 +- mypy/stats.py | 13 +- mypy/strconv.py | 3 + mypy/test/data/parse-errors.test | 9 +- mypy/test/data/parse.test | 8 +- mypy/transform.py | 105 ++++++++------- mypy/traverser.py | 5 +- mypy/treetransform.py | 6 +- mypy/visitor.py | 2 + 16 files changed, 355 insertions(+), 298 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index c4db2aca75d4..b616cb1bc2f0 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1,7 +1,7 @@ """Mypy type checker.""" import itertools - + from typing import Undefined, Any, Dict, Set, List, cast, overload, Tuple, Function, typevar from mypy.errors import Errors @@ -16,7 +16,7 @@ TypeApplication, DictExpr, SliceExpr, FuncExpr, TempNode, SymbolTableNode, Context, ListComprehension, ConditionalExpr, GeneratorExpr, Decorator, SetExpr, PassStmt, TypeVarExpr, UndefinedExpr, PrintStmt, - LITERAL_TYPE, BreakStmt, ContinueStmt + LITERAL_TYPE, BreakStmt, ContinueStmt, YieldFromExpr ) from mypy.nodes import function_type, method_type from mypy import nodes @@ -302,7 +302,7 @@ class TypeChecker(NodeVisitor[Type]): binder = Undefined(ConditionalTypeBinder) # Helper for type checking expressions expr_checker = Undefined('mypy.checkexpr.ExpressionChecker') - + # Stack of function return types return_types = Undefined(List[Type]) # Type context for type inference @@ -313,11 +313,11 @@ class TypeChecker(NodeVisitor[Type]): function_stack = Undefined(List[FuncItem]) # Set to True on return/break/raise, False on blocks that can block any of them breaking_out = False - + globals = Undefined(SymbolTable) locals = Undefined(SymbolTable) modules = Undefined(Dict[str, MypyFile]) - + def __init__(self, errors: Errors, modules: Dict[str, MypyFile], pyversion: int = 3) -> None: """Construct a type checker. @@ -338,16 +338,16 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], self.type_context = [] self.dynamic_funcs = [] self.function_stack = [] - - def visit_file(self, file_node: MypyFile, path: str) -> None: + + def visit_file(self, file_node: MypyFile, path: str) -> None: """Type check a mypy file with the given path.""" self.errors.set_file(path) self.globals = file_node.names self.locals = None - + for d in file_node.defs: self.accept(d) - + def accept(self, node: Node, type_context: Type = None) -> Type: """Type check a node in the given type context.""" self.type_context.append(type_context) @@ -375,7 +375,7 @@ def accept_in_frame(self, node: Node, type_context: Type = None, # # Definitions # - + def visit_var_def(self, defn: VarDef) -> Type: """Type check a variable definition. @@ -409,11 +409,11 @@ def visit_var_def(self, defn: VarDef) -> Type: if (defn.kind == LDEF and not defn.items[0].type and not defn.is_top_level and not self.is_dynamic_function()): self.fail(messages.NEED_ANNOTATION_FOR_VAR, defn) - + def infer_local_variable_type(self, x, y, z): # TODO raise RuntimeError('Not implemented') - + def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> Type: num_abstract = 0 for fdef in defn.items: @@ -436,7 +436,7 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: if is_unsafe_overlapping_signatures(sig1, sig2): self.msg.overloaded_signatures_overlap(i + 1, j + 2, item.func) - + def visit_func_def(self, defn: FuncDef) -> Type: """Type check a function definition.""" self.check_func_item(defn, name=defn.name()) @@ -447,7 +447,7 @@ def visit_func_def(self, defn: FuncDef) -> Type: if not is_same_type(function_type(defn), function_type(defn.original_def)): self.msg.incompatible_conditional_function_def(defn) - + def check_func_item(self, defn: FuncItem, type_override: Callable = None, name: str = None) -> Type: @@ -463,10 +463,10 @@ def check_func_item(self, defn: FuncItem, self.function_stack.append(defn) self.dynamic_funcs.append(defn.type is None and not type_override) - + if fdef: self.errors.push_function(fdef.name()) - + typ = function_type(defn) if type_override: typ = type_override @@ -474,13 +474,13 @@ def check_func_item(self, defn: FuncItem, self.check_func_def(defn, typ, name) else: raise RuntimeError('Not supported') - + if fdef: self.errors.pop_function() - + self.dynamic_funcs.pop() self.function_stack.pop() - + def check_func_def(self, defn: FuncItem, typ: Callable, name: str) -> None: """Type check a function definition.""" # Expand type variables with value restrictions to ordinary types. @@ -489,7 +489,7 @@ def check_func_def(self, defn: FuncItem, typ: Callable, name: str) -> None: self.binder = ConditionalTypeBinder(self.basic_types) self.binder.push_frame() defn.expanded.append(item) - + # We may be checking a function definition or an anonymous # function. In the first case, set up another reference with the # precise type. @@ -546,7 +546,7 @@ def check_func_def(self, defn: FuncItem, typ: Callable, name: str) -> None: def check_reverse_op_method(self, defn: FuncItem, typ: Callable, method: str) -> None: """Check a reverse operator method such as __radd__.""" - + # If the argument of a reverse operator method such as __radd__ # does not define the corresponding non-reverse method such as __add__ # the return type of __radd__ may not reliably represent the value of @@ -567,8 +567,8 @@ def check_reverse_op_method(self, defn: FuncItem, typ: Callable, if method in ('__eq__', '__ne__'): # These are defined for all objects => can't cause trouble. - return - + return + # With 'Any' or 'object' return type we are happy, since any possible # return value is valid. ret_type = typ.ret_type @@ -581,7 +581,7 @@ def check_reverse_op_method(self, defn: FuncItem, typ: Callable, # in an error elsewhere. if len(typ.arg_types) <= 2: # TODO check self argument kind - + # Check for the issue described above. arg_type = typ.arg_types[1] other_method = nodes.normal_from_reverse_op[method] @@ -663,7 +663,7 @@ def check_overlapping_op_methods(self, [None] * 2, forward_type.ret_type, is_type_obj=False, - name=forward_type.name) + name=forward_type.name) reverse_args = reverse_type.arg_types reverse_tweaked = Callable([reverse_args[1], reverse_args[0]], [nodes.ARG_POS] * 2, @@ -671,7 +671,7 @@ def check_overlapping_op_methods(self, reverse_type.ret_type, is_type_obj=False, name=reverse_type.name) - + if is_unsafe_overlapping_signatures(forward_tweaked, reverse_tweaked): self.msg.operator_method_signatures_overlap( @@ -733,13 +733,13 @@ def expand_typevars(self, defn: FuncItem, return result else: return [(defn, typ)] - + def check_method_override(self, defn: FuncBase) -> None: """Check if function definition is compatible with base classes.""" # Check against definitions in base classes. for base in defn.info.mro[1:]: self.check_method_or_accessor_override_for_base(defn, base) - + def check_method_or_accessor_override_for_base(self, defn: FuncBase, base: TypeInfo) -> None: """Check if method definition is compatible with a base class.""" @@ -789,7 +789,7 @@ def check_method_override_for_base_with_name( assert original_type is not None self.msg.signature_incompatible_with_supertype( defn.name(), name, base.name(), defn) - + def check_override(self, override: FunctionLike, original: FunctionLike, name: str, name_in_super: str, supertype: str, node: Context) -> None: @@ -826,20 +826,20 @@ def check_override(self, override: FunctionLike, original: FunctionLike, # Give more detailed messages for the common case of both # signatures having the same number of arguments and no # overloads. - + coverride = cast(Callable, override) coriginal = cast(Callable, original) - + for i in range(len(coverride.arg_types)): if not is_equivalent(coriginal.arg_types[i], coverride.arg_types[i]): self.msg.argument_incompatible_with_supertype( i + 1, name, name_in_super, supertype, node) - + if not is_subtype(coverride.ret_type, coriginal.ret_type): self.msg.return_type_incompatible_with_supertype( name, name_in_super, supertype, node) - + def visit_class_def(self, defn: ClassDef) -> Type: """Type check a class definition.""" typ = defn.info @@ -902,11 +902,11 @@ def check_compatibility(self, name: str, base1: TypeInfo, if not ok: self.msg.base_class_definitions_incompatible(name, base1, base2, ctx) - + # # Statements # - + def visit_block(self, b: Block) -> Type: if b.is_unreachable: return None @@ -914,7 +914,7 @@ def visit_block(self, b: Block) -> Type: self.accept(s) if self.breaking_out: break - + def visit_assignment_stmt(self, s: AssignmentStmt) -> Type: """Type check an assignment statement. @@ -940,7 +940,7 @@ def check_assignments(self, lvalues: List[Node], index_lvalues = [] # type: List[IndexExpr] # Each may be None inferred = [] # type: List[Var] is_inferred = False - + for lv in lvalues: if self.is_definition(lv): is_inferred = True @@ -988,7 +988,7 @@ def check_assignments(self, lvalues: List[Node], if is_inferred: self.infer_variable_type(inferred, lvalues, self.accept(rvalue), rvalue) - + def is_definition(self, s: Node) -> bool: if isinstance(s, NameExpr): if s.is_def: @@ -1004,7 +1004,7 @@ def is_definition(self, s: Node) -> bool: elif isinstance(s, MemberExpr): return s.is_def return False - + def expand_lvalues(self, n: Node) -> List[Node]: if isinstance(n, TupleExpr): return self.expr_checker.unwrap_list(n.items) @@ -1014,7 +1014,7 @@ def expand_lvalues(self, n: Node) -> List[Node]: return self.expand_lvalues(n.expr) else: return [n] - + def infer_variable_type(self, names: List[Var], lvalues: List[Node], init_type: Type, context: Context) -> None: """Infer the type of initialized variables from initializer type.""" @@ -1026,10 +1026,10 @@ def infer_variable_type(self, names: List[Var], lvalues: List[Node], self.fail(messages.NEED_ANNOTATION_FOR_VAR, context) else: # Infer type of the target. - + # Make the type more general (strip away function names etc.). init_type = strip_type(init_type) - + if len(names) > 1: if isinstance(init_type, TupleType): # Initializer with a tuple type. @@ -1068,7 +1068,7 @@ def set_inferred_type(self, var: Var, lvalue: Node, type: Type) -> None: if var: var.type = type self.store_type(lvalue, type) - + def is_valid_inferred_type(self, typ: Type) -> bool: """Is an inferred type invalid? @@ -1175,7 +1175,7 @@ def check_single_assignment(self, return rvalue_type elif index_lvalue: self.check_indexed_assignment(index_lvalue, rvalue, context) - + def check_indexed_assignment(self, lvalue: IndexExpr, rvalue: Node, context: Context) -> None: """Type check indexed assignment base[index] = rvalue. @@ -1189,10 +1189,10 @@ def check_indexed_assignment(self, lvalue: IndexExpr, self.expr_checker.check_call(method_type, [lvalue.index, rvalue], [nodes.ARG_POS, nodes.ARG_POS], context) - + def visit_expression_stmt(self, s: ExpressionStmt) -> Type: self.accept(s.expr) - + def visit_return_stmt(self, s: ReturnStmt) -> Type: """Type check a return statement.""" self.breaking_out = True @@ -1218,7 +1218,7 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type: if (not isinstance(self.return_types[-1], Void) and not self.is_dynamic_function()): self.fail(messages.RETURN_VALUE_EXPECTED, s) - + def visit_yield_stmt(self, s: YieldStmt) -> Type: return_type = self.return_types[-1] if isinstance(return_type, Instance): @@ -1238,7 +1238,7 @@ def visit_yield_stmt(self, s: YieldStmt) -> Type: self.check_subtype(actual_item_type, expected_item_type, s, messages.INCOMPATIBLE_TYPES_IN_YIELD, 'actual type', 'expected type') - + def visit_if_stmt(self, s: IfStmt) -> Type: """Type check an if statement.""" broken = True @@ -1314,17 +1314,17 @@ def visit_operator_assignment_stmt(self, method = infer_operator_assignment_method(lvalue_type, s.op) rvalue_type, method_type = self.expr_checker.check_op( method, lvalue_type, s.rvalue, s) - + if isinstance(s.lvalue, IndexExpr): lv = cast(IndexExpr, s.lvalue) self.check_single_assignment(None, lv, s.rvalue, s.rvalue) else: if not is_subtype(rvalue_type, lvalue_type): self.msg.incompatible_operator_assignment(s.op, s) - + def visit_assert_stmt(self, s: AssertStmt) -> Type: self.accept(s.expr) - + def visit_raise_stmt(self, s: RaiseStmt) -> Type: """Type check a raise statement.""" self.breaking_out = True @@ -1342,7 +1342,7 @@ def visit_raise_stmt(self, s: RaiseStmt) -> Type: self.check_subtype(typ, self.named_type('builtins.BaseException'), s, messages.INVALID_EXCEPTION) - + def visit_try_stmt(self, s: TryStmt) -> Type: """Type check a try statement.""" completed_frames = List[Frame]() @@ -1432,7 +1432,7 @@ def visit_for_stmt(self, s: ForStmt) -> Type: def analyse_iterable_item_type(self, expr: Node) -> Type: """Analyse iterable expression and return iterator item type.""" iterable = self.accept(expr) - + self.check_not_void(iterable, expr) if isinstance(iterable, TupleType): joined = NoneTyp() # type: Type @@ -1490,7 +1490,7 @@ def analyse_index_variables(self, index: List[NameExpr], self.check_multi_assignment(t, [None] * len(index), self.temp_node(item_type), context, messages.INCOMPATIBLE_TYPES_IN_FOR) - + def visit_del_stmt(self, s: DelStmt) -> Type: if isinstance(s.expr, IndexExpr): e = cast(IndexExpr, s.expr) # Cast @@ -1502,7 +1502,7 @@ def visit_del_stmt(self, s: DelStmt) -> Type: else: s.expr.accept(self) return None - + def visit_decorator(self, e: Decorator) -> Type: e.func.accept(self) sig = function_type(e.func) # type: Type @@ -1532,26 +1532,31 @@ def visit_with_stmt(self, s: WithStmt) -> Type: def visit_print_stmt(self, s: PrintStmt) -> Type: for arg in s.args: - self.accept(arg) - + self.accept(arg) + # # Expressions # - + def visit_name_expr(self, e: NameExpr) -> Type: return self.expr_checker.visit_name_expr(e) - + def visit_paren_expr(self, e: ParenExpr) -> Type: return self.expr_checker.visit_paren_expr(e) - + def visit_call_expr(self, e: CallExpr) -> Type: result = self.expr_checker.visit_call_expr(e) self.breaking_out = False return result - + + def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: + result = self.expr_checker.visit_yield_from_expr(e) + self.breaking_out = False + return result + def visit_member_expr(self, e: MemberExpr) -> Type: return self.expr_checker.visit_member_expr(e) - + def visit_break_stmt(self, s: BreakStmt) -> Type: self.breaking_out = True self.binder.allow_jump(self.binder.loop_frames[-1]-1) @@ -1564,59 +1569,59 @@ def visit_continue_stmt(self, s: ContinueStmt) -> Type: def visit_int_expr(self, e: IntExpr) -> Type: return self.expr_checker.visit_int_expr(e) - + def visit_str_expr(self, e: StrExpr) -> Type: return self.expr_checker.visit_str_expr(e) - + def visit_bytes_expr(self, e: BytesExpr) -> Type: return self.expr_checker.visit_bytes_expr(e) - + def visit_unicode_expr(self, e: UnicodeExpr) -> Type: return self.expr_checker.visit_unicode_expr(e) - + def visit_float_expr(self, e: FloatExpr) -> Type: return self.expr_checker.visit_float_expr(e) - + def visit_op_expr(self, e: OpExpr) -> Type: return self.expr_checker.visit_op_expr(e) - + def visit_unary_expr(self, e: UnaryExpr) -> Type: return self.expr_checker.visit_unary_expr(e) - + def visit_index_expr(self, e: IndexExpr) -> Type: return self.expr_checker.visit_index_expr(e) - + def visit_cast_expr(self, e: CastExpr) -> Type: return self.expr_checker.visit_cast_expr(e) - + def visit_super_expr(self, e: SuperExpr) -> Type: return self.expr_checker.visit_super_expr(e) - + def visit_type_application(self, e: TypeApplication) -> Type: return self.expr_checker.visit_type_application(e) def visit_type_var_expr(self, e: TypeVarExpr) -> Type: # TODO Perhaps return a special type used for type variables only? return AnyType() - + def visit_list_expr(self, e: ListExpr) -> Type: return self.expr_checker.visit_list_expr(e) - + def visit_set_expr(self, e: SetExpr) -> Type: return self.expr_checker.visit_set_expr(e) - + def visit_tuple_expr(self, e: TupleExpr) -> Type: return self.expr_checker.visit_tuple_expr(e) - + def visit_dict_expr(self, e: DictExpr) -> Type: return self.expr_checker.visit_dict_expr(e) - + def visit_slice_expr(self, e: SliceExpr) -> Type: return self.expr_checker.visit_slice_expr(e) - + def visit_func_expr(self, e: FuncExpr) -> Type: return self.expr_checker.visit_func_expr(e) - + def visit_list_comprehension(self, e: ListComprehension) -> Type: return self.expr_checker.visit_list_comprehension(e) @@ -1631,11 +1636,11 @@ def visit_temp_node(self, e: TempNode) -> Type: def visit_conditional_expr(self, e: ConditionalExpr) -> Type: return self.expr_checker.visit_conditional_expr(e) - + # # Helpers # - + def check_subtype(self, subtype: Type, supertype: Type, context: Context, msg: str = messages.INCOMPATIBLE_TYPES, subtype_label: str = None, @@ -1654,7 +1659,7 @@ def check_subtype(self, subtype: Type, supertype: Type, context: Context, if extra_info: msg += ' (' + ', '.join(extra_info) + ')' self.fail(msg, context) - + def named_type(self, name: str) -> Instance: """Return an instance type with type given by the name and no type arguments. For example, named_type('builtins.object') @@ -1663,11 +1668,11 @@ def named_type(self, name: str) -> Instance: # Assume that the name refers to a type. sym = self.lookup_qualified(name) return Instance(cast(TypeInfo, sym.node), []) - + def named_type_if_exists(self, name: str) -> Type: """Return named instance type, or UnboundType if the type was not defined. - + This is used to simplify test cases by avoiding the need to define basic types not needed in specific test cases (tuple etc.). @@ -1678,7 +1683,7 @@ def named_type_if_exists(self, name: str) -> Type: return Instance(cast(TypeInfo, sym.node), []) except KeyError: return UnboundType(name) - + def named_generic_type(self, name: str, args: List[Type]) -> Instance: """Return an instance with the given name and type arguments. @@ -1691,30 +1696,30 @@ def lookup_typeinfo(self, fullname: str) -> TypeInfo: # Assume that the name refers to a class. sym = self.lookup_qualified(fullname) return cast(TypeInfo, sym.node) - + def type_type(self) -> Instance: """Return instance type 'type'.""" return self.named_type('builtins.type') - + def object_type(self) -> Instance: """Return instance type 'object'.""" return self.named_type('builtins.object') - + def bool_type(self) -> Instance: """Return instance type 'bool'.""" return self.named_type('builtins.bool') - + def str_type(self) -> Instance: """Return instance type 'str'.""" return self.named_type('builtins.str') - + def tuple_type(self) -> Type: """Return instance type 'tuple'.""" # We need the tuple for analysing member access. We want to be able to # do this even if tuple type is not available (useful in test cases), # so we return an unbound type if there is no tuple type. return self.named_type_if_exists('builtins.tuple') - + def check_type_equivalency(self, t1: Type, t2: Type, node: Context, msg: str = messages.INCOMPATIBLE_TYPES) -> None: """Generate an error if the types are not equivalent. The @@ -1722,14 +1727,14 @@ def check_type_equivalency(self, t1: Type, t2: Type, node: Context, """ if not is_equivalent(t1, t2): self.fail(msg, node) - + def store_type(self, node: Node, typ: Type) -> None: """Store the type of a node in the type map.""" self.type_map[node] = typ - + def is_dynamic_function(self) -> bool: return len(self.dynamic_funcs) > 0 and self.dynamic_funcs[-1] - + def lookup(self, name: str, kind: int) -> SymbolTableNode: """Look up a definition from the symbol table with the given name. TODO remove kind argument @@ -1745,7 +1750,7 @@ def lookup(self, name: str, kind: int) -> SymbolTableNode: if name in table: return table[name] raise KeyError('Failed lookup: {}'.format(name)) - + def lookup_qualified(self, name: str) -> SymbolTableNode: if '.' not in name: return self.lookup(name, GDEF) # FIX kind @@ -1755,13 +1760,13 @@ def lookup_qualified(self, name: str) -> SymbolTableNode: for i in range(1, len(parts) - 1): n = cast(MypyFile, ((n.names.get(parts[i], None).node))) return n.names[parts[-1]] - + def enter(self) -> None: self.locals = SymbolTable() - + def leave(self) -> None: self.locals = None - + def basic_types(self) -> BasicTypes: """Return a BasicTypes instance that contains primitive types that are needed for certain type operations (joins, for example). @@ -1769,26 +1774,26 @@ def basic_types(self) -> BasicTypes: return BasicTypes(self.object_type(), self.named_type('builtins.type'), self.named_type_if_exists('builtins.tuple'), self.named_type_if_exists('builtins.function')) - + def is_within_function(self) -> bool: """Are we currently type checking within a function? I.e. not at class body or at the top level. """ return self.return_types != [] - + def check_not_void(self, typ: Type, context: Context) -> None: """Generate an error if the type is Void.""" if isinstance(typ, Void): self.msg.does_not_return_value(typ, context) - + def temp_node(self, t: Type, context: Context = None) -> Node: """Create a temporary node with the given, fixed type.""" temp = TempNode(t) if context: temp.set_line(context.get_line()) return temp - + def fail(self, msg: str, context: Context) -> None: """Produce an error message.""" self.msg.fail(msg, context) @@ -1805,12 +1810,12 @@ def map_type_from_supertype(typ: Type, sub_info: TypeInfo, """Map type variables in a type defined in a supertype context to be valid in the subtype context. Assume that the result is unique; if more than one type is possible, return one of the alternatives. - + For example, assume - + class D(Generic[S]) ... class C(D[E[T]], Generic[T]) ... - + Now S in the context of D would be mapped to E[T] in the context of C. """ # Create the type of self in subtype, of form t[a1, ...]. @@ -1904,7 +1909,7 @@ class TypeTransformVisitor(TransformVisitor): def __init__(self, map: Dict[int, Type]) -> None: super().__init__() self.map = map - + def type(self, type: Type) -> Type: return expand_type(type, self.map) @@ -1947,10 +1952,10 @@ def is_unsafe_overlapping_signatures(signature: Type, other: Type) -> bool: if is_same_type(signature.ret_type, other.ret_type): return False # If the first signature has more general argument types, the - # latter will never be called + # latter will never be called if is_more_general_arg_prefix(signature, other): return False - return not is_more_precise_signature(signature, other) + return not is_more_precise_signature(signature, other) return True diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9ee66f2e3780..02fdf47ddd8f 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -12,7 +12,7 @@ OpExpr, UnaryExpr, IndexExpr, CastExpr, TypeApplication, ListExpr, TupleExpr, DictExpr, FuncExpr, SuperExpr, ParenExpr, SliceExpr, Context, ListComprehension, GeneratorExpr, SetExpr, MypyFile, Decorator, - UndefinedExpr, ConditionalExpr, TempNode, LITERAL_TYPE + UndefinedExpr, ConditionalExpr, TempNode, LITERAL_TYPE, YieldFromExpr ) from mypy.errors import Errors from mypy.nodes import function_type, method_type @@ -38,19 +38,19 @@ class ExpressionChecker: This class works closely together with checker.TypeChecker. """ - + # Some services are provided by a TypeChecker instance. chk = Undefined('mypy.checker.TypeChecker') # This is shared with TypeChecker, but stored also here for convenience. msg = Undefined(MessageBuilder) - + def __init__(self, chk: 'mypy.checker.TypeChecker', msg: MessageBuilder) -> None: """Construct an expression type checker.""" self.chk = chk self.msg = msg - + def visit_name_expr(self, e: NameExpr) -> Type: """Type check a name expression. @@ -58,7 +58,7 @@ def visit_name_expr(self, e: NameExpr) -> Type: """ result = self.analyse_ref_expr(e) return self.chk.narrow_type_from_binder(e, result) - + def analyse_ref_expr(self, e: RefExpr) -> Type: result = Undefined(Type) node = e.node @@ -97,7 +97,10 @@ def analyse_var_ref(self, var: Var, context: Context) -> Type: return var.type else: return val - + + def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: + return self.visit_call_expr(e.callee) + def visit_call_expr(self, e: CallExpr) -> Type: """Type check a call expression.""" if e.analyzed: @@ -109,7 +112,7 @@ def visit_call_expr(self, e: CallExpr) -> Type: # way we get a more precise callee in dynamically typed functions. callee_type = self.chk.type_map[e.callee] return self.check_call_expr_with_callee_type(callee_type, e) - + def check_call_expr_with_callee_type(self, callee_type: Type, e: CallExpr) -> Type: """Type check call expression. @@ -119,7 +122,7 @@ def check_call_expr_with_callee_type(self, callee_type: Type, """ return self.check_call(callee_type, e.args, e.arg_kinds, e, e.arg_names, callable_node=e.callee)[0] - + def check_call(self, callee: Type, args: List[Node], arg_kinds: List[int], context: Context, arg_names: List[str] = None, @@ -151,24 +154,24 @@ def check_call(self, callee: Type, args: List[Node], self.msg.cannot_instantiate_abstract_class( callee.type_object().name(), type.abstract_attributes, context) - + formal_to_actual = map_actuals_to_formals( arg_kinds, arg_names, callee.arg_kinds, callee.arg_names, lambda i: self.accept(args[i])) - + if callee.is_generic(): callee = self.infer_function_type_arguments_using_context( callee, context) callee = self.infer_function_type_arguments( callee, args, arg_kinds, formal_to_actual, context) - + arg_types = self.infer_arg_types_in_context2( callee, args, arg_kinds, formal_to_actual) self.check_argument_count(callee, arg_types, arg_kinds, arg_names, formal_to_actual, context) - + self.check_argument_types(arg_types, arg_kinds, callee, formal_to_actual, context, messages=arg_messages) @@ -183,7 +186,7 @@ def check_call(self, callee: Type, args: List[Node], self.msg.disable_errors() arg_types = self.infer_arg_types_in_context(None, args) self.msg.enable_errors() - + target = self.overload_call_target(arg_types, is_var_arg, callee, context, messages=arg_messages) @@ -202,7 +205,7 @@ def check_call(self, callee: Type, args: List[Node], callee) else: return self.msg.not_callable(callee, context), AnyType() - + def infer_arg_types_in_context(self, callee: Callable, args: List[Node]) -> List[Type]: """Infer argument expression types using a callable type as context. @@ -212,7 +215,7 @@ def infer_arg_types_in_context(self, callee: Callable, """ # TODO Always called with callee as None, i.e. empty context. res = [] # type: List[Type] - + fixed = len(args) if callee: fixed = min(fixed, callee.max_fixed_args()) @@ -234,7 +237,7 @@ def infer_arg_types_in_context(self, callee: Callable, else: res.append(arg_type) return res - + def infer_arg_types_in_context2( self, callee: Callable, args: List[Node], arg_kinds: List[int], formal_to_actual: List[List[int]]) -> List[Type]: @@ -257,7 +260,7 @@ def infer_arg_types_in_context2( if not t: res[i] = self.accept(args[i]) return res - + def infer_function_type_arguments_using_context( self, callable: Callable, error_context: Context) -> Callable: """Unify callable return type to type context to infer type vars. @@ -285,7 +288,7 @@ def infer_function_type_arguments_using_context( new_args.append(arg) return cast(Callable, self.apply_generic_arguments(callable, new_args, error_context)) - + def infer_function_type_arguments(self, callee_type: Callable, args: List[Node], arg_kinds: List[int], @@ -304,10 +307,10 @@ def infer_function_type_arguments(self, callee_type: Callable, # these errors can be safely ignored as the arguments will be # inferred again later. self.msg.disable_errors() - + arg_types = self.infer_arg_types_in_context2( callee_type, args, arg_kinds, formal_to_actual) - + self.msg.enable_errors() arg_pass_nums = self.get_arg_infer_passes( @@ -319,7 +322,7 @@ def infer_function_type_arguments(self, callee_type: Callable, pass1_args.append(None) else: pass1_args.append(arg) - + inferred_args = infer_function_type_arguments( callee_type, pass1_args, arg_kinds, formal_to_actual, self.chk.basic_types()) # type: List[Type] @@ -390,7 +393,7 @@ def get_arg_infer_passes(self, arg_types: List[Type], for j in formal_to_actual[i]: res[j] = 2 return res - + def apply_inferred_arguments(self, callee_type: Callable, inferred_args: List[Type], context: Context) -> Callable: @@ -464,7 +467,7 @@ def check_argument_count(self, callee: Callable, actual_types: List[Type], actual_kinds[formal_to_actual[i][0]] != nodes.ARG_NAMED): # Positional argument when expecting a keyword argument. self.msg.too_many_positional_arguments(callee, context) - + def check_argument_types(self, arg_types: List[Type], arg_kinds: List[int], callee: Callable, formal_to_actual: List[List[int]], @@ -494,7 +497,7 @@ def check_argument_types(self, arg_types: List[Type], arg_kinds: List[int], self.check_arg(actual_type, arg_type, callee.arg_types[i], actual + 1, callee, context, messages) - + # There may be some remaining tuple varargs items that haven't # been checked yet. Handle them. if (callee.arg_kinds[i] == nodes.ARG_STAR and @@ -508,7 +511,7 @@ def check_argument_types(self, arg_types: List[Type], arg_kinds: List[int], self.check_arg(actual_type, arg_type, callee.arg_types[i], actual + 1, callee, context, messages) - + def check_arg(self, caller_type: Type, original_caller_type: Type, callee_type: Type, n: int, callee: Callable, context: Context, messages: MessageBuilder) -> None: @@ -518,7 +521,7 @@ def check_arg(self, caller_type: Type, original_caller_type: Type, elif not is_subtype(caller_type, callee_type): messages.incompatible_argument(n, callee, original_caller_type, context) - + def overload_call_target(self, arg_types: List[Type], is_var_arg: bool, overload: Overloaded, context: Context, messages: MessageBuilder = None) -> Type: @@ -565,7 +568,7 @@ def overload_call_target(self, arg_types: List[Type], is_var_arg: bool, if self.match_signature_types(arg_types, is_var_arg, m): return m return match[0] - + def matches_signature_erased(self, arg_types: List[Type], is_var_arg: bool, callee: Callable) -> bool: """Determine whether arguments could match the signature at runtime. @@ -575,7 +578,7 @@ def matches_signature_erased(self, arg_types: List[Type], is_var_arg: bool, """ if not is_valid_argc(len(arg_types), False, callee): return False - + if is_var_arg: if not self.is_valid_var_arg(arg_types[-1]): return False @@ -596,7 +599,7 @@ def matches_signature_erased(self, arg_types: List[Type], is_var_arg: bool, self.erase(callee.arg_types[func_fixed])): return False return True - + def match_signature_types(self, arg_types: List[Type], is_var_arg: bool, callee: Callable) -> bool: """Determine whether arguments types match the signature. @@ -620,7 +623,7 @@ def match_signature_types(self, arg_types: List[Type], is_var_arg: bool, callee.arg_types[func_fixed]): return False return True - + def apply_generic_arguments(self, callable: Callable, types: List[Type], context: Context) -> Type: """Apply generic type arguments to a callable type. @@ -628,7 +631,7 @@ def apply_generic_arguments(self, callable: Callable, types: List[Type], For example, applying [int] to 'def [T] (T) -> T' results in 'def [-1:int] (int) -> int'. Here '[-1:int]' is an implicit bound type variable. - + Note that each type can be None; in this case, it will not be applied. """ tvars = callable.variables @@ -636,7 +639,7 @@ def apply_generic_arguments(self, callable: Callable, types: List[Type], self.msg.incompatible_type_application(len(tvars), len(types), context) return AnyType() - + # Check that inferred type variable values are compatible with allowed # values. Also, promote subtype values to allowed values. types = types[:] @@ -661,14 +664,14 @@ def apply_generic_arguments(self, callable: Callable, types: List[Type], # Apply arguments to argument types. arg_types = [expand_type(at, id_to_type) for at in callable.arg_types] - + bound_vars = [(tv.id, id_to_type[tv.id]) for tv in tvars if tv.id in id_to_type] # The callable may retain some type vars if only some were applied. remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type] - + return Callable(arg_types, callable.arg_kinds, callable.arg_names, @@ -678,7 +681,7 @@ def apply_generic_arguments(self, callable: Callable, types: List[Type], remaining_tvars, callable.bound_vars + bound_vars, callable.line, callable.repr) - + def apply_generic_arguments2(self, overload: Overloaded, types: List[Type], context: Context) -> Type: items = [] # type: List[Callable] @@ -690,7 +693,7 @@ def apply_generic_arguments2(self, overload: Overloaded, types: List[Type], # There was an error. return AnyType() return Overloaded(items) - + def visit_member_expr(self, e: MemberExpr) -> Type: """Visit member expression (of form e.id).""" result = self.analyse_ordinary_member_access(e, False) @@ -707,7 +710,7 @@ def analyse_ordinary_member_access(self, e: MemberExpr, return analyse_member_access(e.name, self.accept(e.expr), e, is_lvalue, False, self.chk.basic_types(), self.msg) - + def analyse_external_member_access(self, member: str, base_type: Type, context: Context) -> Type: """Analyse member access that is external, i.e. it cannot @@ -716,27 +719,27 @@ def analyse_external_member_access(self, member: str, base_type: Type, # TODO remove; no private definitions in mypy return analyse_member_access(member, base_type, context, False, False, self.chk.basic_types(), self.msg) - + def visit_int_expr(self, e: IntExpr) -> Type: """Type check an integer literal (trivial).""" return self.named_type('builtins.int') - + def visit_str_expr(self, e: StrExpr) -> Type: """Type check a string literal (trivial).""" return self.named_type('builtins.str') - + def visit_bytes_expr(self, e: BytesExpr) -> Type: """Type check a bytes literal (trivial).""" return self.named_type('builtins.bytes') - + def visit_unicode_expr(self, e: UnicodeExpr) -> Type: """Type check a unicode literal (trivial).""" return self.named_type('builtins.unicode') - + def visit_float_expr(self, e: FloatExpr) -> Type: """Type check a float literal (trivial).""" return self.named_type('builtins.float') - + def visit_op_expr(self, e: OpExpr) -> Type: """Type check a binary operator expression.""" if e.op == 'and' or e.op == 'or': @@ -783,7 +786,7 @@ def get_operator_method(self, op: str) -> str: return '__div__' else: return nodes.op_methods[op] - + def check_op_local(self, method: str, base_type: Type, arg: Node, context: Context, local_errors: MessageBuilder) -> Tuple[Type, Type]: """Type check a binary operation which maps to a method call. @@ -850,12 +853,12 @@ def get_reverse_op_method(self, method: str) -> str: return '__rdiv__' else: return nodes.reverse_op_methods[method] - + def check_boolean_op(self, e: OpExpr, context: Context) -> Type: """Type check a boolean operation ('and' or 'or').""" # A boolean operation can evaluate to either of the operands. - + # We use the current type context to guide the type inference of of # the left operand. We also use the left operand type to guide the type # inference of the right operand so that expressions such as @@ -863,7 +866,7 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: ctx = self.chk.type_context[-1] left_type = self.accept(e.left, ctx) right_type = self.accept(e.right, left_type) - + self.check_not_void(left_type, context) self.check_not_void(right_type, context) @@ -885,7 +888,7 @@ def check_list_multiply(self, e: OpExpr) -> Type: result, method_type = self.check_op('__mul__', left_type, e.right, e) e.method_type = method_type return result - + def visit_unary_expr(self, e: UnaryExpr) -> Type: """Type check an unary operation ('not', '-', '+' or '~').""" operand_type = self.accept(e.expr) @@ -960,13 +963,13 @@ def visit_cast_expr(self, expr: CastExpr) -> Type: if not self.is_valid_cast(source_type, target_type): self.msg.invalid_cast(target_type, source_type, expr) return target_type - + def is_valid_cast(self, source_type: Type, target_type: Type) -> bool: """Is a cast from source_type to target_type meaningful?""" return (isinstance(target_type, AnyType) or (not isinstance(source_type, Void) and not isinstance(target_type, Void))) - + def visit_type_application(self, tapp: TypeApplication) -> Type: """Type check a type application (expr[type, ...]).""" expr_type = self.accept(tapp.expr) @@ -985,7 +988,7 @@ def visit_type_application(self, tapp: TypeApplication) -> Type: new_type = AnyType() self.chk.type_map[tapp.expr] = new_type return new_type - + def visit_list_expr(self, e: ListExpr) -> Type: """Type check a list expression [...].""" return self.check_list_or_set_expr(e.items, 'builtins.list', '', @@ -1030,7 +1033,7 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: self.check_not_void(tt, e) items.append(tt) return TupleType(items) - + def visit_dict_expr(self, e: DictExpr) -> Type: # Translate into type checking a generic function call. tv1 = TypeVar('KT', -1, []) @@ -1055,7 +1058,7 @@ def visit_dict_expr(self, e: DictExpr) -> Type: return self.check_call(constructor, args, [nodes.ARG_POS] * len(args), e)[0] - + def visit_func_expr(self, e: FuncExpr) -> Type: """Type check lambda expression.""" inferred_type = self.infer_lambda_type_using_context(e) @@ -1085,28 +1088,28 @@ def infer_lambda_type_using_context(self, e: FuncExpr) -> Callable: ctx = self.chk.type_context[-1] if not ctx or not isinstance(ctx, Callable): return None - + # The context may have function type variables in it. We replace them # since these are the type variables we are ultimately trying to infer; # they must be considered as indeterminate. We use ErasedType since it # does not affect type inference results (it is for purposes like this # only). ctx = replace_func_type_vars(ctx, ErasedType()) - + callable_ctx = cast(Callable, ctx) - + if callable_ctx.arg_kinds != e.arg_kinds: # Incompatible context; cannot use it to infer types. self.chk.fail(messages.CANNOT_INFER_LAMBDA_TYPE, e) return None - + return callable_ctx - + def visit_super_expr(self, e: SuperExpr) -> Type: """Type check a super expression (non-lvalue).""" t = self.analyse_super(e, False) return t - + def analyse_super(self, e: SuperExpr, is_lvalue: bool) -> Type: """Type check a super expression.""" if e.info and e.info.bases: @@ -1118,11 +1121,11 @@ def analyse_super(self, e: SuperExpr, is_lvalue: bool) -> Type: else: # Invalid super. This has been reported by the semantic analyser. return AnyType() - + def visit_paren_expr(self, e: ParenExpr) -> Type: """Type check a parenthesised expression.""" return self.accept(e.expr, self.chk.type_context[-1]) - + def visit_slice_expr(self, e: SliceExpr) -> Type: for index in [e.begin_index, e.end_index, e.stride]: if index: @@ -1138,12 +1141,12 @@ def visit_list_comprehension(self, e: ListComprehension) -> Type: def visit_generator_expr(self, e: GeneratorExpr) -> Type: return self.check_generator_or_comprehension(e, 'typing.Iterator', '') - + def check_generator_or_comprehension(self, gen: GeneratorExpr, type_name: str, id_for_messages: str) -> Type: """Type check a generator expression or a list comprehension.""" - + self.chk.binder.push_frame() for index, sequence in zip(gen.indices, gen.sequences): sequence_type = self.chk.analyse_iterable_item_type(sequence) @@ -1174,41 +1177,41 @@ def visit_conditional_expr(self, e: ConditionalExpr) -> Type: if_type = self.accept(e.if_expr) else_type = self.accept(e.else_expr, context=if_type) return join.join_types(if_type, else_type, self.chk.basic_types()) - + # # Helpers # - + def accept(self, node: Node, context: Type = None) -> Type: """Type check a node. Alias for TypeChecker.accept.""" return self.chk.accept(node, context) - + def check_not_void(self, typ: Type, context: Context) -> None: """Generate an error if type is Void.""" self.chk.check_not_void(typ, context) - + def is_boolean(self, typ: Type) -> bool: """Is type compatible with bool?""" return is_subtype(typ, self.chk.bool_type()) - + def named_type(self, name: str) -> Instance: """Return an instance type with type given by the name and no type arguments. Alias for TypeChecker.named_type. """ return self.chk.named_type(name) - + def is_valid_var_arg(self, typ: Type) -> bool: """Is a type valid as a *args argument?""" return (isinstance(typ, TupleType) or is_subtype(typ, self.chk.named_generic_type('typing.Iterable', [AnyType()])) or isinstance(typ, AnyType)) - - def is_valid_keyword_var_arg(self, typ: Type) -> bool: + + def is_valid_keyword_var_arg(self, typ: Type) -> bool: """Is a type valid as a **kwargs argument?""" return is_subtype(typ, self.chk.named_generic_type( 'builtins.dict', [self.named_type('builtins.str'), AnyType()])) - + def has_non_method(self, typ: Type, member: str) -> bool: """Does type have a member variable / property with the given name?""" if isinstance(typ, Instance): @@ -1216,7 +1219,7 @@ def has_non_method(self, typ: Type, member: str) -> bool: typ.type.has_readable_member(member)) else: return False - + def has_member(self, typ: Type, member: str) -> bool: """Does type have member with the given name?""" # TODO TupleType => also consider tuple attributes @@ -1229,14 +1232,14 @@ def has_member(self, typ: Type, member: str) -> bool: return result else: return False - + def unwrap(self, e: Node) -> Node: """Unwrap parentheses from an expression node.""" if isinstance(e, ParenExpr): return self.unwrap(e.expr) else: return e - + def unwrap_list(self, a: List[Node]) -> List[Node]: """Unwrap parentheses from a list of expression nodes.""" r = List[Node]() @@ -1367,7 +1370,7 @@ class ArgInferSecondPassQuery(types.TypeQuery): The result is True if the type has a type variable in a callable return type anywhere. For example, the result for Function[[], T] is True if t is a type variable. - """ + """ def __init__(self) -> None: super().__init__(False, types.ANY_TYPE_STRATEGY) diff --git a/mypy/icode.py b/mypy/icode.py index 7e9ae962ea7e..d44a8570db22 100644 --- a/mypy/icode.py +++ b/mypy/icode.py @@ -7,7 +7,7 @@ FuncDef, IntExpr, MypyFile, ReturnStmt, NameExpr, WhileStmt, AssignmentStmt, Node, Var, OpExpr, Block, CallExpr, IfStmt, ParenExpr, UnaryExpr, ExpressionStmt, CoerceExpr, ClassDef, MemberExpr, TypeInfo, - VarDef, SuperExpr, IndexExpr, UndefinedExpr + VarDef, SuperExpr, IndexExpr, UndefinedExpr, YieldFromExpr ) from mypy import nodes from mypy.visitor import NodeVisitor @@ -180,7 +180,7 @@ class Return(Opcode): """Return from function (return rN).""" def __init__(self, retval: int) -> None: self.retval = retval - + def is_exit(self) -> bool: return True @@ -189,10 +189,10 @@ def __str__(self) -> str: class Branch(Opcode): - """Abstract base class for branch opcode.""" + """Abstract base class for branch opcode.""" true_block = Undefined # type: BasicBlock false_block = Undefined # type: BasicBlock - + def is_exit(self) -> bool: return True @@ -203,7 +203,7 @@ def invert(self) -> None: class IfOp(Branch): inversion = {'==': '!=', '!=': '==', '<': '>=', '<=': '>', '>': '<=', '>=': '<'} - + """Conditional operator branch (e.g. if r0 < r1 goto L2 else goto L3).""" def __init__(self, left: int, left_kind: int, @@ -232,7 +232,7 @@ def __str__(self) -> str: class IfR(Branch): """Conditional value branch (if rN goto LN else goto LN). """ negated = False - + def __init__(self, value: int, true_block: BasicBlock, false_block: BasicBlock) -> None: self.value = value @@ -259,7 +259,7 @@ class Goto(Opcode): """Unconditional jump (goto LN).""" def __init__(self, next_block: BasicBlock) -> None: self.next_block = next_block - + def is_exit(self) -> bool: return True @@ -307,7 +307,7 @@ class IcodeBuilder(NodeVisitor[int]): """Generate icode from a parse tree.""" generated = Undefined(Dict[str, FuncIcode]) - + # List of generated blocks in the current scope blocks = Undefined(List[BasicBlock]) # Current basic block @@ -335,9 +335,9 @@ def visit_mypy_file(self, mfile: MypyFile) -> int: # These module are special; their contents are currently all # built-in primitives. return -1 - + self.enter() - + # Initialize non-int global variables. for name in sorted(mfile.names): node = mfile.names[name].node @@ -349,7 +349,7 @@ def visit_mypy_file(self, mfile: MypyFile) -> int: tmp = self.alloc_register() self.add(SetRNone(tmp)) self.add(SetGR(v.fullname(), tmp)) - + for d in mfile.defs: d.accept(self) self.add_implicit_return() @@ -362,7 +362,7 @@ def visit_func_def(self, fdef: FuncDef) -> int: if fdef.name().endswith('*'): # Wrapper functions are not supported yet. return -1 - + self.enter() for arg in fdef.args: @@ -374,12 +374,12 @@ def visit_func_def(self, fdef: FuncDef) -> int: name = '%s.%s' % (fdef.info.name(), fdef.name()) else: name = fdef.name() - + self.generated[name] = FuncIcode(len(fdef.args), self.blocks, self.register_types) self.leave() - + return -1 def add_implicit_return(self, sig: FunctionLike = None) -> None: @@ -400,7 +400,7 @@ def visit_class_def(self, tdef: ClassDef) -> int: # Generate icode for the function that constructs an instance. self.make_class_constructor(tdef) - + return -1 def make_class_constructor(self, tdef: ClassDef) -> None: @@ -409,7 +409,7 @@ def make_class_constructor(self, tdef: ClassDef) -> None: init_argc = len(init.args) - 1 if init.info.fullname() == 'builtins.object': init = None - + self.enter() if init: args = [] # type: List[int] @@ -649,6 +649,9 @@ def visit_unary_expr(self, e: UnaryExpr) -> int: self.add(CallMethod(target, operand, method, inst.type, [])) return target + def visit_yield_from_expr(self, e: YieldFromExpr) -> int: + return self.visit_call_expr(e.callee) + def visit_call_expr(self, e: CallExpr) -> int: args = [] # type: List[int] for arg in e.args: @@ -818,7 +821,7 @@ def add_local(self, node: Var) -> int: reg = self.alloc_register(type) self.lvar_regs[node] = reg return reg - + def set_branches(self, branches: List[Branch], condition: bool, target: BasicBlock) -> None: """Set branch targets for the given condition (True or False). diff --git a/mypy/nodes.py b/mypy/nodes.py index 640fea44c7bb..1f694d2950eb 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -926,6 +926,15 @@ def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_call_expr(self) +class YieldFromExpr(Node): + callee = Undefined(Node) + + def __init__(self, callee: Node) -> None: + self.callee = callee + + def accept(self, visitor: NodeVisitor[T]) -> T: + return visitor.visit_yield_from_expr(self) + class IndexExpr(Node): """Index expression x[y]. diff --git a/mypy/output.py b/mypy/output.py index 522eae758c2b..59f0243fc785 100644 --- a/mypy/output.py +++ b/mypy/output.py @@ -359,6 +359,9 @@ def visit_slice_expr(self, o): self.token(o.repr.colon2) self.node(o.stride) + def visit_yield_from_expr(self, o): + self.visit_call_expr(o.callee) + def visit_call_expr(self, o): r = o.repr self.node(o.callee) diff --git a/mypy/parse.py b/mypy/parse.py index f591dfb1b736..4794b1c1e9a1 100644 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -23,7 +23,8 @@ TupleExpr, GeneratorExpr, ListComprehension, ListExpr, ConditionalExpr, DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, - UnaryExpr, FuncExpr, TypeApplication, PrintStmt, ImportBase, YieldFromStmt + UnaryExpr, FuncExpr, TypeApplication, PrintStmt, ImportBase, YieldFromStmt, + YieldFromExpr ) from mypy import nodes from mypy import noderepr @@ -740,6 +741,8 @@ def parse_return_stmt(self) -> ReturnStmt: expr = None # type: Node if not isinstance(self.current(), Break): expr = self.parse_expression() + if isinstance(expr, YieldFromExpr): #cant go a yield from expr + return None br = self.expect_break() node = ReturnStmt(expr) self.set_repr(node, noderepr.SimpleStmtRepr(return_tok, br)) @@ -788,7 +791,7 @@ def parse_yield_from_expr(self) -> CallExpr: y_tok = self.expect("yield") f_tok = self.expect("from") tok = self.parse_expression() # Here comes when yield from is assigned to a variable - return tok + return YieldFromExpr(tok) def parse_del_stmt(self) -> DelStmt: del_tok = self.expect('del') diff --git a/mypy/pprinter.py b/mypy/pprinter.py index a82dffb38403..22d4187a7f78 100644 --- a/mypy/pprinter.py +++ b/mypy/pprinter.py @@ -26,7 +26,7 @@ def __init__(self) -> None: def output(self) -> str: return ''.join(self.result) - + # # Definitions # @@ -34,7 +34,7 @@ def output(self) -> str: def visit_mypy_file(self, file: MypyFile) -> None: for d in file.defs: d.accept(self) - + def visit_class_def(self, tdef: ClassDef) -> None: self.string('class ') self.string(tdef.name) @@ -54,7 +54,7 @@ def visit_class_def(self, tdef: ClassDef) -> None: for d in tdef.defs.body: d.accept(self) self.dedent() - + def visit_func_def(self, fdef: FuncDef) -> None: # FIX varargs, default args, keyword args etc. ftyp = cast(Callable, fdef.type) @@ -74,7 +74,7 @@ def visit_func_def(self, fdef: FuncDef) -> None: self.string(') -> ') self.type(ftyp.ret_type) fdef.body.accept(self) - + def visit_var_def(self, vdef: VarDef) -> None: if vdef.items[0].name() not in nodes.implicit_module_attrs: self.string(vdef.items[0].name()) @@ -84,7 +84,7 @@ def visit_var_def(self, vdef: VarDef) -> None: self.string(' = ') self.node(vdef.init) self.string('\n') - + # # Statements # @@ -97,17 +97,17 @@ def visit_block(self, b): def visit_pass_stmt(self, o): self.string('pass\n') - + def visit_return_stmt(self, o): self.string('return ') if o.expr: self.node(o.expr) self.string('\n') - + def visit_expression_stmt(self, o): self.node(o.expr) self.string('\n') - + def visit_assignment_stmt(self, o): if isinstance(o.rvalue, CallExpr) and isinstance(o.rvalue.analyzed, TypeVarExpr): @@ -140,11 +140,11 @@ def visit_while_stmt(self, o): if o.else_body: self.string('else') self.node(o.else_body) - + # # Expressions # - + def visit_call_expr(self, o): if o.analyzed: o.analyzed.accept(self) @@ -157,16 +157,19 @@ def visit_call_expr(self, o): if i < len(o.args) - 1: self.string(', ') self.string(')') - + + def visit_yield_from_expr(self, o): + self.visit_call_expr(o.callee) + def visit_member_expr(self, o): self.node(o.expr) self.string('.' + o.name) if o.direct: self.string('!') - + def visit_name_expr(self, o): self.string(o.name) - + def visit_coerce_expr(self, o: CoerceExpr) -> None: self.string('{') self.full_type(o.target_type) @@ -176,14 +179,14 @@ def visit_coerce_expr(self, o: CoerceExpr) -> None: self.string(' ') self.node(o.expr) self.string('}') - + def visit_type_expr(self, o: TypeExpr) -> None: # Type expressions are only generated during transformation, so we must # use automatic formatting. self.string('<') self.full_type(o.type) self.string('>') - + def visit_index_expr(self, o): if o.analyzed: o.analyzed.accept(self) @@ -195,7 +198,7 @@ def visit_index_expr(self, o): def visit_int_expr(self, o): self.string(str(o.value)) - + def visit_str_expr(self, o): self.string(repr(o.value)) @@ -214,7 +217,7 @@ def visit_paren_expr(self, o): self.string('(') self.node(o.expr) self.string(')') - + def visit_super_expr(self, o): self.string('super().') self.string(o.name) @@ -233,7 +236,7 @@ def visit_type_application(self, o): def visit_undefined_expr(self, o): # Omit declared type as redundant. self.string('Undefined') - + # # Helpers # @@ -257,13 +260,13 @@ def last_output_char(self) -> str: if self.result: return self.result[-1][-1] return '' - + def type(self, t): """Pretty-print a type with erased type arguments.""" if t: v = TypeErasedPrettyPrintVisitor() self.string(t.accept(v)) - + def full_type(self, t): """Pretty-print a type, includingn type arguments.""" if t: @@ -279,19 +282,19 @@ class TypeErasedPrettyPrintVisitor(TypeVisitor[str]): Note that the translation does not preserve all information about the types, but this is fine since this is only used in test case output. """ - + def visit_any(self, t): return 'Any' - + def visit_void(self, t): return 'None' - + def visit_instance(self, t): return t.type.name() - + def visit_type_var(self, t): return 'Any*' - + def visit_runtime_type_var(self, t): v = PrettyPrintVisitor() t.node.accept(v) @@ -302,27 +305,27 @@ class TypePrettyPrintVisitor(TypeVisitor[str]): """Pretty-print types. Include type variables. - + Note that the translation does not preserve all information about the types, but this is fine since this is only used in test case output. """ - + def visit_any(self, t): return 'Any' - + def visit_void(self, t): return 'None' - + def visit_instance(self, t): s = t.type.name() if t.args: argstr = ', '.join([a.accept(self) for a in t.args]) s += '[%s]' % argstr return s - + def visit_type_var(self, t): return 'Any*' - + def visit_runtime_type_var(self, t): v = PrettyPrintVisitor() t.node.accept(v) diff --git a/mypy/semanal.py b/mypy/semanal.py index 2b99abd36ee3..0d4b1cb3168e 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -56,7 +56,7 @@ SymbolTableNode, TVAR, UNBOUND_TVAR, ListComprehension, GeneratorExpr, FuncExpr, MDEF, FuncBase, Decorator, SetExpr, UndefinedExpr, TypeVarExpr, StrExpr, PrintStmt, ConditionalExpr, DucktypeExpr, DisjointclassExpr, - ARG_POS, ARG_NAMED, MroError, type_aliases, YieldFromStmt + ARG_POS, ARG_NAMED, MroError, type_aliases, YieldFromStmt, YieldFromExpr ) from mypy.visitor import NodeVisitor from mypy.traverser import TraverserVisitor @@ -1141,6 +1141,9 @@ def visit_dict_expr(self, expr: DictExpr) -> None: def visit_paren_expr(self, expr: ParenExpr) -> None: expr.expr.accept(self) + def visit_yield_from_expr(self, expr: YieldFromExpr) -> None: + self.visit_call_expr(expr.callee) + def visit_call_expr(self, expr: CallExpr) -> None: """Analyze a call expression. diff --git a/mypy/stats.py b/mypy/stats.py index 199d2d48237a..0a87904adf4b 100644 --- a/mypy/stats.py +++ b/mypy/stats.py @@ -5,7 +5,7 @@ import re from typing import Any, Dict, List, cast, Tuple - + from mypy.traverser import TraverserVisitor from mypy.types import ( Type, AnyType, Instance, FunctionLike, TupleType, Void, TypeVar, @@ -14,7 +14,7 @@ from mypy import nodes from mypy.nodes import ( Node, FuncDef, TypeApplication, AssignmentStmt, NameExpr, CallExpr, - MemberExpr, OpExpr, IndexExpr, UnaryExpr + MemberExpr, OpExpr, IndexExpr, UnaryExpr, YieldFromExpr ) @@ -29,7 +29,7 @@ def __init__(self, inferred: bool, typemap: Dict[Node, Type] = None, self.inferred = inferred self.typemap = typemap self.all_nodes = all_nodes - + self.num_precise = 0 self.num_imprecise = 0 self.num_any = 0 @@ -46,9 +46,9 @@ def __init__(self, inferred: bool, typemap: Dict[Node, Type] = None, self.line_map = Dict[int, int]() self.output = List[str]() - + TraverserVisitor.__init__(self) - + def visit_func_def(self, o: FuncDef) -> None: self.line = o.line if len(o.expanded) > 1: @@ -109,6 +109,9 @@ def visit_name_expr(self, o: NameExpr) -> None: self.process_node(o) super().visit_name_expr(o) + def visit_yield_from_expr(self, o: YieldFromExpr) -> None: + self.visit_call_expr(o.callee) + def visit_call_expr(self, o: CallExpr) -> None: self.process_node(o) if o.analyzed: diff --git a/mypy/strconv.py b/mypy/strconv.py index c55862892f3f..e835048d083d 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -324,6 +324,9 @@ def visit_member_expr(self, o): return self.dump([o.expr, self.pretty_name(o.name, o.kind, o.fullname, o.is_def)], o) + def visit_yield_from_expr(self, o): + return self.dump([self.visit_call_expr(o.callee)], o) + def visit_call_expr(self, o): if o.analyzed: return o.analyzed.accept(self) diff --git a/mypy/test/data/parse-errors.test b/mypy/test/data/parse-errors.test index b9df789efac7..e24eea980545 100644 --- a/mypy/test/data/parse-errors.test +++ b/mypy/test/data/parse-errors.test @@ -346,4 +346,11 @@ def f(): yield from [out] file: In function "f": -file, line 2: Parse error before end of line \ No newline at end of file +file, line 2: Parse error before end of line + +[case testYielFromAfterReturn] +def f(): + return yield from h() +[out] +file: In function "f": +file, line 2: Parse error before end of line diff --git a/mypy/test/data/parse.test b/mypy/test/data/parse.test index 4dce5a5cb526..5679d686ab73 100644 --- a/mypy/test/data/parse.test +++ b/mypy/test/data/parse.test @@ -1318,10 +1318,10 @@ MypyFile:1( Block:1( AssignmentStmt:2( NameExpr(a) - CallExpr:2( - NameExpr(h) - Args()))))) - + YieldFromExpr:2( + CallExpr:2( + NameExpr(h) + Args())))))) [case testDel] del x diff --git a/mypy/transform.py b/mypy/transform.py index 48d9b5fa83a0..9b1d299bf5a7 100644 --- a/mypy/transform.py +++ b/mypy/transform.py @@ -17,7 +17,7 @@ Node, MypyFile, TypeInfo, ClassDef, VarDef, FuncDef, Var, ReturnStmt, AssignmentStmt, IfStmt, WhileStmt, MemberExpr, NameExpr, MDEF, CallExpr, SuperExpr, TypeExpr, CastExpr, OpExpr, CoerceExpr, GDEF, - SymbolTableNode, IndexExpr, function_type + SymbolTableNode, IndexExpr, function_type, YieldFromExpr ) from mypy.traverser import TraverserVisitor from mypy.types import Type, AnyType, Callable, TypeVarDef, Instance @@ -38,7 +38,7 @@ class DyncheckTransformVisitor(TraverserVisitor): all non-trivial coercions explicit. Also generate generic wrapper classes for coercions between generic types and wrapper methods for overrides and for more efficient access from dynamically typed code. - + This visitor modifies the parse tree in-place. """ @@ -46,23 +46,23 @@ class DyncheckTransformVisitor(TraverserVisitor): modules = Undefined(Dict[str, MypyFile]) is_pretty = False type_tf = Undefined(TypeTransformer) - + # Stack of function return types return_types = Undefined(List[Type]) # Stack of dynamically typed function flags dynamic_funcs = Undefined(List[bool]) - + # Associate a Node with its start end line numbers. line_map = Undefined(Dict[Node, Tuple[int, int]]) - + is_java = False - + # The current type context (or None if not within a type). _type_context = None # type: TypeInfo - + def type_context(self) -> TypeInfo: return self._type_context - + def __init__(self, type_map: Dict[Node, Type], modules: Dict[str, MypyFile], is_pretty: bool, is_java: bool = False) -> None: @@ -74,11 +74,11 @@ def __init__(self, type_map: Dict[Node, Type], self.modules = modules self.is_pretty = is_pretty self.is_java = is_java - + # # Transform definitions # - + def visit_mypy_file(self, o: MypyFile) -> None: """Transform an file.""" res = [] # type: List[Node] @@ -91,7 +91,7 @@ def visit_mypy_file(self, o: MypyFile) -> None: d.accept(self) res.append(d) o.defs = res - + def visit_var_def(self, o: VarDef) -> None: """Transform a variable definition in-place. @@ -99,7 +99,7 @@ def visit_var_def(self, o: VarDef) -> None: transformed in TypeTransformer. """ super().visit_var_def(o) - + if o.init is not None: if o.items[0].type: t = o.items[0].type @@ -107,7 +107,7 @@ def visit_var_def(self, o: VarDef) -> None: t = AnyType() o.init = self.coerce(o.init, t, self.get_type(o.init), self.type_context()) - + def visit_func_def(self, fdef: FuncDef) -> None: """Transform a global function definition in-place. @@ -116,7 +116,7 @@ def visit_func_def(self, fdef: FuncDef) -> None: """ self.prepend_generic_function_tvar_args(fdef) self.transform_function_body(fdef) - + def transform_function_body(self, fdef: FuncDef) -> None: """Transform the body of a function.""" self.dynamic_funcs.append(fdef.is_implicit) @@ -125,14 +125,14 @@ def transform_function_body(self, fdef: FuncDef) -> None: super().visit_func_def(fdef) self.return_types.pop() self.dynamic_funcs.pop() - + def prepend_generic_function_tvar_args(self, fdef: FuncDef) -> None: """Add implicit function type variable arguments if fdef is generic.""" sig = cast(Callable, function_type(fdef)) tvars = sig.variables if not fdef.type: fdef.type = sig - + tv = [] # type: List[Var] ntvars = len(tvars) if fdef.is_method(): @@ -150,20 +150,20 @@ def prepend_generic_function_tvar_args(self, fdef: FuncDef) -> None: AnyType()) fdef.args = tv + fdef.args fdef.init = List[AssignmentStmt]([None]) * ntvars + fdef.init - + # # Transform statements - # - + # + def transform_block(self, block: List[Node]) -> None: for stmt in block: stmt.accept(self) - + def visit_return_stmt(self, s: ReturnStmt) -> None: super().visit_return_stmt(s) s.expr = self.coerce(s.expr, self.return_types[-1], self.get_type(s.expr), self.type_context()) - + def visit_assignment_stmt(self, s: AssignmentStmt) -> None: super().visit_assignment_stmt(s) if isinstance(s.lvalues[0], IndexExpr): @@ -177,23 +177,23 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: lvalue_type = method_callable.arg_types[1] else: lvalue_type = self.get_type(s.lvalues[0]) - + s.rvalue = self.coerce2(s.rvalue, lvalue_type, self.get_type(s.rvalue), self.type_context()) - + # # Transform expressions # - + def visit_member_expr(self, e: MemberExpr) -> None: super().visit_member_expr(e) - + typ = self.get_type(e.expr) - + if self.dynamic_funcs[-1]: e.expr = self.coerce_to_dynamic(e.expr, typ, self.type_context()) typ = AnyType() - + if isinstance(typ, Instance): # Reference to a statically-typed method variant with the suffix # derived from the base object type. @@ -202,7 +202,7 @@ def visit_member_expr(self, e: MemberExpr) -> None: # Reference to a dynamically-typed method variant. suffix = self.dynamic_suffix() e.name += suffix - + def visit_name_expr(self, e: NameExpr) -> None: super().visit_name_expr(e) if e.kind == MDEF and isinstance(e.node, FuncDef): @@ -211,22 +211,25 @@ def visit_name_expr(self, e: NameExpr) -> None: e.name += suffix # Update representation to have the correct name. prefix = e.repr.components[0].pre - + def get_member_reference_suffix(self, name: str, info: TypeInfo) -> str: if info.has_method(name): fdef = cast(FuncDef, info.get_method(name)) return self.type_suffix(fdef) else: return '' - + + def visit_yield_from_expr(self, e: YieldFromExpr) -> None: + self.visit_call_expr(e.callee) + def visit_call_expr(self, e: CallExpr) -> None: if e.analyzed: # This is not an ordinary call. e.analyzed.accept(self) return - + super().visit_call_expr(e) - + # Do no coercions if this is a call to debugging facilities. if self.is_debugging_call_expr(e): return @@ -243,7 +246,7 @@ def visit_call_expr(self, e: CallExpr) -> None: e.args[i] = self.coerce2(e.args[i], arg_type, self.get_type(e.args[i]), self.type_context()) - + # Prepend type argument values to the call as needed. if isinstance(ctype, Callable) and cast(Callable, ctype).bound_vars != []: @@ -258,7 +261,7 @@ def visit_call_expr(self, e: CallExpr) -> None: (cast(SuperExpr, e.callee)).name == '__init__')): # Filter instance type variables; only include function tvars. bound_vars = [(id, t) for id, t in bound_vars if id < 0] - + args = [] # type: List[Node] for i in range(len(bound_vars)): # Compile type variables to runtime type variable expressions. @@ -268,16 +271,16 @@ def visit_call_expr(self, e: CallExpr) -> None: self.is_java) args.append(TypeExpr(tv)) e.args = args + e.args - + def is_debugging_call_expr(self, e): return isinstance(e.callee, NameExpr) and e.callee.name in ['__print'] - + def visit_cast_expr(self, e: CastExpr) -> None: super().visit_cast_expr(e) if isinstance(self.get_type(e), AnyType): e.expr = self.coerce(e.expr, AnyType(), self.get_type(e.expr), self.type_context()) - + def visit_op_expr(self, e: OpExpr) -> None: super().visit_op_expr(e) if e.op in ['and', 'or']: @@ -325,18 +328,18 @@ def visit_index_expr(self, e: IndexExpr) -> None: method_callable = cast(Callable, method_type) e.index = self.coerce(e.index, method_callable.arg_types[0], self.get_type(e.index), self.type_context()) - + # # Helpers - # - + # + def get_type(self, node: Node) -> Type: """Return the type of a node as reported by the type checker.""" return self.type_map[node] - + def set_type(self, node: Node, typ: Type) -> None: self.type_map[node] = typ - + def type_suffix(self, fdef: FuncDef, info: TypeInfo = None) -> str: """Return the suffix for a mangled name. @@ -355,20 +358,20 @@ def type_suffix(self, fdef: FuncDef, info: TypeInfo = None) -> str: return '`' + info.name() else: return '__' + info.name() - + def dynamic_suffix(self) -> str: """Return the suffix of the dynamic wrapper of a method or class.""" return dynamic_suffix(self.is_pretty) - + def wrapper_class_suffix(self) -> str: """Return the suffix of a generic wrapper class.""" return '**' - + def coerce(self, expr: Node, target_type: Type, source_type: Type, context: TypeInfo, is_wrapper_class: bool = False) -> Node: return coerce(expr, target_type, source_type, context, is_wrapper_class, self.is_java) - + def coerce2(self, expr: Node, target_type: Type, source_type: Type, context: TypeInfo, is_wrapper_class: bool = False) -> Node: """Create coercion from source_type to target_type. @@ -384,7 +387,7 @@ def coerce2(self, expr: Node, target_type: Type, source_type: Type, else: return self.coerce(expr, target_type, source_type, context, is_wrapper_class) - + def coerce_to_dynamic(self, expr: Node, source_type: Type, context: TypeInfo) -> Node: if isinstance(source_type, AnyType): @@ -392,7 +395,7 @@ def coerce_to_dynamic(self, expr: Node, source_type: Type, source_type = translate_runtime_type_vars_in_context( source_type, context, self.is_java) return CoerceExpr(expr, AnyType(), source_type, False) - + def add_line_mapping(self, orig_node: Node, new_node: Node) -> None: """Add a line mapping for a wrapper. @@ -403,13 +406,13 @@ def add_line_mapping(self, orig_node: Node, new_node: Node) -> None: start_line = orig_node.line end_line = start_line # TODO use real end line self.line_map[new_node] = (start_line, end_line) - + def named_type(self, name: str) -> Instance: # TODO combine with checker # Assume that the name refers to a type. sym = self.lookup(name, GDEF) return Instance(cast(TypeInfo, sym.node), []) - + def lookup(self, fullname: str, kind: int) -> SymbolTableNode: # TODO combine with checker # TODO remove kind argument @@ -418,7 +421,7 @@ def lookup(self, fullname: str, kind: int) -> SymbolTableNode: for i in range(1, len(parts) - 1): n = cast(MypyFile, ((n.names.get(parts[i], None).node))) return n.names[parts[-1]] - + def object_member_name(self) -> str: if self.is_java: return '__o_{}'.format(self.type_context().name()) diff --git a/mypy/traverser.py b/mypy/traverser.py index 6d9cde91be57..40d9ca73973a 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -10,7 +10,7 @@ TryStmt, WithStmt, ParenExpr, MemberExpr, OpExpr, SliceExpr, CastExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr, GeneratorExpr, ListComprehension, ConditionalExpr, TypeApplication, - FuncExpr, OverloadedFuncDef, YieldFromStmt + FuncExpr, OverloadedFuncDef, YieldFromStmt, YieldFromExpr ) @@ -150,6 +150,9 @@ def visit_paren_expr(self, o: ParenExpr) -> T: def visit_member_expr(self, o: MemberExpr) -> T: o.expr.accept(self) + def visit_yield_from_expr(self, o: YieldFromExpr) -> T: + self.visit_call_expr(o.callee) + def visit_call_expr(self, o: CallExpr) -> T: for a in o.args: a.accept(self) diff --git a/mypy/treetransform.py b/mypy/treetransform.py index f6cde34880f9..1699fe69e77d 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -16,7 +16,8 @@ UnicodeExpr, FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, UnaryExpr, FuncExpr, TypeApplication, PrintStmt, SymbolTable, RefExpr, UndefinedExpr, TypeVarExpr, DucktypeExpr, - DisjointclassExpr, CoerceExpr, TypeExpr, JavaCast, TempNode, YieldFromStmt + DisjointclassExpr, CoerceExpr, TypeExpr, JavaCast, TempNode, YieldFromStmt, + YieldFromExpr ) from mypy.types import Type from mypy.visitor import NodeVisitor @@ -305,6 +306,9 @@ def copy_ref(self, new: RefExpr, original: RefExpr) -> None: new.node = target new.is_def = original.is_def + def visit_yield_from_expr(self, node: YieldFromExpr) -> Node: + return YieldFromExpr(self.node(node.callee)) + def visit_call_expr(self, node: CallExpr) -> Node: return CallExpr(self.node(node.callee), self.nodes(node.args), diff --git a/mypy/visitor.py b/mypy/visitor.py index 50bc20cc6ff3..ead04a06b456 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -110,6 +110,8 @@ def visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> T: pass def visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> T: pass + def visit_yield_from_expr(self, o: 'mypy.nodes.YieldFromExpr') -> T: + pass def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T: pass def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> T: From 55123279f8534bf8e2cf5cc6e037e4a5ab69d053 Mon Sep 17 00:00:00 2001 From: Rock Neurotiko Date: Wed, 30 Jul 2014 02:33:18 +0200 Subject: [PATCH 04/12] Errors Correct ======== "yield lambda" error correct ------- **parse.py** -> - Modify visit_yield: Check if the next token is "from" before to "expect" it. "Can't yield from a NameExpr" error correct -------- **checkexpr.py** -> - Check if expression of yield from is a CallFunc or a NameExpr and visit the correct one. **icode.py** -> - Check if expression of yield from is a CallFunc or a NameExpr and visit the correct one. **output.py** -> - Check if expression of yield from is a CallFunc or a NameExpr and visit the correct one. **pprinter.py** -> - Check if expression of yield from is a CallFunc or a NameExpr and visit the correct one. **semanal.py** -> - Check if expression of yield from is a CallFunc or a NameExpr and visit the correct one. **stats.py** -> - Check if expression of yield from is a CallFunc or a NameExpr and visit the correct one. **strconv.py** -> - Check if expression of yield from is a CallFunc or a NameExpr and visit the correct one. **transform.py** -> - Check if expression of yield from is a CallFunc or a NameExpr and visit the correct one. **traverser.py** -> - Add NameExpr to import - Check if expression of yield from is a CallFunc or a NameExpr and visit the correct one. --- mypy/checkexpr.py | 7 ++++++- mypy/icode.py | 5 ++++- mypy/output.py | 5 ++++- mypy/parse.py | 6 +++--- mypy/pprinter.py | 5 ++++- mypy/semanal.py | 5 ++++- mypy/stats.py | 5 ++++- mypy/strconv.py | 5 ++++- mypy/transform.py | 5 ++++- mypy/traverser.py | 8 ++++++-- 10 files changed, 43 insertions(+), 13 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 3e7bf99ed32d..91a7a8776068 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -99,7 +99,12 @@ def analyse_var_ref(self, var: Var, context: Context) -> Type: return val def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: - return self.visit_call_expr(e.callee) + if isinstance(e.callee, CallExpr): + return self.visit_call_expr(e.callee) + elif isinstance(e.callee, NameExpr): + return self.visit_name_expr(e.callee) + + # return self.visit_call_expr(e.callee) def visit_call_expr(self, e: CallExpr) -> Type: """Type check a call expression.""" diff --git a/mypy/icode.py b/mypy/icode.py index d44a8570db22..e64f80539787 100644 --- a/mypy/icode.py +++ b/mypy/icode.py @@ -650,7 +650,10 @@ def visit_unary_expr(self, e: UnaryExpr) -> int: return target def visit_yield_from_expr(self, e: YieldFromExpr) -> int: - return self.visit_call_expr(e.callee) + if isinstance(e.callee, CallExpr): + return self.visit_call_expr(e.callee) + elif isinstance(e.callee, NameExpr): + return self.visit_name_expr(e.callee) def visit_call_expr(self, e: CallExpr) -> int: args = [] # type: List[int] diff --git a/mypy/output.py b/mypy/output.py index 07802945ef28..d51dd777c16d 100644 --- a/mypy/output.py +++ b/mypy/output.py @@ -360,7 +360,10 @@ def visit_slice_expr(self, o): self.node(o.stride) def visit_yield_from_expr(self, o): - self.visit_call_expr(o.callee) + if isinstance(o.callee, CallExpr): + self.visit_call_expr(o.callee) + elif isinstance(o.callee, NameExpr): + self.visit_name_expr(o.callee) def visit_call_expr(self, o): r = o.repr diff --git a/mypy/parse.py b/mypy/parse.py index fac629f7334a..ba7f97399f5f 100644 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -776,9 +776,9 @@ def parse_yield_stmt(self) -> YieldStmt: expr = None # type: Node node = YieldStmt(expr) if not isinstance(self.current(), Break): - if isinstance(self.current(), Keyword): + if isinstance(self.current(), Keyword) and self.current_str() == "from": # Not go if it's not from from_tok = self.expect("from") - expr = self.parse_expression() # Here comes when yield from is not assigned + expr = self.parse_expression() # Here comes when yield from is not assigned node = YieldFromStmt(expr) else: expr = self.parse_expression() @@ -790,7 +790,7 @@ def parse_yield_stmt(self) -> YieldStmt: def parse_yield_from_expr(self) -> CallExpr: y_tok = self.expect("yield") f_tok = self.expect("from") - tok = self.parse_expression() # Here comes when yield from is assigned to a variable + tok = self.parse_expression() # Here comes when yield from is assigned to a variable return YieldFromExpr(tok) def parse_del_stmt(self) -> DelStmt: diff --git a/mypy/pprinter.py b/mypy/pprinter.py index 22d4187a7f78..6baa232e9937 100644 --- a/mypy/pprinter.py +++ b/mypy/pprinter.py @@ -159,7 +159,10 @@ def visit_call_expr(self, o): self.string(')') def visit_yield_from_expr(self, o): - self.visit_call_expr(o.callee) + if isinstance(o.callee, CallExpr): + self.visit_call_expr(o.callee) + elif isinstance(o.callee, NameExpr): + self.visit_name_expr(o.callee) def visit_member_expr(self, o): self.node(o.expr) diff --git a/mypy/semanal.py b/mypy/semanal.py index 62a8593661c8..b47c5c985c8e 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1142,7 +1142,10 @@ def visit_paren_expr(self, expr: ParenExpr) -> None: expr.expr.accept(self) def visit_yield_from_expr(self, expr: YieldFromExpr) -> None: - self.visit_call_expr(expr.callee) + if not self.is_func_scope(): # not sure + self.fail("'yield from' outside function", s) + if expr.callee: + expr.callee.accept(self) def visit_call_expr(self, expr: CallExpr) -> None: """Analyze a call expression. diff --git a/mypy/stats.py b/mypy/stats.py index 0a87904adf4b..1889de1ed4aa 100644 --- a/mypy/stats.py +++ b/mypy/stats.py @@ -110,7 +110,10 @@ def visit_name_expr(self, o: NameExpr) -> None: super().visit_name_expr(o) def visit_yield_from_expr(self, o: YieldFromExpr) -> None: - self.visit_call_expr(o.callee) + if isinstance(o.callee, CallExpr): + self.visit_call_expr(o.callee) + elif isinstance(o.callee, NameExpr): + self.visit_name_expr(o.callee) def visit_call_expr(self, o: CallExpr) -> None: self.process_node(o) diff --git a/mypy/strconv.py b/mypy/strconv.py index 95094a029de1..714c8255bc40 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -325,7 +325,10 @@ def visit_member_expr(self, o): o.is_def)], o) def visit_yield_from_expr(self, o): - return self.dump([self.visit_call_expr(o.callee)], o) + if isinstance(o.callee, mypy.nodes.CallExpr): + return self.dump([self.visit_call_expr(o.callee)], o) + elif isinstance(o.callee, mypy.nodes.NameExpr): + return self.dump([self.visit_name_expr(o.callee)], o) def visit_call_expr(self, o): if o.analyzed: diff --git a/mypy/transform.py b/mypy/transform.py index 9b1d299bf5a7..06cf170cce05 100644 --- a/mypy/transform.py +++ b/mypy/transform.py @@ -220,7 +220,10 @@ def get_member_reference_suffix(self, name: str, info: TypeInfo) -> str: return '' def visit_yield_from_expr(self, e: YieldFromExpr) -> None: - self.visit_call_expr(e.callee) + if isinstance(e.callee, CallExpr): + self.visit_call_expr(e.callee) + elif isinstance(e.callee, NameExpr): + self.visit_name_expr(e.callee) def visit_call_expr(self, e: CallExpr) -> None: if e.analyzed: diff --git a/mypy/traverser.py b/mypy/traverser.py index aef50491d357..f9c86389ce98 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -10,7 +10,7 @@ TryStmt, WithStmt, ParenExpr, MemberExpr, OpExpr, SliceExpr, CastExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr, GeneratorExpr, ListComprehension, ConditionalExpr, TypeApplication, - FuncExpr, OverloadedFuncDef, YieldFromStmt, YieldFromExpr + FuncExpr, OverloadedFuncDef, YieldFromStmt, YieldFromExpr, NameExpr ) @@ -151,7 +151,11 @@ def visit_member_expr(self, o: MemberExpr) -> T: o.expr.accept(self) def visit_yield_from_expr(self, o: YieldFromExpr) -> T: - self.visit_call_expr(o.callee) + if isinstance(o.callee, CallExpr): + self.visit_call_expr(o.callee) + elif isinstance(o.callee, NameExpr): + self.visit_name_expr(o.callee) + def visit_call_expr(self, o: CallExpr) -> T: for a in o.args: From 873b1e5243e127e2c3a97d0aa7cedd00df0a4c8e Mon Sep 17 00:00:00 2001 From: Rock Neurotiko Date: Sat, 9 Aug 2014 15:35:53 +0200 Subject: [PATCH 05/12] Yield from stmt and expr wit type checks --- mypy/checker.py | 92 +++++++++++++- mypy/checkexpr.py | 7 +- mypy/icode.py | 5 +- mypy/messages.py | 120 ++++++++++-------- mypy/nodes.py | 7 +- mypy/output.py | 5 +- mypy/parse.py | 22 +++- mypy/pprinter.py | 6 +- mypy/semanal.py | 10 +- mypy/stats.py | 6 +- mypy/strconv.py | 8 +- mypy/test/data/check-statements.test | 92 +++++++++++++- mypy/transform.py | 6 +- mypy/traverser.py | 8 +- mypy/treetransform.py | 2 +- stubs/3.4/asyncio/__init__.py | 8 +- stubs/3.4/asyncio/events.py | 46 +++++-- stubs/3.4/asyncio/examples/README.md | 9 ++ stubs/3.4/asyncio/examples/example_1.py | 20 +++ stubs/3.4/asyncio/examples/example_2.py | 25 ++++ stubs/3.4/asyncio/examples/example_3.py | 20 +++ stubs/3.4/asyncio/examples/example_4.py | 34 +++++ stubs/3.4/asyncio/examples/example_5.py | 22 ++++ stubs/3.4/asyncio/examples/example_6.py | 35 +++++ stubs/3.4/asyncio/examples/example_7.py | 36 ++++++ stubs/3.4/asyncio/examples/example_8.py | 25 ++++ stubs/3.4/asyncio/examples/example_error_1.py | 28 ++++ stubs/3.4/asyncio/examples/example_error_2.py | 30 +++++ stubs/3.4/asyncio/examples/example_error_3.py | 57 +++++++++ stubs/3.4/asyncio/examples/example_error_4.py | 33 +++++ stubs/3.4/asyncio/examples/example_error_7.py | 40 ++++++ stubs/3.4/asyncio/examples/example_error_8.py | 31 +++++ stubs/3.4/asyncio/futures.py | 15 ++- stubs/3.4/asyncio/tasks.py | 42 ++++-- 34 files changed, 804 insertions(+), 148 deletions(-) create mode 100644 stubs/3.4/asyncio/examples/README.md create mode 100644 stubs/3.4/asyncio/examples/example_1.py create mode 100644 stubs/3.4/asyncio/examples/example_2.py create mode 100644 stubs/3.4/asyncio/examples/example_3.py create mode 100644 stubs/3.4/asyncio/examples/example_4.py create mode 100644 stubs/3.4/asyncio/examples/example_5.py create mode 100644 stubs/3.4/asyncio/examples/example_6.py create mode 100644 stubs/3.4/asyncio/examples/example_7.py create mode 100644 stubs/3.4/asyncio/examples/example_8.py create mode 100644 stubs/3.4/asyncio/examples/example_error_1.py create mode 100644 stubs/3.4/asyncio/examples/example_error_2.py create mode 100644 stubs/3.4/asyncio/examples/example_error_3.py create mode 100644 stubs/3.4/asyncio/examples/example_error_4.py create mode 100644 stubs/3.4/asyncio/examples/example_error_7.py create mode 100644 stubs/3.4/asyncio/examples/example_error_8.py diff --git a/mypy/checker.py b/mypy/checker.py index b616cb1bc2f0..843f30d14b03 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2,7 +2,7 @@ import itertools -from typing import Undefined, Any, Dict, Set, List, cast, overload, Tuple, Function, typevar +from typing import Undefined, Any, Dict, Set, List, cast, overload, Tuple, typevar from mypy.errors import Errors from mypy.nodes import ( @@ -15,14 +15,14 @@ BytesExpr, UnicodeExpr, FloatExpr, OpExpr, UnaryExpr, CastExpr, SuperExpr, TypeApplication, DictExpr, SliceExpr, FuncExpr, TempNode, SymbolTableNode, Context, ListComprehension, ConditionalExpr, GeneratorExpr, - Decorator, SetExpr, PassStmt, TypeVarExpr, UndefinedExpr, PrintStmt, - LITERAL_TYPE, BreakStmt, ContinueStmt, YieldFromExpr + Decorator, SetExpr, TypeVarExpr, UndefinedExpr, PrintStmt, + LITERAL_TYPE, BreakStmt, ContinueStmt, YieldFromExpr, YieldFromStmt ) from mypy.nodes import function_type, method_type from mypy import nodes from mypy.types import ( Type, AnyType, Callable, Void, FunctionLike, Overloaded, TupleType, - Instance, NoneTyp, UnboundType, ErrorType, TypeTranslator, BasicTypes, + Instance, NoneTyp, UnboundType, ErrorType, BasicTypes, strip_type, UnionType ) from mypy.sametypes import is_same_type @@ -39,7 +39,7 @@ from mypy.visitor import NodeVisitor from mypy.join import join_simple, join_types from mypy.treetransform import TransformVisitor -from mypy.meet import meet_simple, meet_simple_away, nearest_builtin_ancestor, is_overlapping_types +from mypy.meet import meet_simple, nearest_builtin_ancestor, is_overlapping_types # Kinds of isinstance checks. @@ -1207,18 +1207,37 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type: if not isinstance(self.function_stack[-1], FuncExpr): self.fail(messages.NO_RETURN_VALUE_EXPECTED, s) else: + if self.function_stack[-1].is_coroutine: # Something similar will be needed to mix return and yield + #If the function is a coroutine, wrap the return type in a Future + typ = self.wrap_generic_type(typ, self.return_types[-1], 'asyncio.futures.Future') self.check_subtype( typ, self.return_types[-1], s, messages.INCOMPATIBLE_RETURN_VALUE_TYPE + ": expected {}, got {}".format(self.return_types[-1], typ) ) else: - # Return without a value. It's valid in a generator function. - if not self.function_stack[-1].is_generator: + # Return without a value. It's valid in a generator and coroutine function. + if not self.function_stack[-1].is_generator and not self.function_stack[-1].is_coroutine: if (not isinstance(self.return_types[-1], Void) and not self.is_dynamic_function()): self.fail(messages.RETURN_VALUE_EXPECTED, s) + def wrap_generic_type(self, typ: Type, rtyp: Type, check_type: str) -> Type: + n_diff = self.count_concatenated_types(rtyp, check_type) - self.count_concatenated_types(typ, check_type) + if n_diff >= 1: + return self.named_generic_type(check_type, [typ]) + return typ + + def count_concatenated_types(self, typ: Type, check_type: str) -> int: + c = 0 + while is_subtype(typ, self.named_type(check_type)): + c += 1 + if hasattr(typ, 'args') and typ.args: + typ = typ.args[0] + else: + return c + return c + def visit_yield_stmt(self, s: YieldStmt) -> Type: return_type = self.return_types[-1] if isinstance(return_type, Instance): @@ -1239,6 +1258,55 @@ def visit_yield_stmt(self, s: YieldStmt) -> Type: messages.INCOMPATIBLE_TYPES_IN_YIELD, 'actual type', 'expected type') + def visit_yield_from_stmt(self, s: YieldFromStmt) -> Type: + return_type = self.return_types[-1] + type_func = self.accept(s.expr, return_type) + if isinstance(type_func, Instance): + if hasattr(type_func, 'type') and hasattr(type_func.type, 'fullname') and type_func.type.fullname() == 'asyncio.futures.Future': + # if is a Future, in stmt don't need to do nothing + # because the type Future[Some] jus matters to the main loop + # that python executes, in statement we shouldn't get the Future, + # is just for async purposes. + self.function_stack[-1].is_coroutine = True # Set the function as coroutine + elif is_subtype(type_func, self.named_type('typing.Iterable')): + # If it's and Iterable-Like, let's check the types. + # Maybe just check if have __iter__? (like in analyse_iterable) + self.check_iterable_yf(s) + else: + self.msg.yield_from_not_valid_applied(type_func, s) + elif isinstance(type_func, AnyType): + self.check_iterable_yf(s) + else: + self.msg.yield_from_not_valid_applied(type_func, s) + + def check_iterable_yf(self, s: YieldFromStmt) -> Type: + """ + Check that return type is super type of Iterable (Maybe just check if have __iter__?) + and compare it with the type of the expression + """ + expected_item_type = self.return_types[-1] + if isinstance(expected_item_type, Instance): + if not is_subtype(expected_item_type, self.named_type('typing.Iterable')): + self.fail(messages.INVALID_RETURN_TYPE_FOR_YIELD_FROM, s) + return None + elif hasattr(expected_item_type, 'args') and expected_item_type.args: + expected_item_type = expected_item_type.args[0] # Take the item inside the iterator + # expected_item_type = expected_item_type + elif isinstance(expected_item_type, AnyType): + expected_item_type = AnyType() + else: + self.fail(messages.INVALID_RETURN_TYPE_FOR_YIELD_FROM, s) + return None + if s.expr is None: + actual_item_type = Void() + else: + actual_item_type = self.accept(s.expr, expected_item_type) + if hasattr(actual_item_type, 'args') and actual_item_type.args: + actual_item_type = actual_item_type.args[0] # Take the item inside the iterator + self.check_subtype(actual_item_type, expected_item_type, s, + messages.INCOMPATIBLE_TYPES_IN_YIELD_FROM, + 'actual type', 'expected type') + def visit_if_stmt(self, s: IfStmt) -> Type: """Type check an if statement.""" broken = True @@ -1551,6 +1619,16 @@ def visit_call_expr(self, e: CallExpr) -> Type: def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: result = self.expr_checker.visit_yield_from_expr(e) + if hasattr(result, 'type') and result.type.fullname() == "asyncio.futures.Future": + self.function_stack[-1].is_coroutine = True # Set the function as coroutine + result = result.args[0] # Set the return type as the type inside + elif is_subtype(result, self.named_type('typing.Iterable')): + # TODO + # Check return type Iterator[Some] + # Maybe set result like in the Future + pass + else: + self.msg.yield_from_not_valid_applied(e.expr, e) self.breaking_out = False return result diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 91a7a8776068..059912810684 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -99,12 +99,7 @@ def analyse_var_ref(self, var: Var, context: Context) -> Type: return val def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: - if isinstance(e.callee, CallExpr): - return self.visit_call_expr(e.callee) - elif isinstance(e.callee, NameExpr): - return self.visit_name_expr(e.callee) - - # return self.visit_call_expr(e.callee) + return e.expr.accept(self) # move it to checker? def visit_call_expr(self, e: CallExpr) -> Type: """Type check a call expression.""" diff --git a/mypy/icode.py b/mypy/icode.py index e64f80539787..6ab6202eff7d 100644 --- a/mypy/icode.py +++ b/mypy/icode.py @@ -650,10 +650,7 @@ def visit_unary_expr(self, e: UnaryExpr) -> int: return target def visit_yield_from_expr(self, e: YieldFromExpr) -> int: - if isinstance(e.callee, CallExpr): - return self.visit_call_expr(e.callee) - elif isinstance(e.callee, NameExpr): - return self.visit_name_expr(e.callee) + return e.expr.accept(self) def visit_call_expr(self, e: CallExpr) -> int: args = [] # type: List[int] diff --git a/mypy/messages.py b/mypy/messages.py index 22d2b2036d6e..428012060a91 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -32,9 +32,12 @@ INVALID_EXCEPTION_TYPE = 'Exception type must be derived from BaseException' INVALID_RETURN_TYPE_FOR_YIELD = \ 'Iterator function return type expected for "yield"' +INVALID_RETURN_TYPE_FOR_YIELD_FROM = \ + 'Iterable function return type expected for "yield from"' INCOMPATIBLE_TYPES = 'Incompatible types' INCOMPATIBLE_TYPES_IN_ASSIGNMENT = 'Incompatible types in assignment' INCOMPATIBLE_TYPES_IN_YIELD = 'Incompatible types in yield' +INCOMPATIBLE_TYPES_IN_YIELD_FROM = 'Incompatible types in "yield from"' INIT_MUST_NOT_HAVE_RETURN_TYPE = 'Cannot define return type for "__init__"' GETTER_TYPE_INCOMPATIBLE_WITH_SETTER = \ 'Type of getter incompatible with setter' @@ -63,25 +66,25 @@ class MessageBuilder: """Helper class for reporting type checker error messages with parameters. - + The methods of this class need to be provided with the context within a file; the errors member manages the wider context. - + IDEA: Support a 'verbose mode' that includes full information about types in error messages and that may otherwise produce more detailed error messages. """ - + # Report errors using this instance. It knows about the current file and # import context. errors = Undefined(Errors) - + # Number of times errors have been disabled. disable_count = 0 # Hack to deduplicate error messages from union types disable_type_names = 0 - + def __init__(self, errors: Errors) -> None: self.errors = errors self.disable_count = 0 @@ -106,12 +109,12 @@ def enable_errors(self) -> None: def is_errors(self) -> bool: return self.errors.is_errors() - + def fail(self, msg: str, context: Context) -> None: """Report an error message (unless disabled).""" if self.disable_count <= 0: self.errors.report(context.get_line(), msg.strip()) - + def format(self, typ: Type) -> str: """Convert a type to a relatively short string that is suitable for error messages. Mostly behave like format_simple @@ -141,13 +144,13 @@ def format(self, typ: Type) -> str: else: # Default case; we simply have to return something meaningful here. return 'object' - + def format_simple(self, typ: Type) -> str: """Convert simple types to string that is suitable for error messages. - + Return "" for complex types. Try to keep the length of the result relatively short to avoid overly long error messages. - + Examples: builtins.int -> 'int' Any type -> 'Any' @@ -219,11 +222,11 @@ def format_simple(self, typ: Type) -> str: # # Specific operations # - + # The following operations are for genering specific error messages. They # get some information as arguments, and they build an error message based # on them. - + def has_no_attr(self, typ: Type, member: str, context: Context) -> Type: """Report a missing or non-accessible member. @@ -271,36 +274,36 @@ def has_no_attr(self, typ: Type, member: str, context: Context) -> Type: self.fail('Some element of union has no attribute "{}"'.format( member), context) return AnyType() - + def unsupported_operand_types(self, op: str, left_type: Any, right_type: Any, context: Context) -> None: """Report unsupported operand types for a binary operation. - + Types can be Type objects or strings. """ if isinstance(left_type, Void) or isinstance(right_type, Void): self.check_void(left_type, context) self.check_void(right_type, context) - return + return left_str = '' if isinstance(left_type, str): left_str = left_type else: left_str = self.format(left_type) - + right_str = '' if isinstance(right_type, str): right_str = right_type else: right_str = self.format(right_type) - + if self.disable_type_names: msg = 'Unsupported operand types for {} (likely involving Union)'.format(op) else: msg = 'Unsupported operand types for {} ({} and {})'.format( op, left_str, right_str) self.fail(msg, context) - + def unsupported_left_operand(self, op: str, typ: Type, context: Context) -> None: if not self.check_void(typ, context): @@ -310,14 +313,14 @@ def unsupported_left_operand(self, op: str, typ: Type, msg = 'Unsupported left operand type for {} ({})'.format( op, self.format(typ)) self.fail(msg, context) - + def type_expected_as_right_operand_of_is(self, context: Context) -> None: self.fail('Type expected as right operand of "is"', context) - + def not_callable(self, typ: Type, context: Context) -> Type: self.fail('{} not callable'.format(self.format(typ)), context) return AnyType() - + def incompatible_argument(self, n: int, callee: Callable, arg_type: Type, context: Context) -> None: """Report an error about an incompatible argument type. @@ -331,7 +334,7 @@ def incompatible_argument(self, n: int, callee: Callable, arg_type: Type, if callee.name: name = callee.name base = extract_type(name) - + for op, method in op_methods.items(): for variant in method, '__r' + method[2:]: if name.startswith('"{}" of'.format(variant)): @@ -342,21 +345,21 @@ def incompatible_argument(self, n: int, callee: Callable, arg_type: Type, else: self.unsupported_operand_types(op, base, arg_type, context) - return - + return + if name.startswith('"__getitem__" of'): self.invalid_index_type(arg_type, base, context) - return - + return + if name.startswith('"__setitem__" of'): if n == 1: self.invalid_index_type(arg_type, base, context) else: self.fail(INCOMPATIBLE_TYPES_IN_ASSIGNMENT, context) - return - + return + target = 'to {} '.format(name) - + msg = '' if callee.name == '': name = callee.name[1:-1] @@ -376,31 +379,31 @@ def incompatible_argument(self, n: int, callee: Callable, arg_type: Type, msg = 'Argument {} {}has incompatible type {}; expected {}'.format( n, target, self.format(arg_type), self.format(expected_type)) self.fail(msg, context) - + def invalid_index_type(self, index_type: Type, base_str: str, context: Context) -> None: self.fail('Invalid index type {} for {}'.format( self.format(index_type), base_str), context) - + def invalid_argument_count(self, callee: Callable, num_args: int, context: Context) -> None: if num_args < len(callee.arg_types): self.too_few_arguments(callee, context) else: self.too_many_arguments(callee, context) - + def too_few_arguments(self, callee: Callable, context: Context) -> None: msg = 'Too few arguments' if callee.name: msg += ' for {}'.format(callee.name) self.fail(msg, context) - + def too_many_arguments(self, callee: Callable, context: Context) -> None: msg = 'Too many arguments' if callee.name: msg += ' for {}'.format(callee.name) self.fail(msg, context) - + def too_many_positional_arguments(self, callee: Callable, context: Context) -> None: msg = 'Too many positional arguments' @@ -413,14 +416,14 @@ def unexpected_keyword_argument(self, callee: Callable, name: str, msg = 'Unexpected keyword argument "{}"'.format(name) if callee.name: msg += ' for {}'.format(callee.name) - self.fail(msg, context) + self.fail(msg, context) def duplicate_argument_value(self, callee: Callable, index: int, context: Context) -> None: self.fail('{} gets multiple values for keyword argument "{}"'. format(capitalize(callable_name(callee)), callee.arg_names[index]), context) - + def does_not_return_value(self, void_type: Type, context: Context) -> None: """Report an error about a void type in a non-void context. @@ -433,7 +436,7 @@ def does_not_return_value(self, void_type: Type, context: Context) -> None: else: self.fail('{} does not return a value'.format( capitalize((cast(Void, void_type)).source)), context) - + def no_variant_matches_arguments(self, overload: Overloaded, context: Context) -> None: if overload.name(): @@ -441,23 +444,23 @@ def no_variant_matches_arguments(self, overload: Overloaded, .format(overload.name()), context) else: self.fail('No overload variant matches argument types', context) - + def function_variants_overlap(self, n1: int, n2: int, context: Context) -> None: self.fail('Function signature variants {} and {} overlap'.format( n1 + 1, n2 + 1), context) - + def invalid_cast(self, target_type: Type, source_type: Type, context: Context) -> None: if not self.check_void(source_type, context): self.fail('Cannot cast from {} to {}'.format( self.format(source_type), self.format(target_type)), context) - + def incompatible_operator_assignment(self, op: str, context: Context) -> None: self.fail('Result type of {} incompatible in assignment'.format(op), context) - + def incompatible_value_count_in_assignment(self, lvalue_count: int, rvalue_count: int, context: Context) -> None: @@ -465,26 +468,26 @@ def incompatible_value_count_in_assignment(self, lvalue_count: int, self.fail('Need {} values to assign'.format(lvalue_count), context) elif rvalue_count > lvalue_count: self.fail('Too many values to assign', context) - + def type_incompatible_with_supertype(self, name: str, supertype: TypeInfo, context: Context) -> None: self.fail('Type of "{}" incompatible with supertype "{}"'.format( name, supertype.name), context) - + def signature_incompatible_with_supertype( self, name: str, name_in_super: str, supertype: str, context: Context) -> None: target = self.override_target(name, name_in_super, supertype) self.fail('Signature of "{}" incompatible with {}'.format( name, target), context) - + def argument_incompatible_with_supertype( self, arg_num: int, name: str, name_in_supertype: str, supertype: str, context: Context) -> None: target = self.override_target(name, name_in_supertype, supertype) self.fail('Argument {} of "{}" incompatible with {}' .format(arg_num, name, target), context) - + def return_type_incompatible_with_supertype( self, name: str, name_in_supertype: str, supertype: str, context: Context) -> None: @@ -497,13 +500,13 @@ def override_target(self, name: str, name_in_super: str, target = 'supertype "{}"'.format(supertype) if name_in_super != name: target = '"{}" of {}'.format(name_in_super, target) - return target - + return target + def boolean_return_value_expected(self, method: str, context: Context) -> None: self.fail('Boolean return value expected for method "{}"'.format( method), context) - + def incompatible_type_application(self, expected_arg_count: int, actual_arg_count: int, context: Context) -> None: @@ -516,12 +519,12 @@ def incompatible_type_application(self, expected_arg_count: int, else: self.fail('Type application has too few types ({} expected)' .format(expected_arg_count), context) - + def incompatible_array_item_type(self, typ: Type, index: int, context: Context) -> None: self.fail('Array item {} has incompatible type {}'.format( index, self.format(typ)), context) - + def could_not_infer_type_arguments(self, callee_type: Callable, n: int, context: Context) -> None: if callee_type.name and n > 0: @@ -529,10 +532,10 @@ def could_not_infer_type_arguments(self, callee_type: Callable, n: int, n, callee_type.name), context) else: self.fail('Cannot infer function type argument', context) - + def invalid_var_arg(self, typ: Type, context: Context) -> None: self.fail('List or tuple expected as variable arguments', context) - + def invalid_keyword_var_arg(self, typ: Type, context: Context) -> None: if isinstance(typ, Instance) and ( (cast(Instance, typ)).type.fullname() == 'builtins.dict'): @@ -540,18 +543,18 @@ def invalid_keyword_var_arg(self, typ: Type, context: Context) -> None: else: self.fail('Argument after ** must be a dictionary', context) - + def incomplete_type_var_match(self, member: str, context: Context) -> None: self.fail('"{}" has incomplete match to supertype type variable' .format(member), context) - + def not_implemented(self, msg: str, context: Context) -> Type: self.fail('Feature not implemented yet ({})'.format(msg), context) return AnyType() - + def undefined_in_superclass(self, member: str, context: Context) -> None: self.fail('"{}" undefined in superclass'.format(member), context) - + def check_void(self, typ: Type, context: Context) -> bool: """If type is void, report an error such as '.. does not return a value' and return True. Otherwise, return False. @@ -633,6 +636,11 @@ def signatures_incompatible(self, method: str, other_method: str, self.fail('Signatures of "{}" and "{}" are incompatible'.format( method, other_method), context) + def yield_from_not_valid_applied(self, expr: Type, context: Context) -> Type: + text = self.format(expr) if self.format(expr) != 'object' else expr + self.fail('"yield from" can\'t be applied to {}'.format(text), context) + return AnyType() + def capitalize(s: str) -> str: """Capitalize the first character of a string.""" diff --git a/mypy/nodes.py b/mypy/nodes.py index d5e476517147..3e9f9ee8d4f4 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -237,6 +237,7 @@ class FuncItem(FuncBase): is_overload = False # Is this an overload variant of function with # more than one overload variant? is_generator = False # Contains a yield statement? + is_coroutine = False # Contains @coroutine or yield from Future is_static = False # Uses @staticmethod? is_class = False # Uses @classmethod? expanded = Undefined(List['FuncItem']) # Variants of function with type @@ -927,10 +928,10 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class YieldFromExpr(Node): - callee = Undefined(Node) + expr = Undefined(Node) - def __init__(self, callee: Node) -> None: - self.callee = callee + def __init__(self, expr: Node) -> None: + self.expr = expr def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_yield_from_expr(self) diff --git a/mypy/output.py b/mypy/output.py index d51dd777c16d..5f8a36414842 100644 --- a/mypy/output.py +++ b/mypy/output.py @@ -360,10 +360,7 @@ def visit_slice_expr(self, o): self.node(o.stride) def visit_yield_from_expr(self, o): - if isinstance(o.callee, CallExpr): - self.visit_call_expr(o.callee) - elif isinstance(o.callee, NameExpr): - self.visit_name_expr(o.callee) + o.expr.accept(self) def visit_call_expr(self, o): r = o.repr diff --git a/mypy/parse.py b/mypy/parse.py index ba7f97399f5f..de636eefeabd 100644 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -787,11 +787,21 @@ def parse_yield_stmt(self) -> YieldStmt: self.set_repr(node, noderepr.SimpleStmtRepr(yield_tok, br)) return node - def parse_yield_from_expr(self) -> CallExpr: + def parse_yield_from_expr(self) -> CallExpr: # Maybe the name should be yield_expr y_tok = self.expect("yield") - f_tok = self.expect("from") - tok = self.parse_expression() # Here comes when yield from is assigned to a variable - return YieldFromExpr(tok) + expr = None # type: Node + node = YieldFromExpr(expr) + if self.current_str() == "from": + f_tok = self.expect("from") + tok = self.parse_expression() # Here comes when yield from is assigned to a variable + node = YieldFromExpr(tok) + else: + # TODO + # Here comes the yield expression (ex: x = yield 3 ) + # tok = self.parse_expression() + # node = YieldExpr(tok) # Doesn't exist now + pass + return node def parse_del_stmt(self) -> DelStmt: del_tok = self.expect('del') @@ -1069,8 +1079,8 @@ def parse_expression(self, prec: int = 0) -> Node: expr = self.parse_unicode_literal() elif isinstance(self.current(), FloatLit): expr = self.parse_float_expr() - elif isinstance(t, Keyword) and s == "yield": #maybe check that next is from - expr = self.parse_yield_from_expr() # The expression yield from to assign + elif isinstance(t, Keyword) and s == "yield": + expr = self.parse_yield_from_expr() # The expression yield from and yield to assign else: # Invalid expression. self.parse_error() diff --git a/mypy/pprinter.py b/mypy/pprinter.py index 6baa232e9937..9d3532f0f44a 100644 --- a/mypy/pprinter.py +++ b/mypy/pprinter.py @@ -159,10 +159,8 @@ def visit_call_expr(self, o): self.string(')') def visit_yield_from_expr(self, o): - if isinstance(o.callee, CallExpr): - self.visit_call_expr(o.callee) - elif isinstance(o.callee, NameExpr): - self.visit_name_expr(o.callee) + if o.expr: + o.expr.accept(self) def visit_member_expr(self, o): self.node(o.expr) diff --git a/mypy/semanal.py b/mypy/semanal.py index b47c5c985c8e..75dae5045dba 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -945,6 +945,9 @@ def visit_decorator(self, dec: Decorator) -> None: removed.append(i) dec.func.is_abstract = True self.check_decorated_function_is_method('abstractmethod', dec) + elif refers_to_fullname(d, 'asyncio.tasks.coroutine'): + removed.append(i) + dec.func.is_coroutine = True elif refers_to_fullname(d, 'builtins.staticmethod'): removed.append(i) dec.func.is_static = True @@ -1001,7 +1004,6 @@ def visit_yield_stmt(self, s: YieldStmt) -> None: def visit_yield_from_stmt(self, s: YieldFromStmt) -> None: if not self.is_func_scope(): self.fail("'yield from' outside function", s) - #Check coroutine?? if s.expr: s.expr.accept(self) @@ -1141,11 +1143,11 @@ def visit_dict_expr(self, expr: DictExpr) -> None: def visit_paren_expr(self, expr: ParenExpr) -> None: expr.expr.accept(self) - def visit_yield_from_expr(self, expr: YieldFromExpr) -> None: + def visit_yield_from_expr(self, e: YieldFromExpr) -> None: if not self.is_func_scope(): # not sure self.fail("'yield from' outside function", s) - if expr.callee: - expr.callee.accept(self) + if e.expr: + e.expr.accept(self) def visit_call_expr(self, expr: CallExpr) -> None: """Analyze a call expression. diff --git a/mypy/stats.py b/mypy/stats.py index 1889de1ed4aa..af8667043fc1 100644 --- a/mypy/stats.py +++ b/mypy/stats.py @@ -110,10 +110,8 @@ def visit_name_expr(self, o: NameExpr) -> None: super().visit_name_expr(o) def visit_yield_from_expr(self, o: YieldFromExpr) -> None: - if isinstance(o.callee, CallExpr): - self.visit_call_expr(o.callee) - elif isinstance(o.callee, NameExpr): - self.visit_name_expr(o.callee) + if o.expr: + o.expr.accept(self) def visit_call_expr(self, o: CallExpr) -> None: self.process_node(o) diff --git a/mypy/strconv.py b/mypy/strconv.py index 714c8255bc40..e08eca598b3f 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -325,10 +325,10 @@ def visit_member_expr(self, o): o.is_def)], o) def visit_yield_from_expr(self, o): - if isinstance(o.callee, mypy.nodes.CallExpr): - return self.dump([self.visit_call_expr(o.callee)], o) - elif isinstance(o.callee, mypy.nodes.NameExpr): - return self.dump([self.visit_name_expr(o.callee)], o) + if o.expr: + return self.dump([o.expr.accept(self)], o) + else: + return self.dump([], o) def visit_call_expr(self, o): if o.analyzed: diff --git a/mypy/test/data/check-statements.test b/mypy/test/data/check-statements.test index 51233af0c935..fdc7d2908fa4 100644 --- a/mypy/test/data/check-statements.test +++ b/mypy/test/data/check-statements.test @@ -36,9 +36,9 @@ main, line 3: Incompatible return value type: expected __main__.B, got __main__. [case testReturnWithoutAValue] import typing def f() -> 'A': - return + return def g() -> None: - return + return class A: pass [out] @@ -440,7 +440,7 @@ else: [case testExceptWithoutType] import typing -try: +try: -None # E: Unsupported operand type for unary - (None) except: ~None # E: Unsupported operand type for ~ (None) @@ -515,7 +515,7 @@ class E(BaseException): def __init__(self) -> None: pass @overload def __init__(self, x) -> None: pass -try: +try: pass except E as e: e = E() @@ -596,7 +596,7 @@ def f() -> 'Iterator[List[int]]': [builtins fixtures/for.py] [out] main: In function "f": - + [case testYieldAndReturnWithoutValue] from typing import Iterator def f() -> Iterator[int]: @@ -611,6 +611,88 @@ def f() -> Iterator[None]: [builtins fixtures/for.py] +-- Yield from statement +-- -------------------- + +-- Iterables +-- ---------- + +[case testSimpleYFIter] +from typing import Iterator +def g() -> Iterator[str]: + yield '42' +def h() -> Iterator[int]: + yield 42 +def f() -> Iterator[str]: + yield from g() + yield from h() # E: Incompatible types in "yield from" (actual type "int", expected type "str") +[out] +main: In function "f": + +[case testYFAppliedToAny] +from typing import Any +def g() -> Any: + yield object() +def f() -> Any: + yield from g() +[out] + +[case testYFInFunctionReturningFunction] +from typing import Iterator, Function +def g() -> Iterator[int]: + yield 42 +def f() -> Function[[], None]: + yield from g() # E: Iterable function return type expected for "yield from" +[out] +main: In function "f": + +[case testGoodYFNotIterableReturnType] +from typing import Iterator +def g() -> Iterator[int]: + yield 42 +def f() -> int: + yield from g() # E: Iterable function return type expected for "yield from" +[out] +main: In function "f": + +[case testYFNotAppliedIter] +from typing import Iterator +def g() -> int: + return 42 +def f() -> Iterator[int]: + yield from g() # E: "yield from" can't be applied to "int" +[out] +main: In function "f": + +[case testYFCheckIncompatibleTypesTwoIterables] +from typing import List, Iterator +def g() -> Iterator[List[int]]: + yield [2, 3, 4] +def f() -> Iterator[List[int]]: + yield from g() + yield from [1, 2, 3] # E: Incompatible types in "yield from" (actual type "int", expected type List[int]) +[builtins fixtures/for.py] +[out] +main: In function "f": + +[case testYFNotAppliedToNothing] +def h(): + yield from # E: Parse error before end of line +[out] +main: In function "h": + +[case testYFAndYieldTogether] +from typing import Iterator +def f() -> Iterator[str]: + yield "g1 ham" + yield from g() + yield "g1 eggs" +def g() -> Iterator[str]: + yield "g2 spam" + yield "g2 more spam" +[out] + + -- With statement -- -------------- diff --git a/mypy/transform.py b/mypy/transform.py index 06cf170cce05..11058683f3ba 100644 --- a/mypy/transform.py +++ b/mypy/transform.py @@ -220,10 +220,8 @@ def get_member_reference_suffix(self, name: str, info: TypeInfo) -> str: return '' def visit_yield_from_expr(self, e: YieldFromExpr) -> None: - if isinstance(e.callee, CallExpr): - self.visit_call_expr(e.callee) - elif isinstance(e.callee, NameExpr): - self.visit_name_expr(e.callee) + if e.expr: + e.expr.accept(self) def visit_call_expr(self, e: CallExpr) -> None: if e.analyzed: diff --git a/mypy/traverser.py b/mypy/traverser.py index f9c86389ce98..c3b338bd8eb0 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -10,7 +10,7 @@ TryStmt, WithStmt, ParenExpr, MemberExpr, OpExpr, SliceExpr, CastExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr, GeneratorExpr, ListComprehension, ConditionalExpr, TypeApplication, - FuncExpr, OverloadedFuncDef, YieldFromStmt, YieldFromExpr, NameExpr + FuncExpr, OverloadedFuncDef, YieldFromStmt, YieldFromExpr ) @@ -151,11 +151,7 @@ def visit_member_expr(self, o: MemberExpr) -> T: o.expr.accept(self) def visit_yield_from_expr(self, o: YieldFromExpr) -> T: - if isinstance(o.callee, CallExpr): - self.visit_call_expr(o.callee) - elif isinstance(o.callee, NameExpr): - self.visit_name_expr(o.callee) - + o.expr.accept(self) def visit_call_expr(self, o: CallExpr) -> T: for a in o.args: diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 0a151754f119..4fde312b4936 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -307,7 +307,7 @@ def copy_ref(self, new: RefExpr, original: RefExpr) -> None: new.is_def = original.is_def def visit_yield_from_expr(self, node: YieldFromExpr) -> Node: - return YieldFromExpr(self.node(node.callee)) + return YieldFromExpr(self.node(node.expr)) def visit_call_expr(self, node: CallExpr) -> Node: return CallExpr(self.node(node.callee), diff --git a/stubs/3.4/asyncio/__init__.py b/stubs/3.4/asyncio/__init__.py index cafecf5155ce..98d3092224a9 100644 --- a/stubs/3.4/asyncio/__init__.py +++ b/stubs/3.4/asyncio/__init__.py @@ -1,7 +1,9 @@ """The asyncio package, tracking PEP 3156.""" -from asyncio.futures import * -from asyncio.tasks import * -from asyncio.events import * +from asyncio.futures import Future +from asyncio.tasks import (coroutine, sleep, Task, FIRST_COMPLETED, + FIRST_EXCEPTION, ALL_COMPLETED, wait, wait_for) +from asyncio.events import (AbstractEventLoopPolicy, AbstractEventLoop, + Handle, get_event_loop) __all__ = (futures.__all__, tasks.__all__, diff --git a/stubs/3.4/asyncio/events.py b/stubs/3.4/asyncio/events.py index cada2a42da62..48b7ce2862d5 100644 --- a/stubs/3.4/asyncio/events.py +++ b/stubs/3.4/asyncio/events.py @@ -1,21 +1,24 @@ -from typing import Any, typevar, List, Function, Tuple, Union, Dict +from typing import Any, typevar, List, Function, Tuple, Union, Dict, Undefined from abc import ABCMeta, abstractmethod from asyncio.futures import Future -import socket, subprocess -# __all__ = ['AbstractEventLoopPolicy', -# 'AbstractEventLoop', 'AbstractServer', -# 'Handle', 'TimerHandle', +# __all__ = ['AbstractServer', +# 'TimerHandle', # 'get_event_loop_policy', 'set_event_loop_policy', -# 'get_event_loop', 'set_event_loop', 'new_event_loop', +# 'set_event_loop', 'new_event_loop', # 'get_child_watcher', 'set_child_watcher', # ] -__all__ = ['AbstractEventLoop', 'Handle', 'get_event_loop'] +__all__ = ['AbstractEventLoopPolicy', 'AbstractEventLoop', 'Handle', 'get_event_loop'] T = typevar('T') +PIPE = Undefined(Any) # from subprocess.PIPE + +AF_UNSPEC = 0 # from socket +AI_PASSIVE = 0 + class Handle: __slots__ = [] # type: List[str] _cancelled = False @@ -69,7 +72,7 @@ def create_connection(self, protocol_factory: Any, host: str=None, port: int=Non # return (Transport, Protocol) @abstractmethod def create_server(self, protocol_factory: Any, host: str=None, port: int=None, *, - family: int=socket.AF_UNSPEC, flags: int=socket.AI_PASSIVE, + family: int=AF_UNSPEC, flags: int=AI_PASSIVE, sock: Any=None, backlog: int=100, ssl: Any=None, reuse_address: Any=None) -> Any: pass # ?? check Any # return Server @@ -100,14 +103,14 @@ def connect_write_pipe(self, protocol_factory: Any, pipe: Any) -> tuple: pass #?? check Any # return (Transport, Protocol) @abstractmethod - def subprocess_shell(self, protocol_factory: Any, cmd: Union[bytes,str], *, stdin: Any=subprocess.PIPE, - stdout: Any=subprocess.PIPE, stderr: Any=subprocess.PIPE, + def subprocess_shell(self, protocol_factory: Any, cmd: Union[bytes,str], *, stdin: Any=PIPE, + stdout: Any=PIPE, stderr: Any=PIPE, **kwargs: Dict[str, Any]) -> tuple: pass #?? check Any # return (Transport, Protocol) @abstractmethod - def subprocess_exec(self, protocol_factory: Any, *args: List[Any], stdin: Any=subprocess.PIPE, - stdout: Any=subprocess.PIPE, stderr: Any=subprocess.PIPE, + def subprocess_exec(self, protocol_factory: Any, *args: List[Any], stdin: Any=PIPE, + stdout: Any=PIPE, stderr: Any=PIPE, **kwargs: Dict[str, Any]) -> tuple: pass #?? check Any # return (Transport, Protocol) @@ -146,5 +149,24 @@ def get_debug(self) -> bool: pass @abstractmethod def set_debug(self, enabled: bool) -> None: pass +class AbstractEventLoopPolicy(metaclass=ABCMeta): + @abstractmethod + def get_event_loop(self) -> AbstractEventLoop: pass + @abstractmethod + def set_event_loop(self, loop: AbstractEventLoop): pass + @abstractmethod + def new_event_loop(self) -> Any: pass # return selector_events.BaseSelectorEventLoop + # Child processes handling (Unix only). + @abstractmethod + def get_child_watcher(self) -> Any: pass # return unix_events.AbstractChildWatcher + @abstractmethod + def set_child_watcher(self, watcher: Any) -> None: pass # gen unix_events.AbstractChildWatcher + +class BaseDefaultEventLoopPolicy(AbstractEventLoopPolicy): + def __init__(self) -> None: pass + def get_event_loop(self) -> AbstractEventLoop: pass + def set_event_loop(self, loop: AbstractEventLoop): pass + def new_event_loop(self) -> Any: pass # Same return than AbstractEventLoop + def get_event_loop() -> AbstractEventLoop: pass \ No newline at end of file diff --git a/stubs/3.4/asyncio/examples/README.md b/stubs/3.4/asyncio/examples/README.md new file mode 100644 index 000000000000..6bbb1722fc87 --- /dev/null +++ b/stubs/3.4/asyncio/examples/README.md @@ -0,0 +1,9 @@ +Examples with Futures and asyncio +================================= + +There are two types of files: + +- example\_\*.py: +That ones have **good** examples, with type types well writed and ewerything working. +- example\_error\_\*.py: +That ones are **errors**, they are type-check errors of all kind. diff --git a/stubs/3.4/asyncio/examples/example_1.py b/stubs/3.4/asyncio/examples/example_1.py new file mode 100644 index 000000000000..eeb8ec80cd6f --- /dev/null +++ b/stubs/3.4/asyncio/examples/example_1.py @@ -0,0 +1,20 @@ +from typing import Any +import asyncio +from asyncio import Future + +@asyncio.coroutine +def greet_every_two_seconds() -> 'Future[None]': + """ + That function won't return nothing, but can be applied to + yield from or sended to the main_loop (run_until_complete in this case) + for that reason, the type is Future[None] + """ + while True: + print('Hello World') + yield from asyncio.sleep(2) + +loop = asyncio.get_event_loop() +try: + loop.run_until_complete(greet_every_two_seconds()) +finally: + loop.close() diff --git a/stubs/3.4/asyncio/examples/example_2.py b/stubs/3.4/asyncio/examples/example_2.py new file mode 100644 index 000000000000..5ddb636b9c4c --- /dev/null +++ b/stubs/3.4/asyncio/examples/example_2.py @@ -0,0 +1,25 @@ +import asyncio +from asyncio import Future + +@asyncio.coroutine +def compute(x: int, y: int) -> 'Future[int]': + """ + That function will return a int, but can be "yielded from", so + the type is Future[int] + The return type (int) will be wrapped into a Future. + """ + print("Compute %s + %s ..." % (x, y)) + yield from asyncio.sleep(1.0) + return x + y # Here the int is wrapped in Future[int] + +@asyncio.coroutine +def print_sum(x: int, y: int) -> 'Future[None]': + """ + Don't return nothing, but can be "yielded from", so is a Future. + """ + result = yield from compute(x, y) # The type of result will be int (is extracted from Future[int] + print("%s + %s = %s" % (x, y, result)) + +loop = asyncio.get_event_loop() +loop.run_until_complete(print_sum(1, 2)) +loop.close() diff --git a/stubs/3.4/asyncio/examples/example_3.py b/stubs/3.4/asyncio/examples/example_3.py new file mode 100644 index 000000000000..7bdb1d52a2d2 --- /dev/null +++ b/stubs/3.4/asyncio/examples/example_3.py @@ -0,0 +1,20 @@ +""" +Simple example about the Future instance. +At Future[str] is declared out and passed to the function. +Inside the function the result is setted to a str. +""" + +import asyncio +from asyncio import Future + +@asyncio.coroutine +def slow_operation(future: 'Future[str]') -> 'Future[None]': + yield from asyncio.sleep(1) + future.set_result('Future is done!') + +loop = asyncio.get_event_loop() +future = asyncio.Future() # type: Future[str] +asyncio.Task(slow_operation(future)) +loop.run_until_complete(future) +print(future.result()) +loop.close() diff --git a/stubs/3.4/asyncio/examples/example_4.py b/stubs/3.4/asyncio/examples/example_4.py new file mode 100644 index 000000000000..816fef3dc5f6 --- /dev/null +++ b/stubs/3.4/asyncio/examples/example_4.py @@ -0,0 +1,34 @@ +""" +In this example, we have a coroutine function that is wrapped in a Task. +We also have a Future[str] with a callback function. +""" +import typing +import asyncio +from asyncio import Future, AbstractEventLoop + +@asyncio.coroutine +def slow_operation(future: 'Future[str]') -> 'Future[None]': + """ + Simple coroutine (explained in examples before) + """ + yield from asyncio.sleep(1) + future.set_result('Future is done!') + +def got_result(future: 'Future[str]') -> None: + """ + This is a normal function, so it's not a Future. + This function is setted as callback to the future, + the type of the callback functions is: + Function[[Future[T]], Any] + """ + print(future.result()) + loop.stop() + +loop = asyncio.get_event_loop() # type: AbstractEventLoop +future = asyncio.Future() # type: Future[str] +asyncio.Task(slow_operation(future)) # Here create a task with the function. (The Task need a Future[T] as first argument) +future.add_done_callback(got_result) # and assignt the callback to the future +try: + loop.run_forever() +finally: + loop.close() diff --git a/stubs/3.4/asyncio/examples/example_5.py b/stubs/3.4/asyncio/examples/example_5.py new file mode 100644 index 000000000000..c9da40d10abe --- /dev/null +++ b/stubs/3.4/asyncio/examples/example_5.py @@ -0,0 +1,22 @@ +""" +Example with multiple tasks. +""" +import typing +import asyncio +from asyncio import Task, Future +@asyncio.coroutine +def factorial(name, number) -> 'Future[None]': + f = 1 + for i in range(2, number+1): + print("Task %s: Compute factorial(%s)..." % (name, i)) + yield from asyncio.sleep(1) + f *= i + print("Task %s: factorial(%s) = %s" % (name, number, f)) + +loop = asyncio.get_event_loop() +tasks = [ + asyncio.Task(factorial("A", 2)), + asyncio.Task(factorial("B", 3)), + asyncio.Task(factorial("C", 4))] +loop.run_until_complete(asyncio.wait(tasks)) +loop.close() diff --git a/stubs/3.4/asyncio/examples/example_6.py b/stubs/3.4/asyncio/examples/example_6.py new file mode 100644 index 000000000000..469df703fd7e --- /dev/null +++ b/stubs/3.4/asyncio/examples/example_6.py @@ -0,0 +1,35 @@ +""" +Example with concatenated coroutines. +""" +import typing +import asyncio +from asyncio import Future + +@asyncio.coroutine +def h4() -> 'Future[int]': + x = yield from future + return x + +@asyncio.coroutine +def h3() -> 'Future[int]': + x = yield from h4() + print("h3: %s" % x) + return x + +@asyncio.coroutine +def h2() -> 'Future[int]': + x = yield from h3() + print("h2: %s" % x) + return x + +@asyncio.coroutine +def h() -> 'Future[None]': + x = yield from h2() + print("h: %s" % x) + +loop = asyncio.get_event_loop() +future = asyncio.Future() # type: Future[int] +future.set_result(42) +loop.run_until_complete(h()) +print("Outside %s" % future.result()) +loop.close() diff --git a/stubs/3.4/asyncio/examples/example_7.py b/stubs/3.4/asyncio/examples/example_7.py new file mode 100644 index 000000000000..e664621579af --- /dev/null +++ b/stubs/3.4/asyncio/examples/example_7.py @@ -0,0 +1,36 @@ +""" +Example with concatenated Futures. +The function return type always have one more Future[]. +""" +import typing +import asyncio +from asyncio import Future + +@asyncio.coroutine +def h4() -> 'Future[Future[int]]': + yield from asyncio.sleep(1) + f = asyncio.Future() #type: Future[int] + return f + +@asyncio.coroutine +def h3() -> 'Future[Future[Future[int]]]': + x = yield from h4() + x.set_result(42) + f = asyncio.Future() #type: Future[Future[int]] + f.set_result(x) + return f + +@asyncio.coroutine +def h() -> 'Future[None]': + print("Before") + x = yield from h3() + y = yield from x + z = yield from y + print(z) + print(y) + print(x) + +loop = asyncio.get_event_loop() +loop.run_until_complete(h()) +# loop.run_forever() +loop.close() diff --git a/stubs/3.4/asyncio/examples/example_8.py b/stubs/3.4/asyncio/examples/example_8.py new file mode 100644 index 000000000000..ad677c34babd --- /dev/null +++ b/stubs/3.4/asyncio/examples/example_8.py @@ -0,0 +1,25 @@ +""" +Example with a Future that have an own class +as type (Future[A]) +""" +import typing +import asyncio +from asyncio import Future + +class A: + def __init__(self, x: int) -> None: + self.x = x + + +@asyncio.coroutine +def h() -> 'Future[None]': + x = yield from future + print("h: %s" % x.x) + + +loop = asyncio.get_event_loop() +future = asyncio.Future() # type: Future[A] +future.set_result(A(42)) +loop.run_until_complete(h()) +print("Outside %s" % future.result().x) +loop.close() diff --git a/stubs/3.4/asyncio/examples/example_error_1.py b/stubs/3.4/asyncio/examples/example_error_1.py new file mode 100644 index 000000000000..a5c298156e9b --- /dev/null +++ b/stubs/3.4/asyncio/examples/example_error_1.py @@ -0,0 +1,28 @@ +from typing import Any +import asyncio +from asyncio import Future + +@asyncio.coroutine +def greet() -> 'Future[None]': + """ + The function don't return nothing, but is a coroutine, so the + type is Future[None]. + """ + yield from asyncio.sleep(2) + print('Hello World') + +@asyncio.coroutine +def test() -> 'Future[None]': + """ + The type of greet() is Future[None], so, we can do "yield from greet()" + but we can't do "x = yield from greet()", because the function don't return nothing, + we can't assign to a variable. + """ + yield from greet() + x = yield from greet() # E: Function does not return a value + +loop = asyncio.get_event_loop() +try: + loop.run_until_complete(test()) +finally: + loop.close() diff --git a/stubs/3.4/asyncio/examples/example_error_2.py b/stubs/3.4/asyncio/examples/example_error_2.py new file mode 100644 index 000000000000..d94d73007c7c --- /dev/null +++ b/stubs/3.4/asyncio/examples/example_error_2.py @@ -0,0 +1,30 @@ +""" +Simple error about return types. +The function return type is Future[int] +we are trying to return a str (that is wrapped in a Future[str]) +and the type-check fail. +""" +import asyncio +from asyncio import Future + +@asyncio.coroutine +def compute(x: int, y: int) -> 'Future[int]': + """ + This function will try to return a str, will be wrapped in a Future[str] and + will fail the type check with Future[int] + """ + print("Compute %s + %s ..." % (x, y)) + yield from asyncio.sleep(1.0) + return str(x + y) # E: Incompatible return value type: expected asyncio.futures.Future[builtins.int], got asyncio.futures.Future[builtins.str] + +@asyncio.coroutine +def print_sum(x: int, y: int) -> 'Future[None]': + """ + Don't return nothing, but is a coroutine, so is a Future. + """ + result = yield from compute(x, y) + print("%s + %s = %s" % (x, y, result)) + +loop = asyncio.get_event_loop() +loop.run_until_complete(print_sum(1, 2)) +loop.close() diff --git a/stubs/3.4/asyncio/examples/example_error_3.py b/stubs/3.4/asyncio/examples/example_error_3.py new file mode 100644 index 000000000000..2284a555e059 --- /dev/null +++ b/stubs/3.4/asyncio/examples/example_error_3.py @@ -0,0 +1,57 @@ +""" +Errors about futures. +slow_operation() is the only function that will work. +The other three have errors. +""" + +import asyncio +from asyncio import Future + +@asyncio.coroutine +def slow_operation(future: 'Future[str]') -> 'Future[None]': + """ + This function is OK. + """ + yield from asyncio.sleep(1) + future.set_result('42') + +@asyncio.coroutine +def slow_operation_2(future: 'Future[str]') -> 'Future[None]': + """ + This function fail trying to set an int as result. + """ + yield from asyncio.sleep(1) + future.set_result(42) #Try to set an int as result to a Future[str] + +@asyncio.coroutine +def slow_operation_3(future: 'Future[int]') -> 'Future[None]': + """ + This function fail because try to get a Future[int] and a Future[str] + is given. + """ + yield from asyncio.sleep(1) + future.set_result(42) + + +@asyncio.coroutine +def slow_operation_4(future: 'Future[int]') -> 'Future[None]': + """ + This function fail because try to get a Future[int] and a Future[str] + is given. + This function fail trying to set an str as result. + """ + yield from asyncio.sleep(1) + future.set_result('42') #Try to set an str as result to a Future[int] + +loop = asyncio.get_event_loop() +future = asyncio.Future() # type: Future[str] +future2 = asyncio.Future() # type: Future[str] +future3 = asyncio.Future() # type: Future[str] +future4 = asyncio.Future() # type: Future[str] +asyncio.Task(slow_operation(future)) +asyncio.Task(slow_operation_2(future2)) +asyncio.Task(slow_operation_3(future3)) +asyncio.Task(slow_operation_4(future4)) +loop.run_until_complete(future) +print(future.result()) +loop.close() diff --git a/stubs/3.4/asyncio/examples/example_error_4.py b/stubs/3.4/asyncio/examples/example_error_4.py new file mode 100644 index 000000000000..71b4df32d3bf --- /dev/null +++ b/stubs/3.4/asyncio/examples/example_error_4.py @@ -0,0 +1,33 @@ +""" +In this example, we have a coroutine function that is wrapped in a Task. +We also have a Future[str] with a callback function. +""" +import typing +import asyncio +from asyncio import Future, AbstractEventLoop + +@asyncio.coroutine +def slow_operation(future: 'Future[str]') -> 'Future[None]': + """ + Simple coroutine (explained in examples before) + """ + yield from asyncio.sleep(1) + future.set_result('Future is done!') + +def got_result(future: 'Future[int]') -> None: + """ + We say that we are expecting a Future[int] + but is assigned to a Future[str], so fails in the add_done_callback() + """ + print(future.result()) + loop.stop() + +loop = asyncio.get_event_loop() # type: AbstractEventLoop +future = asyncio.Future() # type: Future[str] +asyncio.Task(slow_operation(future)) # Here create a task with the function. (The Task need a Future[T] as first argument) +future.add_done_callback(got_result) # E: Argument 1 to "add_done_callback" of "Future" has incompatible type Function[[Future[int]] -> None]; expected Function[[Future[str]] -> "Any"] + +try: + loop.run_forever() +finally: + loop.close() diff --git a/stubs/3.4/asyncio/examples/example_error_7.py b/stubs/3.4/asyncio/examples/example_error_7.py new file mode 100644 index 000000000000..ae5de64ed58f --- /dev/null +++ b/stubs/3.4/asyncio/examples/example_error_7.py @@ -0,0 +1,40 @@ +""" +Simple error in concatenated Futures. +In h3() one more Future[] is given in the function type, so the type check fails. +In h2() one less Future[] is given in the function type. +""" +import typing +import asyncio +from asyncio import Future + +@asyncio.coroutine +def h4() -> 'Future[Future[int]]': + yield from asyncio.sleep(1) + f = asyncio.Future() #type: Future[int] + return f + +@asyncio.coroutine +def h3() -> 'Future[Future[Future[Future[int]]]]': + x = yield from h4() + x.set_result(42) + f = asyncio.Future() #type: Future[Future[int]] + f.set_result(x) + return f + """ + Incompatible return value type: expected asyncio.futures.Future[asyncio.futures.Future[asyncio.futures.Future[asyncio.futures.Future[builtins.int]]]], got asyncio.futures.Future[asyncio.futures.Future[asyncio.futures.Future[builtins.int]]] + """ + +@asyncio.coroutine +def h() -> 'Future[None]': + print("Before") + x = yield from h3() + y = yield from x + z = yield from y + print(z) + print(y) + print(x) + +loop = asyncio.get_event_loop() +loop.run_until_complete(h()) +# loop.run_forever() +loop.close() diff --git a/stubs/3.4/asyncio/examples/example_error_8.py b/stubs/3.4/asyncio/examples/example_error_8.py new file mode 100644 index 000000000000..361effb347fa --- /dev/null +++ b/stubs/3.4/asyncio/examples/example_error_8.py @@ -0,0 +1,31 @@ +""" +An error because we try to say that we get a 'B' type in the yield from future, +when we are getting an 'A' type +""" +import typing +import asyncio +from asyncio import Future + + +class A: + def __init__(self, x: int) -> None: + self.x = x + + +class B: + def __init__(self, x: int) -> None: + self.x = x + + +@asyncio.coroutine +def h() -> 'Future[None]': + x = yield from future # type: B # E: Incompatible types in assignment (expression has type "A", variable has type "B") + print("h: %s" % x.x) + + +loop = asyncio.get_event_loop() +future = asyncio.Future() # type: Future[A] +future.set_result(A(42)) +loop.run_until_complete(h()) +print("Outside %s" % future.result().x) +loop.close() diff --git a/stubs/3.4/asyncio/futures.py b/stubs/3.4/asyncio/futures.py index 8b78b5f63a1b..b818102eb1b9 100644 --- a/stubs/3.4/asyncio/futures.py +++ b/stubs/3.4/asyncio/futures.py @@ -1,8 +1,8 @@ -from typing import Any, Function, typevar, Generic, List +from typing import Any, Function, typevar, List, Generic, Iterable, Iterator from asyncio.events import AbstractEventLoop # __all__ = ['CancelledError', 'TimeoutError', # 'InvalidStateError', -# 'Future', 'wrap_future', +# 'wrap_future', # ] __all__ = ['Future'] @@ -17,13 +17,13 @@ def activate(self) -> None: pass def clear(self) -> None: pass def __del__(self) -> None: pass -class Future(Generic[T]): +class Future(Iterator[T], Generic[T]): # (Iterable[T], Generic[T]) _state = '' _exception = Any #Exception _blocking = False _log_traceback = False _tb_logger = _TracebackLogger - def __init__(self, loop: AbstractEventLoop) -> None: pass + def __init__(self, *, loop: AbstractEventLoop = None) -> None: pass def __repr__(self) -> str: pass def __del__(self) -> None: pass def cancel(self) -> bool: pass @@ -32,9 +32,10 @@ def cancelled(self) -> bool: pass def done(self) -> bool: pass def result(self) -> T: pass def exception(self) -> Any: pass - def add_done_callback(self, fn: Function[[],Any]) -> None: pass - def remove_done_callback(self, fn: Function[[], Any]) -> int: pass + def add_done_callback(self, fn: Function[[Future[T]],Any]) -> None: pass + def remove_done_callback(self, fn: Function[[Future[T]], Any]) -> int: pass def set_result(self, result: T) -> None: pass def set_exception(self, exception: Any) -> None: pass def _copy_state(self, other: Any) -> None: pass - def __iter__(self) -> Any: pass \ No newline at end of file + def __iter__(self) -> 'Iterator[T]': pass + def __next__(self) -> 'T': pass diff --git a/stubs/3.4/asyncio/tasks.py b/stubs/3.4/asyncio/tasks.py index 6920a0229ccb..94db71a814c5 100644 --- a/stubs/3.4/asyncio/tasks.py +++ b/stubs/3.4/asyncio/tasks.py @@ -1,14 +1,40 @@ -from typing import Any, typevar +from typing import Any, Iterable, typevar, Set, Dict, List, TextIO, Union, Tuple, Generic, Function from asyncio.events import AbstractEventLoop -# __all__ = ['coroutine', 'Task', -# 'iscoroutinefunction', 'iscoroutine', -# 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', -# 'wait', 'wait_for', 'as_completed', 'sleep', 'async', +from asyncio.futures import Future +# __all__ = ['iscoroutinefunction', 'iscoroutine', +# 'as_completed', 'async', # 'gather', 'shield', # ] -__all__ = ['coroutine', 'sleep'] +__all__ = ['coroutine', 'Task', 'sleep', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'wait_for'] +FIRST_EXCEPTION = 'FIRST_EXCEPTION' +FIRST_COMPLETED = 'FIRST_COMPLETED' +ALL_COMPLETED = 'ALL_COMPLETED' T = typevar('T') -def coroutine(f: Any) -> Any: pass -def sleep(delay: float, result: T=None, loop: AbstractEventLoop=None) -> T: pass +def coroutine(f: Any) -> Any: pass # Here comes and go a function +def sleep(delay: float, result: T=None, loop: AbstractEventLoop=None) -> Future[T]: pass +def wait(fs: List[Any], *, loop: AbstractEventLoop=None, + timeout: float=None, return_when: str=ALL_COMPLETED) -> Future[Tuple[Set[Future[T]], Set[Future[T]]]]: pass +def wait_for(fut: Future[T], timeout: float, *, loop: AbstractEventLoop=None) -> Future[T]: pass +# def wait(fs: Union[List[Iterable], List[Future[T]]], *, loop: AbstractEventLoop=None, +# timeout: int=None, return_when: str=ALL_COMPLETED) -> Future[Tuple[Set[Future[T]], Set[Future[T]]]]: pass + +class Task(Future[T], Generic[T]): + _all_tasks = None # type: Set[Task] + _current_tasks = {} # type: Dict[AbstractEventLoop, Task] + @classmethod + def current_task(cls, loop: AbstractEventLoop=None) -> Task: pass + @classmethod + def all_tasks(cls, loop: AbstractEventLoop=None) -> Set[Task]: pass + # def __init__(self, coro: Union[Iterable[T], Future[T]], *, loop: AbstractEventLoop=None) -> None: pass + def __init__(self, coro: Future[T], *, loop: AbstractEventLoop=None) -> None: pass + def __repr__(self) -> str: pass + def get_stack(self, *, limit: int=None) -> List[Any]: pass # return List[stackframe] + def print_stack(self, *, limit: int=None, file: TextIO=None) -> None: pass + def cancel(self) -> bool: pass + def _step(self, value: Any=None, exc: Exception=None) -> None: pass + def _wakeup(self, future: Future[Any]) -> None: pass + From 7928d772becf73ca67b5d339e24b18496423b6fe Mon Sep 17 00:00:00 2001 From: Rock Neurotiko Date: Sat, 9 Aug 2014 16:36:32 +0200 Subject: [PATCH 06/12] Add fail if the difference of Futures are 0 --- mypy/checker.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 843f30d14b03..210526ed2540 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1209,7 +1209,7 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type: else: if self.function_stack[-1].is_coroutine: # Something similar will be needed to mix return and yield #If the function is a coroutine, wrap the return type in a Future - typ = self.wrap_generic_type(typ, self.return_types[-1], 'asyncio.futures.Future') + typ = self.wrap_generic_type(typ, self.return_types[-1], 'asyncio.futures.Future', s) self.check_subtype( typ, self.return_types[-1], s, messages.INCOMPATIBLE_RETURN_VALUE_TYPE @@ -1222,10 +1222,14 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type: not self.is_dynamic_function()): self.fail(messages.RETURN_VALUE_EXPECTED, s) - def wrap_generic_type(self, typ: Type, rtyp: Type, check_type: str) -> Type: + def wrap_generic_type(self, typ: Type, rtyp: Type, check_type: str, context: Context) -> Type: n_diff = self.count_concatenated_types(rtyp, check_type) - self.count_concatenated_types(typ, check_type) if n_diff >= 1: return self.named_generic_type(check_type, [typ]) + elif n_diff == 0: + self.fail(messages.INCOMPATIBLE_RETURN_VALUE_TYPE + + ": expected {}, got {}".format(rtyp, typ), context) + return typ return typ def count_concatenated_types(self, typ: Type, check_type: str) -> int: From 6e11a96fc357a56b8c7b8c8f6d8611041f88bb84 Mon Sep 17 00:00:00 2001 From: Rock Neurotiko Date: Mon, 15 Sep 2014 11:15:03 +0200 Subject: [PATCH 07/12] remove files created --- mypy/checker.py.orig | 2123 ----------------------------- mypy/checkexpr.py.orig | 1453 -------------------- mypy/messages.py.orig | 694 ---------- mypy/nodes.py.orig | 1845 ------------------------- mypy/output.py.orig | 629 --------- mypy/parse.py.orig | 1845 ------------------------- mypy/pprinter.py.orig | 342 ----- mypy/semanal.py.orig | 1870 ------------------------- mypy/stats.py.orig | 358 ----- mypy/strconv.py.orig | 444 ------ mypy/transform.py.orig | 449 ------ mypy/traverser.py.orig | 243 ---- mypy/treetransform.py.orig | 499 ------- mypy/visitor.py.orig | 225 --- stubs/3.4/asyncio/futures.py.orig | 54 - stubs/3.4/asyncio/tasks.py.orig | 46 - 16 files changed, 13119 deletions(-) delete mode 100644 mypy/checker.py.orig delete mode 100644 mypy/checkexpr.py.orig delete mode 100644 mypy/messages.py.orig delete mode 100644 mypy/nodes.py.orig delete mode 100644 mypy/output.py.orig delete mode 100755 mypy/parse.py.orig delete mode 100644 mypy/pprinter.py.orig delete mode 100644 mypy/semanal.py.orig delete mode 100644 mypy/stats.py.orig delete mode 100644 mypy/strconv.py.orig delete mode 100644 mypy/transform.py.orig delete mode 100644 mypy/traverser.py.orig delete mode 100644 mypy/treetransform.py.orig delete mode 100644 mypy/visitor.py.orig delete mode 100644 stubs/3.4/asyncio/futures.py.orig delete mode 100644 stubs/3.4/asyncio/tasks.py.orig diff --git a/mypy/checker.py.orig b/mypy/checker.py.orig deleted file mode 100644 index 0aff6208ffca..000000000000 --- a/mypy/checker.py.orig +++ /dev/null @@ -1,2123 +0,0 @@ -"""Mypy type checker.""" - -import itertools - -<<<<<<< HEAD -from typing import Undefined, Any, Dict, Set, List, cast, overload, Tuple, typevar -======= -from typing import Undefined, Any, Dict, Set, List, cast, overload, Tuple, Function, typevar ->>>>>>> master - -from mypy.errors import Errors -from mypy.nodes import ( - SymbolTable, Node, MypyFile, VarDef, LDEF, Var, - OverloadedFuncDef, FuncDef, FuncItem, FuncBase, TypeInfo, - ClassDef, GDEF, Block, AssignmentStmt, NameExpr, MemberExpr, IndexExpr, - TupleExpr, ListExpr, ParenExpr, ExpressionStmt, ReturnStmt, IfStmt, - WhileStmt, OperatorAssignmentStmt, YieldStmt, WithStmt, AssertStmt, - RaiseStmt, TryStmt, ForStmt, DelStmt, CallExpr, IntExpr, StrExpr, - BytesExpr, UnicodeExpr, FloatExpr, OpExpr, UnaryExpr, CastExpr, SuperExpr, - TypeApplication, DictExpr, SliceExpr, FuncExpr, TempNode, SymbolTableNode, - Context, ListComprehension, ConditionalExpr, GeneratorExpr, -<<<<<<< HEAD - Decorator, SetExpr, TypeVarExpr, UndefinedExpr, PrintStmt, - LITERAL_TYPE, BreakStmt, ContinueStmt, YieldFromExpr, YieldFromStmt -======= - Decorator, SetExpr, PassStmt, TypeVarExpr, UndefinedExpr, PrintStmt, - LITERAL_TYPE, BreakStmt, ContinueStmt, ComparisonExpr ->>>>>>> master -) -from mypy.nodes import function_type, method_type -from mypy import nodes -from mypy.types import ( - Type, AnyType, Callable, Void, FunctionLike, Overloaded, TupleType, - Instance, NoneTyp, UnboundType, ErrorType, BasicTypes, - strip_type, UnionType -) -from mypy.sametypes import is_same_type -from mypy.messages import MessageBuilder -import mypy.checkexpr -from mypy import messages -from mypy.subtypes import ( - is_subtype, is_equivalent, map_instance_to_supertype, is_proper_subtype, - is_more_precise, restrict_subtype_away -) -from mypy.semanal import self_type, set_callable_name, refers_to_fullname -from mypy.erasetype import erase_typevars -from mypy.expandtype import expand_type_by_instance, expand_type -from mypy.visitor import NodeVisitor -from mypy.join import join_simple, join_types -from mypy.treetransform import TransformVisitor -from mypy.meet import meet_simple, nearest_builtin_ancestor, is_overlapping_types - - -# Kinds of isinstance checks. -ISINSTANCE_OVERLAPPING = 0 -ISINSTANCE_ALWAYS_TRUE = 1 -ISINSTANCE_ALWAYS_FALSE = 2 - -T = typevar('T') - - -def min_with_None_large(x: T, y: T) -> T: - """Return min(x, y) but with a < None for all variables a that are not None""" - if x is None: - return y - return min(x, x if y is None else y) - - -class Frame(Dict[Any, Type]): - pass - - -class Key(AnyType): - pass - - -class ConditionalTypeBinder: - """Keep track of conditional types of variables.""" - - def __init__(self, basic_types_fn) -> None: - self.frames = List[Frame]() - # The first frame is special: it's the declared types of variables. - self.frames.append(Frame()) - self.dependencies = Dict[Key, Set[Key]]() # Set of other keys to invalidate if a key - # is changed - self._added_dependencies = Set[Key]() # Set of keys with dependencies added already - self.basic_types_fn = basic_types_fn - - self.frames_on_escape = Dict[int, List[Frame]]() - - self.try_frames = Set[int]() - self.loop_frames = List[int]() - - def _add_dependencies(self, key: Key, value: Key = None) -> None: - if value is None: - value = key - if value in self._added_dependencies: - return - self._added_dependencies.add(value) - if isinstance(key, tuple): - key = cast(Any, key) # XXX sad - if key != value: - self.dependencies[key] = Set[Key]() - self.dependencies.setdefault(key, Set[Key]()).add(value) - for elt in cast(Any, key): - self._add_dependencies(elt, value) - - def push_frame(self) -> Frame: - d = Frame() - self.frames.append(d) - return d - - def _push(self, key: Key, type: Type, index: int=-1) -> None: - self._add_dependencies(key) - self.frames[index][key] = type - - def _get(self, key: Key, index: int=-1) -> Type: - if index < 0: - index += len(self.frames) - for i in range(index, -1, -1): - if key in self.frames[i]: - return self.frames[i][key] - return None - - def push(self, expr: Node, type: Type) -> None: - if not expr.literal: - return - key = expr.literal_hash - self.frames[0][key] = self.get_declaration(expr) - self._push(key, type) - - def get(self, expr: Node) -> Type: - return self._get(expr.literal_hash) - - def update_from_options(self, frames: List[Frame]) -> bool: - """Update the frame to reflect that each key will be updated - as in one of the frames. Return whether any item changes.""" - - changed = False - keys = set(key for f in frames for key in f) - - for key in keys: - current_value = self._get(key) - resulting_values = [f.get(key, current_value) for f in frames] - if any(x is None for x in resulting_values): - continue - - type = resulting_values[0] - for other in resulting_values[1:]: - type = join_simple(self.frames[0][key], type, - other, self.basic_types_fn()) - if not is_same_type(type, current_value): - self._push(key, type) - changed = True - - return changed - - def update_expand(self, frame: Frame, index: int = -1) -> bool: - """Update frame to include another one, if that other one is larger than the current value. - - Return whether anything changed.""" - result = False - - for key in frame: - old_type = self._get(key, index) - if old_type is None: - continue - replacement = join_simple(self.frames[0][key], old_type, frame[key], - self.basic_types_fn()) - - if not is_same_type(replacement, old_type): - self._push(key, replacement, index) - result = True - return result - - def pop_frame(self, canskip=True, fallthrough=False) -> Tuple[bool, Frame]: - """Pop a frame. - - If canskip, then allow types to skip all the inner frame - blocks. - - If fallthrough, then allow types to escape from the inner - frame to the resulting frame. - - Return whether the newly innermost frame was modified since it - was last on top, and what it would be if the block had run to - completion. - """ - result = self.frames.pop() - - options = self.frames_on_escape.get(len(self.frames) - 1, []) - if canskip: - options.append(self.frames[-1]) - if fallthrough: - options.append(result) - - changed = self.update_from_options(options) - - return (changed, result) - - def get_declaration(self, expr: Any) -> Type: - if hasattr(expr, 'node') and isinstance(expr.node, Var): - return expr.node.type - else: - return self.frames[0].get(expr.literal_hash) - - def assign_type(self, expr: Node, type: Type) -> None: - if not expr.literal: - return - self.invalidate_dependencies(expr) - - declared_type = self.get_declaration(expr) - - if declared_type is None: - # Not sure why this happens. It seems to mainly happen in - # member initialization. - return - if not is_subtype(type, declared_type): - # Pretty sure this is only happens when there's a type error. - - # Ideally this function wouldn't be called if the - # expression has a type error, though -- do other kinds of - # errors cause this function to get called at invalid - # times? - - return - - # If x is Any and y is int, after x = y we do not infer that x is int. - # This could be changed. - - if isinstance(self.most_recent_enclosing_type(expr, type), AnyType): - pass - elif isinstance(type, AnyType): - self.push(expr, declared_type) - else: - self.push(expr, type) - - for i in self.try_frames: - # XXX This should probably not copy the entire frame, but - # just copy this variable into a single stored frame. - self.allow_jump(i) - - def invalidate_dependencies(self, expr: Node) -> None: - """Invalidate knowledge of types that include expr, but not expr itself. - - For example, when expr is foo.bar, invalidate foo.bar.baz and - foo.bar[0]. - - It is overly conservative: it invalidates globally, including - in code paths unreachable from here. - """ - for dep in self.dependencies.get(expr.literal_hash, Set[Key]()): - for f in self.frames: - if dep in f: - del f[dep] - - def most_recent_enclosing_type(self, expr: Node, type: Type) -> Type: - if isinstance(type, AnyType): - return self.get_declaration(expr) - key = expr.literal_hash - enclosers = ([self.get_declaration(expr)] + - [f[key] for f in self.frames - if key in f and is_subtype(type, f[key])]) - return enclosers[-1] - - def allow_jump(self, index: int) -> None: - new_frame = Frame() - for f in self.frames[index + 1:]: - for k in f: - new_frame[k] = f[k] - - self.frames_on_escape.setdefault(index, []).append(new_frame) - - def push_loop_frame(self): - self.loop_frames.append(len(self.frames) - 1) - - def pop_loop_frame(self): - self.loop_frames.pop() - - -def meet_frames(basic_types: BasicTypes, *frames: Frame) -> Frame: - answer = Frame() - for f in frames: - for key in f: - if key in answer: - answer[key] = meet_simple(answer[key], f[key], basic_types) - else: - answer[key] = f[key] - return answer - - -class TypeChecker(NodeVisitor[Type]): - """Mypy type checker. - - Type check mypy source files that have been semantically analysed. - """ - - # Target Python major version - pyversion = 3 - # Error message reporting - errors = Undefined(Errors) - # SymbolNode table for the whole program - symtable = Undefined(SymbolTable) - # Utility for generating messages - msg = Undefined(MessageBuilder) - # Types of type checked nodes - type_map = Undefined(Dict[Node, Type]) - - # Helper for managing conditional types - binder = Undefined(ConditionalTypeBinder) - # Helper for type checking expressions - expr_checker = Undefined('mypy.checkexpr.ExpressionChecker') - - # Stack of function return types - return_types = Undefined(List[Type]) - # Type context for type inference - type_context = Undefined(List[Type]) - # Flags; true for dynamically typed functions - dynamic_funcs = Undefined(List[bool]) - # Stack of functions being type checked - function_stack = Undefined(List[FuncItem]) - # Set to True on return/break/raise, False on blocks that can block any of them - breaking_out = False - - globals = Undefined(SymbolTable) - locals = Undefined(SymbolTable) - modules = Undefined(Dict[str, MypyFile]) - - def __init__(self, errors: Errors, modules: Dict[str, MypyFile], - pyversion: int = 3) -> None: - """Construct a type checker. - - Use errors to report type check errors. Assume symtable has been - populated by the semantic analyzer. - """ - self.expr_checker - self.errors = errors - self.modules = modules - self.pyversion = pyversion - self.msg = MessageBuilder(errors) - self.type_map = {} - self.binder = ConditionalTypeBinder(self.basic_types) - self.binder.push_frame() - self.expr_checker = mypy.checkexpr.ExpressionChecker(self, self.msg) - self.return_types = [] - self.type_context = [] - self.dynamic_funcs = [] - self.function_stack = [] - - def visit_file(self, file_node: MypyFile, path: str) -> None: - """Type check a mypy file with the given path.""" - self.errors.set_file(path) - self.globals = file_node.names - self.locals = None - - for d in file_node.defs: - self.accept(d) - - def accept(self, node: Node, type_context: Type = None) -> Type: - """Type check a node in the given type context.""" - self.type_context.append(type_context) - typ = node.accept(self) - self.type_context.pop() - self.store_type(node, typ) - if self.is_dynamic_function(): - return AnyType() - else: - return typ - - def accept_in_frame(self, node: Node, type_context: Type = None, - repeat_till_fixed: bool = False) -> Type: - """Type check a node in the given type context in a new frame of inferred types.""" - while True: - self.binder.push_frame() - answer = self.accept(node, type_context) - changed, _ = self.binder.pop_frame(True, True) - self.breaking_out = False - if not repeat_till_fixed or not changed: - break - - return answer - - # - # Definitions - # - - def visit_var_def(self, defn: VarDef) -> Type: - """Type check a variable definition. - - It can be of any kind: local, member or global. - """ - # Type check initializer. - if defn.init: - # There is an initializer. - if defn.items[0].type: - # Explicit types. - if len(defn.items) == 1: - self.check_single_assignment(defn.items[0].type, None, - defn.init, defn.init) - else: - # Multiple assignment. - lvt = List[Type]() - for v in defn.items: - lvt.append(v.type) - self.check_multi_assignment( - lvt, [None] * len(lvt), - defn.init, defn.init) - else: - init_type = self.accept(defn.init) - if defn.kind == LDEF and not defn.is_top_level: - # Infer local variable type if there is an initializer - # except if the definition is at the top level (outside a - # function). - self.infer_local_variable_type(defn.items, init_type, defn) - else: - # No initializer - if (defn.kind == LDEF and not defn.items[0].type and - not defn.is_top_level and not self.is_dynamic_function()): - self.fail(messages.NEED_ANNOTATION_FOR_VAR, defn) - - def infer_local_variable_type(self, x, y, z): - # TODO - raise RuntimeError('Not implemented') - - def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> Type: - num_abstract = 0 - for fdef in defn.items: - self.check_func_item(fdef.func, name=fdef.func.name()) - if fdef.func.is_abstract: - num_abstract += 1 - if num_abstract not in (0, len(defn.items)): - self.fail(messages.INCONSISTENT_ABSTRACT_OVERLOAD, defn) - if defn.info: - self.check_method_override(defn) - self.check_inplace_operator_method(defn) - self.check_overlapping_overloads(defn) - - def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: - for i, item in enumerate(defn.items): - for j, item2 in enumerate(defn.items[i + 1:]): - # TODO overloads involving decorators - sig1 = function_type(item.func) - sig2 = function_type(item2.func) - if is_unsafe_overlapping_signatures(sig1, sig2): - self.msg.overloaded_signatures_overlap(i + 1, j + 2, - item.func) - - def visit_func_def(self, defn: FuncDef) -> Type: - """Type check a function definition.""" - self.check_func_item(defn, name=defn.name()) - if defn.info: - self.check_method_override(defn) - self.check_inplace_operator_method(defn) - if defn.original_def: - if not is_same_type(function_type(defn), - function_type(defn.original_def)): - self.msg.incompatible_conditional_function_def(defn) - - def check_func_item(self, defn: FuncItem, - type_override: Callable = None, - name: str = None) -> Type: - """Type check a function. - - If type_override is provided, use it as the function type. - """ - # We may be checking a function definition or an anonymous function. In - # the first case, set up another reference with the precise type. - fdef = None # type: FuncDef - if isinstance(defn, FuncDef): - fdef = defn - - self.function_stack.append(defn) - self.dynamic_funcs.append(defn.type is None and not type_override) - - if fdef: - self.errors.push_function(fdef.name()) - - typ = function_type(defn) - if type_override: - typ = type_override - if isinstance(typ, Callable): - self.check_func_def(defn, typ, name) - else: - raise RuntimeError('Not supported') - - if fdef: - self.errors.pop_function() - - self.dynamic_funcs.pop() - self.function_stack.pop() - - def check_func_def(self, defn: FuncItem, typ: Callable, name: str) -> None: - """Type check a function definition.""" - # Expand type variables with value restrictions to ordinary types. - for item, typ in self.expand_typevars(defn, typ): - old_binder = self.binder - self.binder = ConditionalTypeBinder(self.basic_types) - self.binder.push_frame() - defn.expanded.append(item) - - # We may be checking a function definition or an anonymous - # function. In the first case, set up another reference with the - # precise type. - if isinstance(item, FuncDef): - fdef = item - else: - fdef = None - - self.enter() - - if fdef: - # Check if __init__ has an invalid, non-None return type. - if (fdef.info and fdef.name() == '__init__' and - not isinstance(typ.ret_type, Void) and - not self.dynamic_funcs[-1]): - self.fail(messages.INIT_MUST_NOT_HAVE_RETURN_TYPE, - item.type) - - if name in nodes.reverse_op_method_set: - self.check_reverse_op_method(item, typ, name) - - # Push return type. - self.return_types.append(typ.ret_type) - - # Store argument types. - nargs = len(item.args) - for i in range(len(typ.arg_types)): - arg_type = typ.arg_types[i] - if typ.arg_kinds[i] == nodes.ARG_STAR: - arg_type = self.named_generic_type('builtins.list', - [arg_type]) - elif typ.arg_kinds[i] == nodes.ARG_STAR2: - arg_type = self.named_generic_type('builtins.dict', - [self.str_type(), - arg_type]) - item.args[i].type = arg_type - - # Type check initialization expressions. - for j in range(len(item.init)): - if item.init[j]: - self.accept(item.init[j]) - - # Clear out the default assignments from the binder - self.binder.pop_frame() - self.binder.push_frame() - # Type check body in a new scope. - self.accept_in_frame(item.body) - - self.return_types.pop() - - self.leave() - self.binder = old_binder - - def check_reverse_op_method(self, defn: FuncItem, typ: Callable, - method: str) -> None: - """Check a reverse operator method such as __radd__.""" - - # If the argument of a reverse operator method such as __radd__ - # does not define the corresponding non-reverse method such as __add__ - # the return type of __radd__ may not reliably represent the value of - # the corresponding operation even in a fully statically typed program. - # - # This example illustrates the issue: - # - # class A: pass - # class B: - # def __radd__(self, x: A) -> int: # Note that A does not define - # return 1 # __add__! - # class C(A): - # def __add__(self, x: Any) -> str: return 'x' - # a = Undefined(A) - # a = C() - # a + B() # Result would be 'x', even though static type seems to - # # be int! - - if method in ('__eq__', '__ne__'): - # These are defined for all objects => can't cause trouble. - return - - # With 'Any' or 'object' return type we are happy, since any possible - # return value is valid. - ret_type = typ.ret_type - if isinstance(ret_type, AnyType): - return - if isinstance(ret_type, Instance): - if ret_type.type.fullname() == 'builtins.object': - return - # Plausibly the method could have too few arguments, which would result - # in an error elsewhere. - if len(typ.arg_types) <= 2: - # TODO check self argument kind - - # Check for the issue described above. - arg_type = typ.arg_types[1] - other_method = nodes.normal_from_reverse_op[method] - fail = False - if isinstance(arg_type, Instance): - if not arg_type.type.has_readable_member(other_method): - fail = True - elif isinstance(arg_type, AnyType): - self.msg.reverse_operator_method_with_any_arg_must_return_any( - method, defn) - return - elif isinstance(arg_type, UnionType): - if not arg_type.has_readable_member(other_method): - fail = True - else: - fail = True - if fail: - self.msg.invalid_reverse_operator_signature( - method, other_method, defn) - return - - typ2 = self.expr_checker.analyse_external_member_access( - other_method, arg_type, defn) - self.check_overlapping_op_methods( - typ, method, defn.info, - typ2, other_method, cast(Instance, arg_type), - defn) - - def check_overlapping_op_methods(self, - reverse_type: Callable, - reverse_name: str, - reverse_class: TypeInfo, - forward_type: Type, - forward_name: str, - forward_base: Instance, - context: Context) -> None: - """Check for overlapping method and reverse method signatures. - - Assume reverse method has valid argument count and kinds. - """ - - # Reverse operator method that overlaps unsafely with the - # forward operator method can result in type unsafety. This is - # similar to overlapping overload variants. - # - # This example illustrates the issue: - # - # class X: pass - # class A: - # def __add__(self, x: X) -> int: - # if isinstance(x, X): - # return 1 - # return NotImplemented - # class B: - # def __radd__(self, x: A) -> str: return 'x' - # class C(X, B): pass - # b = Undefined(B) - # b = C() - # A() + b # Result is 1, even though static type seems to be str! - # - # The reason for the problem is that B and X are overlapping - # types, and the return types are different. Also, if the type - # of x in __radd__ would not be A, the methods could be - # non-overlapping. - - if isinstance(forward_type, Callable): - # TODO check argument kinds - if len(forward_type.arg_types) < 1: - # Not a valid operator method -- can't succeed anyway. - return - - # Construct normalized function signatures corresponding to the - # operator methods. The first argument is the left operand and the - # second operatnd is the right argument -- we switch the order of - # the arguments of the reverse method. - forward_tweaked = Callable([forward_base, - forward_type.arg_types[0]], - [nodes.ARG_POS] * 2, - [None] * 2, - forward_type.ret_type, - is_type_obj=False, - name=forward_type.name) - reverse_args = reverse_type.arg_types - reverse_tweaked = Callable([reverse_args[1], reverse_args[0]], - [nodes.ARG_POS] * 2, - [None] * 2, - reverse_type.ret_type, - is_type_obj=False, - name=reverse_type.name) - - if is_unsafe_overlapping_signatures(forward_tweaked, - reverse_tweaked): - self.msg.operator_method_signatures_overlap( - reverse_class.name(), reverse_name, - forward_base.type.name(), forward_name, context) - elif isinstance(forward_type, Overloaded): - for item in forward_type.items(): - self.check_overlapping_op_methods( - reverse_type, reverse_name, reverse_class, - item, forward_name, forward_base, context) - else: - # TODO what about this? - assert False, 'Forward operator method type is not Callable' - - def check_inplace_operator_method(self, defn: FuncBase) -> None: - """Check an inplace operator method such as __iadd__. - - They cannot arbitrarily overlap with __add__. - """ - method = defn.name() - if method not in nodes.inplace_operator_methods: - return - typ = method_type(defn) - cls = defn.info - other_method = '__' + method[3:] - if cls.has_readable_member(other_method): - instance = self_type(cls) - typ2 = self.expr_checker.analyse_external_member_access( - other_method, instance, defn) - fail = False - if isinstance(typ2, FunctionLike): - if not is_more_general_arg_prefix(typ, typ2): - fail = True - else: - # TODO overloads - fail = True - if fail: - self.msg.signatures_incompatible(method, other_method, defn) - - def expand_typevars(self, defn: FuncItem, - typ: Callable) -> List[Tuple[FuncItem, Callable]]: - # TODO use generator - subst = List[List[Tuple[int, Type]]]() - tvars = typ.variables or [] - tvars = tvars[:] - if defn.info: - # Class type variables - tvars += defn.info.defn.type_vars or [] - for tvar in tvars: - if tvar.values: - subst.append([(tvar.id, value) - for value in tvar.values]) - if subst: - result = List[Tuple[FuncItem, Callable]]() - for substitutions in itertools.product(*subst): - mapping = dict(substitutions) - expanded = cast(Callable, expand_type(typ, mapping)) - result.append((expand_func(defn, mapping), expanded)) - return result - else: - return [(defn, typ)] - - def check_method_override(self, defn: FuncBase) -> None: - """Check if function definition is compatible with base classes.""" - # Check against definitions in base classes. - for base in defn.info.mro[1:]: - self.check_method_or_accessor_override_for_base(defn, base) - - def check_method_or_accessor_override_for_base(self, defn: FuncBase, - base: TypeInfo) -> None: - """Check if method definition is compatible with a base class.""" - if base: - name = defn.name() - if name != '__init__': - # Check method override (__init__ is special). - self.check_method_override_for_base_with_name(defn, name, base) - if name in nodes.inplace_operator_methods: - # Figure out the name of the corresponding operator method. - method = '__' + name[3:] - # An inplace overator method such as __iadd__ might not be - # always introduced safely if a base class defined __add__. - # TODO can't come up with an example where this is - # necessary; now it's "just in case" - self.check_method_override_for_base_with_name(defn, method, - base) - - def check_method_override_for_base_with_name( - self, defn: FuncBase, name: str, base: TypeInfo) -> None: - base_attr = base.names.get(name) - if base_attr: - # The name of the method is defined in the base class. - - # Construct the type of the overriding method. - typ = method_type(defn) - # Map the overridden method type to subtype context so that - # it can be checked for compatibility. - original_type = base_attr.type - if original_type is None and isinstance(base_attr.node, - FuncDef): - original_type = function_type(cast(FuncDef, - base_attr.node)) - if isinstance(original_type, FunctionLike): - original = map_type_from_supertype( - method_type(original_type), - defn.info, base) - # Check that the types are compatible. - # TODO overloaded signatures - self.check_override(cast(FunctionLike, typ), - cast(FunctionLike, original), - defn.name(), - name, - base.name(), - defn) - else: - assert original_type is not None - self.msg.signature_incompatible_with_supertype( - defn.name(), name, base.name(), defn) - - def check_override(self, override: FunctionLike, original: FunctionLike, - name: str, name_in_super: str, supertype: str, - node: Context) -> None: - """Check a method override with given signatures. - - Arguments: - override: The signature of the overriding method. - original: The signature of the original supertype method. - name: The name of the subtype. This and the next argument are - only used for generating error messages. - supertype: The name of the supertype. - """ - if (isinstance(override, Overloaded) or - isinstance(original, Overloaded) or - len(cast(Callable, override).arg_types) != - len(cast(Callable, original).arg_types) or - cast(Callable, override).min_args != - cast(Callable, original).min_args): - # Use boolean variable to clarify code. - fail = False - if not is_subtype(override, original): - fail = True - elif (not isinstance(original, Overloaded) and - isinstance(override, Overloaded) and - name in nodes.reverse_op_methods.keys()): - # Operator method overrides cannot introduce overloading, as - # this could be unsafe with reverse operator methods. - fail = True - if fail: - self.msg.signature_incompatible_with_supertype( - name, name_in_super, supertype, node) - return - else: - # Give more detailed messages for the common case of both - # signatures having the same number of arguments and no - # overloads. - - coverride = cast(Callable, override) - coriginal = cast(Callable, original) - - for i in range(len(coverride.arg_types)): - if not is_subtype(coriginal.arg_types[i], - coverride.arg_types[i]): - self.msg.argument_incompatible_with_supertype( - i + 1, name, name_in_super, supertype, node) - - if not is_subtype(coverride.ret_type, coriginal.ret_type): - self.msg.return_type_incompatible_with_supertype( - name, name_in_super, supertype, node) - - def visit_class_def(self, defn: ClassDef) -> Type: - """Type check a class definition.""" - typ = defn.info - self.errors.push_type(defn.name) - old_binder = self.binder - self.binder = ConditionalTypeBinder(self.basic_types) - self.binder.push_frame() - self.accept(defn.defs) - self.binder = old_binder - self.check_multiple_inheritance(typ) - self.errors.pop_type() - - def check_multiple_inheritance(self, typ: TypeInfo) -> None: - """Check for multiple inheritance related errors.""" - - if len(typ.bases) <= 1: - # No multiple inheritance. - return - # Verify that inherited attributes are compatible. - mro = typ.mro[1:] - for i, base in enumerate(mro): - for name in base.names: - for base2 in mro[i + 1:]: - # We only need to check compatibility of attributes from classes not - # in a subclass relationship. For subclasses, normal (single inheritance) - # checks suffice (these are implemented elsewhere). - if name in base2.names and not base2 in base.mro: - self.check_compatibility(name, base, base2, typ) - # Verify that base class layouts are compatible. - builtin_bases = [nearest_builtin_ancestor(base.type) - for base in typ.bases] - for base1 in builtin_bases: - for base2 in builtin_bases: - if not (base1 in base2.mro or base2 in base1.mro): - self.fail(messages.INSTANCE_LAYOUT_CONFLICT, typ) - # Verify that no disjointclass constraints are violated. - for base in typ.mro: - for disjoint in base.disjointclass_decls: - if disjoint in typ.mro: - self.msg.disjointness_violation(base, disjoint, typ) - - def check_compatibility(self, name: str, base1: TypeInfo, - base2: TypeInfo, ctx: Context) -> None: - if name == '__init__': - # __init__ can be incompatible -- it's a special case. - return - first = base1[name] - second = base2[name] - first_type = first.type - second_type = second.type - if (isinstance(first_type, FunctionLike) and - isinstance(second_type, FunctionLike)): - # Method override - first_sig = method_type(cast(FunctionLike, first_type)) - second_sig = method_type(cast(FunctionLike, second_type)) - # TODO Can we relax the equivalency requirement? - ok = is_equivalent(first_sig, second_sig) - else: - ok = is_equivalent(first_type, second_type) - if not ok: - self.msg.base_class_definitions_incompatible(name, base1, base2, - ctx) - - # - # Statements - # - - def visit_block(self, b: Block) -> Type: - if b.is_unreachable: - return None - for s in b.body: - self.accept(s) - if self.breaking_out: - break - - def visit_assignment_stmt(self, s: AssignmentStmt) -> Type: - """Type check an assignment statement. - - Handle all kinds of assignment statements (simple, indexed, multiple). - """ - self.check_assignments(self.expand_lvalues(s.lvalues[-1]), s.rvalue, - s.type) - if len(s.lvalues) > 1: - # Chained assignment (e.g. x = y = ...). - # Make sure that rvalue type will not be reinferred. - rvalue = self.temp_node(self.type_map[s.rvalue], s) - for lv in s.lvalues[:-1]: - self.check_assignments(self.expand_lvalues(lv), rvalue, - s.type) - - def check_assignments(self, lvalues: List[Node], - rvalue: Node, force_rvalue_type: Type=None) -> None: - # Collect lvalue types. Index lvalues require special consideration, - # since we cannot typecheck them until we know the rvalue type. - # For each lvalue, one of lvalue_types[i] or index_lvalues[i] is not - # None. - lvalue_types = [] # type: List[Type] # Each may be None - index_lvalues = [] # type: List[IndexExpr] # Each may be None - inferred = [] # type: List[Var] - is_inferred = False - - for lv in lvalues: - if self.is_definition(lv): - is_inferred = True - if isinstance(lv, NameExpr): - inferred.append(cast(Var, lv.node)) - else: - m = cast(MemberExpr, lv) - self.accept(m.expr) - inferred.append(m.def_var) - lvalue_types.append(None) - index_lvalues.append(None) - elif isinstance(lv, IndexExpr): - lvalue_types.append(None) - index_lvalues.append(lv) - inferred.append(None) - elif isinstance(lv, MemberExpr): - lvalue_types.append( - self.expr_checker.analyse_ordinary_member_access(lv, - True)) - self.store_type(lv, lvalue_types[-1]) - index_lvalues.append(None) - inferred.append(None) - elif isinstance(lv, NameExpr): - lvalue_types.append(self.expr_checker.analyse_ref_expr(lv)) - self.store_type(lv, lvalue_types[-1]) - index_lvalues.append(None) - inferred.append(None) - else: - lvalue_types.append(self.accept(lv)) - index_lvalues.append(None) - inferred.append(None) - - if len(lvalues) == 1: - # Single lvalue. - rvalue_type = self.check_single_assignment(lvalue_types[0], - index_lvalues[0], rvalue, rvalue) - if rvalue_type and not force_rvalue_type: - self.binder.assign_type(lvalues[0], rvalue_type) - else: - rvalue_types = self.check_multi_assignment(lvalue_types, index_lvalues, - rvalue, rvalue) - if rvalue_types and not force_rvalue_type: - for lv, rt in zip(lvalues, rvalue_types): - self.binder.assign_type(lv, rt) - if is_inferred: - self.infer_variable_type(inferred, lvalues, self.accept(rvalue), - rvalue) - - def is_definition(self, s: Node) -> bool: - if isinstance(s, NameExpr): - if s.is_def: - return True - # If the node type is not defined, this must the first assignment - # that we process => this is a definition, even though the semantic - # analyzer did not recognize this as such. This can arise in code - # that uses isinstance checks, if type checking of the primary - # definition is skipped due to an always False type check. - node = s.node - if isinstance(node, Var): - return node.type is None - elif isinstance(s, MemberExpr): - return s.is_def - return False - - def expand_lvalues(self, n: Node) -> List[Node]: - if isinstance(n, TupleExpr): - return self.expr_checker.unwrap_list(n.items) - elif isinstance(n, ListExpr): - return self.expr_checker.unwrap_list(n.items) - elif isinstance(n, ParenExpr): - return self.expand_lvalues(n.expr) - else: - return [n] - - def infer_variable_type(self, names: List[Var], lvalues: List[Node], - init_type: Type, context: Context) -> None: - """Infer the type of initialized variables from initializer type.""" - if isinstance(init_type, Void): - self.check_not_void(init_type, context) - elif not self.is_valid_inferred_type(init_type): - # We cannot use the type of the initialization expression for type - # inference (it's not specific enough). - self.fail(messages.NEED_ANNOTATION_FOR_VAR, context) - else: - # Infer type of the target. - - # Make the type more general (strip away function names etc.). - init_type = strip_type(init_type) - - if len(names) > 1: - if isinstance(init_type, TupleType): - # Initializer with a tuple type. - if len(init_type.items) == len(names): - for i in range(len(names)): - self.set_inferred_type(names[i], lvalues[i], - init_type.items[i]) - else: - self.msg.incompatible_value_count_in_assignment( - len(names), len(init_type.items), context) - elif (isinstance(init_type, Instance) and - is_subtype(init_type, - self.named_generic_type('typing.Iterable', - [AnyType()]))): - # Initializer with an iterable type. - item_type = self.iterable_item_type(cast(Instance, - init_type)) - for i in range(len(names)): - self.set_inferred_type(names[i], lvalues[i], item_type) - elif isinstance(init_type, AnyType): - for i in range(len(names)): - self.set_inferred_type(names[i], lvalues[i], AnyType()) - else: - self.fail(messages.INCOMPATIBLE_TYPES_IN_ASSIGNMENT, - context) - else: - for v in names: - self.set_inferred_type(v, lvalues[0], init_type) - - def set_inferred_type(self, var: Var, lvalue: Node, type: Type) -> None: - """Store inferred variable type. - - Store the type to both the variable node and the expression node that - refers to the variable (lvalue). If var is None, do nothing. - """ - if var: - var.type = type - self.store_type(lvalue, type) - - def is_valid_inferred_type(self, typ: Type) -> bool: - """Is an inferred type invalid? - - Examples include the None type or a type with a None component. - """ - if is_same_type(typ, NoneTyp()): - return False - elif isinstance(typ, Instance): - for arg in typ.args: - if not self.is_valid_inferred_type(arg): - return False - elif isinstance(typ, TupleType): - for item in typ.items: - if not self.is_valid_inferred_type(item): - return False - return True - - def narrow_type_from_binder(self, expr: Node, known_type: Type) -> Type: - if expr.literal >= LITERAL_TYPE: - restriction = self.binder.get(expr) - if restriction: - ans = meet_simple(known_type, restriction, self.basic_types()) - return ans - return known_type - - def check_multi_assignment(self, lvalue_types: List[Type], - index_lvalues: List[IndexExpr], - rvalue: Node, - context: Context, - msg: str = None) -> List[Type]: - if not msg: - msg = messages.INCOMPATIBLE_TYPES_IN_ASSIGNMENT - # First handle case where rvalue is of form Undefined, ... - rvalue_type = get_undefined_tuple(rvalue) - undefined_rvalue = True - if not rvalue_type: - # Infer the type of an ordinary rvalue expression. - rvalue_type = self.accept(rvalue) # TODO maybe elsewhere; redundant - undefined_rvalue = False - # Try to expand rvalue to lvalue(s). - rvalue_types = None # type: List[Type] - if isinstance(rvalue_type, AnyType): - pass - elif isinstance(rvalue_type, TupleType): - # Rvalue with tuple type. - items = [] # type: List[Type] - for i in range(len(lvalue_types)): - if lvalue_types[i]: - items.append(lvalue_types[i]) - elif i < len(rvalue_type.items): - # TODO Figure out more precise type context, probably - # based on the type signature of the _set method. - items.append(rvalue_type.items[i]) - if not undefined_rvalue: - # Infer rvalue again, now in the correct type context. - rvalue_type = cast(TupleType, self.accept(rvalue, - TupleType(items))) - if len(rvalue_type.items) != len(lvalue_types): - self.msg.incompatible_value_count_in_assignment( - len(lvalue_types), len(rvalue_type.items), context) - else: - # The number of values is compatible. Check their types. - for j in range(len(lvalue_types)): - self.check_single_assignment( - lvalue_types[j], index_lvalues[j], - self.temp_node(rvalue_type.items[j]), context, msg) - rvalue_types = rvalue_type.items - elif (is_subtype(rvalue_type, - self.named_generic_type('typing.Iterable', - [AnyType()])) and - isinstance(rvalue_type, Instance)): - # Rvalue is iterable. - rvalue_types = [] - item_type = self.iterable_item_type(cast(Instance, rvalue_type)) - for k in range(len(lvalue_types)): - type = self.check_single_assignment(lvalue_types[k], - index_lvalues[k], - self.temp_node(item_type), - context, msg) - rvalue_types.append(type) - else: - self.fail(msg, context) - return rvalue_types - - def check_single_assignment(self, - lvalue_type: Type, index_lvalue: IndexExpr, - rvalue: Node, context: Context, - msg: str = messages.INCOMPATIBLE_TYPES_IN_ASSIGNMENT) -> Type: - """Type check an assignment. - - If lvalue_type is None, the index_lvalue argument must be the - index expr for indexed assignment (__setitem__). - Otherwise, lvalue_type is used as the type of the lvalue. - """ - if lvalue_type: - if refers_to_fullname(rvalue, 'typing.Undefined'): - # The rvalue is just 'Undefined'; this is always valid. - # Infer the type of 'Undefined' from the lvalue type. - self.store_type(rvalue, lvalue_type) - return None - rvalue_type = self.accept(rvalue, lvalue_type) - self.check_subtype(rvalue_type, lvalue_type, context, msg, - 'expression has type', 'variable has type') - return rvalue_type - elif index_lvalue: - self.check_indexed_assignment(index_lvalue, rvalue, context) - - def check_indexed_assignment(self, lvalue: IndexExpr, - rvalue: Node, context: Context) -> None: - """Type check indexed assignment base[index] = rvalue. - - The lvalue argument is the base[index] expression. - """ - basetype = self.accept(lvalue.base) - method_type = self.expr_checker.analyse_external_member_access( - '__setitem__', basetype, context) - lvalue.method_type = method_type - self.expr_checker.check_call(method_type, [lvalue.index, rvalue], - [nodes.ARG_POS, nodes.ARG_POS], - context) - - def visit_expression_stmt(self, s: ExpressionStmt) -> Type: - self.accept(s.expr) - - def visit_return_stmt(self, s: ReturnStmt) -> Type: - """Type check a return statement.""" - self.breaking_out = True - if self.is_within_function(): - if s.expr: - # Return with a value. - typ = self.accept(s.expr, self.return_types[-1]) - # Returning a value of type Any is always fine. - if not isinstance(typ, AnyType): - if isinstance(self.return_types[-1], Void): - # FuncExpr (lambda) may have a Void return. - # Function returning a value of type None may have a Void return. - if (not isinstance(self.function_stack[-1], FuncExpr) and - not isinstance(typ, NoneTyp)): - self.fail(messages.NO_RETURN_VALUE_EXPECTED, s) - else: - if self.function_stack[-1].is_coroutine: # Something similar will be needed to mix return and yield - #If the function is a coroutine, wrap the return type in a Future - typ = self.wrap_generic_type(typ, self.return_types[-1], 'asyncio.futures.Future', s) - self.check_subtype( - typ, self.return_types[-1], s, - messages.INCOMPATIBLE_RETURN_VALUE_TYPE - + ": expected {}, got {}".format(self.return_types[-1], typ) - ) - else: - # Return without a value. It's valid in a generator and coroutine function. - if not self.function_stack[-1].is_generator and not self.function_stack[-1].is_coroutine: - if (not isinstance(self.return_types[-1], Void) and -<<<<<<< HEAD - not self.is_dynamic_function()): - self.fail(messages.RETURN_VALUE_EXPECTED, s) - - def wrap_generic_type(self, typ: Type, rtyp: Type, check_type: str, context: Context) -> Type: - n_diff = self.count_concatenated_types(rtyp, check_type) - self.count_concatenated_types(typ, check_type) - if n_diff >= 1: - return self.named_generic_type(check_type, [typ]) - elif n_diff == 0: - self.fail(messages.INCOMPATIBLE_RETURN_VALUE_TYPE - + ": expected {}, got {}".format(rtyp, typ), context) - return typ - return typ - - def count_concatenated_types(self, typ: Type, check_type: str) -> int: - c = 0 - while is_subtype(typ, self.named_type(check_type)): - c += 1 - if hasattr(typ, 'args') and typ.args: - typ = typ.args[0] - else: - return c - return c -======= - not self.is_dynamic_function()): - self.fail(messages.RETURN_VALUE_EXPECTED, s) ->>>>>>> master - - def visit_yield_stmt(self, s: YieldStmt) -> Type: - return_type = self.return_types[-1] - if isinstance(return_type, Instance): - if return_type.type.fullname() != 'typing.Iterator': - self.fail(messages.INVALID_RETURN_TYPE_FOR_YIELD, s) - return None - expected_item_type = return_type.args[0] - elif isinstance(return_type, AnyType): - expected_item_type = AnyType() - else: - self.fail(messages.INVALID_RETURN_TYPE_FOR_YIELD, s) - return None - if s.expr is None: - actual_item_type = Void() # type: Type - else: - actual_item_type = self.accept(s.expr, expected_item_type) - self.check_subtype(actual_item_type, expected_item_type, s, - messages.INCOMPATIBLE_TYPES_IN_YIELD, - 'actual type', 'expected type') - -<<<<<<< HEAD - def visit_yield_from_stmt(self, s: YieldFromStmt) -> Type: - return_type = self.return_types[-1] - type_func = self.accept(s.expr, return_type) - if isinstance(type_func, Instance): - if hasattr(type_func, 'type') and hasattr(type_func.type, 'fullname') and type_func.type.fullname() == 'asyncio.futures.Future': - # if is a Future, in stmt don't need to do nothing - # because the type Future[Some] jus matters to the main loop - # that python executes, in statement we shouldn't get the Future, - # is just for async purposes. - self.function_stack[-1].is_coroutine = True # Set the function as coroutine - elif is_subtype(type_func, self.named_type('typing.Iterable')): - # If it's and Iterable-Like, let's check the types. - # Maybe just check if have __iter__? (like in analyse_iterable) - self.check_iterable_yf(s) - else: - self.msg.yield_from_not_valid_applied(type_func, s) - elif isinstance(type_func, AnyType): - self.check_iterable_yf(s) - else: - self.msg.yield_from_not_valid_applied(type_func, s) - - def check_iterable_yf(self, s: YieldFromStmt) -> Type: - """ - Check that return type is super type of Iterable (Maybe just check if have __iter__?) - and compare it with the type of the expression - """ - expected_item_type = self.return_types[-1] - if isinstance(expected_item_type, Instance): - if not is_subtype(expected_item_type, self.named_type('typing.Iterable')): - self.fail(messages.INVALID_RETURN_TYPE_FOR_YIELD_FROM, s) - return None - elif hasattr(expected_item_type, 'args') and expected_item_type.args: - expected_item_type = expected_item_type.args[0] # Take the item inside the iterator - # expected_item_type = expected_item_type - elif isinstance(expected_item_type, AnyType): - expected_item_type = AnyType() - else: - self.fail(messages.INVALID_RETURN_TYPE_FOR_YIELD_FROM, s) - return None - if s.expr is None: - actual_item_type = Void() - else: - actual_item_type = self.accept(s.expr, expected_item_type) - if hasattr(actual_item_type, 'args') and actual_item_type.args: - actual_item_type = actual_item_type.args[0] # Take the item inside the iterator - self.check_subtype(actual_item_type, expected_item_type, s, - messages.INCOMPATIBLE_TYPES_IN_YIELD_FROM, - 'actual type', 'expected type') - -======= ->>>>>>> master - def visit_if_stmt(self, s: IfStmt) -> Type: - """Type check an if statement.""" - broken = True - ending_frames = List[Frame]() - clauses_frame = self.binder.push_frame() - for e, b in zip(s.expr, s.body): - t = self.accept(e) - self.check_not_void(t, e) - var, type, elsetype, kind = find_isinstance_check(e, self.type_map) - if kind == ISINSTANCE_ALWAYS_FALSE: - # XXX should issue a warning? - pass - else: - # Only type check body if the if condition can be true. - self.binder.push_frame() - if var: - self.binder.push(var, type) - self.accept(b) - _, frame = self.binder.pop_frame() - self.binder.allow_jump(len(self.binder.frames) - 1) - if not self.breaking_out: - broken = False - ending_frames.append(meet_frames(self.basic_types(), clauses_frame, frame)) - - self.breaking_out = False - - if var: - self.binder.push(var, elsetype) - if kind == ISINSTANCE_ALWAYS_TRUE: - # The condition is always true => remaining elif/else blocks - # can never be reached. - - # Might also want to issue a warning - # print("Warning: isinstance always true") - if broken: - self.binder.pop_frame() - self.breaking_out = True - return None - break - else: - if s.else_body: - self.accept(s.else_body) - - if self.breaking_out and broken: - self.binder.pop_frame() - return None - - if not self.breaking_out: - ending_frames.append(clauses_frame) - - self.breaking_out = False - else: - ending_frames.append(clauses_frame) - - self.binder.pop_frame() - self.binder.update_from_options(ending_frames) - - def visit_while_stmt(self, s: WhileStmt) -> Type: - """Type check a while statement.""" - self.binder.push_frame() - self.binder.push_loop_frame() - self.accept_in_frame(IfStmt([s.expr], [s.body], None), - repeat_till_fixed=True) - self.binder.pop_loop_frame() - if s.else_body: - self.accept(s.else_body) - self.binder.pop_frame(False, True) - - def visit_operator_assignment_stmt(self, - s: OperatorAssignmentStmt) -> Type: - """Type check an operator assignment statement, e.g. x += 1.""" - lvalue_type = self.accept(s.lvalue) - method = infer_operator_assignment_method(lvalue_type, s.op) - rvalue_type, method_type = self.expr_checker.check_op( - method, lvalue_type, s.rvalue, s) - - if isinstance(s.lvalue, IndexExpr): - lv = cast(IndexExpr, s.lvalue) - self.check_single_assignment(None, lv, s.rvalue, s.rvalue) - else: - if not is_subtype(rvalue_type, lvalue_type): - self.msg.incompatible_operator_assignment(s.op, s) - - def visit_assert_stmt(self, s: AssertStmt) -> Type: - self.accept(s.expr) - - def visit_raise_stmt(self, s: RaiseStmt) -> Type: - """Type check a raise statement.""" - self.breaking_out = True - if s.expr: - typ = self.accept(s.expr) - if isinstance(typ, FunctionLike): - if typ.is_type_obj(): - # Cases like "raise ExceptionClass". - typeinfo = typ.type_object() - base = self.lookup_typeinfo('builtins.BaseException') - if base in typeinfo.mro: - # Good! - return None - # Else fall back to the check below (which will fail). - self.check_subtype(typ, - self.named_type('builtins.BaseException'), s, - messages.INVALID_EXCEPTION) - - def visit_try_stmt(self, s: TryStmt) -> Type: - """Type check a try statement.""" - completed_frames = List[Frame]() - self.binder.push_frame() - self.binder.try_frames.add(len(self.binder.frames) - 2) - self.accept(s.body) - self.binder.try_frames.remove(len(self.binder.frames) - 2) - if s.else_body: - self.accept(s.else_body) - changed, frame_on_completion = self.binder.pop_frame() - completed_frames.append(frame_on_completion) - - for i in range(len(s.handlers)): - if s.types[i]: - t = self.exception_type(s.types[i]) - if s.vars[i]: - self.check_assignments([s.vars[i]], - self.temp_node(t, s.vars[i])) - self.binder.push_frame() - self.accept(s.handlers[i]) - changed, frame_on_completion = self.binder.pop_frame() - completed_frames.append(frame_on_completion) - if s.else_body: - self.binder.push_frame() - self.accept(s.else_body) - changed, frame_on_completion = self.binder.pop_frame() - completed_frames.append(frame_on_completion) - - self.binder.update_from_options(completed_frames) - - if s.finally_body: - self.accept(s.finally_body) - - def exception_type(self, n: Node) -> Type: - if isinstance(n, ParenExpr): - # Multiple exception types (...). - unwrapped = self.expr_checker.unwrap(n) - if isinstance(unwrapped, TupleExpr): - t = None # type: Type - for item in unwrapped.items: - tt = self.exception_type(item) - if t: - t = join_types(t, tt, self.basic_types()) - else: - t = tt - return t - else: - # A single exception type; should evaluate to a type object type. - type = self.accept(n) - return self.check_exception_type(type, n) - self.fail('Unsupported exception', n) - return AnyType() - - @overload - def check_exception_type(self, type: FunctionLike, - context: Context) -> Type: - item = type.items()[0] - ret = item.ret_type - if (is_subtype(ret, self.named_type('builtins.BaseException')) - and item.is_type_obj()): - return ret - else: - self.fail(messages.INVALID_EXCEPTION_TYPE, context) - return AnyType() - - @overload - def check_exception_type(self, type: AnyType, context: Context) -> Type: - return AnyType() - - @overload - def check_exception_type(self, type: Type, context: Context) -> Type: - self.fail(messages.INVALID_EXCEPTION_TYPE, context) - return AnyType() - - def visit_for_stmt(self, s: ForStmt) -> Type: - """Type check a for statement.""" - item_type = self.analyse_iterable_item_type(s.expr) - self.analyse_index_variables(s.index, s.is_annotated(), item_type, s) - self.binder.push_frame() - self.binder.push_loop_frame() - self.accept_in_frame(s.body, repeat_till_fixed=True) - self.binder.pop_loop_frame() - if s.else_body: - self.accept(s.else_body) - self.binder.pop_frame(False, True) - - def analyse_iterable_item_type(self, expr: Node) -> Type: - """Analyse iterable expression and return iterator item type.""" - iterable = self.accept(expr) - - self.check_not_void(iterable, expr) - if isinstance(iterable, TupleType): - joined = NoneTyp() # type: Type - for item in iterable.items: - joined = join_types(joined, item, self.basic_types()) - if isinstance(joined, ErrorType): - self.fail(messages.CANNOT_INFER_ITEM_TYPE, expr) - return AnyType() - return joined - else: - # Non-tuple iterable. - self.check_subtype(iterable, - self.named_generic_type('typing.Iterable', - [AnyType()]), - expr, messages.ITERABLE_EXPECTED) - - echk = self.expr_checker - method = echk.analyse_external_member_access('__iter__', iterable, - expr) - iterator = echk.check_call(method, [], [], expr)[0] - if self.pyversion >= 3: - nextmethod = '__next__' - else: - nextmethod = 'next' - method = echk.analyse_external_member_access(nextmethod, iterator, - expr) - return echk.check_call(method, [], [], expr)[0] - - def analyse_index_variables(self, index: List[NameExpr], - is_annotated: bool, - item_type: Type, context: Context) -> None: - """Type check or infer for loop or list comprehension index vars.""" - if not is_annotated: - # Create a temporary copy of variables with Node item type. - # TODO this is ugly - node_index = [] # type: List[Node] - for i in index: - node_index.append(i) - self.check_assignments(node_index, - self.temp_node(item_type, context)) - elif len(index) == 1: - v = cast(Var, index[0].node) - if v.type: - self.check_single_assignment(v.type, None, - self.temp_node(item_type), context, - messages.INCOMPATIBLE_TYPES_IN_FOR) - else: - t = [] # type: List[Type] - for ii in index: - v = cast(Var, ii.node) - if v.type: - t.append(v.type) - else: - t.append(AnyType()) - self.check_multi_assignment(t, [None] * len(index), - self.temp_node(item_type), context, - messages.INCOMPATIBLE_TYPES_IN_FOR) - - def visit_del_stmt(self, s: DelStmt) -> Type: - if isinstance(s.expr, IndexExpr): - e = cast(IndexExpr, s.expr) # Cast - m = MemberExpr(e.base, '__delitem__') - m.line = s.line - c = CallExpr(m, [e.index], [nodes.ARG_POS], [None]) - c.line = s.line - return c.accept(self) - else: - s.expr.accept(self) - return None - - def visit_decorator(self, e: Decorator) -> Type: - e.func.accept(self) - sig = function_type(e.func) # type: Type - # Process decorators from the inside out. - for i in range(len(e.decorators)): - n = len(e.decorators) - 1 - i - dec = self.accept(e.decorators[n]) - temp = self.temp_node(sig) - sig, t2 = self.expr_checker.check_call(dec, [temp], - [nodes.ARG_POS], e) - sig = set_callable_name(sig, e.func) - e.var.type = sig - e.var.is_ready = True - - def visit_with_stmt(self, s: WithStmt) -> Type: - echk = self.expr_checker - for expr, name in zip(s.expr, s.name): - ctx = self.accept(expr) - enter = echk.analyse_external_member_access('__enter__', ctx, expr) - obj = echk.check_call(enter, [], [], expr)[0] - if name: - self.check_assignments([name], self.temp_node(obj, expr)) - exit = echk.analyse_external_member_access('__exit__', ctx, expr) - arg = self.temp_node(AnyType(), expr) - echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr) - self.accept(s.body) - - def visit_print_stmt(self, s: PrintStmt) -> Type: - for arg in s.args: - self.accept(arg) - - # - # Expressions - # - - def visit_name_expr(self, e: NameExpr) -> Type: - return self.expr_checker.visit_name_expr(e) - - def visit_paren_expr(self, e: ParenExpr) -> Type: - return self.expr_checker.visit_paren_expr(e) - - def visit_call_expr(self, e: CallExpr) -> Type: - result = self.expr_checker.visit_call_expr(e) - self.breaking_out = False - return result - -<<<<<<< HEAD - def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: - result = self.expr_checker.visit_yield_from_expr(e) - if hasattr(result, 'type') and result.type.fullname() == "asyncio.futures.Future": - self.function_stack[-1].is_coroutine = True # Set the function as coroutine - result = result.args[0] # Set the return type as the type inside - elif is_subtype(result, self.named_type('typing.Iterable')): - # TODO - # Check return type Iterator[Some] - # Maybe set result like in the Future - pass - else: - self.msg.yield_from_not_valid_applied(e.expr, e) - self.breaking_out = False - return result - -======= ->>>>>>> master - def visit_member_expr(self, e: MemberExpr) -> Type: - return self.expr_checker.visit_member_expr(e) - - def visit_break_stmt(self, s: BreakStmt) -> Type: - self.breaking_out = True - self.binder.allow_jump(self.binder.loop_frames[-1] - 1) - return None - - def visit_continue_stmt(self, s: ContinueStmt) -> Type: - self.breaking_out = True - self.binder.allow_jump(self.binder.loop_frames[-1]) - return None - - def visit_int_expr(self, e: IntExpr) -> Type: - return self.expr_checker.visit_int_expr(e) - - def visit_str_expr(self, e: StrExpr) -> Type: - return self.expr_checker.visit_str_expr(e) - - def visit_bytes_expr(self, e: BytesExpr) -> Type: - return self.expr_checker.visit_bytes_expr(e) - - def visit_unicode_expr(self, e: UnicodeExpr) -> Type: - return self.expr_checker.visit_unicode_expr(e) - - def visit_float_expr(self, e: FloatExpr) -> Type: - return self.expr_checker.visit_float_expr(e) - - def visit_op_expr(self, e: OpExpr) -> Type: - return self.expr_checker.visit_op_expr(e) - -<<<<<<< HEAD -======= - def visit_comparison_expr(self, e: ComparisonExpr) -> Type: - return self.expr_checker.visit_comparison_expr(e) - ->>>>>>> master - def visit_unary_expr(self, e: UnaryExpr) -> Type: - return self.expr_checker.visit_unary_expr(e) - - def visit_index_expr(self, e: IndexExpr) -> Type: - return self.expr_checker.visit_index_expr(e) - - def visit_cast_expr(self, e: CastExpr) -> Type: - return self.expr_checker.visit_cast_expr(e) - - def visit_super_expr(self, e: SuperExpr) -> Type: - return self.expr_checker.visit_super_expr(e) - - def visit_type_application(self, e: TypeApplication) -> Type: - return self.expr_checker.visit_type_application(e) - - def visit_type_var_expr(self, e: TypeVarExpr) -> Type: - # TODO Perhaps return a special type used for type variables only? - return AnyType() - - def visit_list_expr(self, e: ListExpr) -> Type: - return self.expr_checker.visit_list_expr(e) - - def visit_set_expr(self, e: SetExpr) -> Type: - return self.expr_checker.visit_set_expr(e) - - def visit_tuple_expr(self, e: TupleExpr) -> Type: - return self.expr_checker.visit_tuple_expr(e) - - def visit_dict_expr(self, e: DictExpr) -> Type: - return self.expr_checker.visit_dict_expr(e) - - def visit_slice_expr(self, e: SliceExpr) -> Type: - return self.expr_checker.visit_slice_expr(e) - - def visit_func_expr(self, e: FuncExpr) -> Type: - return self.expr_checker.visit_func_expr(e) - - def visit_list_comprehension(self, e: ListComprehension) -> Type: - return self.expr_checker.visit_list_comprehension(e) - - def visit_generator_expr(self, e: GeneratorExpr) -> Type: - return self.expr_checker.visit_generator_expr(e) - - def visit_undefined_expr(self, e: UndefinedExpr) -> Type: - return self.expr_checker.visit_undefined_expr(e) - - def visit_temp_node(self, e: TempNode) -> Type: - return e.type - - def visit_conditional_expr(self, e: ConditionalExpr) -> Type: - return self.expr_checker.visit_conditional_expr(e) - - # - # Helpers - # - - def check_subtype(self, subtype: Type, supertype: Type, context: Context, - msg: str = messages.INCOMPATIBLE_TYPES, - subtype_label: str = None, - supertype_label: str = None) -> None: - """Generate an error if the subtype is not compatible with - supertype.""" - if not is_subtype(subtype, supertype): - if isinstance(subtype, Void): - self.msg.does_not_return_value(subtype, context) - else: - extra_info = [] # type: List[str] - if subtype_label is not None: - extra_info.append(subtype_label + ' ' + self.msg.format_simple(subtype)) - if supertype_label is not None: - extra_info.append(supertype_label + ' ' + self.msg.format_simple(supertype)) - if extra_info: - msg += ' (' + ', '.join(extra_info) + ')' - self.fail(msg, context) - - def named_type(self, name: str) -> Instance: - """Return an instance type with type given by the name and no - type arguments. For example, named_type('builtins.object') - produces the object type. - """ - # Assume that the name refers to a type. - sym = self.lookup_qualified(name) - return Instance(cast(TypeInfo, sym.node), []) - - def named_type_if_exists(self, name: str) -> Type: - """Return named instance type, or UnboundType if the type was - not defined. - - This is used to simplify test cases by avoiding the need to - define basic types not needed in specific test cases (tuple - etc.). - """ - try: - # Assume that the name refers to a type. - sym = self.lookup_qualified(name) - return Instance(cast(TypeInfo, sym.node), []) - except KeyError: - return UnboundType(name) - - def named_generic_type(self, name: str, args: List[Type]) -> Instance: - """Return an instance with the given name and type arguments. - - Assume that the number of arguments is correct. Assume that - the name refers to a compatible generic type. - """ - return Instance(self.lookup_typeinfo(name), args) - - def lookup_typeinfo(self, fullname: str) -> TypeInfo: - # Assume that the name refers to a class. - sym = self.lookup_qualified(fullname) - return cast(TypeInfo, sym.node) - - def type_type(self) -> Instance: - """Return instance type 'type'.""" - return self.named_type('builtins.type') - - def object_type(self) -> Instance: - """Return instance type 'object'.""" - return self.named_type('builtins.object') - - def bool_type(self) -> Instance: - """Return instance type 'bool'.""" - return self.named_type('builtins.bool') - - def str_type(self) -> Instance: - """Return instance type 'str'.""" - return self.named_type('builtins.str') - - def tuple_type(self) -> Type: - """Return instance type 'tuple'.""" - # We need the tuple for analysing member access. We want to be able to - # do this even if tuple type is not available (useful in test cases), - # so we return an unbound type if there is no tuple type. - return self.named_type_if_exists('builtins.tuple') - - def check_type_equivalency(self, t1: Type, t2: Type, node: Context, - msg: str = messages.INCOMPATIBLE_TYPES) -> None: - """Generate an error if the types are not equivalent. The - dynamic type is equivalent with all types. - """ - if not is_equivalent(t1, t2): - self.fail(msg, node) - - def store_type(self, node: Node, typ: Type) -> None: - """Store the type of a node in the type map.""" - self.type_map[node] = typ - - def is_dynamic_function(self) -> bool: - return len(self.dynamic_funcs) > 0 and self.dynamic_funcs[-1] - - def lookup(self, name: str, kind: int) -> SymbolTableNode: - """Look up a definition from the symbol table with the given name. - TODO remove kind argument - """ - if self.locals is not None and name in self.locals: - return self.locals[name] - elif name in self.globals: - return self.globals[name] - else: - b = self.globals.get('__builtins__', None) - if b: - table = cast(MypyFile, b.node).names - if name in table: - return table[name] - raise KeyError('Failed lookup: {}'.format(name)) - - def lookup_qualified(self, name: str) -> SymbolTableNode: - if '.' not in name: - return self.lookup(name, GDEF) # FIX kind - else: - parts = name.split('.') - n = self.modules[parts[0]] - for i in range(1, len(parts) - 1): - n = cast(MypyFile, ((n.names.get(parts[i], None).node))) - return n.names[parts[-1]] - - def enter(self) -> None: - self.locals = SymbolTable() - - def leave(self) -> None: - self.locals = None - - def basic_types(self) -> BasicTypes: - """Return a BasicTypes instance that contains primitive types that are - needed for certain type operations (joins, for example). - """ - return BasicTypes(self.object_type(), self.named_type('builtins.type'), - self.named_type_if_exists('builtins.tuple'), - self.named_type_if_exists('builtins.function')) - - def is_within_function(self) -> bool: - """Are we currently type checking within a function? - - I.e. not at class body or at the top level. - """ - return self.return_types != [] - - def check_not_void(self, typ: Type, context: Context) -> None: - """Generate an error if the type is Void.""" - if isinstance(typ, Void): - self.msg.does_not_return_value(typ, context) - - def temp_node(self, t: Type, context: Context = None) -> Node: - """Create a temporary node with the given, fixed type.""" - temp = TempNode(t) - if context: - temp.set_line(context.get_line()) - return temp - - def fail(self, msg: str, context: Context) -> None: - """Produce an error message.""" - self.msg.fail(msg, context) - - def iterable_item_type(self, instance: Instance) -> Type: - iterable = map_instance_to_supertype( - instance, - self.lookup_typeinfo('typing.Iterable')) - return iterable.args[0] - - -def map_type_from_supertype(typ: Type, sub_info: TypeInfo, - super_info: TypeInfo) -> Type: - """Map type variables in a type defined in a supertype context to be valid - in the subtype context. Assume that the result is unique; if more than - one type is possible, return one of the alternatives. - - For example, assume - - class D(Generic[S]) ... - class C(D[E[T]], Generic[T]) ... - - Now S in the context of D would be mapped to E[T] in the context of C. - """ - # Create the type of self in subtype, of form t[a1, ...]. - inst_type = self_type(sub_info) - # Map the type of self to supertype. This gets us a description of the - # supertype type variables in terms of subtype variables, i.e. t[t1, ...] - # so that any type variables in tN are to be interpreted in subtype - # context. - inst_type = map_instance_to_supertype(inst_type, super_info) - # Finally expand the type variables in type with those in the previously - # constructed type. Note that both type and inst_type may have type - # variables, but in type they are interpreterd in supertype context while - # in inst_type they are interpreted in subtype context. This works even if - # the names of type variables in supertype and subtype overlap. - return expand_type_by_instance(typ, inst_type) - - -def get_undefined_tuple(rvalue: Node) -> Type: - """Get tuple type corresponding to a tuple of Undefined values. - - The type is Tuple[Any, ...]. If rvalue is not of the right form, return - None. - """ - if isinstance(rvalue, TupleExpr): - for item in rvalue.items: - if not refers_to_fullname(item, 'typing.Undefined'): - break - else: - return TupleType([AnyType()] * len(rvalue.items)) - return None - - -def find_isinstance_check(node: Node, - type_map: Dict[Node, Type]) -> Tuple[Node, Type, Type, int]: - """Check if node is an isinstance(variable, type) check. - - If successful, return tuple (variable, target-type, else-type, - kind); otherwise, return (None, AnyType, AnyType, -1). - - When successful, the kind takes one of these values: - - ISINSTANCE_OVERLAPPING: The type of variable and the target type are - partially overlapping => the test result can be True or False. - ISINSTANCE_ALWAYS_TRUE: The target type at least as general as the - variable type => the test is always True. - ISINSTANCE_ALWAYS_FALSE: The target type and the variable type are not - overlapping => the test is always False. - """ - if isinstance(node, CallExpr): - if refers_to_fullname(node.callee, 'builtins.isinstance'): - expr = node.args[0] - if expr.literal == LITERAL_TYPE: - type = get_isinstance_type(node.args[1], type_map) - if type: - vartype = type_map[expr] - kind = ISINSTANCE_OVERLAPPING - elsetype = vartype - if vartype: - if is_proper_subtype(vartype, type): - kind = ISINSTANCE_ALWAYS_TRUE - elsetype = None - elif not is_overlapping_types(vartype, type): - kind = ISINSTANCE_ALWAYS_FALSE - else: - elsetype = restrict_subtype_away(vartype, type) - return expr, type, elsetype, kind - # Not a supported isinstance check - return None, AnyType(), AnyType(), -1 - - -def get_isinstance_type(node: Node, type_map: Dict[Node, Type]) -> Type: - type = type_map[node] - if isinstance(type, FunctionLike): - if type.is_type_obj(): - # Type variables may be present -- erase them, which is the best - # we can do (outside disallowing them here). - return erase_typevars(type.items()[0].ret_type) - return None - - -def expand_node(defn: Node, map: Dict[int, Type]) -> Node: - visitor = TypeTransformVisitor(map) - return defn.accept(visitor) - - -def expand_func(defn: FuncItem, map: Dict[int, Type]) -> FuncItem: - return cast(FuncItem, expand_node(defn, map)) - - -class TypeTransformVisitor(TransformVisitor): - def __init__(self, map: Dict[int, Type]) -> None: - super().__init__() - self.map = map - - def type(self, type: Type) -> Type: - return expand_type(type, self.map) - - -def is_unsafe_overlapping_signatures(signature: Type, other: Type) -> bool: - """Check if two signatures may be unsafely overlapping. - - Two signatures s and t are overlapping if both can be valid for the same - statically typed values and the return types are incompatible. - - Assume calls are first checked against 'signature', then against 'other'. - Thus if 'signature' is more general than 'other', there is no unsafe - overlapping. - - TODO If argument types vary covariantly, the return type may vary - covariantly as well. - """ - if isinstance(signature, Callable): - if isinstance(other, Callable): - # TODO varargs - # TODO keyword args - # TODO erasure - # TODO allow to vary covariantly - # Check if the argument counts are overlapping. - min_args = max(signature.min_args, other.min_args) - max_args = min(len(signature.arg_types), len(other.arg_types)) - if min_args > max_args: - # Argument counts are not overlapping. - return False - # Signatures are overlapping iff if they are overlapping for the - # smallest common argument count. - for i in range(min_args): - t1 = signature.arg_types[i] - t2 = other.arg_types[i] - if not is_overlapping_types(t1, t2): - return False - # All arguments types for the smallest common argument count are - # overlapping => the signature is overlapping. The overlapping is - # safe if the return types are identical. - if is_same_type(signature.ret_type, other.ret_type): - return False - # If the first signature has more general argument types, the - # latter will never be called - if is_more_general_arg_prefix(signature, other): - return False - return not is_more_precise_signature(signature, other) - return True - - -def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool: - """Does t have wider arguments than s?""" - # TODO should an overload with additional items be allowed to be more - # general than one with fewer items (or just one item)? - # TODO check argument kinds - if isinstance(t, Callable): - if isinstance(s, Callable): - return all(is_proper_subtype(args, argt) - for argt, args in zip(t.arg_types, s.arg_types)) - elif isinstance(t, FunctionLike): - if isinstance(s, FunctionLike): - if len(t.items()) == len(s.items()): - return all(is_same_arg_prefix(items, itemt) - for items, itemt in zip(t.items(), s.items())) - return False - - -def is_same_arg_prefix(t: Callable, s: Callable) -> bool: - # TODO check argument kinds - return all(is_same_type(argt, args) - for argt, args in zip(t.arg_types, s.arg_types)) - - -def is_more_precise_signature(t: Callable, s: Callable) -> bool: - """Is t more precise than s? - - A signature t is more precise than s if all argument types and the return - type of t are more precise than the corresponding types in s. - - Assume that the argument kinds and names are compatible, and that the - argument counts are overlapping. - """ - # TODO generic function types - # Only consider the common prefix of argument types. - for argt, args in zip(t.arg_types, s.arg_types): - if not is_more_precise(argt, args): - return False - return is_more_precise(t.ret_type, s.ret_type) - - -def infer_operator_assignment_method(type: Type, operator: str) -> str: - """Return the method used for operator assignment for given value type. - - For example, if operator is '+', return '__iadd__' or '__add__' depending - on which method is supported by the type. - """ - method = nodes.op_methods[operator] - if isinstance(type, Instance): - if operator in nodes.ops_with_inplace_method: - inplace = '__i' + method[2:] - if type.type.has_readable_member(inplace): - method = inplace - return method diff --git a/mypy/checkexpr.py.orig b/mypy/checkexpr.py.orig deleted file mode 100644 index 117f072a38ac..000000000000 --- a/mypy/checkexpr.py.orig +++ /dev/null @@ -1,1453 +0,0 @@ -"""Expression type checker. This file is conceptually part of TypeChecker.""" - -from typing import Undefined, cast, List, Tuple, Dict, Function - -from mypy.types import ( - Type, AnyType, Callable, Overloaded, NoneTyp, Void, TypeVarDef, - TupleType, Instance, TypeVar, TypeTranslator, ErasedType, FunctionLike, UnionType -) -from mypy.nodes import ( - NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr, - Node, MemberExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr, - OpExpr, UnaryExpr, IndexExpr, CastExpr, TypeApplication, ListExpr, - TupleExpr, DictExpr, FuncExpr, SuperExpr, ParenExpr, SliceExpr, Context, - ListComprehension, GeneratorExpr, SetExpr, MypyFile, Decorator, -<<<<<<< HEAD - UndefinedExpr, ConditionalExpr, TempNode, LITERAL_TYPE, YieldFromExpr -======= - UndefinedExpr, ConditionalExpr, ComparisonExpr, TempNode, LITERAL_TYPE ->>>>>>> master -) -from mypy.errors import Errors -from mypy.nodes import function_type, method_type -from mypy import nodes -import mypy.checker -from mypy import types -from mypy.sametypes import is_same_type -from mypy.replacetvars import replace_func_type_vars, replace_type_vars -from mypy.messages import MessageBuilder -from mypy import messages -from mypy.infer import infer_type_arguments, infer_function_type_arguments -from mypy import join -from mypy.expandtype import expand_type, expand_caller_var_args -from mypy.subtypes import is_subtype -from mypy import erasetype -from mypy.checkmember import analyse_member_access, type_object_type -from mypy.semanal import self_type -from mypy.constraints import get_actual_type - - -class ExpressionChecker: - """Expression type checker. - - This class works closely together with checker.TypeChecker. - """ - - # Some services are provided by a TypeChecker instance. - chk = Undefined('mypy.checker.TypeChecker') - # This is shared with TypeChecker, but stored also here for convenience. - msg = Undefined(MessageBuilder) - - def __init__(self, - chk: 'mypy.checker.TypeChecker', - msg: MessageBuilder) -> None: - """Construct an expression type checker.""" - self.chk = chk - self.msg = msg - - def visit_name_expr(self, e: NameExpr) -> Type: - """Type check a name expression. - - It can be of any kind: local, member or global. - """ - result = self.analyse_ref_expr(e) - return self.chk.narrow_type_from_binder(e, result) - - def analyse_ref_expr(self, e: RefExpr) -> Type: - result = Undefined(Type) - node = e.node - if isinstance(node, Var): - # Variable reference. - result = self.analyse_var_ref(node, e) - elif isinstance(node, FuncDef): - # Reference to a global function. - result = function_type(node) - elif isinstance(node, OverloadedFuncDef): - result = node.type - elif isinstance(node, TypeInfo): - # Reference to a type object. - result = type_object_type(node, self.chk.type_type) - elif isinstance(node, MypyFile): - # Reference to a module object. - result = self.chk.named_type('builtins.module') - elif isinstance(node, Decorator): - result = self.analyse_var_ref(node.var, e) - else: - # Unknown reference; use any type implicitly to avoid - # generating extra type errors. - result = AnyType() - return result - - def analyse_var_ref(self, var: Var, context: Context) -> Type: - if not var.type: - if not var.is_ready: - self.msg.cannot_determine_type(var.name(), context) - # Implicit 'Any' type. - return AnyType() - else: - # Look up local type of variable with type (inferred or explicit). - val = self.chk.binder.get(var) - if val is None: - return var.type - else: - return val - -<<<<<<< HEAD - def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: - return e.expr.accept(self) # move it to checker? - -======= ->>>>>>> master - def visit_call_expr(self, e: CallExpr) -> Type: - """Type check a call expression.""" - if e.analyzed: - # It's really a special form that only looks like a call. - return self.accept(e.analyzed) - self.accept(e.callee) - # Access callee type directly, since accept may return the Any type - # even if the type is known (in a dynamically typed function). This - # way we get a more precise callee in dynamically typed functions. - callee_type = self.chk.type_map[e.callee] - return self.check_call_expr_with_callee_type(callee_type, e) - - def check_call_expr_with_callee_type(self, callee_type: Type, - e: CallExpr) -> Type: - """Type check call expression. - - The given callee type overrides the type of the callee - expression. - """ - return self.check_call(callee_type, e.args, e.arg_kinds, e, - e.arg_names, callable_node=e.callee)[0] - - def check_call(self, callee: Type, args: List[Node], - arg_kinds: List[int], context: Context, - arg_names: List[str] = None, - callable_node: Node = None, - arg_messages: MessageBuilder = None) -> Tuple[Type, Type]: - """Type check a call. - - Also infer type arguments if the callee is a generic function. - - Return (result type, inferred callee type). - - Arguments: - callee: type of the called value - args: actual argument expressions - arg_kinds: contains nodes.ARG_* constant for each argument in args - describing whether the argument is positional, *arg, etc. - arg_names: names of arguments (optional) - callable_node: associate the inferred callable type to this node, - if specified - arg_messages: TODO - """ - arg_messages = arg_messages or self.msg - is_var_arg = nodes.ARG_STAR in arg_kinds - if isinstance(callee, Callable): - if callee.is_type_obj(): - t = callee.type_object() - if callee.is_type_obj() and callee.type_object().is_abstract: - type = callee.type_object() - self.msg.cannot_instantiate_abstract_class( - callee.type_object().name(), type.abstract_attributes, - context) - - formal_to_actual = map_actuals_to_formals( - arg_kinds, arg_names, - callee.arg_kinds, callee.arg_names, - lambda i: self.accept(args[i])) - - if callee.is_generic(): - callee = self.infer_function_type_arguments_using_context( - callee, context) - callee = self.infer_function_type_arguments( - callee, args, arg_kinds, formal_to_actual, context) - - arg_types = self.infer_arg_types_in_context2( - callee, args, arg_kinds, formal_to_actual) - - self.check_argument_count(callee, arg_types, arg_kinds, - arg_names, formal_to_actual, context) - - self.check_argument_types(arg_types, arg_kinds, callee, - formal_to_actual, context, - messages=arg_messages) - if callable_node: - # Store the inferred callable type. - self.chk.store_type(callable_node, callee) - return callee.ret_type, callee - elif isinstance(callee, Overloaded): - # Type check arguments in empty context. They will be checked again - # later in a context derived from the signature; these types are - # only used to pick a signature variant. - self.msg.disable_errors() - arg_types = self.infer_arg_types_in_context(None, args) - self.msg.enable_errors() - - target = self.overload_call_target(arg_types, is_var_arg, - callee, context, - messages=arg_messages) - return self.check_call(target, args, arg_kinds, context, arg_names, - arg_messages=arg_messages) - elif isinstance(callee, AnyType) or self.chk.is_dynamic_function(): - self.infer_arg_types_in_context(None, args) - return AnyType(), AnyType() - elif isinstance(callee, UnionType): - self.msg.disable_type_names += 1 - results = [self.check_call(subtype, args, arg_kinds, context, arg_names, - arg_messages=arg_messages) - for subtype in callee.items] - self.msg.disable_type_names -= 1 - return (UnionType.make_simplified_union([res[0] for res in results]), - callee) - else: - return self.msg.not_callable(callee, context), AnyType() - - def infer_arg_types_in_context(self, callee: Callable, - args: List[Node]) -> List[Type]: - """Infer argument expression types using a callable type as context. - - For example, if callee argument 2 has type List[int], infer the - argument expression with List[int] type context. - """ - # TODO Always called with callee as None, i.e. empty context. -<<<<<<< HEAD - res = [] # type: List[Type] -======= - res = [] # type: List[Type] ->>>>>>> master - - fixed = len(args) - if callee: - fixed = min(fixed, callee.max_fixed_args()) - - arg_type = None # type: Type - ctx = None # type: Type - for i, arg in enumerate(args): - if i < fixed: - if callee and i < len(callee.arg_types): - ctx = callee.arg_types[i] - arg_type = self.accept(arg, ctx) - else: - if callee and callee.is_var_arg: - arg_type = self.accept(arg, callee.arg_types[-1]) - else: - arg_type = self.accept(arg) - if isinstance(arg_type, ErasedType): - res.append(NoneTyp()) - else: - res.append(arg_type) - return res - - def infer_arg_types_in_context2( - self, callee: Callable, args: List[Node], arg_kinds: List[int], - formal_to_actual: List[List[int]]) -> List[Type]: - """Infer argument expression types using a callable type as context. - - For example, if callee argument 2 has type List[int], infer the - argument exprsession with List[int] type context. - - Returns the inferred types of *actual arguments*. - """ - res = [None] * len(args) # type: List[Type] - - for i, actuals in enumerate(formal_to_actual): - for ai in actuals: - if arg_kinds[ai] != nodes.ARG_STAR: - res[ai] = self.accept(args[ai], callee.arg_types[i]) - - # Fill in the rest of the argument types. - for i, t in enumerate(res): - if not t: - res[i] = self.accept(args[i]) - return res - - def infer_function_type_arguments_using_context( - self, callable: Callable, error_context: Context) -> Callable: - """Unify callable return type to type context to infer type vars. - - For example, if the return type is set[t] where 't' is a type variable - of callable, and if the context is set[int], return callable modified - by substituting 't' with 'int'. - """ - ctx = self.chk.type_context[-1] - if not ctx: - return callable - # The return type may have references to function type variables that - # we are inferring right now. We must consider them as indeterminate - # and they are not potential results; thus we replace them with the - # special ErasedType type. On the other hand, class type variables are - # valid results. - erased_ctx = replace_func_type_vars(ctx, ErasedType()) - ret_type = callable.ret_type - if isinstance(ret_type, TypeVar): - if ret_type.values: - # The return type is a type variable with values, but we can't easily restrict - # type inference to conform to the valid values. Give up and just use function - # arguments for type inference. - ret_type = NoneTyp() - args = infer_type_arguments(callable.type_var_ids(), ret_type, - erased_ctx, self.chk.basic_types()) - # Only substite non-None and non-erased types. - new_args = [] # type: List[Type] - for arg in args: - if isinstance(arg, NoneTyp) or has_erased_component(arg): - new_args.append(None) - else: - new_args.append(arg) - return cast(Callable, self.apply_generic_arguments(callable, new_args, - error_context)) - - def infer_function_type_arguments(self, callee_type: Callable, - args: List[Node], - arg_kinds: List[int], - formal_to_actual: List[List[int]], - context: Context) -> Callable: - """Infer the type arguments for a generic callee type. - - Infer based on the types of arguments. - - Return a derived callable type that has the arguments applied (and - stored as implicit type arguments). - """ - if not self.chk.is_dynamic_function(): - # Disable type errors during type inference. There may be errors - # due to partial available context information at this time, but - # these errors can be safely ignored as the arguments will be - # inferred again later. - self.msg.disable_errors() - - arg_types = self.infer_arg_types_in_context2( - callee_type, args, arg_kinds, formal_to_actual) - - self.msg.enable_errors() - - arg_pass_nums = self.get_arg_infer_passes( - callee_type.arg_types, formal_to_actual, len(args)) - - pass1_args = [] # type: List[Type] - for i, arg in enumerate(arg_types): - if arg_pass_nums[i] > 1: - pass1_args.append(None) - else: - pass1_args.append(arg) - - inferred_args = infer_function_type_arguments( - callee_type, pass1_args, arg_kinds, formal_to_actual, - self.chk.basic_types()) # type: List[Type] - - if 2 in arg_pass_nums: - # Second pass of type inference. - (callee_type, - inferred_args) = self.infer_function_type_arguments_pass2( - callee_type, args, arg_kinds, formal_to_actual, - inferred_args, context) - else: - # In dynamically typed functions use implicit 'Any' types for - # type variables. - inferred_args = [AnyType()] * len(callee_type.variables) - return self.apply_inferred_arguments(callee_type, inferred_args, - context) - - def infer_function_type_arguments_pass2( - self, callee_type: Callable, - args: List[Node], - arg_kinds: List[int], - formal_to_actual: List[List[int]], - inferred_args: List[Type], - context: Context) -> Tuple[Callable, List[Type]]: - """Perform second pass of generic function type argument inference. - - The second pass is needed for arguments with types such as func, - where both s and t are type variables, when the actual argument is a - lambda with inferred types. The idea is to infer the type variable t - in the first pass (based on the types of other arguments). This lets - us infer the argument and return type of the lambda expression and - thus also the type variable s in this second pass. - - Return (the callee with type vars applied, inferred actual arg types). - """ - # None or erased types in inferred types mean that there was not enough - # information to infer the argument. Replace them with None values so - # that they are not applied yet below. - for i, arg in enumerate(inferred_args): - if isinstance(arg, NoneTyp) or isinstance(arg, ErasedType): - inferred_args[i] = None - - callee_type = cast(Callable, self.apply_generic_arguments( - callee_type, inferred_args, context)) - arg_types = self.infer_arg_types_in_context2( - callee_type, args, arg_kinds, formal_to_actual) - - inferred_args = infer_function_type_arguments( - callee_type, arg_types, arg_kinds, formal_to_actual, - self.chk.basic_types()) - - return callee_type, inferred_args - - def get_arg_infer_passes(self, arg_types: List[Type], - formal_to_actual: List[List[int]], - num_actuals: int) -> List[int]: - """Return pass numbers for args for two-pass argument type inference. - - For each actual, the pass number is either 1 (first pass) or 2 (second - pass). - - Two-pass argument type inference primarily lets us infer types of - lambdas more effectively. - """ - res = [1] * num_actuals - for i, arg in enumerate(arg_types): - if arg.accept(ArgInferSecondPassQuery()): - for j in formal_to_actual[i]: - res[j] = 2 - return res - - def apply_inferred_arguments(self, callee_type: Callable, - inferred_args: List[Type], - context: Context) -> Callable: - """Apply inferred values of type arguments to a generic function. - - Inferred_args contains the values of function type arguments. - """ - # Report error if some of the variables could not be solved. In that - # case assume that all variables have type Any to avoid extra - # bogus error messages. - for i, inferred_type in enumerate(inferred_args): - if not inferred_type: - # Could not infer a non-trivial type for a type variable. - self.msg.could_not_infer_type_arguments( - callee_type, i + 1, context) - inferred_args = [AnyType()] * len(inferred_args) - # Apply the inferred types to the function type. In this case the - # return type must be Callable, since we give the right number of type - # arguments. - return cast(Callable, self.apply_generic_arguments(callee_type, - inferred_args, context)) - - def check_argument_count(self, callee: Callable, actual_types: List[Type], - actual_kinds: List[int], actual_names: List[str], - formal_to_actual: List[List[int]], - context: Context) -> None: - """Check that the number of arguments to a function are valid. - - Also check that there are no duplicate values for arguments. - """ - formal_kinds = callee.arg_kinds - - # Collect list of all actual arguments matched to formal arguments. - all_actuals = [] # type: List[int] - for actuals in formal_to_actual: - all_actuals.extend(actuals) - - is_error = False # Keep track of errors to avoid duplicate errors. - for i, kind in enumerate(actual_kinds): - if i not in all_actuals and ( - kind != nodes.ARG_STAR or - not is_empty_tuple(actual_types[i])): - # Extra actual: not matched by a formal argument. - if kind != nodes.ARG_NAMED: - self.msg.too_many_arguments(callee, context) - else: - self.msg.unexpected_keyword_argument( - callee, actual_names[i], context) - is_error = True - elif kind == nodes.ARG_STAR and ( - nodes.ARG_STAR not in formal_kinds): - actual_type = actual_types[i] - if isinstance(actual_type, TupleType): - if all_actuals.count(i) < len(actual_type.items): - # Too many tuple items as some did not match. - self.msg.too_many_arguments(callee, context) - # *args can be applied even if the function takes a fixed - # number of positional arguments. This may succeed at runtime. - - for i, kind in enumerate(formal_kinds): - if kind == nodes.ARG_POS and (not formal_to_actual[i] and - not is_error): - # No actual for a mandatory positional formal. - self.msg.too_few_arguments(callee, context) - elif kind in [nodes.ARG_POS, nodes.ARG_OPT, - nodes.ARG_NAMED] and is_duplicate_mapping( - formal_to_actual[i], actual_kinds): - self.msg.duplicate_argument_value(callee, i, context) - elif (kind == nodes.ARG_NAMED and formal_to_actual[i] and - actual_kinds[formal_to_actual[i][0]] != nodes.ARG_NAMED): - # Positional argument when expecting a keyword argument. - self.msg.too_many_positional_arguments(callee, context) - - def check_argument_types(self, arg_types: List[Type], arg_kinds: List[int], - callee: Callable, - formal_to_actual: List[List[int]], - context: Context, - messages: MessageBuilder = None) -> None: - """Check argument types against a callable type. - - Report errors if the argument types are not compatible. - """ - messages = messages or self.msg - # Keep track of consumed tuple *arg items. - tuple_counter = [0] - for i, actuals in enumerate(formal_to_actual): - for actual in actuals: - arg_type = arg_types[actual] - # Check that a *arg is valid as varargs. - if (arg_kinds[actual] == nodes.ARG_STAR and - not self.is_valid_var_arg(arg_type)): - messages.invalid_var_arg(arg_type, context) - if (arg_kinds[actual] == nodes.ARG_STAR2 and - not self.is_valid_keyword_var_arg(arg_type)): - messages.invalid_keyword_var_arg(arg_type, context) - # Get the type of an inidividual actual argument (for *args - # and **args this is the item type, not the collection type). - actual_type = get_actual_type(arg_type, arg_kinds[actual], - tuple_counter) - self.check_arg(actual_type, arg_type, - callee.arg_types[i], - actual + 1, callee, context, messages) - - # There may be some remaining tuple varargs items that haven't - # been checked yet. Handle them. - if (callee.arg_kinds[i] == nodes.ARG_STAR and - arg_kinds[actual] == nodes.ARG_STAR and - isinstance(arg_types[actual], TupleType)): - tuplet = cast(TupleType, arg_types[actual]) - while tuple_counter[0] < len(tuplet.items): - actual_type = get_actual_type(arg_type, - arg_kinds[actual], - tuple_counter) - self.check_arg(actual_type, arg_type, - callee.arg_types[i], - actual + 1, callee, context, messages) - - def check_arg(self, caller_type: Type, original_caller_type: Type, - callee_type: Type, n: int, callee: Callable, - context: Context, messages: MessageBuilder) -> None: - """Check the type of a single argument in a call.""" - if isinstance(caller_type, Void): - messages.does_not_return_value(caller_type, context) - elif not is_subtype(caller_type, callee_type): - messages.incompatible_argument(n, callee, original_caller_type, - context) - - def overload_call_target(self, arg_types: List[Type], is_var_arg: bool, - overload: Overloaded, context: Context, - messages: MessageBuilder = None) -> Type: - """Infer the correct overload item to call with given argument types. - - The return value may be Callable or AnyType (if an unique item - could not be determined). If is_var_arg is True, the caller - uses varargs. - """ - messages = messages or self.msg - # TODO also consider argument names and kinds - # TODO for overlapping signatures we should try to get a more precise - # result than 'Any' - match = [] # type: List[Callable] - for typ in overload.items(): - if self.matches_signature_erased(arg_types, is_var_arg, typ): - if (match and not is_same_type(match[-1].ret_type, - typ.ret_type) and - not mypy.checker.is_more_precise_signature( - match[-1], typ)): - # Ambiguous return type. Either the function overload is - # overlapping (which results in an error elsewhere) or the - # caller has provided some Any argument types; in - # either case can only infer the type to be Any, as it is - # not an error to use Any types in calls. - # - # Overlapping overload items are fine if the items are - # covariant in both argument types and return types with - # respect to type precision. - return AnyType() - else: - match.append(typ) - if not match: - messages.no_variant_matches_arguments(overload, context) - return AnyType() - else: - if len(match) == 1: - return match[0] - else: - # More than one signature matches. Pick the first *non-erased* - # matching signature, or default to the first one if none - # match. - for m in match: - if self.match_signature_types(arg_types, is_var_arg, m): - return m - return match[0] - - def matches_signature_erased(self, arg_types: List[Type], is_var_arg: bool, - callee: Callable) -> bool: - """Determine whether arguments could match the signature at runtime. - - If is_var_arg is True, the caller uses varargs. This is used for - overload resolution. - """ - if not is_valid_argc(len(arg_types), False, callee): - return False - - if is_var_arg: - if not self.is_valid_var_arg(arg_types[-1]): - return False - arg_types, rest = expand_caller_var_args(arg_types, - callee.max_fixed_args()) - - # Fixed function arguments. - func_fixed = callee.max_fixed_args() - for i in range(min(len(arg_types), func_fixed)): - if not is_subtype(self.erase(arg_types[i]), - self.erase( - callee.arg_types[i])): - return False - # Function varargs. - if callee.is_var_arg: - for i in range(func_fixed, len(arg_types)): - if not is_subtype(self.erase(arg_types[i]), - self.erase(callee.arg_types[func_fixed])): - return False - return True - - def match_signature_types(self, arg_types: List[Type], is_var_arg: bool, - callee: Callable) -> bool: - """Determine whether arguments types match the signature. - - If is_var_arg is True, the caller uses varargs. Assume that argument - counts are compatible. - """ - if is_var_arg: - arg_types, rest = expand_caller_var_args(arg_types, - callee.max_fixed_args()) - - # Fixed function arguments. - func_fixed = callee.max_fixed_args() - for i in range(min(len(arg_types), func_fixed)): - if not is_subtype(arg_types[i], callee.arg_types[i]): - return False - # Function varargs. - if callee.is_var_arg: - for i in range(func_fixed, len(arg_types)): - if not is_subtype(arg_types[i], - callee.arg_types[func_fixed]): - return False - return True - - def apply_generic_arguments(self, callable: Callable, types: List[Type], - context: Context) -> Type: - """Apply generic type arguments to a callable type. - - For example, applying [int] to 'def [T] (T) -> T' results in - 'def [-1:int] (int) -> int'. Here '[-1:int]' is an implicit bound type - variable. - - Note that each type can be None; in this case, it will not be applied. - """ - tvars = callable.variables - if len(tvars) != len(types): - self.msg.incompatible_type_application(len(tvars), len(types), - context) - return AnyType() - - # Check that inferred type variable values are compatible with allowed - # values. Also, promote subtype values to allowed values. - types = types[:] - for i, type in enumerate(types): - values = callable.variables[i].values - if values and type: - if isinstance(type, AnyType): - continue - for value in values: - if is_subtype(type, value): - types[i] = value - break - else: - self.msg.incompatible_typevar_value( - callable, i + 1, type, context) - - # Create a map from type variable id to target type. - id_to_type = {} # type: Dict[int, Type] - for i, tv in enumerate(tvars): - if types[i]: - id_to_type[tv.id] = types[i] - - # Apply arguments to argument types. - arg_types = [expand_type(at, id_to_type) for at in callable.arg_types] - - bound_vars = [(tv.id, id_to_type[tv.id]) - for tv in tvars - if tv.id in id_to_type] - - # The callable may retain some type vars if only some were applied. - remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type] - - return Callable(arg_types, - callable.arg_kinds, - callable.arg_names, - expand_type(callable.ret_type, id_to_type), - callable.is_type_obj(), - callable.name, - remaining_tvars, - callable.bound_vars + bound_vars, - callable.line, callable.repr) - - def apply_generic_arguments2(self, overload: Overloaded, types: List[Type], - context: Context) -> Type: - items = [] # type: List[Callable] - for item in overload.items(): - applied = self.apply_generic_arguments(item, types, context) - if isinstance(applied, Callable): - items.append(applied) - else: - # There was an error. - return AnyType() - return Overloaded(items) - - def visit_member_expr(self, e: MemberExpr) -> Type: - """Visit member expression (of form e.id).""" - result = self.analyse_ordinary_member_access(e, False) - return self.chk.narrow_type_from_binder(e, result) - - def analyse_ordinary_member_access(self, e: MemberExpr, - is_lvalue: bool) -> Type: - """Analyse member expression or member lvalue.""" - if e.kind is not None: - # This is a reference to a module attribute. - return self.analyse_ref_expr(e) - else: - # This is a reference to a non-module attribute. - return analyse_member_access(e.name, self.accept(e.expr), e, - is_lvalue, False, - self.chk.basic_types(), self.msg) - - def analyse_external_member_access(self, member: str, base_type: Type, - context: Context) -> Type: - """Analyse member access that is external, i.e. it cannot - refer to private definitions. Return the result type. - """ - # TODO remove; no private definitions in mypy - return analyse_member_access(member, base_type, context, False, False, - self.chk.basic_types(), self.msg) - - def visit_int_expr(self, e: IntExpr) -> Type: - """Type check an integer literal (trivial).""" - return self.named_type('builtins.int') - - def visit_str_expr(self, e: StrExpr) -> Type: - """Type check a string literal (trivial).""" - return self.named_type('builtins.str') - - def visit_bytes_expr(self, e: BytesExpr) -> Type: - """Type check a bytes literal (trivial).""" - return self.named_type('builtins.bytes') - - def visit_unicode_expr(self, e: UnicodeExpr) -> Type: - """Type check a unicode literal (trivial).""" - return self.named_type('builtins.unicode') - - def visit_float_expr(self, e: FloatExpr) -> Type: - """Type check a float literal (trivial).""" - return self.named_type('builtins.float') - - def visit_op_expr(self, e: OpExpr) -> Type: - """Type check a binary operator expression.""" - if e.op == 'and' or e.op == 'or': - return self.check_boolean_op(e, e) - if e.op == '*' and isinstance(e.left, ListExpr): - # Expressions of form [...] * e get special type inference. - return self.check_list_multiply(e) - left_type = self.accept(e.left) - - if e.op in nodes.op_methods: - method = self.get_operator_method(e.op) - result, method_type = self.check_op(method, left_type, e.right, e, - allow_reverse=True) - e.method_type = method_type - return result - else: - raise RuntimeError('Unknown operator {}'.format(e.op)) - - def visit_comparison_expr(self, e: ComparisonExpr) -> Type: - """Type check a comparison expression. - - Comparison expressions are type checked consecutive-pair-wise - That is, 'a < b > c == d' is check as 'a < b and b > c and c == d' - """ - result = None # type: mypy.types.Type - - # Check each consecutive operand pair and their operator - for left, right, operator in zip(e.operands, e.operands[1:], e.operators): - left_type = self.accept(left) - - method_type = None # type: mypy.types.Type - - if operator == 'in' or operator == 'not in': - right_type = self.accept(right) # TODO only evaluate if needed - - local_errors = self.msg.copy() - sub_result, method_type = self.check_op_local('__contains__', right_type, - left, e, local_errors) - if (local_errors.is_errors() and - # is_valid_var_arg is True for any Iterable - self.is_valid_var_arg(right_type)): - itertype = self.chk.analyse_iterable_item_type(right) - method_type = Callable([left_type], [nodes.ARG_POS], [None], - self.chk.bool_type(), False) - sub_result = self.chk.bool_type() - if not is_subtype(left_type, itertype): - self.msg.unsupported_operand_types('in', left_type, right_type, e) - else: - self.msg.add_errors(local_errors) - if operator == 'not in': - sub_result = self.chk.bool_type() - elif operator in nodes.op_methods: - method = self.get_operator_method(operator) - sub_result, method_type = self.check_op(method, left_type, right, e, - allow_reverse=True) - - elif operator == 'is' or operator == 'is not': - sub_result = self.chk.bool_type() - method_type = None - else: - raise RuntimeError('Unknown comparison operator {}'.format(operator)) - - e.method_types.append(method_type) - - # Determine type of boolean-and of result and sub_result - if result == None: - result = sub_result - else: - # TODO: check on void needed? - self.check_not_void(sub_result, e) - result = join.join_types(result, sub_result, self.chk.basic_types()) - - return result - - def get_operator_method(self, op: str) -> str: - if op == '/' and self.chk.pyversion == 2: - # TODO also check for "from __future__ import division" - return '__div__' - else: - return nodes.op_methods[op] - - def check_op_local(self, method: str, base_type: Type, arg: Node, - context: Context, local_errors: MessageBuilder) -> Tuple[Type, Type]: - """Type check a binary operation which maps to a method call. - - Return tuple (result type, inferred operator method type). - """ - method_type = analyse_member_access(method, base_type, context, False, False, - self.chk.basic_types(), local_errors) - return self.check_call(method_type, [arg], [nodes.ARG_POS], - context, arg_messages=local_errors) - - def check_op(self, method: str, base_type: Type, arg: Node, - context: Context, - allow_reverse: bool = False) -> Tuple[Type, Type]: - """Type check a binary operation which maps to a method call. - - Return tuple (result type, inferred operator method type). - """ - # Use a local error storage for errors related to invalid argument - # type (but NOT other errors). This error may need to be suppressed - # for operators which support __rX methods. - local_errors = self.msg.copy() - if not allow_reverse or self.has_member(base_type, method): - result = self.check_op_local(method, base_type, arg, context, - local_errors) - if allow_reverse: - arg_type = self.chk.type_map[arg] - if isinstance(arg_type, AnyType): - # If the right operand has type Any, we can't make any - # conjectures about the type of the result, since the - # operand could have a __r method that returns anything. - result = AnyType(), result[1] - success = not local_errors.is_errors() - else: - result = AnyType(), AnyType() - success = False - if success or not allow_reverse or isinstance(base_type, AnyType): - # We were able to call the normal variant of the operator method, - # or there was some problem not related to argument type - # validity, or the operator has no __rX method. In any case, we - # don't need to consider the __rX method. - self.msg.add_errors(local_errors) - return result - else: - # Calling the operator method was unsuccessful. Try the __rX - # method of the other operand instead. - rmethod = self.get_reverse_op_method(method) - arg_type = self.accept(arg) - if self.has_member(arg_type, rmethod): - method_type = self.analyse_external_member_access( - rmethod, arg_type, context) - temp = TempNode(base_type) - return self.check_call(method_type, [temp], [nodes.ARG_POS], - context) - else: - # No __rX method either. Do deferred type checking to produce - # error message that we may have missed previously. - # TODO Fix type checking an expression more than once. - return self.check_op_local(method, base_type, arg, context, - self.msg) - - def get_reverse_op_method(self, method: str) -> str: - if method == '__div__' and self.chk.pyversion == 2: - return '__rdiv__' - else: - return nodes.reverse_op_methods[method] - - def check_boolean_op(self, e: OpExpr, context: Context) -> Type: - """Type check a boolean operation ('and' or 'or').""" - - # A boolean operation can evaluate to either of the operands. - - # We use the current type context to guide the type inference of of - # the left operand. We also use the left operand type to guide the type - # inference of the right operand so that expressions such as - # '[1] or []' are inferred correctly. - ctx = self.chk.type_context[-1] - left_type = self.accept(e.left, ctx) - right_type = self.accept(e.right, left_type) - - self.check_not_void(left_type, context) - self.check_not_void(right_type, context) - - return join.join_types(left_type, right_type, - self.chk.basic_types()) - - def check_list_multiply(self, e: OpExpr) -> Type: - """Type check an expression of form '[...] * e'. - - Type inference is special-cased for this common construct. - """ - right_type = self.accept(e.right) - if is_subtype(right_type, self.chk.named_type('builtins.int')): - # Special case: [...] * . Use the type context of the - # OpExpr, since the multiplication does not affect the type. - left_type = self.accept(e.left, context=self.chk.type_context[-1]) - else: - left_type = self.accept(e.left) - result, method_type = self.check_op('__mul__', left_type, e.right, e) - e.method_type = method_type - return result - - def visit_unary_expr(self, e: UnaryExpr) -> Type: - """Type check an unary operation ('not', '-', '+' or '~').""" - operand_type = self.accept(e.expr) - op = e.op - if op == 'not': - self.check_not_void(operand_type, e) - result = self.chk.bool_type() # type: Type - elif op == '-': - method_type = self.analyse_external_member_access('__neg__', - operand_type, e) - result, method_type = self.check_call(method_type, [], [], e) - e.method_type = method_type - elif op == '+': - method_type = self.analyse_external_member_access('__pos__', - operand_type, e) - result, method_type = self.check_call(method_type, [], [], e) - e.method_type = method_type - else: - assert op == '~', "unhandled unary operator" - method_type = self.analyse_external_member_access('__invert__', - operand_type, e) - result, method_type = self.check_call(method_type, [], [], e) - e.method_type = method_type - return result - - def visit_index_expr(self, e: IndexExpr) -> Type: - """Type check an index expression (base[index]). - - It may also represent type application. - """ - result = self.visit_index_expr_helper(e) - return self.chk.narrow_type_from_binder(e, result) - - def visit_index_expr_helper(self, e: IndexExpr) -> Type: - if e.analyzed: - # It's actually a type application. - return self.accept(e.analyzed) - left_type = self.accept(e.base) - if isinstance(left_type, TupleType): - # Special case for tuples. They support indexing only by integer - # literals. - index = self.unwrap(e.index) - ok = False - if isinstance(index, IntExpr): - n = index.value - ok = True - elif isinstance(index, UnaryExpr): - if index.op == '-': - operand = index.expr - if isinstance(operand, IntExpr): - n = len(left_type.items) - operand.value - ok = True - if ok: - if n >= 0 and n < len(left_type.items): - return left_type.items[n] - else: - self.chk.fail(messages.TUPLE_INDEX_OUT_OF_RANGE, e) - return AnyType() - else: - self.chk.fail(messages.TUPLE_INDEX_MUST_BE_AN_INT_LITERAL, e) - return AnyType() - else: - result, method_type = self.check_op('__getitem__', left_type, - e.index, e) - e.method_type = method_type - return result - - def visit_cast_expr(self, expr: CastExpr) -> Type: - """Type check a cast expression.""" - source_type = self.accept(expr.expr) - target_type = expr.type - if not self.is_valid_cast(source_type, target_type): - self.msg.invalid_cast(target_type, source_type, expr) - return target_type - - def is_valid_cast(self, source_type: Type, target_type: Type) -> bool: - """Is a cast from source_type to target_type meaningful?""" - return (isinstance(target_type, AnyType) or - (not isinstance(source_type, Void) and - not isinstance(target_type, Void))) - - def visit_type_application(self, tapp: TypeApplication) -> Type: - """Type check a type application (expr[type, ...]).""" - expr_type = self.accept(tapp.expr) - if isinstance(expr_type, Callable): - new_type = self.apply_generic_arguments(expr_type, - tapp.types, tapp) - elif isinstance(expr_type, Overloaded): - overload = expr_type - # Only target items with the right number of generic type args. - items = [c for c in overload.items() - if len(c.variables) == len(tapp.types)] - new_type = self.apply_generic_arguments2(Overloaded(items), - tapp.types, tapp) - else: - self.chk.fail(messages.INVALID_TYPE_APPLICATION_TARGET_TYPE, tapp) - new_type = AnyType() - self.chk.type_map[tapp.expr] = new_type - return new_type - - def visit_list_expr(self, e: ListExpr) -> Type: - """Type check a list expression [...].""" - return self.check_list_or_set_expr(e.items, 'builtins.list', '', - e) - - def visit_set_expr(self, e: SetExpr) -> Type: - return self.check_list_or_set_expr(e.items, 'builtins.set', '', e) - - def check_list_or_set_expr(self, items: List[Node], fullname: str, - tag: str, context: Context) -> Type: - # Translate into type checking a generic function call. - tv = TypeVar('T', -1, []) - constructor = Callable([tv], - [nodes.ARG_STAR], - [None], - self.chk.named_generic_type(fullname, - [tv]), - False, - tag, - [TypeVarDef('T', -1, None)]) - return self.check_call(constructor, - items, - [nodes.ARG_POS] * len(items), context)[0] - - def visit_tuple_expr(self, e: TupleExpr) -> Type: - """Type check a tuple expression.""" - ctx = None # type: TupleType - # Try to determine type context for type inference. - if isinstance(self.chk.type_context[-1], TupleType): - t = cast(TupleType, self.chk.type_context[-1]) - if len(t.items) == len(e.items): - ctx = t - # Infer item types. - items = [] # type: List[Type] - for i in range(len(e.items)): - item = e.items[i] - tt = Undefined # type: Type - if not ctx: - tt = self.accept(item) - else: - tt = self.accept(item, ctx.items[i]) - self.check_not_void(tt, e) - items.append(tt) - return TupleType(items) - - def visit_dict_expr(self, e: DictExpr) -> Type: - # Translate into type checking a generic function call. - tv1 = TypeVar('KT', -1, []) - tv2 = TypeVar('VT', -2, []) - constructor = Undefined(Callable) - # The callable type represents a function like this: - # - # def (*v: Tuple[kt, vt]) -> Dict[kt, vt]: ... - constructor = Callable([TupleType([tv1, tv2])], - [nodes.ARG_STAR], - [None], - self.chk.named_generic_type('builtins.dict', - [tv1, tv2]), - False, - '', - [TypeVarDef('KT', -1, None), - TypeVarDef('VT', -2, None)]) - # Synthesize function arguments. - args = List[Node]() - for key, value in e.items: - args.append(TupleExpr([key, value])) - return self.check_call(constructor, - args, - [nodes.ARG_POS] * len(args), e)[0] - - def visit_func_expr(self, e: FuncExpr) -> Type: - """Type check lambda expression.""" - inferred_type = self.infer_lambda_type_using_context(e) - if not inferred_type: - # No useful type context. - ret_type = e.expr().accept(self.chk) - if not e.args: - # Form 'lambda: e'; just use the inferred return type. - return Callable([], [], [], ret_type, is_type_obj=False) - else: - # TODO: Consider reporting an error. However, this is fine if - # we are just doing the first pass in contextual type - # inference. - return AnyType() - else: - # Type context available. - self.chk.check_func_item(e, type_override=inferred_type) - ret_type = self.chk.type_map[e.expr()] - return replace_callable_return_type(inferred_type, ret_type) - - def infer_lambda_type_using_context(self, e: FuncExpr) -> Callable: - """Try to infer lambda expression type using context. - - Return None if could not infer type. - """ - # TODO also accept 'Any' context - ctx = self.chk.type_context[-1] - if not ctx or not isinstance(ctx, Callable): - return None - - # The context may have function type variables in it. We replace them - # since these are the type variables we are ultimately trying to infer; - # they must be considered as indeterminate. We use ErasedType since it - # does not affect type inference results (it is for purposes like this - # only). - ctx = replace_func_type_vars(ctx, ErasedType()) - - callable_ctx = cast(Callable, ctx) - - if callable_ctx.arg_kinds != e.arg_kinds: - # Incompatible context; cannot use it to infer types. - self.chk.fail(messages.CANNOT_INFER_LAMBDA_TYPE, e) - return None - - return callable_ctx - - def visit_super_expr(self, e: SuperExpr) -> Type: - """Type check a super expression (non-lvalue).""" - t = self.analyse_super(e, False) - return t - - def analyse_super(self, e: SuperExpr, is_lvalue: bool) -> Type: - """Type check a super expression.""" - if e.info and e.info.bases: - # TODO fix multiple inheritance etc - return analyse_member_access(e.name, self_type(e.info), e, - is_lvalue, True, - self.chk.basic_types(), self.msg, - e.info.mro[1]) - else: - # Invalid super. This has been reported by the semantic analyser. - return AnyType() - - def visit_paren_expr(self, e: ParenExpr) -> Type: - """Type check a parenthesised expression.""" - return self.accept(e.expr, self.chk.type_context[-1]) - - def visit_slice_expr(self, e: SliceExpr) -> Type: - for index in [e.begin_index, e.end_index, e.stride]: - if index: - t = self.accept(index) - self.chk.check_subtype(t, self.named_type('builtins.int'), - index, messages.INVALID_SLICE_INDEX) - return self.named_type('builtins.slice') - - def visit_list_comprehension(self, e: ListComprehension) -> Type: - return self.check_generator_or_comprehension( - e.generator, 'builtins.list', '') - - def visit_generator_expr(self, e: GeneratorExpr) -> Type: - return self.check_generator_or_comprehension(e, 'typing.Iterator', - '') - - def check_generator_or_comprehension(self, gen: GeneratorExpr, - type_name: str, - id_for_messages: str) -> Type: - """Type check a generator expression or a list comprehension.""" - - self.chk.binder.push_frame() - for index, sequence, conditions in zip(gen.indices, gen.sequences, - gen.condlists): - sequence_type = self.chk.analyse_iterable_item_type(sequence) - self.chk.analyse_index_variables(index, False, sequence_type, gen) - for condition in conditions: - self.accept(condition) - self.chk.binder.pop_frame() - - # Infer the type of the list comprehension by using a synthetic generic - # callable type. - tv = TypeVar('T', -1, []) - constructor = Callable([tv], - [nodes.ARG_POS], - [None], - self.chk.named_generic_type(type_name, [tv]), - False, - id_for_messages, - [TypeVarDef('T', -1, None)]) - return self.check_call(constructor, - [gen.left_expr], [nodes.ARG_POS], gen)[0] - - def visit_undefined_expr(self, e: UndefinedExpr) -> Type: - return e.type - - def visit_conditional_expr(self, e: ConditionalExpr) -> Type: - cond_type = self.accept(e.cond) - self.check_not_void(cond_type, e) - if_type = self.accept(e.if_expr) - else_type = self.accept(e.else_expr, context=if_type) - return join.join_types(if_type, else_type, self.chk.basic_types()) - - # - # Helpers - # - - def accept(self, node: Node, context: Type = None) -> Type: - """Type check a node. Alias for TypeChecker.accept.""" - return self.chk.accept(node, context) - - def check_not_void(self, typ: Type, context: Context) -> None: - """Generate an error if type is Void.""" - self.chk.check_not_void(typ, context) - - def is_boolean(self, typ: Type) -> bool: - """Is type compatible with bool?""" - return is_subtype(typ, self.chk.bool_type()) - - def named_type(self, name: str) -> Instance: - """Return an instance type with type given by the name and no type - arguments. Alias for TypeChecker.named_type. - """ - return self.chk.named_type(name) - - def is_valid_var_arg(self, typ: Type) -> bool: - """Is a type valid as a *args argument?""" - return (isinstance(typ, TupleType) or - is_subtype(typ, self.chk.named_generic_type('typing.Iterable', - [AnyType()])) or - isinstance(typ, AnyType)) - - def is_valid_keyword_var_arg(self, typ: Type) -> bool: - """Is a type valid as a **kwargs argument?""" - return is_subtype(typ, self.chk.named_generic_type( - 'builtins.dict', [self.named_type('builtins.str'), AnyType()])) - - def has_non_method(self, typ: Type, member: str) -> bool: - """Does type have a member variable / property with the given name?""" - if isinstance(typ, Instance): - return (not typ.type.has_method(member) and - typ.type.has_readable_member(member)) - else: - return False - - def has_member(self, typ: Type, member: str) -> bool: - """Does type have member with the given name?""" - # TODO TupleType => also consider tuple attributes - if isinstance(typ, Instance): - return typ.type.has_readable_member(member) - elif isinstance(typ, AnyType): - return True - elif isinstance(typ, UnionType): - result = all(self.has_member(x, member) for x in typ.items) - return result - else: - return False - - def unwrap(self, e: Node) -> Node: - """Unwrap parentheses from an expression node.""" - if isinstance(e, ParenExpr): - return self.unwrap(e.expr) - else: - return e - - def unwrap_list(self, a: List[Node]) -> List[Node]: - """Unwrap parentheses from a list of expression nodes.""" - r = List[Node]() - for n in a: - r.append(self.unwrap(n)) - return r - - def erase(self, type: Type) -> Type: - """Replace type variable types in type with Any.""" - return erasetype.erase_type(type, self.chk.basic_types()) - - -def is_valid_argc(nargs: int, is_var_arg: bool, callable: Callable) -> bool: - """Return a boolean indicating whether a call expression has a - (potentially) compatible number of arguments for calling a function. - Varargs at caller are not checked. - """ - if is_var_arg: - if callable.is_var_arg: - return True - else: - return nargs - 1 <= callable.max_fixed_args() - elif callable.is_var_arg: - return nargs >= callable.min_args - else: - # Neither has varargs. - return nargs <= len(callable.arg_types) and nargs >= callable.min_args - - -def map_actuals_to_formals(caller_kinds: List[int], - caller_names: List[str], - callee_kinds: List[int], - callee_names: List[str], - caller_arg_type: Function[[int], - Type]) -> List[List[int]]: - """Calculate mapping between actual (caller) args and formals. - - The result contains a list of caller argument indexes mapping to each - callee argument index, indexed by callee index. - - The caller_arg_type argument should evaluate to the type of the actual - argument type with the given index. - """ - ncallee = len(callee_kinds) - map = [None] * ncallee # type: List[List[int]] - for i in range(ncallee): - map[i] = [] - j = 0 - for i, kind in enumerate(caller_kinds): - if kind == nodes.ARG_POS: - if j < ncallee: - if callee_kinds[j] in [nodes.ARG_POS, nodes.ARG_OPT, - nodes.ARG_NAMED]: - map[j].append(i) - j += 1 - elif callee_kinds[j] == nodes.ARG_STAR: - map[j].append(i) - elif kind == nodes.ARG_STAR: - # We need to to know the actual type to map varargs. - argt = caller_arg_type(i) - if isinstance(argt, TupleType): - # A tuple actual maps to a fixed number of formals. - for k in range(len(argt.items)): - if j < ncallee: - if callee_kinds[j] != nodes.ARG_STAR2: - map[j].append(i) - else: - raise NotImplementedError() - j += 1 - else: - # Assume that it is an iterable (if it isn't, there will be - # an error later). - while j < ncallee: - if callee_kinds[j] in (nodes.ARG_NAMED, nodes.ARG_STAR2): - break - else: - map[j].append(i) - j += 1 - elif kind == nodes.ARG_NAMED: - name = caller_names[i] - if name in callee_names: - map[callee_names.index(name)].append(i) - elif nodes.ARG_STAR2 in callee_kinds: - map[callee_kinds.index(nodes.ARG_STAR2)].append(i) - else: - assert kind == nodes.ARG_STAR2 - for j in range(ncallee): - # TODO tuple varargs complicate this - no_certain_match = ( - not map[j] or caller_kinds[map[j][0]] == nodes.ARG_STAR) - if ((callee_names[j] and no_certain_match) - or callee_kinds[j] == nodes.ARG_STAR2): - map[j].append(i) - return map - - -def is_empty_tuple(t: Type) -> bool: - return isinstance(t, TupleType) and not cast(TupleType, t).items - - -def is_duplicate_mapping(mapping: List[int], actual_kinds: List[int]) -> bool: - # Multiple actuals can map to the same formal only if they both come from - # varargs (*args and **kwargs); in this case at runtime it is possible that - # there are no duplicates. We need to allow this, as the convention - # f(..., *args, **kwargs) is common enough. - return len(mapping) > 1 and not ( - len(mapping) == 2 and - actual_kinds[mapping[0]] == nodes.ARG_STAR and - actual_kinds[mapping[1]] == nodes.ARG_STAR2) - - -def replace_callable_return_type(c: Callable, new_ret_type: Type) -> Callable: - """Return a copy of a callable type with a different return type.""" - return Callable(c.arg_types, - c.arg_kinds, - c.arg_names, - new_ret_type, - c.is_type_obj(), - c.name, - c.variables, - c.bound_vars, - c.line) - - -class ArgInferSecondPassQuery(types.TypeQuery): - """Query whether an argument type should be inferred in the second pass. - - The result is True if the type has a type variable in a callable return - type anywhere. For example, the result for Function[[], T] is True if t is - a type variable. - """ - def __init__(self) -> None: - super().__init__(False, types.ANY_TYPE_STRATEGY) - - def visit_callable(self, t: Callable) -> bool: - return self.query_types(t.arg_types) or t.accept(HasTypeVarQuery()) - - -class HasTypeVarQuery(types.TypeQuery): - """Visitor for querying whether a type has a type variable component.""" - def __init__(self) -> None: - super().__init__(False, types.ANY_TYPE_STRATEGY) - - def visit_type_var(self, t: TypeVar) -> bool: - return True - - -def has_erased_component(t: Type) -> bool: - return t is not None and t.accept(HasErasedComponentsQuery()) - - -class HasErasedComponentsQuery(types.TypeQuery): - """Visitor for querying whether a type has an erased component.""" - def __init__(self) -> None: - super().__init__(False, types.ANY_TYPE_STRATEGY) - - def visit_erased_type(self, t: ErasedType) -> bool: - return True diff --git a/mypy/messages.py.orig b/mypy/messages.py.orig deleted file mode 100644 index 5508bc37f4df..000000000000 --- a/mypy/messages.py.orig +++ /dev/null @@ -1,694 +0,0 @@ -"""Facilities and constants for generating error messages during type checking. - -The type checker itself does not deal with message string literals to -improve code clarity and to simplify localization (in the future).""" - -import re - -from typing import Undefined, cast, List, Any, Sequence, Iterable - -from mypy.errors import Errors -from mypy.types import ( - Type, Callable, Instance, TypeVar, TupleType, UnionType, Void, NoneTyp, AnyType, - Overloaded, FunctionLike -) -from mypy.nodes import ( - TypeInfo, Context, op_methods, FuncDef, reverse_type_aliases -) - - -# Constants that represent simple type checker error message, i.e. messages -# that do not have any parameters. - -NO_RETURN_VALUE_EXPECTED = 'No return value expected' -INCOMPATIBLE_RETURN_VALUE_TYPE = 'Incompatible return value type' -RETURN_VALUE_EXPECTED = 'Return value expected' -BOOLEAN_VALUE_EXPECTED = 'Boolean value expected' -BOOLEAN_EXPECTED_FOR_IF = 'Boolean value expected for if condition' -BOOLEAN_EXPECTED_FOR_WHILE = 'Boolean value expected for while condition' -BOOLEAN_EXPECTED_FOR_UNTIL = 'Boolean value expected for until condition' -BOOLEAN_EXPECTED_FOR_NOT = 'Boolean value expected for not operand' -INVALID_EXCEPTION = 'Exception must be derived from BaseException' -INVALID_EXCEPTION_TYPE = 'Exception type must be derived from BaseException' -INVALID_RETURN_TYPE_FOR_YIELD = \ -<<<<<<< HEAD - 'Iterator function return type expected for "yield"' -INVALID_RETURN_TYPE_FOR_YIELD_FROM = \ - 'Iterable function return type expected for "yield from"' -======= - 'Iterator function return type expected for "yield"' ->>>>>>> master -INCOMPATIBLE_TYPES = 'Incompatible types' -INCOMPATIBLE_TYPES_IN_ASSIGNMENT = 'Incompatible types in assignment' -INCOMPATIBLE_TYPES_IN_YIELD = 'Incompatible types in yield' -INCOMPATIBLE_TYPES_IN_YIELD_FROM = 'Incompatible types in "yield from"' -INIT_MUST_NOT_HAVE_RETURN_TYPE = 'Cannot define return type for "__init__"' -GETTER_TYPE_INCOMPATIBLE_WITH_SETTER = \ - 'Type of getter incompatible with setter' -TUPLE_INDEX_MUST_BE_AN_INT_LITERAL = 'Tuple index must an integer literal' -TUPLE_INDEX_OUT_OF_RANGE = 'Tuple index out of range' -TYPE_CONSTANT_EXPECTED = 'Type "Constant" or initializer expected' -INCOMPATIBLE_PAIR_ITEM_TYPE = 'Incompatible Pair item type' -INVALID_TYPE_APPLICATION_TARGET_TYPE = 'Invalid type application target type' -INCOMPATIBLE_TUPLE_ITEM_TYPE = 'Incompatible tuple item type' -INCOMPATIBLE_KEY_TYPE = 'Incompatible dictionary key type' -INCOMPATIBLE_VALUE_TYPE = 'Incompatible dictionary value type' -NEED_ANNOTATION_FOR_VAR = 'Need type annotation for variable' -ITERABLE_EXPECTED = 'Iterable expected' -INCOMPATIBLE_TYPES_IN_FOR = 'Incompatible types in for statement' -INCOMPATIBLE_ARRAY_VAR_ARGS = 'Incompatible variable arguments in call' -INVALID_SLICE_INDEX = 'Slice index must be an integer or None' -CANNOT_INFER_LAMBDA_TYPE = 'Cannot infer type of lambda' -CANNOT_INFER_ITEM_TYPE = 'Cannot infer iterable item type' -CANNOT_ACCESS_INIT = 'Cannot access "__init__" directly' -CANNOT_ASSIGN_TO_METHOD = 'Cannot assign to a method' -CANNOT_ASSIGN_TO_TYPE = 'Cannot assign to a type' -INCONSISTENT_ABSTRACT_OVERLOAD = \ - 'Overloaded method has both abstract and non-abstract variants' -INSTANCE_LAYOUT_CONFLICT = 'Instance layout conflict in multiple inheritance' - - -class MessageBuilder: - """Helper class for reporting type checker error messages with parameters. - - The methods of this class need to be provided with the context within a - file; the errors member manages the wider context. - - IDEA: Support a 'verbose mode' that includes full information about types - in error messages and that may otherwise produce more detailed error - messages. - """ - - # Report errors using this instance. It knows about the current file and - # import context. - errors = Undefined(Errors) - - # Number of times errors have been disabled. - disable_count = 0 - - # Hack to deduplicate error messages from union types - disable_type_names = 0 - - def __init__(self, errors: Errors) -> None: - self.errors = errors - self.disable_count = 0 - self.disable_type_names = 0 - - # - # Helpers - # - - def copy(self) -> 'MessageBuilder': - new = MessageBuilder(self.errors.copy()) - new.disable_count = self.disable_count - return new - - def add_errors(self, messages: 'MessageBuilder') -> None: - """Add errors in messages to this builder.""" - self.errors.error_info.extend(messages.errors.error_info) - - def disable_errors(self) -> None: - self.disable_count += 1 - - def enable_errors(self) -> None: - self.disable_count -= 1 - - def is_errors(self) -> bool: - return self.errors.is_errors() - - def fail(self, msg: str, context: Context) -> None: - """Report an error message (unless disabled).""" - if self.disable_count <= 0: - self.errors.report(context.get_line(), msg.strip()) - - def format(self, typ: Type) -> str: - """Convert a type to a relatively short string that is - suitable for error messages. Mostly behave like format_simple - below, but never return an empty string. - """ - s = self.format_simple(typ) - if s != '': - # If format_simple returns a non-trivial result, use that. - return s - elif isinstance(typ, FunctionLike): - func = cast(FunctionLike, typ) - if func.is_type_obj(): - # The type of a type object type can be derived from the - # return type (this always works). - itype = cast(Instance, func.items()[0].ret_type) - return self.format(itype) - elif isinstance(func, Callable): - arg_types = map(self.format, func.arg_types) - return_type = self.format(func.ret_type) - return 'Function[[{}] -> {}]'.format(", ".join(arg_types), return_type) - else: - # Use a simple representation for function types; proper - # function types may result in long and difficult-to-read - # error messages. - return 'functionlike' - else: - # Default case; we simply have to return something meaningful here. - return 'object' - - def format_simple(self, typ: Type) -> str: - """Convert simple types to string that is suitable for error messages. - - Return "" for complex types. Try to keep the length of the result - relatively short to avoid overly long error messages. - - Examples: - builtins.int -> 'int' - Any type -> 'Any' - void -> None - function type -> "" (empty string) - """ - if isinstance(typ, Instance): - itype = cast(Instance, typ) - # Get the short name of the type. - base_str = itype.type.name() - if itype.args == []: - # No type arguments. Place the type name in quotes to avoid - # potential for confusion: otherwise, the type name could be - # interpreted as a normal word. - return '"{}"'.format(base_str) - elif itype.type.fullname() in reverse_type_aliases: - alias = reverse_type_aliases[itype.type.fullname()] - alias = alias.split('.')[-1] - items = [strip_quotes(self.format(arg)) for arg in itype.args] - return '{}[{}]'.format(alias, ', '.join(items)) - else: - # There are type arguments. Convert the arguments to strings - # (using format() instead of format_simple() to avoid empty - # strings). If the result is too long, replace arguments - # with [...]. - a = [] # type: List[str] - for arg in itype.args: - a.append(strip_quotes(self.format(arg))) - s = ', '.join(a) - if len((base_str + s)) < 25: - return '{}[{}]'.format(base_str, s) - else: - return '{}[...]'.format(base_str) - elif isinstance(typ, TypeVar): - # This is similar to non-generic instance types. - return '"{}"'.format((cast(TypeVar, typ)).name) - elif isinstance(typ, TupleType): - items = [] - for t in (cast(TupleType, typ)).items: - items.append(strip_quotes(self.format(t))) - s = '"Tuple[{}]"'.format(', '.join(items)) - if len(s) < 40: - return s - else: - return 'tuple(length {})'.format(len(items)) - elif isinstance(typ, UnionType): - items = [] - for t in (cast(UnionType, typ)).items: - items.append(strip_quotes(self.format(t))) - s = '"Union[{}]"'.format(', '.join(items)) - if len(s) < 40: - return s - else: - return 'union(length {})'.format(len(items)) - elif isinstance(typ, Void): - return 'None' - elif isinstance(typ, NoneTyp): - return 'None' - elif isinstance(typ, AnyType): - return '"Any"' - elif typ is None: - raise RuntimeError('Type is None') - else: - # No simple representation for this type that would convey very - # useful information. No need to mention the type explicitly in a - # message. - return '' - - # - # Specific operations - # - - # The following operations are for genering specific error messages. They - # get some information as arguments, and they build an error message based - # on them. - - def has_no_attr(self, typ: Type, member: str, context: Context) -> Type: - """Report a missing or non-accessible member. - - The type argument is the base type. If member corresponds to - an operator, use the corresponding operator name in the - messages. Return type Any. - """ - if (isinstance(typ, Instance) and - (cast(Instance, typ)).type.has_readable_member(member)): - self.fail('Member "{}" is not assignable'.format(member), context) - elif isinstance(typ, Void): - self.check_void(typ, context) - elif member == '__contains__': - self.fail('Unsupported right operand type for in ({})'.format( - self.format(typ)), context) - elif member in op_methods.values(): - # Access to a binary operator member (e.g. _add). This case does - # not handle indexing operations. - for op, method in op_methods.items(): - if method == member: - self.unsupported_left_operand(op, typ, context) - break - elif member == '__neg__': - self.fail('Unsupported operand type for unary - ({})'.format( - self.format(typ)), context) - elif member == '__pos__': - self.fail('Unsupported operand type for unary + ({})'.format( - self.format(typ)), context) - elif member == '__invert__': - self.fail('Unsupported operand type for ~ ({})'.format( - self.format(typ)), context) - elif member == '__getitem__': - # Indexed get. - self.fail('Value of type {} is not indexable'.format( - self.format(typ)), context) - elif member == '__setitem__': - # Indexed set. - self.fail('Unsupported target for indexed assignment', context) - else: - # The non-special case: a missing ordinary attribute. - if not self.disable_type_names: - self.fail('{} has no attribute "{}"'.format(self.format(typ), - member), context) - else: - self.fail('Some element of union has no attribute "{}"'.format( - member), context) - return AnyType() - - def unsupported_operand_types(self, op: str, left_type: Any, - right_type: Any, context: Context) -> None: - """Report unsupported operand types for a binary operation. - - Types can be Type objects or strings. - """ - if isinstance(left_type, Void) or isinstance(right_type, Void): - self.check_void(left_type, context) - self.check_void(right_type, context) - return - left_str = '' - if isinstance(left_type, str): - left_str = left_type - else: - left_str = self.format(left_type) - - right_str = '' - if isinstance(right_type, str): - right_str = right_type - else: - right_str = self.format(right_type) - - if self.disable_type_names: - msg = 'Unsupported operand types for {} (likely involving Union)'.format(op) - else: - msg = 'Unsupported operand types for {} ({} and {})'.format( - op, left_str, right_str) - self.fail(msg, context) - - def unsupported_left_operand(self, op: str, typ: Type, - context: Context) -> None: - if not self.check_void(typ, context): - if self.disable_type_names: - msg = 'Unsupported left operand type for {} (some union)'.format(op) - else: - msg = 'Unsupported left operand type for {} ({})'.format( - op, self.format(typ)) - self.fail(msg, context) - - def type_expected_as_right_operand_of_is(self, context: Context) -> None: - self.fail('Type expected as right operand of "is"', context) - - def not_callable(self, typ: Type, context: Context) -> Type: - self.fail('{} not callable'.format(self.format(typ)), context) - return AnyType() - - def incompatible_argument(self, n: int, callee: Callable, arg_type: Type, - context: Context) -> None: - """Report an error about an incompatible argument type. - - The argument type is arg_type, argument number is n and the - callee type is 'callee'. If the callee represents a method - that corresponds to an operator, use the corresponding - operator name in the messages. - """ - target = '' - if callee.name: - name = callee.name - base = extract_type(name) - - for op, method in op_methods.items(): - for variant in method, '__r' + method[2:]: - if name.startswith('"{}" of'.format(variant)): - if op == 'in' or variant != method: - # Reversed order of base/argument. - self.unsupported_operand_types(op, arg_type, base, - context) - else: - self.unsupported_operand_types(op, base, arg_type, - context) - return - - if name.startswith('"__getitem__" of'): - self.invalid_index_type(arg_type, base, context) - return - - if name.startswith('"__setitem__" of'): - if n == 1: - self.invalid_index_type(arg_type, base, context) - else: - self.fail(INCOMPATIBLE_TYPES_IN_ASSIGNMENT, context) - return - - target = 'to {} '.format(name) - - msg = '' - if callee.name == '': - name = callee.name[1:-1] - msg = '{} item {} has incompatible type {}'.format( - name[0].upper() + name[1:], n, self.format_simple(arg_type)) - elif callee.name == '': - msg = 'List comprehension has incompatible type List[{}]'.format( - strip_quotes(self.format(arg_type))) - elif callee.name == '': - msg = 'Generator has incompatible item type {}'.format( - self.format_simple(arg_type)) - else: - try: - expected_type = callee.arg_types[n - 1] - except IndexError: # Varargs callees - expected_type = callee.arg_types[-1] - msg = 'Argument {} {}has incompatible type {}; expected {}'.format( - n, target, self.format(arg_type), self.format(expected_type)) - self.fail(msg, context) - - def invalid_index_type(self, index_type: Type, base_str: str, - context: Context) -> None: - self.fail('Invalid index type {} for {}'.format( - self.format(index_type), base_str), context) - - def invalid_argument_count(self, callee: Callable, num_args: int, - context: Context) -> None: - if num_args < len(callee.arg_types): - self.too_few_arguments(callee, context) - else: - self.too_many_arguments(callee, context) - - def too_few_arguments(self, callee: Callable, context: Context) -> None: - msg = 'Too few arguments' - if callee.name: - msg += ' for {}'.format(callee.name) - self.fail(msg, context) - - def too_many_arguments(self, callee: Callable, context: Context) -> None: - msg = 'Too many arguments' - if callee.name: - msg += ' for {}'.format(callee.name) - self.fail(msg, context) - - def too_many_positional_arguments(self, callee: Callable, - context: Context) -> None: - msg = 'Too many positional arguments' - if callee.name: - msg += ' for {}'.format(callee.name) - self.fail(msg, context) - - def unexpected_keyword_argument(self, callee: Callable, name: str, - context: Context) -> None: - msg = 'Unexpected keyword argument "{}"'.format(name) - if callee.name: - msg += ' for {}'.format(callee.name) - self.fail(msg, context) - - def duplicate_argument_value(self, callee: Callable, index: int, - context: Context) -> None: - self.fail('{} gets multiple values for keyword argument "{}"'. - format(capitalize(callable_name(callee)), - callee.arg_names[index]), context) - - def does_not_return_value(self, void_type: Type, context: Context) -> None: - """Report an error about a void type in a non-void context. - - The first argument must be a void type. If the void type has a - source in it, report it in the error message. This allows - giving messages such as 'Foo does not return a value'. - """ - if (cast(Void, void_type)).source is None: - self.fail('Function does not return a value', context) - else: - self.fail('{} does not return a value'.format( - capitalize((cast(Void, void_type)).source)), context) - - def no_variant_matches_arguments(self, overload: Overloaded, - context: Context) -> None: - if overload.name(): - self.fail('No overload variant of {} matches argument types' - .format(overload.name()), context) - else: - self.fail('No overload variant matches argument types', context) - - def function_variants_overlap(self, n1: int, n2: int, - context: Context) -> None: - self.fail('Function signature variants {} and {} overlap'.format( - n1 + 1, n2 + 1), context) - - def invalid_cast(self, target_type: Type, source_type: Type, - context: Context) -> None: - if not self.check_void(source_type, context): - self.fail('Cannot cast from {} to {}'.format( - self.format(source_type), self.format(target_type)), context) - - def incompatible_operator_assignment(self, op: str, - context: Context) -> None: - self.fail('Result type of {} incompatible in assignment'.format(op), - context) - - def incompatible_value_count_in_assignment(self, lvalue_count: int, - rvalue_count: int, - context: Context) -> None: - if rvalue_count < lvalue_count: - self.fail('Need {} values to assign'.format(lvalue_count), context) - elif rvalue_count > lvalue_count: - self.fail('Too many values to assign', context) - - def type_incompatible_with_supertype(self, name: str, supertype: TypeInfo, - context: Context) -> None: - self.fail('Type of "{}" incompatible with supertype "{}"'.format( - name, supertype.name), context) - - def signature_incompatible_with_supertype( - self, name: str, name_in_super: str, supertype: str, - context: Context) -> None: - target = self.override_target(name, name_in_super, supertype) - self.fail('Signature of "{}" incompatible with {}'.format( - name, target), context) - - def argument_incompatible_with_supertype( - self, arg_num: int, name: str, name_in_supertype: str, - supertype: str, context: Context) -> None: - target = self.override_target(name, name_in_supertype, supertype) - self.fail('Argument {} of "{}" incompatible with {}' - .format(arg_num, name, target), context) - - def return_type_incompatible_with_supertype( - self, name: str, name_in_supertype: str, supertype: str, - context: Context) -> None: - target = self.override_target(name, name_in_supertype, supertype) - self.fail('Return type of "{}" incompatible with {}' - .format(name, target), context) - - def override_target(self, name: str, name_in_super: str, - supertype: str) -> str: - target = 'supertype "{}"'.format(supertype) - if name_in_super != name: - target = '"{}" of {}'.format(name_in_super, target) - return target - - def boolean_return_value_expected(self, method: str, - context: Context) -> None: - self.fail('Boolean return value expected for method "{}"'.format( - method), context) - - def incompatible_type_application(self, expected_arg_count: int, - actual_arg_count: int, - context: Context) -> None: - if expected_arg_count == 0: - self.fail('Type application targets a non-generic function', - context) - elif actual_arg_count > expected_arg_count: - self.fail('Type application has too many types ({} expected)' - .format(expected_arg_count), context) - else: - self.fail('Type application has too few types ({} expected)' - .format(expected_arg_count), context) - - def incompatible_array_item_type(self, typ: Type, index: int, - context: Context) -> None: - self.fail('Array item {} has incompatible type {}'.format( - index, self.format(typ)), context) - - def could_not_infer_type_arguments(self, callee_type: Callable, n: int, - context: Context) -> None: - if callee_type.name and n > 0: - self.fail('Cannot infer type argument {} of {}'.format( - n, callee_type.name), context) - else: - self.fail('Cannot infer function type argument', context) - - def invalid_var_arg(self, typ: Type, context: Context) -> None: - self.fail('List or tuple expected as variable arguments', context) - - def invalid_keyword_var_arg(self, typ: Type, context: Context) -> None: - if isinstance(typ, Instance) and ( - (cast(Instance, typ)).type.fullname() == 'builtins.dict'): - self.fail('Keywords must be strings', context) - else: - self.fail('Argument after ** must be a dictionary', - context) - - def incomplete_type_var_match(self, member: str, context: Context) -> None: - self.fail('"{}" has incomplete match to supertype type variable' - .format(member), context) - - def not_implemented(self, msg: str, context: Context) -> Type: - self.fail('Feature not implemented yet ({})'.format(msg), context) - return AnyType() - - def undefined_in_superclass(self, member: str, context: Context) -> None: - self.fail('"{}" undefined in superclass'.format(member), context) - - def check_void(self, typ: Type, context: Context) -> bool: - """If type is void, report an error such as '.. does not - return a value' and return True. Otherwise, return False. - """ - if isinstance(typ, Void): - self.does_not_return_value(typ, context) - return True - else: - return False - - def cannot_determine_type(self, name: str, context: Context) -> None: - self.fail("Cannot determine type of '%s'" % name, context) - - def invalid_method_type(self, sig: Callable, context: Context) -> None: - self.fail('Invalid method type', context) - - def incompatible_conditional_function_def(self, defn: FuncDef) -> None: - self.fail('All conditional function variants must have identical ' - 'signatures', defn) - - def cannot_instantiate_abstract_class(self, class_name: str, - abstract_attributes: List[str], - context: Context) -> None: - attrs = format_string_list("'%s'" % a for a in abstract_attributes[:5]) - self.fail("Cannot instantiate abstract class '%s' with abstract " - "method%s %s" % (class_name, plural_s(abstract_attributes), - attrs), - context) - - def base_class_definitions_incompatible(self, name: str, base1: TypeInfo, - base2: TypeInfo, - context: Context) -> None: - self.fail('Definition of "{}" in base class "{}" is incompatible ' - 'with definition in base class "{}"'.format( - name, base1.name(), base2.name()), context) - - def cant_assign_to_method(self, context: Context) -> None: - self.fail(CANNOT_ASSIGN_TO_METHOD, context) - - def read_only_property(self, name: str, type: TypeInfo, - context: Context) -> None: - self.fail('Property "{}" defined in "{}" is read-only'.format( - name, type.name()), context) - - def incompatible_typevar_value(self, callee: Callable, index: int, - type: Type, context: Context) -> None: - self.fail('Type argument {} of {} has incompatible value {}'.format( - index, callable_name(callee), self.format(type)), context) - - def disjointness_violation(self, cls: TypeInfo, disjoint: TypeInfo, - context: Context) -> None: - self.fail('disjointclass constraint of class {} disallows {} as a ' - 'base class'.format(cls.name(), disjoint.name()), context) - - def overloaded_signatures_overlap(self, index1: int, index2: int, - context: Context) -> None: - self.fail('Overloaded function signatures {} and {} overlap with ' - 'incompatible return types'.format(index1, index2), context) - - def invalid_reverse_operator_signature(self, reverse: str, other: str, - context: Context) -> None: - self.fail('"Any" return type expected since argument to {} does not ' - 'support {}'.format(reverse, other), context) - - def reverse_operator_method_with_any_arg_must_return_any( - self, method: str, context: Context) -> None: - self.fail('"Any" return type expected since argument to {} has type ' - '"Any"'.format(method), context) - - def operator_method_signatures_overlap( - self, reverse_class: str, reverse_method: str, forward_class: str, - forward_method: str, context: Context) -> None: - self.fail('Signatures of "{}" of "{}" and "{}" of "{}" are unsafely ' - 'overlapping'.format(reverse_method, reverse_class, - forward_method, forward_class), context) - - def signatures_incompatible(self, method: str, other_method: str, - context: Context) -> None: - self.fail('Signatures of "{}" and "{}" are incompatible'.format( - method, other_method), context) - - def yield_from_not_valid_applied(self, expr: Type, context: Context) -> Type: - text = self.format(expr) if self.format(expr) != 'object' else expr - self.fail('"yield from" can\'t be applied to {}'.format(text), context) - return AnyType() - - -def capitalize(s: str) -> str: - """Capitalize the first character of a string.""" - if s == '': - return '' - else: - return s[0].upper() + s[1:] - - -def extract_type(name: str) -> str: - """If the argument is the name of a method (of form C.m), return - the type portion in quotes (e.g. "y"). Otherwise, return the string - unmodified. - """ - name = re.sub('^"[a-zA-Z0-9_]+" of ', '', name) - return name - - -def strip_quotes(s: str) -> str: - """Strip a double quote at the beginning and end of the string, if any.""" - s = re.sub('^"', '', s) - s = re.sub('"$', '', s) - return s - - -def plural_s(s: Sequence[Any]) -> str: - if len(s) > 1: - return 's' - else: - return '' - - -def format_string_list(s: Iterable[str]) -> str: - l = list(s) - assert len(l) > 0 - if len(l) == 1: - return l[0] - else: - return '%s and %s' % (', '.join(l[:-1]), l[-1]) - - -def callable_name(type: Callable) -> str: - if type.name: - return type.name - else: - return 'function' diff --git a/mypy/nodes.py.orig b/mypy/nodes.py.orig deleted file mode 100644 index 998e739412ac..000000000000 --- a/mypy/nodes.py.orig +++ /dev/null @@ -1,1845 +0,0 @@ -"""Abstract syntax tree node classes (i.e. parse tree).""" - -import re -from abc import abstractmethod, ABCMeta - -from typing import ( - Any, overload, typevar, Undefined, List, Tuple, cast, Set, Dict -) - -from mypy.lex import Token -import mypy.strconv -from mypy.visitor import NodeVisitor -from mypy.util import dump_tagged, short_type - - -class Context(metaclass=ABCMeta): - """Base type for objects that are valid as error message locations.""" - #@abstractmethod - def get_line(self) -> int: pass - - -import mypy.types - - -T = typevar('T') - - -# Variable kind constants -# TODO rename to use more descriptive names - -LDEF = 0 # type: int -GDEF = 1 # type: int -MDEF = 2 # type: int -MODULE_REF = 3 # type: int -# Type variable declared using typevar(...) has kind UNBOUND_TVAR. It's not -# valid as a type. A type variable is valid as a type (kind TVAR) within -# (1) a generic class that uses the type variable as a type argument or -# (2) a generic function that refers to the type variable in its signature. -UNBOUND_TVAR = 4 # type: 'int' -TVAR = 5 # type: int - - -LITERAL_YES = 2 -LITERAL_TYPE = 1 -LITERAL_NO = 0 - -node_kinds = { - LDEF: 'Ldef', - GDEF: 'Gdef', - MDEF: 'Mdef', - MODULE_REF: 'ModuleRef', - UNBOUND_TVAR: 'UnboundTvar', - TVAR: 'Tvar', -} - - -implicit_module_attrs = ['__name__', '__doc__', '__file__'] - - -type_aliases = { - 'typing.List': '__builtins__.list', - 'typing.Dict': '__builtins__.dict', - 'typing.Set': '__builtins__.set', -} - -reverse_type_aliases = dict((name.replace('__builtins__', 'builtins'), alias) - for alias, name in type_aliases.items()) - - -class Node(Context): - """Common base class for all non-type parse tree nodes.""" - - line = -1 - # Textual representation - repr = None # type: Any - - literal = LITERAL_NO - literal_hash = None # type: Any - - def __str__(self) -> str: - ans = self.accept(mypy.strconv.StrConv()) - if ans is None: - return repr(self) - return ans - - @overload - def set_line(self, tok: Token) -> 'Node': - self.line = tok.line - return self - - @overload - def set_line(self, line: int) -> 'Node': - self.line = line - return self - - def get_line(self) -> int: - # TODO this should be just 'line' - return self.line - - def accept(self, visitor: NodeVisitor[T]) -> T: - raise RuntimeError('Not implemented') - - -class SymbolNode(Node): - # Nodes that can be stored in a symbol table. - - # TODO do not use methods for these - - @abstractmethod - def name(self) -> str: pass - - @abstractmethod - def fullname(self) -> str: pass - - -class MypyFile(SymbolNode): - """The abstract syntax tree of a single source file.""" - -<<<<<<< HEAD - _name = None # type: str # Module name ('__main__' for initial file) - _fullname = None # type: str # Qualified module name - path = '' # Path to the file (None if not known) - defs = Undefined # type: List[Node] # Global definitions and statements - is_bom = False # Is there a UTF-8 BOM at the start? -======= - _name = None # type: str # Module name ('__main__' for initial file) - _fullname = None # type: str # Qualified module name - path = '' # Path to the file (None if not known) - defs = Undefined # type: List[Node] # Global definitions and statements - is_bom = False # Is there a UTF-8 BOM at the start? ->>>>>>> master - names = Undefined('SymbolTable') - imports = Undefined(List['ImportBase']) # All import nodes within the file - - def __init__(self, defs: List[Node], imports: List['ImportBase'], - is_bom: bool = False) -> None: - self.defs = defs - self.line = 1 # Dummy line number - self.imports = imports - self.is_bom = is_bom - - def name(self) -> str: - return self._name - - def fullname(self) -> str: - return self._fullname - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_mypy_file(self) - - -class ImportBase(Node): - """Base class for all import statements.""" - is_unreachable = False - - -class Import(ImportBase): - """import m [as n]""" - - ids = Undefined(List[Tuple[str, str]]) # (module id, as id) - - def __init__(self, ids: List[Tuple[str, str]]) -> None: - self.ids = ids - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_import(self) - - -class ImportFrom(ImportBase): - """from m import x, ...""" - -<<<<<<< HEAD - names = Undefined(List[Tuple[str, str]]) # Tuples (name, as name) -======= - names = Undefined(List[Tuple[str, str]]) # Tuples (name, as name) ->>>>>>> master - - def __init__(self, id: str, names: List[Tuple[str, str]]) -> None: - self.id = id - self.names = names - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_import_from(self) - - -class ImportAll(ImportBase): - """from m import *""" - - def __init__(self, id: str) -> None: - self.id = id - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_import_all(self) - - -class FuncBase(SymbolNode): - """Abstract base class for function-like nodes""" - - # Type signature (Callable or Overloaded) - type = None # type: mypy.types.Type - # If method, reference to TypeInfo - info = None # type: TypeInfo - - @abstractmethod - def name(self) -> str: pass - - def fullname(self) -> str: - return self.name() - - def is_method(self) -> bool: - return bool(self.info) - - -class OverloadedFuncDef(FuncBase): - """A logical node representing all the variants of an overloaded function. - - This node has no explicit representation in the source program. - Overloaded variants must be consecutive in the source file. - """ - - items = Undefined(List['Decorator']) -<<<<<<< HEAD - _fullname = None # type: str -======= - _fullname = None # type: str ->>>>>>> master - - def __init__(self, items: List['Decorator']) -> None: - self.items = items - self.set_line(items[0].line) - - def name(self) -> str: - return self.items[1].func.name() - - def fullname(self) -> str: - return self._fullname - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_overloaded_func_def(self) - - -class FuncItem(FuncBase): -<<<<<<< HEAD - args = Undefined(List['Var']) # Argument names - arg_kinds = Undefined(List[int]) # Kinds of arguments (ARG_*) -======= - args = Undefined(List['Var']) # Argument names - arg_kinds = Undefined(List[int]) # Kinds of arguments (ARG_*) ->>>>>>> master - - # Initialization expessions for fixed args; None if no initialiser - init = Undefined(List['AssignmentStmt']) - min_args = 0 # Minimum number of arguments - max_pos = 0 # Maximum number of positional arguments, -1 if - # no explicit limit (*args not included) - body = Undefined('Block') - is_implicit = False # Implicit dynamic types? - is_overload = False # Is this an overload variant of function with - # more than one overload variant? - is_generator = False # Contains a yield statement? - is_coroutine = False # Contains @coroutine or yield from Future - is_static = False # Uses @staticmethod? - is_class = False # Uses @classmethod? - expanded = Undefined(List['FuncItem']) # Variants of function with type - # variables with values expanded - - def __init__(self, args: List['Var'], arg_kinds: List[int], - init: List[Node], body: 'Block', - typ: 'mypy.types.Type' = None) -> None: - self.args = args - self.arg_kinds = arg_kinds - self.max_pos = arg_kinds.count(ARG_POS) + arg_kinds.count(ARG_OPT) - self.body = body - self.type = typ - self.expanded = [] - - i2 = List[AssignmentStmt]() - self.min_args = 0 - for i in range(len(init)): - if init[i] is not None: - rvalue = init[i] - lvalue = NameExpr(args[i].name()).set_line(rvalue.line) - assign = AssignmentStmt([lvalue], rvalue) - assign.set_line(rvalue.line) - i2.append(assign) - else: - i2.append(None) - if i < self.max_fixed_argc(): - self.min_args = i + 1 - self.init = i2 - - def max_fixed_argc(self) -> int: - return self.max_pos - - @overload - def set_line(self, tok: Token) -> Node: - super().set_line(tok) - for n in self.args: - n.line = self.line - return self - - @overload - def set_line(self, tok: int) -> Node: - super().set_line(tok) - for n in self.args: - n.line = self.line - return self - - def init_expressions(self) -> List[Node]: - res = List[Node]() - for i in self.init: - if i is not None: - res.append(i.rvalue) - else: - res.append(None) - return res - - -class FuncDef(FuncItem): - """Function definition. - - This is a non-lambda function defined using 'def'. - """ - -<<<<<<< HEAD - _fullname = None # type: str # Name with module prefix -======= - _fullname = None # type: str # Name with module prefix ->>>>>>> master - is_decorated = False - is_conditional = False # Defined conditionally (within block)? - is_abstract = False - is_property = False -<<<<<<< HEAD - original_def = None # type: FuncDef # Original conditional definition -======= - original_def = None # type: FuncDef # Original conditional definition ->>>>>>> master - - def __init__(self, - name: str, # Function name - args: List['Var'], # Argument names - arg_kinds: List[int], # Arguments kinds (nodes.ARG_*) - init: List[Node], # Initializers (each may be None) - body: 'Block', - typ: 'mypy.types.Type' = None) -> None: - super().__init__(args, arg_kinds, init, body, typ) - self._name = name - - def name(self) -> str: - return self._name - - def fullname(self) -> str: - return self._fullname - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_func_def(self) - - def is_constructor(self) -> bool: - return self.info is not None and self._name == '__init__' - - def get_name(self) -> str: - """TODO merge with name()""" - return self._name - - -class Decorator(SymbolNode): - """A decorated function. - - A single Decorator object can include any number of function decorators. - """ - -<<<<<<< HEAD - func = Undefined(FuncDef) # Decorated function - decorators = Undefined(List[Node]) # Decorators, at least one - var = Undefined('Var') # Represents the decorated function obj -======= - func = Undefined(FuncDef) # Decorated function - decorators = Undefined(List[Node]) # Decorators, at least one - var = Undefined('Var') # Represents the decorated function obj ->>>>>>> master - is_overload = False - - def __init__(self, func: FuncDef, decorators: List[Node], - var: 'Var') -> None: - self.func = func - self.decorators = decorators - self.var = var - self.is_overload = False - - def name(self) -> str: - return self.func.name() - - def fullname(self) -> str: - return self.func.fullname() - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_decorator(self) - - -class Var(SymbolNode): - """A variable. - - It can refer to global/local variable or a data attribute. - """ - -<<<<<<< HEAD - _name = None # type: str # Name without module prefix - _fullname = None # type: str # Name with module prefix -======= - _name = None # type: str # Name without module prefix - _fullname = None # type: str # Name with module prefix ->>>>>>> master - info = Undefined('TypeInfo') # Defining class (for member variables) - type = None # type: mypy.types.Type # Declared or inferred type, or None - is_self = False # Is this the first argument to an ordinary method - # (usually "self")? - is_ready = False # If inferred, is the inferred type available? - # Is this initialized explicitly to a non-None value in class body? - is_initialized_in_class = False - is_staticmethod = False - is_classmethod = False - is_property = False - - def __init__(self, name: str, type: 'mypy.types.Type' = None) -> None: - self._name = name - self.type = type - self.is_self = False - self.is_ready = True - self.is_initialized_in_class = False - - def name(self) -> str: - return self._name - - def fullname(self) -> str: - return self._fullname - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_var(self) - - -class ClassDef(Node): - """Class definition""" - -<<<<<<< HEAD - name = Undefined(str) # Name of the class without module prefix - fullname = None # type: str # Fully qualified name of the class -======= - name = Undefined(str) # Name of the class without module prefix - fullname = None # type: str # Fully qualified name of the class ->>>>>>> master - defs = Undefined('Block') - type_vars = Undefined(List['mypy.types.TypeVarDef']) - # Base classes (Instance or UnboundType). - base_types = Undefined(List['mypy.types.Type']) - info = None # type: TypeInfo # Related TypeInfo - metaclass = '' - decorators = Undefined(List[Node]) - # Built-in/extension class? (single implementation inheritance only) - is_builtinclass = False - - def __init__(self, name: str, defs: 'Block', - type_vars: List['mypy.types.TypeVarDef'] = None, - base_types: List['mypy.types.Type'] = None, - metaclass: str = None) -> None: - if not base_types: - base_types = [] - self.name = name - self.defs = defs - self.type_vars = type_vars or [] - self.base_types = base_types - self.metaclass = metaclass - self.decorators = [] - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_class_def(self) - - def is_generic(self) -> bool: - return self.info.is_generic() - - -class VarDef(Node): - """Variable definition with explicit types""" - - items = Undefined(List[Var]) -<<<<<<< HEAD - kind = None # type: int # LDEF/GDEF/MDEF/... - init = Undefined(Node) # Expression or None - is_top_level = False # Is the definition at the top level (not within - # a function or a type)? -======= - kind = None # type: int # LDEF/GDEF/MDEF/... - init = Undefined(Node) # Expression or None - is_top_level = False # Is the definition at the top level (not within - # a function or a type)? ->>>>>>> master - - def __init__(self, items: List[Var], is_top_level: bool, - init: Node = None) -> None: - self.items = items - self.is_top_level = is_top_level - self.init = init - - def info(self) -> 'TypeInfo': - return self.items[0].info - - @overload - def set_line(self, tok: Token) -> Node: - super().set_line(tok) - for n in self.items: - n.line = self.line - return self - - @overload - def set_line(self, tok: int) -> Node: - super().set_line(tok) - for n in self.items: - n.line = self.line - return self - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_var_def(self) - - -class GlobalDecl(Node): - """Declaration global x, y, ...""" - - names = Undefined(List[str]) - - def __init__(self, names: List[str]) -> None: - self.names = names - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_global_decl(self) - - -class Block(Node): - body = Undefined(List[Node]) - # True if we can determine that this block is not executed. For example, - # this applies to blocks that are protected by something like "if PY3:" - # when using Python 2. - is_unreachable = False - - def __init__(self, body: List[Node]) -> None: - self.body = body - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_block(self) - - -# Statements - - -class ExpressionStmt(Node): - """An expression as a statament, such as print(s).""" - expr = Undefined(Node) - - def __init__(self, expr: Node) -> None: - self.expr = expr - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_expression_stmt(self) - - -class AssignmentStmt(Node): - """Assignment statement - - The same node class is used for single assignment, multiple assignment - (e.g. x, y = z) and chained assignment (e.g. x = y = z), assignments - that define new names, and assignments with explicit types (# type). - - An lvalue can be NameExpr, TupleExpr, ListExpr, MemberExpr, IndexExpr or - ParenExpr. - """ - - lvalues = Undefined(List[Node]) - rvalue = Undefined(Node) - type = None # type: mypy.types.Type # Declared type in a comment, - # may be None. - - def __init__(self, lvalues: List[Node], rvalue: Node, - type: 'mypy.types.Type' = None) -> None: - self.lvalues = lvalues - self.rvalue = rvalue - self.type = type - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_assignment_stmt(self) - - -class OperatorAssignmentStmt(Node): - """Operator assignment statement such as x += 1""" - - op = '' - lvalue = Undefined(Node) - rvalue = Undefined(Node) - - def __init__(self, op: str, lvalue: Node, rvalue: Node) -> None: - self.op = op - self.lvalue = lvalue - self.rvalue = rvalue - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_operator_assignment_stmt(self) - - -class WhileStmt(Node): - expr = Undefined(Node) - body = Undefined(Block) - else_body = Undefined(Block) - - def __init__(self, expr: Node, body: Block, else_body: Block) -> None: - self.expr = expr - self.body = body - self.else_body = else_body - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_while_stmt(self) - - -class ForStmt(Node): - # Index variables - index = Undefined(List['NameExpr']) - # Index variable types (each may be None) - types = Undefined(List['mypy.types.Type']) - # Expression to iterate - expr = Undefined(Node) - body = Undefined(Block) - else_body = Undefined(Block) - - def __init__(self, index: List['NameExpr'], expr: Node, body: Block, - else_body: Block, - types: List['mypy.types.Type'] = None) -> None: - self.index = index - self.expr = expr - self.body = body - self.else_body = else_body - self.types = types - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_for_stmt(self) - - def is_annotated(self) -> bool: - ann = False - for t in self.types: - if t is not None: - ann = True - return ann - - -class ReturnStmt(Node): - expr = Undefined(Node) # Expression or None - - def __init__(self, expr: Node) -> None: - self.expr = expr - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_return_stmt(self) - - -class AssertStmt(Node): - expr = Undefined(Node) - - def __init__(self, expr: Node) -> None: - self.expr = expr - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_assert_stmt(self) - - -class YieldStmt(Node): - expr = Undefined(Node) - - def __init__(self, expr: Node) -> None: - self.expr = expr - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_yield_stmt(self) - - -class YieldFromStmt(Node): - expr = Undefined(Node) - - def __init__(self, expr: Node) -> None: - self.expr = expr - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_yield_from_stmt(self) - - -class DelStmt(Node): - expr = Undefined(Node) - - def __init__(self, expr: Node) -> None: - self.expr = expr - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_del_stmt(self) - - -class BreakStmt(Node): - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_break_stmt(self) - - -class ContinueStmt(Node): - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_continue_stmt(self) - - -class PassStmt(Node): - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_pass_stmt(self) - - -class IfStmt(Node): - expr = Undefined(List[Node]) - body = Undefined(List[Block]) - else_body = Undefined(Block) - - def __init__(self, expr: List[Node], body: List[Block], - else_body: Block) -> None: - self.expr = expr - self.body = body - self.else_body = else_body - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_if_stmt(self) - - -class RaiseStmt(Node): - expr = Undefined(Node) - from_expr = Undefined(Node) - - def __init__(self, expr: Node, from_expr: Node = None) -> None: - self.expr = expr - self.from_expr = from_expr - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_raise_stmt(self) - - -class TryStmt(Node): - body = Undefined(Block) # Try body - types = Undefined(List[Node]) # Except type expressions - vars = Undefined(List['NameExpr']) # Except variable names - handlers = Undefined(List[Block]) # Except bodies - else_body = Undefined(Block) - finally_body = Undefined(Block) - - def __init__(self, body: Block, vars: List['NameExpr'], types: List[Node], - handlers: List[Block], else_body: Block, - finally_body: Block) -> None: - self.body = body - self.vars = vars - self.types = types - self.handlers = handlers - self.else_body = else_body - self.finally_body = finally_body - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_try_stmt(self) - - -class WithStmt(Node): - expr = Undefined(List[Node]) - name = Undefined(List['NameExpr']) - body = Undefined(Block) - - def __init__(self, expr: List[Node], name: List['NameExpr'], - body: Block) -> None: - self.expr = expr - self.name = name - self.body = body - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_with_stmt(self) - - -class PrintStmt(Node): - """Python 2 print statement""" - - args = Undefined(List[Node]) - newline = False - - def __init__(self, args: List[Node], newline: bool) -> None: - self.args = args - self.newline = newline - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_print_stmt(self) - - -# Expressions - - -class IntExpr(Node): - """Integer literal""" - - value = 0 - literal = LITERAL_YES - - def __init__(self, value: int) -> None: - self.value = value - self.literal_hash = value - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_int_expr(self) - - -class StrExpr(Node): - """String literal""" - - value = '' - literal = LITERAL_YES - - def __init__(self, value: str) -> None: - self.value = value - self.literal_hash = value - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_str_expr(self) - - -class BytesExpr(Node): - """Bytes literal""" - -<<<<<<< HEAD - value = '' # TODO use bytes -======= - value = '' # TODO use bytes ->>>>>>> master - literal = LITERAL_YES - - def __init__(self, value: str) -> None: - self.value = value - self.literal_hash = value - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_bytes_expr(self) - - -class UnicodeExpr(Node): - """Unicode literal (Python 2.x)""" - -<<<<<<< HEAD - value = '' # TODO use bytes -======= - value = '' # TODO use bytes ->>>>>>> master - literal = LITERAL_YES - - def __init__(self, value: str) -> None: - self.value = value - self.literal_hash = value - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_unicode_expr(self) - - -class FloatExpr(Node): - """Float literal""" - - value = 0.0 - literal = LITERAL_YES - - def __init__(self, value: float) -> None: - self.value = value - self.literal_hash = value - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_float_expr(self) - - -class ParenExpr(Node): - """Parenthesised expression""" - - expr = Undefined(Node) - - def __init__(self, expr: Node) -> None: - self.expr = expr - self.literal = self.expr.literal - self.literal_hash = ('Paren', expr.literal_hash,) - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_paren_expr(self) - - -class RefExpr(Node): - """Abstract base class for name-like constructs""" - -<<<<<<< HEAD - kind = None # type: int # LDEF/GDEF/MDEF/... (None if not available) - node = Undefined(Node) # Var, FuncDef or TypeInfo that describes this - fullname = None # type: str # Fully qualified name (or name if not global) -======= - kind = None # type: int # LDEF/GDEF/MDEF/... (None if not available) - node = Undefined(Node) # Var, FuncDef or TypeInfo that describes this - fullname = None # type: str # Fully qualified name (or name if not global) ->>>>>>> master - - # Does this define a new name with inferred type? - # - # For members, after semantic analysis, this does not take base - # classes into consideration at all; the type checker deals with these. - is_def = False - - -class NameExpr(RefExpr): - """Name expression - - This refers to a local name, global name or a module. - """ - -<<<<<<< HEAD - name = None # type: str # Name referred to (may be qualified) - info = Undefined('TypeInfo') # TypeInfo of class surrounding expression - # (may be None) -======= - name = None # type: str # Name referred to (may be qualified) - info = Undefined('TypeInfo') # TypeInfo of class surrounding expression - # (may be None) ->>>>>>> master - literal = LITERAL_TYPE - - def __init__(self, name: str) -> None: - self.name = name - self.literal_hash = ('Var', name,) - - def type_node(self): - return cast('TypeInfo', self.node) - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_name_expr(self) - - -class MemberExpr(RefExpr): - """Member access expression x.y""" - - expr = Undefined(Node) - name = None # type: str - # The variable node related to a definition. - def_var = None # type: Var - # Is this direct assignment to a data member (bypassing accessors)? - direct = False - - def __init__(self, expr: Node, name: str, direct: bool = False) -> None: - self.expr = expr - self.name = name - self.direct = direct - self.literal = self.expr.literal - self.literal_hash = ('Member', expr.literal_hash, name, direct) - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_member_expr(self) - - -# Kinds of arguments - -# Positional argument -ARG_POS = 0 # type: int -# Positional, optional argument (functions only, not calls) -ARG_OPT = 1 # type: int -# *arg argument -ARG_STAR = 2 # type: int -# Keyword argument x=y in call, or keyword-only function arg -ARG_NAMED = 3 # type: int -# **arg argument -ARG_STAR2 = 4 # type: int - - -class CallExpr(Node): - """Call expression. - - This can also represent several special forms that are syntactically calls - such as cast(...) and Undefined(...). - """ - - callee = Undefined(Node) - args = Undefined(List[Node]) -<<<<<<< HEAD - arg_kinds = Undefined(List[int]) # ARG_ constants - arg_names = Undefined(List[str]) # Each name can be None if not a keyword - # argument. - analyzed = Undefined(Node) # If not None, the node that represents - # the meaning of the CallExpr. For - # cast(...) this is a CastExpr. -======= - arg_kinds = Undefined(List[int]) # ARG_ constants - arg_names = Undefined(List[str]) # Each name can be None if not a keyword - # argument. - analyzed = Undefined(Node) # If not None, the node that represents - # the meaning of the CallExpr. For - # cast(...) this is a CastExpr. ->>>>>>> master - - def __init__(self, callee: Node, args: List[Node], arg_kinds: List[int], - arg_names: List[str] = None, analyzed: Node = None) -> None: - if not arg_names: - arg_names = [None] * len(args) - self.callee = callee - self.args = args - self.arg_kinds = arg_kinds - self.arg_names = arg_names - self.analyzed = analyzed - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_call_expr(self) - - -class YieldFromExpr(Node): - expr = Undefined(Node) - - def __init__(self, expr: Node) -> None: - self.expr = expr - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_yield_from_expr(self) - -class IndexExpr(Node): - """Index expression x[y]. - - Also wraps type application as a special form. - """ - - base = Undefined(Node) - index = Undefined(Node) - # Inferred __getitem__ method type - method_type = None # type: mypy.types.Type - # If not None, this is actually semantically a type application - # Class[type, ...]. - analyzed = Undefined('TypeApplication') - - def __init__(self, base: Node, index: Node) -> None: - self.base = base - self.index = index - self.analyzed = None - if self.index.literal == LITERAL_YES: - self.literal = self.base.literal - self.literal_hash = ('Member', base.literal_hash, - index.literal_hash) - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_index_expr(self) - - -class UnaryExpr(Node): - """Unary operation""" - - op = '' - expr = Undefined(Node) - # Inferred operator method type -<<<<<<< HEAD - method_type = None # type: mypy.types.Type -======= - method_type = None # type: mypy.types.Type ->>>>>>> master - - def __init__(self, op: str, expr: Node) -> None: - self.op = op - self.expr = expr - self.literal = self.expr.literal - self.literal_hash = ('Unary', op, expr.literal_hash) - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_unary_expr(self) - - -# Map from binary operator id to related method name (in Python 3). -op_methods = { - '+': '__add__', - '-': '__sub__', - '*': '__mul__', - '/': '__truediv__', - '%': '__mod__', - '//': '__floordiv__', - '**': '__pow__', - '&': '__and__', - '|': '__or__', - '^': '__xor__', - '<<': '__lshift__', - '>>': '__rshift__', - '==': '__eq__', - '!=': '__ne__', - '<': '__lt__', - '>=': '__ge__', - '>': '__gt__', - '<=': '__le__', - 'in': '__contains__', -} - -ops_with_inplace_method = { - '+', '-', '*', '/', '%', '//', '**', '&', '|', '^', '<<', '>>'} - -inplace_operator_methods = set( - '__i' + op_methods[op][2:] for op in ops_with_inplace_method) - -reverse_op_methods = { - '__add__': '__radd__', - '__sub__': '__rsub__', - '__mul__': '__rmul__', - '__truediv__': '__rtruediv__', - '__mod__': '__rmod__', - '__floordiv__': '__rfloordiv__', - '__pow__': '__rpow__', - '__and__': '__rand__', - '__or__': '__ror__', - '__xor__': '__rxor__', - '__lshift__': '__rlshift__', - '__rshift__': '__rrshift__', - '__eq__': '__eq__', - '__ne__': '__ne__', - '__lt__': '__gt__', - '__ge__': '__le__', - '__gt__': '__lt__', - '__le__': '__ge__', -} - -normal_from_reverse_op = dict((m, n) for n, m in reverse_op_methods.items()) -reverse_op_method_set = set(reverse_op_methods.values()) - - -class OpExpr(Node): -<<<<<<< HEAD - """Binary operation (other than . or [], which have specific nodes).""" -======= - """Binary operation (other than . or [] or comparison operators, - which have specific nodes).""" ->>>>>>> master - - op = '' - left = Undefined(Node) - right = Undefined(Node) -<<<<<<< HEAD - # Inferred type for the operator method type (when relevant; None for - # 'is'). - method_type = None # type: mypy.types.Type -======= - # Inferred type for the operator method type (when relevant). - method_type = None # type: mypy.types.Type ->>>>>>> master - - def __init__(self, op: str, left: Node, right: Node) -> None: - self.op = op - self.left = left - self.right = right - self.literal = min(self.left.literal, self.right.literal) - self.literal_hash = ('Binary', op, left.literal_hash, right.literal_hash) - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_op_expr(self) - - -class ComparisonExpr(Node): - """Comparison expression (e.g. a < b > c < d).""" - - operators = Undefined(List[str]) - operands = Undefined(List[Node]) - # Inferred type for the operator methods (when relevant; None for 'is'). - method_types = Undefined(List["mypy.types.Type"]) - - def __init__(self, operators: List[str], operands: List[Node]) -> None: - self.operators = operators - self.operands = operands - self.method_types = [] - self.literal = min(o.literal for o in self.operands) - self.literal_hash = ( ('Comparison',) + tuple(operators) + - tuple(o.literal_hash for o in operands) ) - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_comparison_expr(self) - - -class SliceExpr(Node): - """Slice expression (e.g. 'x:y', 'x:', '::2' or ':'). - - This is only valid as index in index expressions. - """ - - begin_index = Undefined(Node) # May be None - end_index = Undefined(Node) # May be None - stride = Undefined(Node) # May be None - - def __init__(self, begin_index: Node, end_index: Node, - stride: Node) -> None: - self.begin_index = begin_index - self.end_index = end_index - self.stride = stride - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_slice_expr(self) - - -class CastExpr(Node): - """Cast expression cast(type, expr).""" - - expr = Undefined(Node) - type = Undefined('mypy.types.Type') - - def __init__(self, expr: Node, typ: 'mypy.types.Type') -> None: - self.expr = expr - self.type = typ - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_cast_expr(self) - - -class SuperExpr(Node): - """Expression super().name""" - - name = '' -<<<<<<< HEAD - info = Undefined('TypeInfo') # Type that contains this super expression -======= - info = Undefined('TypeInfo') # Type that contains this super expression ->>>>>>> master - - def __init__(self, name: str) -> None: - self.name = name - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_super_expr(self) - - -class FuncExpr(FuncItem): - """Lambda expression""" - - def name(self) -> str: - return '' - - def expr(self) -> Node: - """Return the expression (the body) of the lambda.""" - ret = cast(ReturnStmt, self.body.body[0]) - return ret.expr - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_func_expr(self) - - -class ListExpr(Node): - """List literal expression [...].""" - -<<<<<<< HEAD - items = Undefined(List[Node] ) -======= - items = Undefined(List[Node]) ->>>>>>> master - - def __init__(self, items: List[Node]) -> None: - self.items = items - if all(x.literal == LITERAL_YES for x in items): - self.literal = LITERAL_YES - self.literal_hash = ('List',) + tuple(x.literal_hash for x in items) - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_list_expr(self) - - -class DictExpr(Node): - """Dictionary literal expression {key: value, ...}.""" - - items = Undefined(List[Tuple[Node, Node]]) - - def __init__(self, items: List[Tuple[Node, Node]]) -> None: - self.items = items - if all(x[0].literal == LITERAL_YES and x[1].literal == LITERAL_YES - for x in items): - self.literal = LITERAL_YES - self.literal_hash = ('Dict',) + tuple((x[0].literal_hash, x[1].literal_hash) for x in items) - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_dict_expr(self) - - -class TupleExpr(Node): - """Tuple literal expression (..., ...)""" - - items = Undefined(List[Node]) - - def __init__(self, items: List[Node]) -> None: - self.items = items - if all(x.literal == LITERAL_YES for x in items): - self.literal = LITERAL_YES - self.literal_hash = ('Tuple',) + tuple(x.literal_hash for x in items) - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_tuple_expr(self) - - -class SetExpr(Node): - """Set literal expression {value, ...}.""" - - items = Undefined(List[Node]) - - def __init__(self, items: List[Node]) -> None: - self.items = items - if all(x.literal == LITERAL_YES for x in items): - self.literal = LITERAL_YES - self.literal_hash = ('Set',) + tuple(x.literal_hash for x in items) - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_set_expr(self) - - -class GeneratorExpr(Node): - """Generator expression ... for ... in ... [ for ... in ... ] [ if ... ].""" - - left_expr = Undefined(Node) - sequences_expr = Undefined(List[Node]) - condlists = Undefined(List[List[Node]]) - indices = Undefined(List[List[NameExpr]]) - types = Undefined(List[List['mypy.types.Type']]) - - def __init__(self, left_expr: Node, indices: List[List[NameExpr]], - types: List[List['mypy.types.Type']], sequences: List[Node], - condlists: List[List[Node]]) -> None: - self.left_expr = left_expr - self.sequences = sequences - self.condlists = condlists - self.indices = indices - self.types = types - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_generator_expr(self) - - -class ListComprehension(Node): - """List comprehension (e.g. [x + 1 for x in a])""" - - generator = Undefined(GeneratorExpr) - - def __init__(self, generator: GeneratorExpr) -> None: - self.generator = generator - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_list_comprehension(self) - - -class ConditionalExpr(Node): - """Conditional expression (e.g. x if y else z)""" - - cond = Undefined(Node) - if_expr = Undefined(Node) - else_expr = Undefined(Node) - - def __init__(self, cond: Node, if_expr: Node, else_expr: Node) -> None: - self.cond = cond - self.if_expr = if_expr - self.else_expr = else_expr - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_conditional_expr(self) - - -class UndefinedExpr(Node): - """Expression Undefined(type), used as an initializer. - - This is used to declare the type of a variable without initializing with - a proper value. For example: - - x = Undefined(List[int]) - """ - - def __init__(self, type: 'mypy.types.Type') -> None: - self.type = type - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_undefined_expr(self) - - -class TypeApplication(Node): - """Type application expr[type, ...]""" - - expr = Undefined(Node) - types = Undefined(List['mypy.types.Type']) - - def __init__(self, expr: Node, types: List['mypy.types.Type']) -> None: - self.expr = expr - self.types = types - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_type_application(self) - - -class TypeVarExpr(SymbolNode): - """Type variable expression typevar(...).""" - - _name = '' - _fullname = '' - # Value restriction: only types in the list are valid as values. If the - # list is empty, there is no restriction. - values = Undefined(List['mypy.types.Type']) - - def __init__(self, name: str, fullname: str, - values: List['mypy.types.Type']) -> None: - self._name = name - self._fullname = fullname - self.values = values - - def name(self) -> str: - return self._name - - def fullname(self) -> str: - return self._fullname - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_type_var_expr(self) - - -class DucktypeExpr(Node): - """Ducktype class decorator expression ducktype(...).""" - - type = Undefined('mypy.types.Type') - - def __init__(self, type: 'mypy.types.Type') -> None: - self.type = type - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_ducktype_expr(self) - - -class DisjointclassExpr(Node): - """Disjoint class class decorator expression disjointclass(cls).""" - - cls = Undefined(RefExpr) - - def __init__(self, cls: RefExpr) -> None: - self.cls = cls - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_disjointclass_expr(self) - - -class CoerceExpr(Node): - """Implicit coercion expression. - - This is used only when compiling/transforming. These are inserted - after type checking. - """ - - expr = Undefined(Node) - target_type = Undefined('mypy.types.Type') - source_type = Undefined('mypy.types.Type') - is_wrapper_class = False - - def __init__(self, expr: Node, target_type: 'mypy.types.Type', - source_type: 'mypy.types.Type', - is_wrapper_class: bool) -> None: - self.expr = expr - self.target_type = target_type - self.source_type = source_type - self.is_wrapper_class = is_wrapper_class - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_coerce_expr(self) - - -class JavaCast(Node): - # TODO obsolete; remove - expr = Undefined(Node) - target = Undefined('mypy.types.Type') - - def __init__(self, expr: Node, target: 'mypy.types.Type') -> None: - self.expr = expr - self.target = target - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_java_cast(self) - - -class TypeExpr(Node): - """Expression that evaluates to a runtime representation of a type. - - This is used only for runtime type checking. This node is always generated - only after type checking. - """ - - type = Undefined('mypy.types.Type') - - def __init__(self, typ: 'mypy.types.Type') -> None: - self.type = typ - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_type_expr(self) - - -class TempNode(Node): - """Temporary dummy node used during type checking. - - This node is not present in the original program; it is just an artifact - of the type checker implementation. It only represents an opaque node with - some fixed type. - """ - - type = Undefined('mypy.types.Type') - - def __init__(self, typ: 'mypy.types.Type') -> None: - self.type = typ - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_temp_node(self) - - -class TypeInfo(SymbolNode): - """Class representing the type structure of a single class. - - The corresponding ClassDef instance represents the parse tree of - the class. - """ - -<<<<<<< HEAD - _fullname = None # type: str # Fully qualified name - defn = Undefined(ClassDef) # Corresponding ClassDef -======= - _fullname = None # type: str # Fully qualified name - defn = Undefined(ClassDef) # Corresponding ClassDef ->>>>>>> master - # Method Resolution Order: the order of looking up attributes. The first - # value always to refers to self. - mro = Undefined(List['TypeInfo']) - subtypes = Undefined(Set['TypeInfo']) # Direct subclasses - names = Undefined('SymbolTable') # Names defined directly in this type - is_abstract = False # Does the class have any abstract attributes? - abstract_attributes = Undefined(List[str]) - # All classes in this build unit that are disjoint with this class. - disjoint_classes = Undefined(List['TypeInfo']) - # Targets of disjointclass declarations present in this class only (for - # generating error messages). - disjointclass_decls = Undefined(List['TypeInfo']) - - # Information related to type annotations. - - # Generic type variable names - type_vars = Undefined(List[str]) - - # Direct base classes. - bases = Undefined(List['mypy.types.Instance']) - - # Duck type compatibility (ducktype decorator) -<<<<<<< HEAD - ducktype = None # type: mypy.types.Type -======= - ducktype = None # type: mypy.types.Type ->>>>>>> master - - def __init__(self, names: 'SymbolTable', defn: ClassDef) -> None: - """Initialize a TypeInfo.""" - self.names = names - self.defn = defn - self.subtypes = set() - self.mro = [] - self.type_vars = [] - self.bases = [] - self._fullname = defn.fullname - self.is_abstract = False - self.abstract_attributes = [] - self.disjoint_classes = [] - self.disjointclass_decls = [] - if defn.type_vars: - for vd in defn.type_vars: - self.type_vars.append(vd.name) - - def name(self) -> str: - """Short name.""" - return self.defn.name - - def fullname(self) -> str: - return self._fullname - - def is_generic(self) -> bool: - """Is the type generic (i.e. does it have type variables)?""" - return self.type_vars is not None and len(self.type_vars) > 0 - - def get(self, name: str) -> 'SymbolTableNode': - for cls in self.mro: - n = cls.names.get(name) - if n: - return n - return None - - def __getitem__(self, name: str) -> 'SymbolTableNode': - n = self.get(name) - if n: - return n - else: - raise KeyError(name) - - def __repr__(self) -> str: - return '' % self.fullname() - -<<<<<<< HEAD - -======= ->>>>>>> master - # IDEA: Refactor the has* methods to be more consistent and document - # them. - - def has_readable_member(self, name: str) -> bool: - return self.get(name) is not None - - def has_writable_member(self, name: str) -> bool: - return self.has_var(name) - - def has_var(self, name: str) -> bool: - return self.get_var(name) is not None - - def has_method(self, name: str) -> bool: - return self.get_method(name) is not None - - def get_var(self, name: str) -> Var: - for cls in self.mro: - if name in cls.names: - node = cls.names[name].node - if isinstance(node, Var): - return cast(Var, node) - else: - return None - return None - - def get_var_or_getter(self, name: str) -> SymbolNode: - # TODO getter - return self.get_var(name) - - def get_var_or_setter(self, name: str) -> SymbolNode: - # TODO setter - return self.get_var(name) - - def get_method(self, name: str) -> FuncBase: - for cls in self.mro: - if name in cls.names: - node = cls.names[name].node - if isinstance(node, FuncBase): - return node - else: - return None - return None - - def calculate_mro(self) -> None: - """Calculate and set mro (method resolution order). - - Raise MroError if cannot determine mro. - """ - self.mro = linearize_hierarchy(self) - - def has_base(self, fullname: str) -> bool: - """Return True if type has a base type with the specified name. - - This can be either via extension or via implementation. - """ - for cls in self.mro: - if cls.fullname() == fullname: - return True - return False - - def all_subtypes(self) -> 'Set[TypeInfo]': - """Return TypeInfos of all subtypes, including this type, as a set.""" - subtypes = set([self]) - for subt in self.subtypes: - for t in subt.all_subtypes(): - subtypes.add(t) - return subtypes - - def all_base_classes(self) -> 'List[TypeInfo]': - """Return a list of base classes, including indirect bases.""" - assert False - - def direct_base_classes(self) -> 'List[TypeInfo]': - """Return a direct base classes. - - Omit base classes of other base classes. - """ - return [base.type for base in self.bases] - - def __str__(self) -> str: - """Return a string representation of the type. - - This includes the most important information about the type. - """ - base = None # type: str - if self.bases: - base = 'Bases({})'.format(', '.join(str(base) - for base in self.bases)) - return dump_tagged(['Name({})'.format(self.fullname()), - base, - ('Names', sorted(self.names.keys()))], - 'TypeInfo') - - -class SymbolTableNode: - # LDEF/GDEF/MDEF/UNBOUND_TVAR/TVAR/... - kind = None # type: int - # AST node of definition (FuncDef/Var/TypeInfo/Decorator/TypeVarExpr, - # or None for a bound type variable). - node = Undefined(SymbolNode) - # Type variable id (for bound type variables only) - tvar_id = 0 - # Module id (e.g. "foo.bar") or None - mod_id = '' - # If None, fall back to type of node - type_override = Undefined('mypy.types.Type') - - def __init__(self, kind: int, node: SymbolNode, mod_id: str = None, - typ: 'mypy.types.Type' = None, tvar_id: int = 0) -> None: - self.kind = kind - self.node = node - self.type_override = typ - self.mod_id = mod_id - self.tvar_id = tvar_id - - @property - def fullname(self) -> str: - if self.node is not None: - return self.node.fullname() - else: - return None - - @property - def type(self) -> 'mypy.types.Type': - # IDEA: Get rid of the Any type. - node = self.node # type: Any - if self.type_override is not None: - return self.type_override - elif ((isinstance(node, Var) or isinstance(node, FuncDef)) - and node.type is not None): - return node.type - elif isinstance(node, Decorator): - return (cast(Decorator, node)).var.type - else: - return None - - def __str__(self) -> str: - s = '{}/{}'.format(node_kinds[self.kind], short_type(self.node)) - if self.mod_id is not None: - s += ' ({})'.format(self.mod_id) - # Include declared type of variables and functions. - if self.type is not None: - s += ' : {}'.format(self.type) - return s - - -class SymbolTable(Dict[str, SymbolTableNode]): - def __str__(self) -> str: - a = List[str]() - for key, value in self.items(): - # Filter out the implicit import of builtins. - if isinstance(value, SymbolTableNode): - if (value.fullname != 'builtins' and - value.fullname.split('.')[-1] not in - implicit_module_attrs): - a.append(' ' + str(key) + ' : ' + str(value)) - else: - a.append(' ') - a = sorted(a) - a.insert(0, 'SymbolTable(') - a[-1] += ')' - return '\n'.join(a) - - -def clean_up(s: str) -> str: - # TODO remove - return re.sub('.*::', '', s) - - -def function_type(func: FuncBase) -> 'mypy.types.FunctionLike': - if func.type: - return cast(mypy.types.FunctionLike, func.type) - else: - # Implicit type signature with dynamic types. - # Overloaded functions always have a signature, so func must be an - # ordinary function. - fdef = cast(FuncDef, func) - name = func.name() - if name: - name = '"{}"'.format(name) - names = [] # type: List[str] - for arg in fdef.args: - names.append(arg.name()) - return mypy.types.Callable([mypy.types.AnyType()] * len(fdef.args), - fdef.arg_kinds, - names, - mypy.types.AnyType(), - False, - name) - - -@overload -def method_type(func: FuncBase) -> 'mypy.types.FunctionLike': - """Return the signature of a method (omit self).""" - return method_type(function_type(func)) - - -@overload -def method_type(sig: 'mypy.types.FunctionLike') -> 'mypy.types.FunctionLike': - if isinstance(sig, mypy.types.Callable): - csig = cast(mypy.types.Callable, sig) - return method_callable(csig) - else: - osig = cast(mypy.types.Overloaded, sig) - items = List[mypy.types.Callable]() - for c in osig.items(): - items.append(method_callable(c)) - return mypy.types.Overloaded(items) - - -def method_callable(c: 'mypy.types.Callable') -> 'mypy.types.Callable': - return mypy.types.Callable(c.arg_types[1:], - c.arg_kinds[1:], - c.arg_names[1:], - c.ret_type, - c.is_type_obj(), - c.name, - c.variables, - c.bound_vars) - - -class MroError(Exception): - """Raised if a consistent mro cannot be determined for a class.""" - - -def linearize_hierarchy(info: TypeInfo) -> List[TypeInfo]: - # TODO describe - if info.mro: - return info.mro - bases = info.direct_base_classes() - return [info] + merge([linearize_hierarchy(base) for base in bases] + - [bases]) - - -def merge(seqs: List[List[TypeInfo]]) -> List[TypeInfo]: - seqs = [s[:] for s in seqs] - result = List[TypeInfo]() - while True: - seqs = [s for s in seqs if s] - if not seqs: - return result - for seq in seqs: - head = seq[0] - if not [s for s in seqs if head in s[1:]]: - break - else: - raise MroError() - result.append(head) - for s in seqs: - if s[0] is head: - del s[0] diff --git a/mypy/output.py.orig b/mypy/output.py.orig deleted file mode 100644 index 36fbde86e010..000000000000 --- a/mypy/output.py.orig +++ /dev/null @@ -1,629 +0,0 @@ -"""Parse tree pretty printer.""" - -import re - -import typing - -from mypy import nodes -from mypy.visitor import NodeVisitor -from mypy.typerepr import CommonTypeRepr - - -class OutputVisitor(NodeVisitor): - """Parse tree Node visitor that outputs the original, formatted - source code. You can implement custom transformations by - subclassing this class. - """ - def __init__(self): - super().__init__() - self.result = [] # strings - self.line_number = 1 - # If True, omit the next character if it is a space - self.omit_next_space = False - # Number of spaces of indent right now - self.indent = 0 - # Number of spaces of extra indent to add when encountering a line - # break - self.extra_indent = 0 - self.block_depth = 0 - - def output(self): - """Return a string representation of the output.""" - return ''.join(self.result) - - def visit_mypy_file(self, o): - self.nodes(o.defs) - self.token(o.repr.eof) - - def visit_import(self, o): - r = o.repr - self.token(r.import_tok) - for i in range(len(r.components)): - self.tokens(r.components[i]) - if r.as_names[i]: - self.tokens(r.as_names[i]) - if i < len(r.commas): - self.token(r.commas[i]) - self.token(r.br) - - def visit_import_from(self, o): - self.output_import_from_or_all(o) - - def visit_import_all(self, o): - self.output_import_from_or_all(o) - - def output_import_from_or_all(self, o): - r = o.repr - self.token(r.from_tok) - self.tokens(r.components) - self.token(r.import_tok) - self.token(r.lparen) - for misc, comma in r.names: - self.tokens(misc) - self.token(comma) - self.token(r.rparen) - self.token(r.br) - - def visit_class_def(self, o): - r = o.repr - self.tokens([r.class_tok, r.name]) - self.type_vars(o.type_vars) - self.token(r.lparen) - for i in range(len(o.base_types)): - if o.base_types[i].repr: - self.type(o.base_types[i]) - if i < len(r.commas): - self.token(r.commas[i]) - self.token(r.rparen) - self.node(o.defs) - - def type_vars(self, v): - # IDEA: Combine this with type_vars in TypeOutputVisitor. - if v and v.repr: - r = v.repr - self.token(r.langle) - for i in range(len(v.items)): - d = v.items[i] - self.token(d.repr.name) - self.token(d.repr.is_tok) - if d.bound: - self.type(d.bound) - if i < len(r.commas): - self.token(r.commas[i]) - self.token(r.rangle) - - def visit_func_def(self, o): - r = o.repr - - if r.def_tok: - self.token(r.def_tok) - else: - self.type(o.type.items()[0].ret_type) - - self.token(r.name) - - self.function_header(o, r.args, o.arg_kinds) - - self.node(o.body) - - def visit_overloaded_func_def(self, o): - for f in o.items: - f.accept(self) - - def function_header(self, o, arg_repr, arg_kinds, pre_args_func=None, - erase_type=False, strip_space_before_first_arg=False): - r = o.repr - - t = None - if o.type and not erase_type: - t = o.type - - init = o.init - - if t: - self.type_vars(t.variables) - - self.token(arg_repr.lseparator) - if pre_args_func: - pre_args_func() - asterisk = 0 - for i in range(len(arg_repr.arg_names)): - if t: - if t.arg_types[i].repr: - self.type(t.arg_types[i]) - if arg_kinds[i] in [nodes.ARG_STAR, nodes.ARG_STAR2]: - self.token(arg_repr.asterisk[asterisk]) - asterisk += 1 - if not erase_type: - self.token(arg_repr.arg_names[i]) - else: - n = arg_repr.arg_names[i].rep() - if i == 0 and strip_space_before_first_arg: - # Remove spaces before the first argument name. Generally - # spaces are only present after a type, and if we erase the - # type, we should also erase also the spaces. - n = re.sub(' +([a-zA-Z0-9_]+)$', '\\1', n) - self.string(n) - if i < len(arg_repr.assigns): - self.token(arg_repr.assigns[i]) - if init and i < len(init) and init[i]: - self.node(init[i].rvalue) - if i < len(arg_repr.commas): - self.token(arg_repr.commas[i]) - self.token(arg_repr.rseparator) - - def visit_var_def(self, o): - r = o.repr - if r: - for v in o.items: - self.type(v.type) - self.node(v) - self.token(r.assign) - self.node(o.init) - self.token(r.br) - - def visit_var(self, o): - r = o.repr - self.token(r.name) - self.token(r.comma) - - def visit_decorator(self, o): - for at, br, dec in zip(o.repr.ats, o.repr.brs, o.decorators): - self.token(at) - self.node(dec) - self.token(br) - self.node(o.func) - - # Statements - - def visit_block(self, o): - r = o.repr - self.tokens([r.colon, r.br, r.indent]) - self.block_depth += 1 - old_indent = self.indent - self.indent = len(r.indent.string) - self.nodes(o.body) - self.token(r.dedent) - self.indent = old_indent - self.block_depth -= 1 - - def visit_global_decl(self, o): - r = o.repr - self.token(r.global_tok) - for i in range(len(r.names)): - self.token(r.names[i]) - if i < len(r.commas): - self.token(r.commas[i]) - self.token(r.br) - - def visit_expression_stmt(self, o): - self.node(o.expr) - self.token(o.repr.br) - - def visit_assignment_stmt(self, o): - r = o.repr - i = 0 - for lv in o.lvalues: - self.node(lv) - self.token(r.assigns[i]) - i += 1 - self.node(o.rvalue) - self.token(r.br) - - def visit_operator_assignment_stmt(self, o): - r = o.repr - self.node(o.lvalue) - self.token(r.assign) - self.node(o.rvalue) - self.token(r.br) - - def visit_return_stmt(self, o): - self.simple_stmt(o, o.expr) - - def visit_assert_stmt(self, o): - self.simple_stmt(o, o.expr) - - def visit_yield_stmt(self, o): - self.simple_stmt(o, o.expr) - -<<<<<<< HEAD - def visit_yield_from_stmt(self, o): - self.simple_stmt(o, o.expr) - -======= ->>>>>>> master - def visit_del_stmt(self, o): - self.simple_stmt(o, o.expr) - - def visit_break_stmt(self, o): - self.simple_stmt(o) - - def visit_continue_stmt(self, o): - self.simple_stmt(o) - - def visit_pass_stmt(self, o): - self.simple_stmt(o) - - def simple_stmt(self, o, expr=None): - self.token(o.repr.keyword) - self.node(expr) - self.token(o.repr.br) - - def visit_raise_stmt(self, o): - self.token(o.repr.raise_tok) - self.node(o.expr) - if o.from_expr: - self.token(o.repr.from_tok) - self.node(o.from_expr) - self.token(o.repr.br) - - def visit_while_stmt(self, o): - self.token(o.repr.while_tok) - self.node(o.expr) - self.node(o.body) - if o.else_body: - self.token(o.repr.else_tok) - self.node(o.else_body) - - def visit_for_stmt(self, o): - r = o.repr - self.token(r.for_tok) - for i in range(len(o.index)): - self.type(o.types[i]) - self.node(o.index[i]) - self.token(r.commas[i]) - self.token(r.in_tok) - self.node(o.expr) - - self.node(o.body) - if o.else_body: - self.token(r.else_tok) - self.node(o.else_body) - - def visit_if_stmt(self, o): - r = o.repr - self.token(r.if_tok) - self.node(o.expr[0]) - self.node(o.body[0]) - for i in range(1, len(o.expr)): - self.token(r.elif_toks[i - 1]) - self.node(o.expr[i]) - self.node(o.body[i]) - self.token(r.else_tok) - if o.else_body: - self.node(o.else_body) - - def visit_try_stmt(self, o): - r = o.repr - self.token(r.try_tok) - self.node(o.body) - for i in range(len(o.types)): - self.token(r.except_toks[i]) - self.node(o.types[i]) - self.token(r.as_toks[i]) - self.node(o.vars[i]) - self.node(o.handlers[i]) - if o.else_body: - self.token(r.else_tok) - self.node(o.else_body) - if o.finally_body: - self.token(r.finally_tok) - self.node(o.finally_body) - - def visit_with_stmt(self, o): - self.token(o.repr.with_tok) - for i in range(len(o.expr)): - self.node(o.expr[i]) - self.token(o.repr.as_toks[i]) - self.node(o.name[i]) - if i < len(o.repr.commas): - self.token(o.repr.commas[i]) - self.node(o.body) - - # Expressions - - def visit_int_expr(self, o): - self.token(o.repr.int) - - def visit_str_expr(self, o): - self.tokens(o.repr.string) - - def visit_bytes_expr(self, o): - self.tokens(o.repr.string) - - def visit_float_expr(self, o): - self.token(o.repr.float) - - def visit_paren_expr(self, o): - self.token(o.repr.lparen) - self.node(o.expr) - self.token(o.repr.rparen) - - def visit_name_expr(self, o): - # Supertype references may not have a representation. - if o.repr: - self.token(o.repr.id) - - def visit_member_expr(self, o): - self.node(o.expr) - self.token(o.repr.dot) - self.token(o.repr.name) - - def visit_index_expr(self, o): - self.node(o.base) - self.token(o.repr.lbracket) - self.node(o.index) - self.token(o.repr.rbracket) - - def visit_slice_expr(self, o): - self.node(o.begin_index) - self.token(o.repr.colon) - self.node(o.end_index) - self.token(o.repr.colon2) - self.node(o.stride) - -<<<<<<< HEAD - def visit_yield_from_expr(self, o): - o.expr.accept(self) - -======= ->>>>>>> master - def visit_call_expr(self, o): - r = o.repr - self.node(o.callee) - self.token(r.lparen) - nargs = len(o.args) - nkeyword = 0 - for i in range(nargs): - if o.arg_kinds[i] == nodes.ARG_STAR: - self.token(r.star) - elif o.arg_kinds[i] == nodes.ARG_STAR2: - self.token(r.star2) - elif o.arg_kinds[i] == nodes.ARG_NAMED: - self.tokens(r.keywords[nkeyword]) - nkeyword += 1 - self.node(o.args[i]) - if i < len(r.commas): - self.token(r.commas[i]) - self.token(r.rparen) - - def visit_op_expr(self, o): - self.node(o.left) - self.tokens([o.repr.op]) - self.node(o.right) - -<<<<<<< HEAD -======= - def visit_comparison_expr(self, o): - self.node(o.operands[0]) - for ops, operand in zip(o.repr.operators, o.operands[1:]): - # ops = op, op2 - self.tokens(list(ops)) - self.node(operand) - ->>>>>>> master - def visit_cast_expr(self, o): - self.token(o.repr.lparen) - self.type(o.type) - self.token(o.repr.rparen) - self.node(o.expr) - - def visit_super_expr(self, o): - r = o.repr - self.tokens([r.super_tok, r.lparen, r.rparen, r.dot, r.name]) - - def visit_unary_expr(self, o): - self.token(o.repr.op) - self.node(o.expr) - - def visit_list_expr(self, o): - r = o.repr - self.token(r.lbracket) - self.comma_list(o.items, r.commas) - self.token(r.rbracket) - - def visit_set_expr(self, o): - self.visit_list_expr(o) - - def visit_tuple_expr(self, o): - r = o.repr - self.token(r.lparen) - self.comma_list(o.items, r.commas) - self.token(r.rparen) - - def visit_dict_expr(self, o): - r = o.repr - self.token(r.lbrace) - i = 0 - for k, v in o.items: - self.node(k) - self.token(r.colons[i]) - self.node(v) - if i < len(r.commas): - self.token(r.commas[i]) - i += 1 - self.token(r.rbrace) - - def visit_func_expr(self, o): - r = o.repr - self.token(r.lambda_tok) - self.function_header(o, r.args, o.arg_kinds) - self.token(r.colon) - self.node(o.body.body[0].expr) - - def visit_type_application(self, o): - self.node(o.expr) - self.token(o.repr.langle) - self.type_list(o.types, o.repr.commas) - self.token(o.repr.rangle) - - def visit_generator_expr(self, o): - r = o.repr - self.node(o.left_expr) - for i in range(len(o.indices)): - self.token(r.for_toks[i]) - for j in range(len(o.indices[i])): - self.node(o.types[i][j]) - self.node(o.indices[i][j]) - if j < len(o.indices[i]) - 1: - self.token(r.commas[0]) - self.token(r.in_toks[i]) - self.node(o.sequences[i]) - for cond, if_tok in zip(o.condlists[i], r.if_toklists[i]): - self.token(if_tok) - self.node(cond) - - def visit_list_comprehension(self, o): - self.token(o.repr.lbracket) - self.node(o.generator) - self.token(o.repr.rbracket) - - # Helpers - - def line(self): - return self.line_number - - def string(self, s): - """Output a string.""" - if self.omit_next_space: - if s.startswith(' '): - s = s[1:] - self.omit_next_space = False - self.line_number += s.count('\n') - if s != '': - s = s.replace('\n', '\n' + ' ' * self.extra_indent) - self.result.append(s) - - def token(self, t): - """Output a token.""" - self.string(t.rep()) - - def tokens(self, a): - """Output an array of tokens.""" - for t in a: - self.token(t) - - def node(self, n): - """Output a node.""" - if n: n.accept(self) - - def nodes(self, a): - """Output an array of nodes.""" - for n in a: - self.node(n) - - def comma_list(self, items, commas): - for i in range(len(items)): - self.node(items[i]) - if i < len(commas): - self.token(commas[i]) - - def type_list(self, items, commas): - for i in range(len(items)): - self.type(items[i]) - if i < len(commas): - self.token(commas[i]) - - def type(self, t): - """Output a type.""" - if t: - v = TypeOutputVisitor() - t.accept(v) - self.string(v.output()) - - def last_output_char(self): - if self.result and self.result[-1]: - return self.result[-1][-1] - else: - return '' - - -class TypeOutputVisitor: - """Type visitor that outputs source code.""" - def __init__(self): - self.result = [] # strings - - def output(self): - """Return a string representation of the output.""" - return ''.join(self.result) - - def visit_unbound_type(self, t): - self.visit_instance(t) - - def visit_any(self, t): - if t.repr: - self.token(t.repr.any_tok) - - def visit_void(self, t): - if t.repr: - self.token(t.repr.void) - - def visit_instance(self, t): - r = t.repr - if isinstance(r, CommonTypeRepr): - self.tokens(r.components) - self.token(r.langle) - self.comma_list(t.args, r.commas) - self.token(r.rangle) - else: - # List type t[]. - assert len(t.args) == 1 - self.comma_list(t.args, []) - self.tokens([r.lbracket, r.rbracket]) - - def visit_type_var(self, t): - self.token(t.repr.name) - - def visit_tuple_type(self, t): - r = t.repr - self.tokens(r.components) - self.token(r.langle) - self.comma_list(t.items, r.commas) - self.token(r.rangle) - - def visit_callable(self, t): - r = t.repr - self.tokens([r.func, r.langle]) - t.ret_type.accept(self) - self.token(r.lparen) - self.comma_list(t.arg_types, r.commas) - self.tokens([r.rparen, r.rangle]) - - def type_vars(self, v): - if v and v.repr: - r = v.repr - self.token(r.langle) - for i in range(len(v.items)): - d = v.items[i] - self.token(d.repr.name) - self.token(d.repr.is_tok) - if d.bound: - self.type(d.bound) - if i < len(r.commas): - self.token(r.commas[i]) - self.token(r.rangle) - - # Helpers - - def string(self, s): - """Output a string.""" - self.result.append(s) - - def token(self, t): - """Output a token.""" - self.result.append(t.rep()) - - def tokens(self, a): - """Output an array of tokens.""" - for t in a: - self.token(t) - - def type(self, n): - """Output a type.""" - if n: n.accept(self) - - def comma_list(self, items, commas): - for i in range(len(items)): - self.type(items[i]) - if i < len(commas): - self.token(commas[i]) diff --git a/mypy/parse.py.orig b/mypy/parse.py.orig deleted file mode 100755 index d696cd3193e2..000000000000 --- a/mypy/parse.py.orig +++ /dev/null @@ -1,1845 +0,0 @@ -"""Mypy parser. - -Constructs a parse tree (abstract syntax tree) based on a string -representing a source file. Performs only minimal semantic checks. -""" - -import re - -from typing import Undefined, List, Tuple, Any, Set, cast - -from mypy import lex -from mypy.lex import ( - Token, Eof, Bom, Break, Name, Colon, Dedent, IntLit, StrLit, BytesLit, - UnicodeLit, FloatLit, Op, Indent, Keyword, Punct, LexError -) -import mypy.types -from mypy.nodes import ( - MypyFile, Import, Node, ImportAll, ImportFrom, FuncDef, OverloadedFuncDef, - ClassDef, Decorator, Block, Var, VarDef, OperatorAssignmentStmt, - ExpressionStmt, AssignmentStmt, ReturnStmt, RaiseStmt, AssertStmt, - YieldStmt, DelStmt, BreakStmt, ContinueStmt, PassStmt, GlobalDecl, - WhileStmt, ForStmt, IfStmt, TryStmt, WithStmt, CastExpr, ParenExpr, - TupleExpr, GeneratorExpr, ListComprehension, ListExpr, ConditionalExpr, - DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, - FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, -<<<<<<< HEAD - UnaryExpr, FuncExpr, TypeApplication, PrintStmt, ImportBase, YieldFromStmt, - YieldFromExpr -======= - UnaryExpr, FuncExpr, TypeApplication, PrintStmt, ImportBase, ComparisonExpr ->>>>>>> master -) -from mypy import nodes -from mypy import noderepr -from mypy.errors import Errors, CompileError -from mypy.types import Void, Type, Callable, AnyType, UnboundType -from mypy.parsetype import ( - parse_type, parse_types, parse_signature, TypeParseError -) - - -precedence = { - '**': 16, - '-u': 15, '+u': 15, '~': 15, # unary operators (-, + and ~) - '': 14, - '*': 13, '/': 13, '//': 13, '%': 13, - '+': 12, '-': 12, - '>>': 11, '<<': 11, - '&': 10, - '^': 9, - '|': 8, - '==': 7, '!=': 7, '<': 7, '>': 7, '<=': 7, '>=': 7, 'is': 7, 'in': 7, - 'not': 6, - 'and': 5, - 'or': 4, - '': 3, # conditional expression - '': 2, # list comprehension - ',': 1} - - -op_assign = set([ - '+=', '-=', '*=', '/=', '//=', '%=', '**=', '|=', '&=', '^=', '>>=', - '<<=']) - -op_comp = set([ - '>', '<', '==', '>=', '<=', '<>', '!=', 'is', 'is', 'in', 'not']) - -none = Token('') # Empty token - - -def parse(s: str, fnam: str = None, errors: Errors = None, - pyversion: int = 3, custom_typing_module: str = None) -> MypyFile: - """Parse a source file, without doing any semantic analysis. - - Return the parse tree. If errors is not provided, raise ParseError - on failure. Otherwise, use the errors object to report parse errors. - - The pyversion argument determines the Python syntax variant (2 for 2.x and - 3 for 3.x). - """ - parser = Parser(fnam, errors, pyversion, custom_typing_module) - tree = parser.parse(s) - tree.path = fnam - return tree - - -class Parser: - tok = Undefined(List[Token]) - ind = 0 - errors = Undefined(Errors) - raise_on_error = False - - # Are we currently parsing the body of a class definition? - is_class_body = False - # All import nodes encountered so far in this parse unit. - imports = Undefined(List[ImportBase]) - # Names imported from __future__. - future_options = Undefined(List[str]) - - def __init__(self, fnam: str, errors: Errors, pyversion: int, - custom_typing_module: str = None) -> None: - self.raise_on_error = errors is None - self.pyversion = pyversion - self.custom_typing_module = custom_typing_module - if errors is not None: - self.errors = errors - else: - self.errors = Errors() - if fnam is not None: - self.errors.set_file(fnam) - else: - self.errors.set_file('') - - def parse(self, s: str) -> MypyFile: - self.tok = lex.lex(s) - self.ind = 0 - self.imports = [] - self.future_options = [] - file = self.parse_file() - if self.raise_on_error and self.errors.is_errors(): - self.errors.raise_error() - return file - - def parse_file(self) -> MypyFile: - """Parse a mypy source file.""" - is_bom = self.parse_bom() - defs = self.parse_defs() - eof = self.expect_type(Eof) - node = MypyFile(defs, self.imports, is_bom) - self.set_repr(node, noderepr.MypyFileRepr(eof)) - return node - - # Parse the initial part - - def parse_bom(self) -> bool: - """Parse the optional byte order mark at the beginning of a file.""" - if isinstance(self.current(), Bom): - self.expect_type(Bom) - if isinstance(self.current(), Break): - self.expect_break() - return True - else: - return False - - def parse_import(self) -> Import: - import_tok = self.expect('import') - ids = List[Tuple[str, str]]() - id_toks = List[List[Token]]() - commas = List[Token]() - as_names = List[Tuple[Token, Token]]() - while True: - id, components = self.parse_qualified_name() - if id == self.custom_typing_module: - id = 'typing' - id_toks.append(components) - as_id = id - if self.current_str() == 'as': - as_tok = self.expect('as') - name_tok = self.expect_type(Name) - as_id = name_tok.string - as_names.append((as_tok, name_tok)) - else: - as_names.append(None) - ids.append((id, as_id)) - if self.current_str() != ',': - break - commas.append(self.expect(',')) - br = self.expect_break() - node = Import(ids) - self.imports.append(node) - self.set_repr(node, noderepr.ImportRepr(import_tok, id_toks, as_names, - commas, br)) - return node - - def parse_import_from(self) -> Node: - from_tok = self.expect('from') - name, components = self.parse_qualified_name() - if name == self.custom_typing_module: - name = 'typing' - import_tok = self.expect('import') - name_toks = List[Tuple[List[Token], Token]]() - lparen = none - rparen = none - node = None # type: ImportBase - if self.current_str() == '*': - name_toks.append(([self.skip()], none)) - node = ImportAll(name) - else: - is_paren = self.current_str() == '(' - if is_paren: - lparen = self.expect('(') - targets = List[Tuple[str, str]]() - while True: - id, as_id, toks = self.parse_import_name() - if '%s.%s' % (name, id) == self.custom_typing_module: - if targets or self.current_str() == ',': - self.fail('You cannot import any other modules when you ' - 'import a custom typing module', - toks[0].line) - node = Import([('typing', as_id)]) - self.skip_until_break() - break - targets.append((id, as_id)) - if self.current_str() != ',': - name_toks.append((toks, none)) - break - name_toks.append((toks, self.expect(','))) - if is_paren and self.current_str() == ')': - break - if is_paren: - rparen = self.expect(')') - if node is None: - node = ImportFrom(name, targets) - br = self.expect_break() - self.imports.append(node) - # TODO: Fix representation if there is a custom typing module import. - self.set_repr(node, noderepr.ImportFromRepr( - from_tok, components, import_tok, lparen, name_toks, rparen, br)) - if name == '__future__': - self.future_options.extend(target[0] for target in targets) - return node - - def parse_import_name(self) -> Tuple[str, str, List[Token]]: - tok = self.expect_type(Name) - name = tok.string - tokens = [tok] - if self.current_str() == 'as': - tokens.append(self.skip()) - as_name = self.expect_type(Name) - tokens.append(as_name) - return name, as_name.string, tokens - else: - return name, name, tokens - - def parse_qualified_name(self) -> Tuple[str, List[Token]]: - """Parse a name with an optional module qualifier. - - Return a tuple with the name as a string and a token array - containing all the components of the name. - """ - components = List[Token]() - tok = self.expect_type(Name) - n = tok.string - components.append(tok) - while self.current_str() == '.': - components.append(self.expect('.')) - tok = self.expect_type(Name) - n += '.' + tok.string - components.append(tok) - return n, components - - # Parsing global definitions - - def parse_defs(self) -> List[Node]: - defs = List[Node]() - while not self.eof(): - try: - defn = self.parse_statement() - if defn is not None: - if not self.try_combine_overloads(defn, defs): - defs.append(defn) - except ParseError: - pass - return defs - - def parse_class_def(self) -> ClassDef: - old_is_class_body = self.is_class_body - self.is_class_body = True - - type_tok = self.expect('class') - lparen = none - rparen = none -<<<<<<< HEAD - metaclass = None # type: str -======= - metaclass = None # type: str ->>>>>>> master - - try: - commas, base_types = List[Token](), List[Type]() - try: - name_tok = self.expect_type(Name) - name = name_tok.string - - self.errors.push_type(name) - - if self.current_str() == '(': - lparen = self.skip() - while True: - if self.current_str() == 'metaclass': - metaclass = self.parse_metaclass() - break - base_types.append(self.parse_super_type()) - if self.current_str() != ',': - break - commas.append(self.skip()) - rparen = self.expect(')') - except ParseError: - pass - - defs, _ = self.parse_block() - - node = ClassDef(name, defs, None, base_types, metaclass=metaclass) - self.set_repr(node, noderepr.TypeDefRepr(type_tok, name_tok, - lparen, commas, rparen)) - return node - finally: - self.errors.pop_type() - self.is_class_body = old_is_class_body - - def parse_super_type(self) -> Type: - if (isinstance(self.current(), Name) and self.current_str() != 'void'): - return self.parse_type() - else: - self.parse_error() - - def parse_metaclass(self) -> str: - self.expect('metaclass') - self.expect('=') - return self.parse_qualified_name()[0] - - def parse_decorated_function_or_class(self) -> Node: - ats = List[Token]() - brs = List[Token]() - decorators = List[Node]() - while self.current_str() == '@': - ats.append(self.expect('@')) - decorators.append(self.parse_expression()) - brs.append(self.expect_break()) - if self.current_str() != 'class': - func = self.parse_function() - func.is_decorated = True - var = Var(func.name()) - # Types of decorated functions must always be inferred. - var.is_ready = False - var.set_line(decorators[0].line) - node = Decorator(func, decorators, var) - self.set_repr(node, noderepr.DecoratorRepr(ats, brs)) - return node - else: - cls = self.parse_class_def() - cls.decorators = decorators - return cls - - def parse_function(self) -> FuncDef: - def_tok = self.expect('def') - is_method = self.is_class_body - self.is_class_body = False - try: - (name, args, init, kinds, - typ, is_error, toks) = self.parse_function_header() - - body, comment_type = self.parse_block(allow_type=True) - if comment_type: - # The function has a # type: ... signature. - if typ: - self.errors.report( - def_tok.line, 'Function has duplicate type signatures') - sig = cast(Callable, comment_type) - if is_method: - self.check_argument_kinds(kinds, - [nodes.ARG_POS] + sig.arg_kinds, - def_tok.line) - # Add implicit 'self' argument to signature. - typ = Callable(List[Type]([AnyType()]) + sig.arg_types, - kinds, - [arg.name() for arg in args], - sig.ret_type, - False) - else: - self.check_argument_kinds(kinds, sig.arg_kinds, - def_tok.line) - typ = Callable(sig.arg_types, - kinds, - [arg.name() for arg in args], - sig.ret_type, - False) - - # If there was a serious error, we really cannot build a parse tree - # node. - if is_error: - return None - - node = FuncDef(name, args, kinds, init, body, typ) - name_tok, arg_reprs = toks - node.set_line(name_tok) - self.set_repr(node, noderepr.FuncRepr(def_tok, name_tok, - arg_reprs)) - return node - finally: - self.errors.pop_function() - self.is_class_body = is_method - - def check_argument_kinds(self, funckinds: List[int], sigkinds: List[int], - line: int) -> None: - """Check that * and ** arguments are consistent. - - Arguments: - funckinds: kinds of arguments in function definition - sigkinds: kinds of arguments in signature (after # type:) - """ - for kind, token in [(nodes.ARG_STAR, '*'), - (nodes.ARG_STAR2, '**')]: - if ((kind in funckinds and - sigkinds[funckinds.index(kind)] != kind) or - (funckinds.count(kind) != sigkinds.count(kind))): - self.fail( - "Inconsistent use of '{}' in function " - "signature".format(token), line) - - def parse_function_header(self) -> Tuple[str, List[Var], List[Node], - List[int], Type, bool, - Tuple[Token, Any]]: - """Parse function header (a name followed by arguments) - - Returns a 7-tuple with the following items: - name - arguments - initializers - kinds - signature (annotation) - error flag (True if error) - (name token, representation of arguments) - """ - name_tok = none - - try: - name_tok = self.expect_type(Name) - name = name_tok.string - - self.errors.push_function(name) - - (args, init, kinds, typ, arg_repr) = self.parse_args() - except ParseError: - if not isinstance(self.current(), Break): - self.ind -= 1 # Kludge: go back to the Break token - # Resynchronise parsing by going back over :, if present. - if isinstance(self.tok[self.ind - 1], Colon): - self.ind -= 1 - return (name, [], [], [], None, True, (name_tok, None)) - - return (name, args, init, kinds, typ, False, (name_tok, arg_repr)) - - def parse_args(self) -> Tuple[List[Var], List[Node], List[int], Type, - noderepr.FuncArgsRepr]: - """Parse a function signature (...) [-> t].""" - lparen = self.expect('(') - - # Parse the argument list (everything within '(' and ')'). - (args, init, kinds, - has_inits, arg_names, - commas, asterisk, - assigns, arg_types) = self.parse_arg_list() - - rparen = self.expect(')') - - if self.current_str() == '->': - self.skip() - ret_type = self.parse_type() - else: - ret_type = None - - self.verify_argument_kinds(kinds, lparen.line) - -<<<<<<< HEAD - names = [] # type: List[str] -======= - names = [] # type: List[str] ->>>>>>> master - for arg in args: - names.append(arg.name()) - - annotation = self.build_func_annotation( - ret_type, arg_types, kinds, names, lparen.line) - - return (args, init, kinds, annotation, - noderepr.FuncArgsRepr(lparen, rparen, arg_names, commas, - assigns, asterisk)) - - def build_func_annotation(self, ret_type: Type, arg_types: List[Type], - kinds: List[int], names: List[str], - line: int, is_default_ret: bool = False) -> Type: - # Are there any type annotations? - if ((ret_type and not is_default_ret) - or arg_types != [None] * len(arg_types)): - # Yes. Construct a type for the function signature. - return self.construct_function_type(arg_types, kinds, names, - ret_type, line) - else: - return None - - def parse_arg_list( - self, allow_signature: bool = True) -> Tuple[List[Var], List[Node], - List[int], bool, - List[Token], List[Token], - List[Token], List[Token], - List[Type]]: - """Parse function definition argument list. - - This includes everything between '(' and ')'). - - Return a 9-tuple with these items: - arguments, initializers, kinds, has inits, arg name tokens, - comma tokens, asterisk tokens, assignment tokens, argument types - """ - args = [] # type: List[Var] - kinds = [] # type: List[int] - names = [] # type: List[str] - init = [] # type: List[Node] - has_inits = False -<<<<<<< HEAD - arg_types = [] # type: List[Type] - - arg_names = [] # type: List[Token] - commas = [] # type: List[Token] - asterisk = [] # type: List[Token] - assigns = [] # type: List[Token] -======= - arg_types = [] # type: List[Type] - - arg_names = [] # type: List[Token] - commas = [] # type: List[Token] - asterisk = [] # type: List[Token] - assigns = [] # type: List[Token] ->>>>>>> master - - require_named = False - bare_asterisk_before = -1 - - if self.current_str() != ')' and self.current_str() != ':': - while self.current_str() != ')': - if self.current_str() == '*' and self.peek().string == ',': - if require_named: - # can only have one bare star, must be before any - # *args or **args - self.parse_error() - self.expect('*') - require_named = True - bare_asterisk_before = len(args) - elif self.current_str() in ['*', '**']: - if bare_asterisk_before == len(args): - # named arguments must follow bare * - self.parse_error() - asterisk.append(self.skip()) - isdict = asterisk[-1].string == '**' - name = self.expect_type(Name) - arg_names.append(name) - names.append(name.string) - var_arg = Var(name.string) - self.set_repr(var_arg, noderepr.VarRepr(name, none)) - args.append(var_arg) - init.append(None) - assigns.append(none) - if isdict: - kinds.append(nodes.ARG_STAR2) - else: - kinds.append(nodes.ARG_STAR) - arg_types.append(self.parse_arg_type(allow_signature)) - require_named = True - else: - name = self.expect_type(Name) - arg_names.append(name) - args.append(Var(name.string)) - arg_types.append(self.parse_arg_type(allow_signature)) - - if self.current_str() == '=': - assigns.append(self.expect('=')) - init.append(self.parse_expression(precedence[','])) - has_inits = True - if require_named: - kinds.append(nodes.ARG_NAMED) - else: - kinds.append(nodes.ARG_OPT) - else: - init.append(None) - assigns.append(none) - if require_named: - # required keyword-only argument - kinds.append(nodes.ARG_NAMED) - else: - kinds.append(nodes.ARG_POS) - - if self.current().string != ',': - break - commas.append(self.expect(',')) - - return (args, init, kinds, has_inits, arg_names, commas, asterisk, - assigns, arg_types) - - def parse_arg_type(self, allow_signature: bool) -> Type: - if self.current_str() == ':' and allow_signature: - self.skip() - return self.parse_type() - else: - return None - - def verify_argument_kinds(self, kinds: List[int], line: int) -> None: - found = Set[int]() - for i, kind in enumerate(kinds): - if kind == nodes.ARG_POS and found & set([nodes.ARG_OPT, - nodes.ARG_STAR, - nodes.ARG_STAR2]): - self.fail('Invalid argument list', line) - elif kind == nodes.ARG_STAR and nodes.ARG_STAR in found: - self.fail('Invalid argument list', line) - elif kind == nodes.ARG_STAR2 and i != len(kinds) - 1: - self.fail('Invalid argument list', line) - found.add(kind) - - def construct_function_type(self, arg_types: List[Type], kinds: List[int], - names: List[str], ret_type: Type, - line: int) -> Callable: - # Complete the type annotation by replacing omitted types with 'Any'. - arg_types = arg_types[:] - for i in range(len(arg_types)): - if arg_types[i] is None: - arg_types[i] = AnyType() - if ret_type is None: - ret_type = AnyType() - return Callable(arg_types, kinds, names, ret_type, False, None, - None, [], line, None) - - # Parsing statements - - def parse_block(self, allow_type: bool = False) -> Tuple[Block, Type]: - colon = self.expect(':') - if not isinstance(self.current(), Break): - # Block immediately after ':'. - node = Block([self.parse_statement()]).set_line(colon) - self.set_repr(node, noderepr.BlockRepr(colon, none, none, none)) - return cast(Block, node), None - else: - # Indented block. - br = self.expect_break() - type = self.parse_type_comment(br, signature=True) - indent = self.expect_indent() - stmt = [] # type: List[Node] - while (not isinstance(self.current(), Dedent) and - not isinstance(self.current(), Eof)): - try: - s = self.parse_statement() - if s is not None: - if not self.try_combine_overloads(s, stmt): - stmt.append(s) - except ParseError: - pass - dedent = none - if isinstance(self.current(), Dedent): - dedent = self.skip() - node = Block(stmt).set_line(colon) - self.set_repr(node, noderepr.BlockRepr(colon, br, indent, dedent)) - return cast(Block, node), type - - def try_combine_overloads(self, s: Node, stmt: List[Node]) -> bool: - if isinstance(s, Decorator) and stmt: - fdef = cast(Decorator, s) - n = fdef.func.name() - if (isinstance(stmt[-1], Decorator) and - (cast(Decorator, stmt[-1])).func.name() == n): - stmt[-1] = OverloadedFuncDef([cast(Decorator, stmt[-1]), fdef]) - return True - elif (isinstance(stmt[-1], OverloadedFuncDef) and - (cast(OverloadedFuncDef, stmt[-1])).name() == n): - (cast(OverloadedFuncDef, stmt[-1])).items.append(fdef) - return True - return False - - def parse_statement(self) -> Node: - stmt = Undefined # type: Node - t = self.current() - ts = self.current_str() - if ts == 'if': - stmt = self.parse_if_stmt() - elif ts == 'def': - stmt = self.parse_function() - elif ts == 'while': - stmt = self.parse_while_stmt() - elif ts == 'return': - stmt = self.parse_return_stmt() - elif ts == 'for': - stmt = self.parse_for_stmt() - elif ts == 'try': - stmt = self.parse_try_stmt() - elif ts == 'break': - stmt = self.parse_break_stmt() - elif ts == 'continue': - stmt = self.parse_continue_stmt() - elif ts == 'pass': - stmt = self.parse_pass_stmt() - elif ts == 'raise': - stmt = self.parse_raise_stmt() - elif ts == 'import': - stmt = self.parse_import() - elif ts == 'from': - stmt = self.parse_import_from() - elif ts == 'class': - stmt = self.parse_class_def() - elif ts == 'global': - stmt = self.parse_global_decl() - elif ts == 'assert': - stmt = self.parse_assert_stmt() - elif ts == 'yield': - stmt = self.parse_yield_stmt() - elif ts == 'del': - stmt = self.parse_del_stmt() - elif ts == 'with': - stmt = self.parse_with_stmt() - elif ts == '@': - stmt = self.parse_decorated_function_or_class() - elif ts == 'print' and (self.pyversion == 2 and - 'print_function' not in self.future_options): - stmt = self.parse_print_stmt() - else: - stmt = self.parse_expression_or_assignment() - if stmt is not None: - stmt.set_line(t) - return stmt - - def parse_expression_or_assignment(self) -> Node: - e = self.parse_expression() - if self.current_str() == '=': - return self.parse_assignment(e) - elif self.current_str() in op_assign: - # Operator assignment statement. - op = self.current_str()[:-1] - assign = self.skip() - r = self.parse_expression() - br = self.expect_break() - node = OperatorAssignmentStmt(op, e, r) - self.set_repr(node, - noderepr.OperatorAssignmentStmtRepr(assign, br)) - return node - else: - # Expression statement. - br = self.expect_break() - expr = ExpressionStmt(e) - self.set_repr(expr, noderepr.ExpressionStmtRepr(br)) - return expr - - def parse_assignment(self, lv: Any) -> Node: - """Parse an assignment statement. - - Assume that lvalue has been parsed already, and the current token is =. - Also parse an optional '# type:' comment. - """ - assigns = [self.expect('=')] - lvalues = [lv] - - e = self.parse_expression() - while self.current_str() == '=': - lvalues.append(e) - assigns.append(self.skip()) - e = self.parse_expression() - br = self.expect_break() - - type = self.parse_type_comment(br, signature=False) - assignment = AssignmentStmt(lvalues, e, type) - self.set_repr(assignment, noderepr.AssignmentStmtRepr(assigns, br)) - return assignment - - def parse_return_stmt(self) -> ReturnStmt: - return_tok = self.expect('return') - expr = None # type: Node - if not isinstance(self.current(), Break): - expr = self.parse_expression() - if isinstance(expr, YieldFromExpr): #cant go a yield from expr - return None - br = self.expect_break() - node = ReturnStmt(expr) - self.set_repr(node, noderepr.SimpleStmtRepr(return_tok, br)) - return node - - def parse_raise_stmt(self) -> RaiseStmt: - raise_tok = self.expect('raise') - expr = None # type: Node - from_expr = None # type: Node - from_tok = none - if not isinstance(self.current(), Break): - expr = self.parse_expression() - if self.current_str() == 'from': - from_tok = self.expect('from') - from_expr = self.parse_expression() - br = self.expect_break() - node = RaiseStmt(expr, from_expr) - self.set_repr(node, noderepr.RaiseStmtRepr(raise_tok, from_tok, br)) - return node - - def parse_assert_stmt(self) -> AssertStmt: - assert_tok = self.expect('assert') - expr = self.parse_expression() - br = self.expect_break() - node = AssertStmt(expr) - self.set_repr(node, noderepr.SimpleStmtRepr(assert_tok, br)) - return node - - def parse_yield_stmt(self) -> YieldStmt: - yield_tok = self.expect('yield') -<<<<<<< HEAD - expr = None # type: Node - node = YieldStmt(expr) -======= - expr = None # type: Node ->>>>>>> master - if not isinstance(self.current(), Break): - if isinstance(self.current(), Keyword) and self.current_str() == "from": # Not go if it's not from - from_tok = self.expect("from") - expr = self.parse_expression() # Here comes when yield from is not assigned - node = YieldFromStmt(expr) - else: - expr = self.parse_expression() - node = YieldStmt(expr) - br = self.expect_break() - self.set_repr(node, noderepr.SimpleStmtRepr(yield_tok, br)) - return node - -<<<<<<< HEAD - def parse_yield_from_expr(self) -> CallExpr: # Maybe the name should be yield_expr - y_tok = self.expect("yield") - expr = None # type: Node - node = YieldFromExpr(expr) - if self.current_str() == "from": - f_tok = self.expect("from") - tok = self.parse_expression() # Here comes when yield from is assigned to a variable - node = YieldFromExpr(tok) - else: - # TODO - # Here comes the yield expression (ex: x = yield 3 ) - # tok = self.parse_expression() - # node = YieldExpr(tok) # Doesn't exist now - pass - return node - -======= ->>>>>>> master - def parse_del_stmt(self) -> DelStmt: - del_tok = self.expect('del') - expr = self.parse_expression() - br = self.expect_break() - node = DelStmt(expr) - self.set_repr(node, noderepr.SimpleStmtRepr(del_tok, br)) - return node - - def parse_break_stmt(self) -> BreakStmt: - break_tok = self.expect('break') - br = self.expect_break() - node = BreakStmt() - self.set_repr(node, noderepr.SimpleStmtRepr(break_tok, br)) - return node - - def parse_continue_stmt(self) -> ContinueStmt: - continue_tok = self.expect('continue') - br = self.expect_break() - node = ContinueStmt() - self.set_repr(node, noderepr.SimpleStmtRepr(continue_tok, br)) - return node - - def parse_pass_stmt(self) -> PassStmt: - pass_tok = self.expect('pass') - br = self.expect_break() - node = PassStmt() - self.set_repr(node, noderepr.SimpleStmtRepr(pass_tok, br)) - return node - - def parse_global_decl(self) -> GlobalDecl: - global_tok = self.expect('global') - names = List[str]() - name_toks = List[Token]() - commas = List[Token]() - while True: - n = self.expect_type(Name) - names.append(n.string) - name_toks.append(n) - if self.current_str() != ',': - break - commas.append(self.skip()) - br = self.expect_break() - node = GlobalDecl(names) - self.set_repr(node, noderepr.GlobalDeclRepr(global_tok, name_toks, - commas, br)) - return node - - def parse_while_stmt(self) -> WhileStmt: - is_error = False - while_tok = self.expect('while') - try: - expr = self.parse_expression() - except ParseError: - is_error = True - body, _ = self.parse_block() - if self.current_str() == 'else': - else_tok = self.expect('else') - else_body, _ = self.parse_block() - else: - else_body = None - else_tok = none - if is_error is not None: - node = WhileStmt(expr, body, else_body) - self.set_repr(node, noderepr.WhileStmtRepr(while_tok, else_tok)) - return node - else: - return None - - def parse_for_stmt(self) -> ForStmt: - for_tok = self.expect('for') - index, types, commas = self.parse_for_index_variables() - in_tok = self.expect('in') - expr = self.parse_expression() - - body, _ = self.parse_block() - - if self.current_str() == 'else': - else_tok = self.expect('else') - else_body, _ = self.parse_block() - else: - else_body = None - else_tok = none - - node = ForStmt(index, expr, body, else_body, types) - self.set_repr(node, noderepr.ForStmtRepr(for_tok, commas, in_tok, - else_tok)) - return node - - def parse_for_index_variables(self) -> Tuple[List[NameExpr], List[Type], - List[Token]]: - # Parse index variables of a 'for' statement. - index = List[NameExpr]() - types = List[Type]() - commas = List[Token]() - - is_paren = self.current_str() == '(' - if is_paren: - self.skip() - - while True: - v = self.parse_name_expr() - index.append(v) - types.append(None) - if self.current_str() != ',': - commas.append(none) - break - commas.append(self.skip()) - - if is_paren: - self.expect(')') - - return index, types, commas - - def parse_if_stmt(self) -> IfStmt: - is_error = False - - if_tok = self.expect('if') - expr = List[Node]() - try: - expr.append(self.parse_expression()) - except ParseError: - is_error = True - - body = [self.parse_block()[0]] - - elif_toks = List[Token]() - while self.current_str() == 'elif': - elif_toks.append(self.expect('elif')) - try: - expr.append(self.parse_expression()) - except ParseError: - is_error = True - body.append(self.parse_block()[0]) - - if self.current_str() == 'else': - else_tok = self.expect('else') - else_body, _ = self.parse_block() - else: - else_tok = none - else_body = None - - if not is_error: - node = IfStmt(expr, body, else_body) - self.set_repr(node, noderepr.IfStmtRepr(if_tok, elif_toks, - else_tok)) - return node - else: - return None - - def parse_try_stmt(self) -> Node: - try_tok = self.expect('try') - body, _ = self.parse_block() - is_error = False - vars = List[NameExpr]() - types = List[Node]() - handlers = List[Block]() - except_toks, name_toks, as_toks, except_brs = (List[Token](), - List[Token](), - List[Token](), - List[Token]()) - while self.current_str() == 'except': - except_toks.append(self.expect('except')) - if not isinstance(self.current(), Colon): - try: - t = self.current() - types.append(self.parse_expression().set_line(t)) - if self.current_str() == 'as': - as_toks.append(self.expect('as')) - vars.append(self.parse_name_expr()) - else: - name_toks.append(none) - vars.append(None) - as_toks.append(none) - except ParseError: - is_error = True - else: - types.append(None) - vars.append(None) - as_toks.append(none) - handlers.append(self.parse_block()[0]) - if not is_error: - if self.current_str() == 'else': - else_tok = self.skip() - else_body, _ = self.parse_block() - else: - else_tok = none - else_body = None - if self.current_str() == 'finally': - finally_tok = self.expect('finally') - finally_body, _ = self.parse_block() - else: - finally_tok = none - finally_body = None - node = TryStmt(body, vars, types, handlers, else_body, - finally_body) - self.set_repr(node, noderepr.TryStmtRepr(try_tok, except_toks, - name_toks, as_toks, - else_tok, finally_tok)) - return node - else: - return None - - def parse_with_stmt(self) -> WithStmt: - with_tok = self.expect('with') - as_toks = List[Token]() - commas = List[Token]() - expr = List[Node]() - name = List[NameExpr]() - while True: - e = self.parse_expression(precedence[',']) - if self.current_str() == 'as': - as_toks.append(self.expect('as')) - n = self.parse_name_expr() - else: - as_toks.append(none) - n = None - expr.append(e) - name.append(n) - if self.current_str() != ',': - break - commas.append(self.expect(',')) - body, _ = self.parse_block() - node = WithStmt(expr, name, body) - self.set_repr(node, noderepr.WithStmtRepr(with_tok, as_toks, commas)) - return node - - def parse_print_stmt(self) -> PrintStmt: - self.expect('print') - args = List[Node]() - while not isinstance(self.current(), Break): - args.append(self.parse_expression(precedence[','])) - if self.current_str() == ',': - comma = True - self.skip() - else: - comma = False - break - self.expect_break() - return PrintStmt(args, newline=not comma) - - # Parsing expressions - - def parse_expression(self, prec: int = 0) -> Node: - """Parse a subexpression within a specific precedence context.""" -<<<<<<< HEAD - expr = Undefined # type: Node - t = self.current() # Remember token for setting the line number. -======= - expr = Undefined # type: Node - t = self.current() # Remember token for setting the line number. ->>>>>>> master - - # Parse a "value" expression or unary operator expression and store - # that in expr. - s = self.current_str() - if s == '(': - # Parerenthesised expression or cast. - expr = self.parse_parentheses() - elif s == '[': - expr = self.parse_list_expr() - elif s in ['-', '+', 'not', '~']: - # Unary operation. - expr = self.parse_unary_expr() - elif s == 'lambda': - expr = self.parse_lambda_expr() - elif s == '{': - expr = self.parse_dict_or_set_expr() - else: - if isinstance(self.current(), Name): - # Name expression. - expr = self.parse_name_expr() - elif isinstance(self.current(), IntLit): - expr = self.parse_int_expr() - elif isinstance(self.current(), StrLit): - expr = self.parse_str_expr() - elif isinstance(self.current(), BytesLit): - expr = self.parse_bytes_literal() - elif isinstance(self.current(), UnicodeLit): - expr = self.parse_unicode_literal() - elif isinstance(self.current(), FloatLit): - expr = self.parse_float_expr() - elif isinstance(t, Keyword) and s == "yield": - expr = self.parse_yield_from_expr() # The expression yield from and yield to assign - else: - # Invalid expression. - self.parse_error() - - # Set the line of the expression node, if not specified. This - # simplifies recording the line number as not every node type needs to - # deal with it separately. - if expr.line < 0: - expr.set_line(t) - - # Parse operations that require a left argument (stored in expr). - while True: - t = self.current() - s = self.current_str() - if s == '(': - # Call expression. - expr = self.parse_call_expr(expr) - elif s == '.': - # Member access expression. - expr = self.parse_member_expr(expr) - elif s == '[': - # Indexing expression. - expr = self.parse_index_expr(expr) - elif s == ',': - # The comma operator is used to build tuples. Comma also - # separates array items and function arguments, but in this - # case the precedence is too low to build a tuple. - if precedence[','] > prec: - expr = self.parse_tuple_expr(expr) - else: - break - elif s == 'for': - if precedence[''] > prec: - # List comprehension or generator expression. Parse as - # generator expression; it will be converted to list - # comprehension if needed elsewhere. - expr = self.parse_generator_expr(expr) - else: - break - elif s == 'if': - # Conditional expression. - if precedence[''] > prec: - expr = self.parse_conditional_expr(expr) - else: - break - else: - # Binary operation or a special case. - if isinstance(self.current(), Op): - op = self.current_str() - op_prec = precedence[op] - if op == 'not': - # Either "not in" or an error. - op_prec = precedence['in'] - if op_prec > prec: - if op in op_comp: - expr = self.parse_comparison_expr(expr, op_prec) - else: - expr = self.parse_bin_op_expr(expr, op_prec) - else: - # The operation cannot be associated with the - # current left operand due to the precedence - # context; let the caller handle it. - break - else: - # Not an operation that accepts a left argument; let the - # caller handle the rest. - break - - # Set the line of the expression node, if not specified. This - # simplifies recording the line number as not every node type - # needs to deal with it separately. - if expr.line < 0: - expr.set_line(t) - - return expr - - def parse_parentheses(self) -> Node: - lparen = self.skip() - if self.current_str() == ')': - # Empty tuple (). - expr = self.parse_empty_tuple_expr(lparen) # type: Node - else: - # Parenthesised expression. - expr = self.parse_expression(0) - rparen = self.expect(')') - expr = ParenExpr(expr) - self.set_repr(expr, noderepr.ParenExprRepr(lparen, rparen)) - return expr - - def parse_empty_tuple_expr(self, lparen: Any) -> TupleExpr: - rparen = self.expect(')') - node = TupleExpr([]) - self.set_repr(node, noderepr.TupleExprRepr(lparen, [], rparen)) - return node - - def parse_list_expr(self) -> Node: - """Parse list literal or list comprehension.""" - items = List[Node]() - lbracket = self.expect('[') - commas = List[Token]() - while self.current_str() != ']' and not self.eol(): - items.append(self.parse_expression(precedence[''])) - if self.current_str() != ',': - break - commas.append(self.expect(',')) - if self.current_str() == 'for' and len(items) == 1: - items[0] = self.parse_generator_expr(items[0]) - rbracket = self.expect(']') - if len(items) == 1 and isinstance(items[0], GeneratorExpr): - list_comp = ListComprehension(cast(GeneratorExpr, items[0])) - self.set_repr(list_comp, noderepr.ListComprehensionRepr(lbracket, - rbracket)) - return list_comp - else: - expr = ListExpr(items) - self.set_repr(expr, noderepr.ListSetExprRepr(lbracket, commas, - rbracket, none, none)) - return expr - - def parse_generator_expr(self, left_expr: Node) -> GeneratorExpr: - indices = List[List[NameExpr]]() - sequences = List[Node]() - types = List[List[Type]]() - for_toks = List[Token]() - in_toks = List[Token]() - if_toklists = List[List[Token]]() - condlists = List[List[Node]]() - while self.current_str() == 'for': - if_toks = List[Token]() - conds = List[Node]() - for_toks.append(self.expect('for')) - index, type, commas = self.parse_for_index_variables() - indices.append(index) - types.append(type) - in_toks.append(self.expect('in')) - sequence = self.parse_expression_list() - sequences.append(sequence) - while self.current_str() == 'if': - if_toks.append(self.skip()) - conds.append(self.parse_expression(precedence[''])) - if_toklists.append(if_toks) - condlists.append(conds) - - gen = GeneratorExpr(left_expr, indices, types, sequences, condlists) - gen.set_line(for_toks[0]) - self.set_repr(gen, noderepr.GeneratorExprRepr(for_toks, commas, in_toks, - if_toklists)) - return gen - - def parse_expression_list(self) -> Node: - prec = precedence[''] - expr = self.parse_expression(prec) - if self.current_str() != ',': - return expr - else: - t = self.current() - return self.parse_tuple_expr(expr, prec).set_line(t) - - def parse_conditional_expr(self, left_expr: Node) -> ConditionalExpr: - self.expect('if') - cond = self.parse_expression(precedence['']) - self.expect('else') - else_expr = self.parse_expression(precedence['']) - return ConditionalExpr(cond, left_expr, else_expr) - - def parse_dict_or_set_expr(self) -> Node: - items = List[Tuple[Node, Node]]() - lbrace = self.expect('{') - colons = List[Token]() - commas = List[Token]() - while self.current_str() != '}' and not self.eol(): - key = self.parse_expression(precedence[',']) - if self.current_str() in [',', '}'] and items == []: - return self.parse_set_expr(key, lbrace) - elif self.current_str() != ':': - self.parse_error() - colons.append(self.expect(':')) - value = self.parse_expression(precedence[',']) - items.append((key, value)) - if self.current_str() != ',': - break - commas.append(self.expect(',')) - rbrace = self.expect('}') - node = DictExpr(items) - self.set_repr(node, noderepr.DictExprRepr(lbrace, colons, commas, - rbrace, none, none, none)) - return node - - def parse_set_expr(self, first: Node, lbrace: Token) -> SetExpr: - items = [first] - commas = List[Token]() - while self.current_str() != '}' and not self.eol(): - commas.append(self.expect(',')) - if self.current_str() == '}': - break - items.append(self.parse_expression(precedence[','])) - rbrace = self.expect('}') - expr = SetExpr(items) - self.set_repr(expr, noderepr.ListSetExprRepr(lbrace, commas, - rbrace, none, none)) - return expr - - def parse_tuple_expr(self, expr: Node, - prec: int = precedence[',']) -> TupleExpr: - items = [expr] - commas = List[Token]() - while True: - commas.append(self.expect(',')) - if (self.current_str() in [')', ']', '='] or - isinstance(self.current(), Break)): - break - items.append(self.parse_expression(prec)) - if self.current_str() != ',': break - node = TupleExpr(items) - self.set_repr(node, noderepr.TupleExprRepr(none, commas, none)) - return node - - def parse_name_expr(self) -> NameExpr: - tok = self.expect_type(Name) - node = NameExpr(tok.string) - node.set_line(tok) - self.set_repr(node, noderepr.NameExprRepr(tok)) - return node - - def parse_int_expr(self) -> IntExpr: - tok = self.expect_type(IntLit) - s = tok.string - v = 0 - if len(s) > 2 and s[1] in 'xX': - v = int(s[2:], 16) - elif len(s) > 2 and s[1] in 'oO': - v = int(s[2:], 8) - else: - v = int(s) - node = IntExpr(v) - self.set_repr(node, noderepr.IntExprRepr(tok)) - return node - - def parse_str_expr(self) -> Node: - # XXX \uxxxx literals - tok = [self.expect_type(StrLit)] - value = (cast(StrLit, tok[0])).parsed() - while isinstance(self.current(), StrLit): - t = cast(StrLit, self.skip()) - tok.append(t) - value += t.parsed() - node = Undefined(Node) - if self.pyversion == 2 and 'unicode_literals' in self.future_options: - node = UnicodeExpr(value) - else: - node = StrExpr(value) - self.set_repr(node, noderepr.StrExprRepr(tok)) - return node - - def parse_bytes_literal(self) -> Node: - # XXX \uxxxx literals - tok = [self.expect_type(BytesLit)] - value = (cast(BytesLit, tok[0])).parsed() - while isinstance(self.current(), BytesLit): - t = cast(BytesLit, self.skip()) - tok.append(t) - value += t.parsed() - if self.pyversion >= 3: - node = BytesExpr(value) # type: Node - else: - node = StrExpr(value) - self.set_repr(node, noderepr.StrExprRepr(tok)) - return node - - def parse_unicode_literal(self) -> Node: - # XXX \uxxxx literals - tok = [self.expect_type(UnicodeLit)] - value = (cast(UnicodeLit, tok[0])).parsed() - while isinstance(self.current(), UnicodeLit): - t = cast(UnicodeLit, self.skip()) - tok.append(t) - value += t.parsed() - if self.pyversion >= 3: - # Python 3.3 supports u'...' as an alias of '...'. - node = StrExpr(value) # type: Node - else: - node = UnicodeExpr(value) - self.set_repr(node, noderepr.StrExprRepr(tok)) - return node - - def parse_float_expr(self) -> FloatExpr: - tok = self.expect_type(FloatLit) - node = FloatExpr(float(tok.string)) - self.set_repr(node, noderepr.FloatExprRepr(tok)) - return node - - def parse_call_expr(self, callee: Any) -> CallExpr: - lparen = self.expect('(') - (args, kinds, names, - commas, star, star2, assigns) = self.parse_arg_expr() - rparen = self.expect(')') - node = CallExpr(callee, args, kinds, names) - self.set_repr(node, noderepr.CallExprRepr(lparen, commas, star, star2, - assigns, rparen)) - return node - - def parse_arg_expr(self) -> Tuple[List[Node], List[int], List[str], - List[Token], Token, Token, - List[List[Token]]]: - """Parse arguments in a call expression (within '(' and ')'). - - Return a tuple with these items: - argument expressions - argument kinds - argument names (for named arguments; None for ordinary args) - comma tokens - * token (if any) - ** token (if any) - (assignment, name) tokens - """ -<<<<<<< HEAD - args = [] # type: List[Node] - kinds = [] # type: List[int] - names = [] # type: List[str] -======= - args = [] # type: List[Node] - kinds = [] # type: List[int] - names = [] # type: List[str] ->>>>>>> master - star = none - star2 = none - commas = [] # type: List[Token] - keywords = [] # type: List[List[Token]] - var_arg = False - dict_arg = False - named_args = False - while self.current_str() != ')' and not self.eol() and not dict_arg: - if isinstance(self.current(), Name) and self.peek().string == '=': - # Named argument - name = self.expect_type(Name) - assign = self.expect('=') - kinds.append(nodes.ARG_NAMED) - names.append(name.string) - keywords.append([name, assign]) - named_args = True - elif (self.current_str() == '*' and not var_arg and not dict_arg - and not named_args): - # *args - var_arg = True - star = self.expect('*') - kinds.append(nodes.ARG_STAR) - names.append(None) - elif self.current_str() == '**': - # **kwargs - star2 = self.expect('**') - dict_arg = True - kinds.append(nodes.ARG_STAR2) - names.append(None) - elif not var_arg and not named_args: - # Ordinary argument - kinds.append(nodes.ARG_POS) - names.append(None) - else: - self.parse_error() - args.append(self.parse_expression(precedence[','])) - if self.current_str() != ',': - break - commas.append(self.expect(',')) - return args, kinds, names, commas, star, star2, keywords - - def parse_member_expr(self, expr: Any) -> Node: - dot = self.expect('.') - name = self.expect_type(Name) - node = Undefined(Node) - if (isinstance(expr, CallExpr) and isinstance(expr.callee, NameExpr) - and expr.callee.name == 'super'): - # super() expression - node = SuperExpr(name.string) - self.set_repr(node, - noderepr.SuperExprRepr(expr.callee.repr.id, - expr.repr.lparen, - expr.repr.rparen, dot, name)) - else: - node = MemberExpr(expr, name.string) - self.set_repr(node, noderepr.MemberExprRepr(dot, name)) - return node - - def parse_index_expr(self, base: Any) -> IndexExpr: - lbracket = self.expect('[') - if self.current_str() != ':': - index = self.parse_expression(0) - else: - index = None - if self.current_str() == ':': - # Slice. - colon = self.expect(':') - if self.current_str() != ']' and self.current_str() != ':': - end_index = self.parse_expression(0) - else: - end_index = None - colon2 = none - stride = None # type: Node - if self.current_str() == ':': - colon2 = self.expect(':') - if self.current_str() != ']': - stride = self.parse_expression() - index = SliceExpr(index, end_index, stride).set_line(colon.line) - self.set_repr(index, noderepr.SliceExprRepr(colon, colon2)) - rbracket = self.expect(']') - node = IndexExpr(base, index) - self.set_repr(node, noderepr.IndexExprRepr(lbracket, rbracket)) - return node - - def parse_bin_op_expr(self, left: Node, prec: int) -> OpExpr: - op = self.expect_type(Op) - op_str = op.string - if op_str == '~': - self.ind -= 1 - self.parse_error() - right = self.parse_expression(prec) - node = OpExpr(op_str, left, right) - self.set_repr(node, noderepr.OpExprRepr(op)) - return node - - def parse_comparison_expr(self, left: Node, prec: int) -> ComparisonExpr: - operators = [] # type: List[Tuple[Token, Token]] - operators_str = [] # type: List[str] - operands = [left] - - while True: - op = self.expect_type(Op) - op2 = none - op_str = op.string - if op_str == 'not': - if self.current_str() == 'in': - op_str = 'not in' - op2 = self.skip() - else: - self.parse_error() - elif op_str == 'is' and self.current_str() == 'not': - op_str = 'is not' - op2 = self.skip() - - operators_str.append(op_str) - operators.append( (op, op2) ) - operand = self.parse_expression(prec) - operands.append(operand) - - # Continue if next token is a comparison operator - t = self.current() - s = self.current_str() - if s not in op_comp: - break - - node = ComparisonExpr(operators_str, operands) - self.set_repr(node, noderepr.ComparisonExprRepr(operators)) - return node - -<<<<<<< HEAD -======= - ->>>>>>> master - def parse_unary_expr(self) -> UnaryExpr: - op_tok = self.skip() - op = op_tok.string - if op == '-' or op == '+': - prec = precedence['-u'] - else: - prec = precedence[op] - expr = self.parse_expression(prec) - node = UnaryExpr(op, expr) - self.set_repr(node, noderepr.UnaryExprRepr(op_tok)) - return node - - def parse_lambda_expr(self) -> FuncExpr: - is_error = False - lambda_tok = self.expect('lambda') - - (args, init, kinds, has_inits, - arg_names, commas, asterisk, - assigns, arg_types) = self.parse_arg_list(allow_signature=False) - - names = List[str]() - for arg in args: - names.append(arg.name()) - - # Use 'object' as the placeholder return type; it will be inferred - # later. We can't use 'Any' since it could make type inference results - # less precise. - ret_type = UnboundType('__builtins__.object') - typ = self.build_func_annotation(ret_type, arg_types, kinds, names, - lambda_tok.line, is_default_ret=True) - - colon = self.expect(':') - - expr = self.parse_expression(precedence[',']) - - body = Block([ReturnStmt(expr).set_line(lambda_tok)]) - body.set_line(colon) - - node = FuncExpr(args, kinds, init, body, typ) - self.set_repr(node, - noderepr.FuncExprRepr( - lambda_tok, colon, - noderepr.FuncArgsRepr(none, none, arg_names, commas, - assigns, asterisk))) - return node - - # Helper methods - - def skip(self) -> Token: - self.ind += 1 - return self.tok[self.ind - 1] - - def expect(self, string: str) -> Token: - if self.current_str() == string: - self.ind += 1 - return self.tok[self.ind - 1] - else: - self.parse_error() - - def expect_indent(self) -> Token: - if isinstance(self.current(), Indent): - return self.expect_type(Indent) - else: - self.fail('Expected an indented block', self.current().line) - return none - - def fail(self, msg: str, line: int) -> None: - self.errors.report(line, msg) - - def expect_type(self, typ: type) -> Token: - if isinstance(self.current(), typ): - self.ind += 1 - return self.tok[self.ind - 1] - else: - self.parse_error() - - def expect_colon_and_break(self) -> Tuple[Token, Token]: - return self.expect_type(Colon), self.expect_type(Break) - - def expect_break(self) -> Token: - return self.expect_type(Break) - - def expect_end(self) -> Tuple[Token, Token]: - return self.expect('end'), self.expect_type(Break) - - def current(self) -> Token: - return self.tok[self.ind] - - def current_str(self) -> str: - return self.current().string - - def peek(self) -> Token: - return self.tok[self.ind + 1] - - def parse_error(self) -> None: - self.parse_error_at(self.current()) - raise ParseError() - - def parse_error_at(self, tok: Token, skip: bool = True) -> None: - msg = '' - if isinstance(tok, LexError): - msg = token_repr(tok) - msg = msg[0].upper() + msg[1:] - elif isinstance(tok, Indent) or isinstance(tok, Dedent): - msg = 'Inconsistent indentation' - else: - msg = 'Parse error before {}'.format(token_repr(tok)) - - self.errors.report(tok.line, msg) - - if skip: - self.skip_until_next_line() - - def skip_until_break(self) -> None: - n = 0 - while (not isinstance(self.current(), Break) - and not isinstance(self.current(), Eof)): - self.skip() - n += 1 - if isinstance(self.tok[self.ind - 1], Colon) and n > 1: - self.ind -= 1 - - def skip_until_next_line(self) -> None: - self.skip_until_break() - if isinstance(self.current(), Break): - self.skip() - - def eol(self) -> bool: - return isinstance(self.current(), Break) or self.eof() - - def eof(self) -> bool: - return isinstance(self.current(), Eof) - - # Type annotation related functionality - - def parse_type(self) -> Type: - line = self.current().line - try: - typ, self.ind = parse_type(self.tok, self.ind) - except TypeParseError as e: - self.parse_error_at(e.token) - raise ParseError() - return typ - - annotation_prefix_re = re.compile(r'#\s*type:') - - def parse_type_comment(self, token: Token, signature: bool) -> Type: - """Parse a '# type: ...' annotation. - - Return None if no annotation found. If signature is True, expect - a type signature of form (...) -> t. - """ - whitespace_or_comments = token.rep().strip() - if self.annotation_prefix_re.match(whitespace_or_comments): - type_as_str = whitespace_or_comments.split(':', 1)[1].strip() - tokens = lex.lex(type_as_str, token.line) - if len(tokens) < 2: - # Empty annotation (only Eof token) - self.errors.report(token.line, 'Empty type annotation') - return None - try: - if not signature: - type, index = parse_types(tokens, 0) - else: - type, index = parse_signature(tokens) - except TypeParseError as e: - self.parse_error_at(e.token, skip=False) - return None - if index < len(tokens) - 2: - self.parse_error_at(tokens[index], skip=False) - return None - return type - else: - return None - - # Representation management - - def set_repr(self, node: Node, repr: Any) -> None: - node.repr = repr - - def repr(self, node: Node) -> Any: - return node.repr - - def paren_repr(self, e: Node) -> Tuple[List[Token], List[Token]]: - """If e is a ParenExpr, return an array of left-paren tokens - (more that one if nested parens) and an array of corresponding - right-paren tokens. Otherwise, return [], []. - """ - if isinstance(e, ParenExpr): - lp, rp = self.paren_repr(e.expr) - lp.insert(0, self.repr(e).lparen) - rp.append(self.repr(e).rparen) - return lp, rp - else: - return [], [] - - -class ParseError(Exception): pass - - -def token_repr(tok: Token) -> str: - """Return a representation of a token for use in parse error messages.""" - if isinstance(tok, Break): - return 'end of line' - elif isinstance(tok, Eof): - return 'end of file' - elif isinstance(tok, Keyword) or isinstance(tok, Name): - return '"{}"'.format(tok.string) - elif isinstance(tok, IntLit) or isinstance(tok, FloatLit): - return 'numeric literal' - elif isinstance(tok, StrLit): - return 'string literal' - elif (isinstance(tok, Punct) or isinstance(tok, Op) - or isinstance(tok, Colon)): - return tok.string - elif isinstance(tok, Bom): - return 'byte order mark' - elif isinstance(tok, Indent): - return 'indent' - elif isinstance(tok, Dedent): - return 'dedent' - else: - if isinstance(tok, LexError): - t = tok.type - if t == lex.NUMERIC_LITERAL_ERROR: - return 'invalid numeric literal' - elif t == lex.UNTERMINATED_STRING_LITERAL: - return 'unterminated string literal' - elif t == lex.INVALID_CHARACTER: - msg = 'unrecognized character' - if ord(tok.string) in range(33, 127): - msg += ' ' + tok.string - return msg - elif t == lex.INVALID_UTF8_SEQUENCE: - return 'invalid UTF-8 sequence' - elif t == lex.NON_ASCII_CHARACTER_IN_COMMENT: - return 'non-ASCII character in comment' - elif t == lex.NON_ASCII_CHARACTER_IN_STRING: - return 'non-ASCII character in string' - elif t == lex.INVALID_DEDENT: - return 'inconsistent indentation' - raise ValueError('Unknown token {}'.format(repr(tok))) - - -def unwrap_parens(node: Node) -> Node: - """Unwrap any outer parentheses in node. - - If the node is a parenthesised expression, recursively find the first - non-parenthesised subexpression and return that. Otherwise, return node. - """ - if isinstance(node, ParenExpr): - return unwrap_parens(node.expr) - else: - return node - - -if __name__ == '__main__': - # Parse a file and dump the AST (or display errors). - import sys - if len(sys.argv) != 2: - print('Usage: parse.py FILE') - sys.exit(2) - fnam = sys.argv[1] - s = open(fnam).read() - errors = Errors() - try: - tree = parse(s, fnam) - print(tree) - except CompileError as e: - for msg in e.messages: - print(msg) diff --git a/mypy/pprinter.py.orig b/mypy/pprinter.py.orig deleted file mode 100644 index 0f22ad4cb3d4..000000000000 --- a/mypy/pprinter.py.orig +++ /dev/null @@ -1,342 +0,0 @@ -from typing import List, cast - -from mypy.output import TypeOutputVisitor -from mypy.nodes import ( - Node, VarDef, ClassDef, FuncDef, MypyFile, CoerceExpr, TypeExpr, CallExpr, - TypeVarExpr -) -from mypy.visitor import NodeVisitor -from mypy.types import Void, TypeVisitor, Callable, Instance, Type, UnboundType -from mypy.maptypevar import num_slots -from mypy.transutil import tvar_arg_name -from mypy import coerce -from mypy import nodes - - -class PrettyPrintVisitor(NodeVisitor): - """Convert transformed parse trees into formatted source code. - - Use automatic formatting (i.e. omit original formatting). - """ - - def __init__(self) -> None: - super().__init__() - self.result = [] # type: List[str] - self.indent = 0 - - def output(self) -> str: - return ''.join(self.result) - - # - # Definitions - # - - def visit_mypy_file(self, file: MypyFile) -> None: - for d in file.defs: - d.accept(self) - - def visit_class_def(self, tdef: ClassDef) -> None: - self.string('class ') - self.string(tdef.name) - if tdef.base_types: - b = [] # type: List[str] - for bt in tdef.base_types: - if not bt: - continue - elif isinstance(bt, UnboundType): - b.append(bt.name) - elif (cast(Instance, bt)).type.fullname() != 'builtins.object': - typestr = bt.accept(TypeErasedPrettyPrintVisitor()) - b.append(typestr) - if b: - self.string('({})'.format(', '.join(b))) - self.string(':\n') - for d in tdef.defs.body: - d.accept(self) - self.dedent() - - def visit_func_def(self, fdef: FuncDef) -> None: - # FIX varargs, default args, keyword args etc. - ftyp = cast(Callable, fdef.type) - self.string('def ') - self.string(fdef.name()) - self.string('(') - for i in range(len(fdef.args)): - a = fdef.args[i] - self.string(a.name()) - if i < len(ftyp.arg_types): - self.string(': ') - self.type(ftyp.arg_types[i]) - else: - self.string('xxx ') - if i < len(fdef.args) - 1: - self.string(', ') - self.string(') -> ') - self.type(ftyp.ret_type) - fdef.body.accept(self) - - def visit_var_def(self, vdef: VarDef) -> None: - if vdef.items[0].name() not in nodes.implicit_module_attrs: - self.string(vdef.items[0].name()) - self.string(': ') - self.type(vdef.items[0].type) - if vdef.init: - self.string(' = ') - self.node(vdef.init) - self.string('\n') - - # - # Statements - # - - def visit_block(self, b): - self.string(':\n') - for s in b.body: - s.accept(self) - self.dedent() - - def visit_pass_stmt(self, o): - self.string('pass\n') - - def visit_return_stmt(self, o): - self.string('return ') - if o.expr: - self.node(o.expr) - self.string('\n') - - def visit_expression_stmt(self, o): - self.node(o.expr) - self.string('\n') - - def visit_assignment_stmt(self, o): - if isinstance(o.rvalue, CallExpr) and isinstance(o.rvalue.analyzed, - TypeVarExpr): - # Skip type variable definition 'x = typevar(...)'. - return - self.node(o.lvalues[0]) # FIX multiple lvalues - if o.type: - self.string(': ') - self.type(o.type) - self.string(' = ') - self.node(o.rvalue) - self.string('\n') - - def visit_if_stmt(self, o): - self.string('if ') - self.node(o.expr[0]) - self.node(o.body[0]) - for e, b in zip(o.expr[1:], o.body[1:]): - self.string('elif ') - self.node(e) - self.node(b) - if o.else_body: - self.string('else') - self.node(o.else_body) - - def visit_while_stmt(self, o): - self.string('while ') - self.node(o.expr) - self.node(o.body) - if o.else_body: - self.string('else') - self.node(o.else_body) - - # - # Expressions - # - - def visit_call_expr(self, o): - if o.analyzed: - o.analyzed.accept(self) - return - self.node(o.callee) - self.string('(') - self.omit_next_space = True - for i in range(len(o.args)): - self.node(o.args[i]) - if i < len(o.args) - 1: - self.string(', ') - self.string(')') - -<<<<<<< HEAD - def visit_yield_from_expr(self, o): - if o.expr: - o.expr.accept(self) - -======= ->>>>>>> master - def visit_member_expr(self, o): - self.node(o.expr) - self.string('.' + o.name) - if o.direct: - self.string('!') - - def visit_name_expr(self, o): - self.string(o.name) - - def visit_coerce_expr(self, o: CoerceExpr) -> None: - self.string('{') - self.full_type(o.target_type) - if coerce.is_special_primitive(o.source_type): - self.string(' <= ') - self.type(o.source_type) - self.string(' ') - self.node(o.expr) - self.string('}') - - def visit_type_expr(self, o: TypeExpr) -> None: - # Type expressions are only generated during transformation, so we must - # use automatic formatting. - self.string('<') - self.full_type(o.type) - self.string('>') - - def visit_index_expr(self, o): - if o.analyzed: - o.analyzed.accept(self) - return - self.node(o.base) - self.string('[') - self.node(o.index) - self.string(']') - - def visit_int_expr(self, o): - self.string(str(o.value)) - - def visit_str_expr(self, o): - self.string(repr(o.value)) - - def visit_op_expr(self, o): - self.node(o.left) - self.string(' %s ' % o.op) - self.node(o.right) - - def visit_comparison_expr(self, o): - self.node(o.operands[0]) - for operator, operand in zip(o.operators, o.operands[1:]): - self.string(' %s ' % operator) - self.node(operand) - - def visit_unary_expr(self, o): - self.string(o.op) - if o.op == 'not': - self.string(' ') - self.node(o.expr) - - def visit_paren_expr(self, o): - self.string('(') - self.node(o.expr) - self.string(')') - - def visit_super_expr(self, o): - self.string('super().') - self.string(o.name) - - def visit_cast_expr(self, o): - self.string('cast(') - self.type(o.type) - self.string(', ') - self.node(o.expr) - self.string(')') - - def visit_type_application(self, o): - # Type arguments are erased in transformation. - self.node(o.expr) - - def visit_undefined_expr(self, o): - # Omit declared type as redundant. - self.string('Undefined') - - # - # Helpers - # - - def string(self, s: str) -> None: - if not s: - return - if self.last_output_char() == '\n': - self.result.append(' ' * self.indent) - self.result.append(s) - if s.endswith(':\n'): - self.indent += 4 - - def dedent(self) -> None: - self.indent -= 4 - - def node(self, n: Node) -> None: - n.accept(self) - - def last_output_char(self) -> str: - if self.result: - return self.result[-1][-1] - return '' - - def type(self, t): - """Pretty-print a type with erased type arguments.""" - if t: - v = TypeErasedPrettyPrintVisitor() - self.string(t.accept(v)) - - def full_type(self, t): - """Pretty-print a type, includingn type arguments.""" - if t: - v = TypePrettyPrintVisitor() - self.string(t.accept(v)) - - -class TypeErasedPrettyPrintVisitor(TypeVisitor[str]): - """Pretty-print types. - - Omit type variables (e.g. C instead of C[int]). - - Note that the translation does not preserve all information about the - types, but this is fine since this is only used in test case output. - """ - - def visit_any(self, t): - return 'Any' - - def visit_void(self, t): - return 'None' - - def visit_instance(self, t): - return t.type.name() - - def visit_type_var(self, t): - return 'Any*' - - def visit_runtime_type_var(self, t): - v = PrettyPrintVisitor() - t.node.accept(v) - return v.output() - - -class TypePrettyPrintVisitor(TypeVisitor[str]): - """Pretty-print types. - - Include type variables. - - Note that the translation does not preserve all information about the - types, but this is fine since this is only used in test case output. - """ - - def visit_any(self, t): - return 'Any' - - def visit_void(self, t): - return 'None' - - def visit_instance(self, t): - s = t.type.name() - if t.args: - argstr = ', '.join([a.accept(self) for a in t.args]) - s += '[%s]' % argstr - return s - - def visit_type_var(self, t): - return 'Any*' - - def visit_runtime_type_var(self, t): - v = PrettyPrintVisitor() - t.node.accept(v) - return v.output() diff --git a/mypy/semanal.py.orig b/mypy/semanal.py.orig deleted file mode 100644 index 552f6cdc46b5..000000000000 --- a/mypy/semanal.py.orig +++ /dev/null @@ -1,1870 +0,0 @@ -"""The semantic analyzer. - -Bind names to definitions and do various other simple consistency -checks. For example, consider this program: - - x = 1 - y = x - -Here semantic analysis would detect that the assignment 'x = 1' -defines a new variable, the type of which is to be inferred (in a -later pass; type inference or type checking is not part of semantic -analysis). Also, it would bind both references to 'x' to the same -module-level variable node. The second assignment would also be -analyzed, and the type of 'y' marked as being inferred. - -Semantic analysis is the first analysis pass after parsing, and it is -subdivided into three passes: - - * FirstPass looks up externally visible names defined in a module but - ignores imports and local definitions. It helps enable (some) - cyclic references between modules, such as module 'a' that imports - module 'b' and used names defined in b *and* vice versa. The first - pass can be performed before dependent modules have been processed. - - * SemanticAnalyzer is the second pass. It does the bulk of the work. - It assumes that dependent modules have been semantically analyzed, - up to the second pass, unless there is a import cycle. - - * ThirdPass checks that type argument counts are valid; for example, - it will reject Dict[int]. We don't do this in the second pass, - since we infer the type argument counts of classes during this - pass, and it is possible to refer to classes defined later in a - file, which would not have the type argument count set yet. - -Semantic analysis of types is implemented in module mypy.typeanal. - -TODO: Check if the third pass slows down type checking significantly. - We could probably get rid of it -- for example, we could collect all - analyzed types in a collection and check them without having to - traverse the entire AST. -""" - -from typing import ( - Undefined, List, Dict, Set, Tuple, cast, Any, overload, typevar -) - -from mypy.nodes import ( - MypyFile, TypeInfo, Node, AssignmentStmt, FuncDef, OverloadedFuncDef, - ClassDef, VarDef, Var, GDEF, MODULE_REF, FuncItem, Import, - ImportFrom, ImportAll, Block, LDEF, NameExpr, MemberExpr, - IndexExpr, ParenExpr, TupleExpr, ListExpr, ExpressionStmt, ReturnStmt, - RaiseStmt, YieldStmt, AssertStmt, OperatorAssignmentStmt, WhileStmt, - ForStmt, BreakStmt, ContinueStmt, IfStmt, TryStmt, WithStmt, DelStmt, - GlobalDecl, SuperExpr, DictExpr, CallExpr, RefExpr, OpExpr, UnaryExpr, - SliceExpr, CastExpr, TypeApplication, Context, SymbolTable, - SymbolTableNode, TVAR, UNBOUND_TVAR, ListComprehension, GeneratorExpr, - FuncExpr, MDEF, FuncBase, Decorator, SetExpr, UndefinedExpr, TypeVarExpr, - StrExpr, PrintStmt, ConditionalExpr, DucktypeExpr, DisjointclassExpr, -<<<<<<< HEAD - ARG_POS, ARG_NAMED, MroError, type_aliases, YieldFromStmt, YieldFromExpr -======= - ComparisonExpr, ARG_POS, ARG_NAMED, MroError, type_aliases ->>>>>>> master -) -from mypy.visitor import NodeVisitor -from mypy.traverser import TraverserVisitor -from mypy.errors import Errors -from mypy.types import ( - NoneTyp, Callable, Overloaded, Instance, Type, TypeVar, AnyType, - FunctionLike, UnboundType, TypeList, ErrorType, TypeVarDef, - replace_leading_arg_type, TupleType, UnionType -) -from mypy.nodes import function_type, implicit_module_attrs -from mypy.typeanal import TypeAnalyser, TypeAnalyserPass3, analyse_node -from mypy.parsetype import parse_str_as_type, TypeParseError - - -T = typevar('T') - - -# Inferred value of an expression. -ALWAYS_TRUE = 0 -ALWAYS_FALSE = 1 -TRUTH_VALUE_UNKNOWN = 2 - - -class TypeTranslationError(Exception): - """Exception raised when an expression is not valid as a type.""" - - -class SemanticAnalyzer(NodeVisitor): - """Semantically analyze parsed mypy files. - - The analyzer binds names and does various consistency checks for a - parse tree. Note that type checking is performed as a separate - pass. - - This is the second phase of semantic analysis. - """ - - # Library search paths - lib_path = Undefined(List[str]) - # Module name space - modules = Undefined(Dict[str, MypyFile]) - # Global name space for current module - globals = Undefined(SymbolTable) - # Names declared using "global" (separate set for each scope) - global_decls = Undefined(List[Set[str]]) - # Local names of function scopes; None for non-function scopes. - locals = Undefined(List[SymbolTable]) - # Nested block depths of scopes - block_depth = Undefined(List[int]) - # TypeInfo of directly enclosing class (or None) - type = Undefined(TypeInfo) - # Stack of outer classes (the second tuple item contains tvars). - type_stack = Undefined(List[Tuple[TypeInfo, List[SymbolTableNode]]]) - # Stack of functions being analyzed - function_stack = Undefined(List[FuncItem]) - - loop_depth = 0 # Depth of breakable loops - cur_mod_id = '' # Current module id (or None) (phase 2) - imports = Undefined(Set[str]) # Imported modules (during phase 2 analysis) - errors = Undefined(Errors) # Keep track of generated errors - - def __init__(self, lib_path: List[str], errors: Errors, - pyversion: int = 3) -> None: - """Construct semantic analyzer. - - Use lib_path to search for modules, and report analysis errors - using the Errors instance. - """ - self.locals = [None] - self.imports = set() - self.type = None - self.type_stack = [] - self.function_stack = [] - self.block_depth = [0] - self.loop_depth = 0 - self.lib_path = lib_path - self.errors = errors - self.modules = {} - self.pyversion = pyversion - self.stored_vars = Dict[Node, Type]() - - def visit_file(self, file_node: MypyFile, fnam: str) -> None: - self.errors.set_file(fnam) - self.globals = file_node.names - self.cur_mod_id = file_node.fullname() - - if 'builtins' in self.modules: - self.globals['__builtins__'] = SymbolTableNode( - MODULE_REF, self.modules['builtins'], self.cur_mod_id) - - defs = file_node.defs - for d in defs: - d.accept(self) - - if self.cur_mod_id == 'builtins': - remove_imported_names_from_symtable(self.globals, 'builtins') - - def visit_func_def(self, defn: FuncDef) -> None: - self.errors.push_function(defn.name()) - self.update_function_type_variables(defn) - self.errors.pop_function() - - if self.is_class_scope(): - # Method definition - defn.is_conditional = self.block_depth[-1] > 0 - defn.info = self.type - if not defn.is_decorated: - if not defn.is_overload: - if defn.name() in self.type.names: - n = self.type.names[defn.name()].node - if self.is_conditional_func(n, defn): - defn.original_def = cast(FuncDef, n) - else: - self.name_already_defined(defn.name(), defn) - self.type.names[defn.name()] = SymbolTableNode(MDEF, defn) - if not defn.is_static: - if not defn.args: - self.fail('Method must have at least one argument', defn) - elif defn.type: - sig = cast(FunctionLike, defn.type) - # TODO: A classmethod's first argument should be more - # precisely typed than Any. - leading_type = AnyType() if defn.is_class else self_type(self.type) - defn.type = replace_implicit_first_type(sig, leading_type) - - if self.is_func_scope() and (not defn.is_decorated and - not defn.is_overload): - self.add_local_func(defn, defn) - defn._fullname = defn.name() - - self.errors.push_function(defn.name()) - self.analyse_function(defn) - self.errors.pop_function() - - def is_conditional_func(self, n: Node, defn: FuncDef) -> bool: - return (isinstance(n, FuncDef) and cast(FuncDef, n).is_conditional and - defn.is_conditional) - - def update_function_type_variables(self, defn: FuncDef) -> None: - """Make any type variables in the signature of defn explicit. - - Update the signature of defn to contain type variable definitions - if defn is generic. - """ - if defn.type: - functype = cast(Callable, defn.type) - typevars = self.infer_type_variables(functype) - # Do not define a new type variable if already defined in scope. - typevars = [(tvar, values) for tvar, values in typevars - if not self.is_defined_type_var(tvar, defn)] - if typevars: - defs = [TypeVarDef(tvar[0], -i - 1, tvar[1]) - for i, tvar in enumerate(typevars)] - functype.variables = defs - - def infer_type_variables(self, - type: Callable) -> List[Tuple[str, List[Type]]]: - """Return list of unique type variables referred to in a callable.""" - names = List[str]() - values = List[List[Type]]() - for arg in type.arg_types + [type.ret_type]: - for tvar, vals in self.find_type_variables_in_type(arg): - if tvar not in names: - names.append(tvar) - values.append(vals) - return list(zip(names, values)) - - def find_type_variables_in_type( - self, type: Type) -> List[Tuple[str, List[Type]]]: - """Return a list of all unique type variable references in type.""" - result = List[Tuple[str, List[Type]]]() - if isinstance(type, UnboundType): - name = type.name - node = self.lookup_qualified(name, type) - if node and node.kind == UNBOUND_TVAR: - result.append((name, cast(TypeVarExpr, node.node).values)) - for arg in type.args: - result.extend(self.find_type_variables_in_type(arg)) - elif isinstance(type, TypeList): - for item in type.items: - result.extend(self.find_type_variables_in_type(item)) - elif isinstance(type, UnionType): - for item in type.items: - result.extend(self.find_type_variables_in_type(item)) - elif isinstance(type, AnyType): - pass - else: - assert False, 'Unsupported type %s' % type - return result - - def is_defined_type_var(self, tvar: str, context: Node) -> bool: - return self.lookup_qualified(tvar, context).kind == TVAR - - def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: - t = List[Callable]() - for item in defn.items: - # TODO support decorated overloaded functions properly - item.is_overload = True - item.func.is_overload = True - item.accept(self) - t.append(cast(Callable, function_type(item.func))) - if not [dec for dec in item.decorators - if refers_to_fullname(dec, 'typing.overload')]: - self.fail("'overload' decorator expected", item) - - defn.type = Overloaded(t) - defn.type.line = defn.line - - if self.is_class_scope(): - self.type.names[defn.name()] = SymbolTableNode(MDEF, defn, - typ=defn.type) - defn.info = self.type - elif self.is_func_scope(): - self.add_local_func(defn, defn) - - def analyse_function(self, defn: FuncItem) -> None: - is_method = self.is_class_scope() - tvarnodes = self.add_func_type_variables_to_symbol_table(defn) - if defn.type: - # Signature must be analyzed in the surrounding scope so that - # class-level imported names and type variables are in scope. - defn.type = self.anal_type(defn.type) - self.check_function_signature(defn) - if isinstance(defn, FuncDef): - defn.info = self.type - defn.type = set_callable_name(defn.type, defn) - self.function_stack.append(defn) - self.enter() - for init in defn.init: - if init: - init.rvalue.accept(self) - for v in defn.args: - self.add_local(v, defn) - for init_ in defn.init: - if init_: - init_.lvalues[0].accept(self) - - # The first argument of a non-static, non-class method is like 'self' - # (though the name could be different), having the enclosing class's - # instance type. - if is_method and not defn.is_static and not defn.is_class and defn.args: - defn.args[0].is_self = True - - defn.body.accept(self) - disable_typevars(tvarnodes) - self.leave() - self.function_stack.pop() - - def add_func_type_variables_to_symbol_table( - self, defn: FuncItem) -> List[SymbolTableNode]: - nodes = List[SymbolTableNode]() - if defn.type: - tt = defn.type - names = self.type_var_names() - items = cast(Callable, tt).variables - for i, item in enumerate(items): - name = item.name - if name in names: - self.name_already_defined(name, defn) - node = self.add_type_var(name, -i - 1, defn) - nodes.append(node) - names.add(name) - return nodes - - def type_var_names(self) -> Set[str]: - if not self.type: - return set() - else: - return set(self.type.type_vars) - - def add_type_var(self, fullname: str, id: int, - context: Context) -> SymbolTableNode: - node = self.lookup_qualified(fullname, context) - node.kind = TVAR - node.tvar_id = id - return node - - def check_function_signature(self, fdef: FuncItem) -> None: - sig = cast(Callable, fdef.type) - if len(sig.arg_types) < len(fdef.args): - self.fail('Type signature has too few arguments', fdef) - elif len(sig.arg_types) > len(fdef.args): - self.fail('Type signature has too many arguments', fdef) - - def visit_class_def(self, defn: ClassDef) -> None: - self.clean_up_bases_and_infer_type_variables(defn) - self.setup_class_def_analysis(defn) - self.analyze_base_classes(defn) - self.analyze_metaclass(defn) - - for decorator in defn.decorators: - self.analyze_class_decorator(defn, decorator) - - # Analyze class body. - defn.defs.accept(self) - - self.calculate_abstract_status(defn.info) - self.setup_ducktyping(defn) - - # Restore analyzer state. - self.block_depth.pop() - self.locals.pop() - self.type, tvarnodes = self.type_stack.pop() - disable_typevars(tvarnodes) - if self.type_stack: - # Enable type variables of the enclosing class again. - enable_typevars(self.type_stack[-1][1]) - - def analyze_class_decorator(self, defn: ClassDef, decorator: Node) -> None: - decorator.accept(self) - if refers_to_fullname(decorator, 'typing.builtinclass'): - defn.is_builtinclass = True - - def calculate_abstract_status(self, typ: TypeInfo) -> None: - """Calculate abstract status of a class. - - Set is_abstract of the type to True if the type has an unimplemented - abstract attribute. Also compute a list of abstract attributes. - """ - concrete = Set[str]() - abstract = List[str]() - for base in typ.mro: - for name, symnode in base.names.items(): - node = symnode.node - if isinstance(node, OverloadedFuncDef): - # Unwrap an overloaded function definition. We can just - # check arbitrarily the first overload item. If the - # different items have a different abstract status, there - # should be an error reported elsewhere. - func = node.items[0] # type: Node - else: - func = node - if isinstance(func, Decorator): - fdef = func.func - if fdef.is_abstract and name not in concrete: - typ.is_abstract = True - abstract.append(name) - concrete.add(name) - typ.abstract_attributes = sorted(abstract) - - def setup_ducktyping(self, defn: ClassDef) -> None: - for decorator in defn.decorators: - if isinstance(decorator, CallExpr): - analyzed = decorator.analyzed - if isinstance(analyzed, DucktypeExpr): - defn.info.ducktype = analyzed.type - elif isinstance(analyzed, DisjointclassExpr): - node = analyzed.cls.node - if isinstance(node, TypeInfo): - defn.info.disjoint_classes.append(node) - defn.info.disjointclass_decls.append(node) - node.disjoint_classes.append(defn.info) - else: - self.fail('Argument 1 to disjointclass does not refer ' - 'to a class', analyzed) - - def clean_up_bases_and_infer_type_variables(self, defn: ClassDef) -> None: - """Remove extra base classes such as Generic and infer type vars. - - For example, consider this class: - - . class Foo(Bar, Generic[t]): ... - - Now we will remove Generic[t] from bases of Foo and infer that the - type variable 't' is a type argument of Foo. - """ - removed = List[int]() - type_vars = List[TypeVarDef]() - for i, base in enumerate(defn.base_types): - tvars = self.analyze_typevar_declaration(base) - if tvars is not None: - if type_vars: - self.fail('Duplicate Generic or AbstractGeneric in bases', - defn) - removed.append(i) - for j, tvar in enumerate(tvars): - name, values = tvar - type_vars.append(TypeVarDef(name, j + 1, values)) - if type_vars: - defn.type_vars = type_vars - if defn.info: - defn.info.type_vars = [tv.name for tv in type_vars] - for i in reversed(removed): - del defn.base_types[i] - - def analyze_typevar_declaration(self, t: Type) -> List[Tuple[str, - List[Type]]]: - if not isinstance(t, UnboundType): - return None - unbound = cast(UnboundType, t) - sym = self.lookup_qualified(unbound.name, unbound) - if sym is None: - return None - if sym.node.fullname() in ('typing.Generic', - 'typing.AbstractGeneric'): - tvars = List[Tuple[str, List[Type]]]() - for arg in unbound.args: - tvar = self.analyze_unbound_tvar(arg) - if tvar: - tvars.append(tvar) - else: - self.fail('Free type variable expected in %s[...]' % - sym.node.name(), t) - return tvars - return None - - def analyze_unbound_tvar(self, t: Type) -> Tuple[str, List[Type]]: - if not isinstance(t, UnboundType): - return None - unbound = cast(UnboundType, t) - sym = self.lookup_qualified(unbound.name, unbound) - if sym is not None and sym.kind == UNBOUND_TVAR: - return unbound.name, cast(TypeVarExpr, sym.node).values[:] - return None - - def setup_class_def_analysis(self, defn: ClassDef) -> None: - """Prepare for the analysis of a class definition.""" - if not defn.info: - defn.info = TypeInfo(SymbolTable(), defn) - defn.info._fullname = defn.info.name() - if self.is_func_scope() or self.type: - kind = MDEF - if self.is_func_scope(): - kind = LDEF - self.add_symbol(defn.name, SymbolTableNode(kind, defn.info), defn) - if self.type_stack: - # Disable type variables of the enclosing class. - disable_typevars(self.type_stack[-1][1]) - tvarnodes = self.add_class_type_variables_to_symbol_table(defn.info) - # Remember previous active class and type vars of *this* class. - self.type_stack.append((self.type, tvarnodes)) - self.locals.append(None) # Add class scope - self.block_depth.append(-1) # The class body increments this to 0 - self.type = defn.info - - def analyze_base_classes(self, defn: ClassDef) -> None: - """Analyze and set up base classes.""" - bases = List[Instance]() - for i in range(len(defn.base_types)): - base = self.anal_type(defn.base_types[i]) - if isinstance(base, Instance): - defn.base_types[i] = base - bases.append(base) - # Add 'object' as implicit base if there is no other base class. - if (not bases and defn.fullname != 'builtins.object'): - obj = self.object_type() - defn.base_types.insert(0, obj) - bases.append(obj) - defn.info.bases = bases - if not self.verify_base_classes(defn): - return - try: - defn.info.calculate_mro() - except MroError: - self.fail("Cannot determine consistent method resolution order " - '(MRO) for "%s"' % defn.name, defn) - else: - # If there are cyclic imports, we may be missing 'object' in - # the MRO. Fix MRO if needed. - if defn.info.mro[-1].fullname() != 'builtins.object': - defn.info.mro.append(self.object_type().type) - - def verify_base_classes(self, defn: ClassDef) -> bool: - base_classes = List[str]() - info = defn.info - for base in info.bases: - baseinfo = base.type - if self.is_base_class(info, baseinfo): - self.fail('Cycle in inheritance hierarchy', defn) - # Clear bases to forcefully get rid of the cycle. - info.bases = [] - if baseinfo.fullname() == 'builtins.bool': - self.fail("'%s' is not a valid base class" % - baseinfo.name(), defn) - return False - dup = find_duplicate(info.direct_base_classes()) - if dup: - self.fail('Duplicate base class "%s"' % dup.name(), defn) - return False - return True - - def is_base_class(self, t: TypeInfo, s: TypeInfo) -> bool: - """Determine if t is a base class of s (but do not use mro).""" - # Search the base class graph for t, starting from s. - worklist = [s] - visited = {s} - while worklist: - nxt = worklist.pop() - if nxt == t: - return True - for base in nxt.bases: - if base.type not in visited: - worklist.append(base.type) - visited.add(base.type) - return False - - def analyze_metaclass(self, defn: ClassDef) -> None: - if defn.metaclass: - sym = self.lookup_qualified(defn.metaclass, defn) - if sym is not None and not isinstance(sym.node, TypeInfo): - self.fail("Invalid metaclass '%s'" % defn.metaclass, defn) - - def object_type(self) -> Instance: - return self.named_type('__builtins__.object') - - def named_type(self, qualified_name: str) -> Instance: - sym = self.lookup_qualified(qualified_name, None) - return Instance(cast(TypeInfo, sym.node), []) - - def is_instance_type(self, t: Type) -> bool: - return isinstance(t, Instance) - - def add_class_type_variables_to_symbol_table( - self, info: TypeInfo) -> List[SymbolTableNode]: - vars = info.type_vars - nodes = List[SymbolTableNode]() - if vars: - for i in range(len(vars)): - node = self.add_type_var(vars[i], i + 1, info) - nodes.append(node) - return nodes - - def visit_import(self, i: Import) -> None: - for id, as_id in i.ids: - if as_id != id: - self.add_module_symbol(id, as_id, i) - else: - base = id.split('.')[0] - self.add_module_symbol(base, base, i) - - def add_module_symbol(self, id: str, as_id: str, context: Context) -> None: - if id in self.modules: - m = self.modules[id] - self.add_symbol(as_id, SymbolTableNode(MODULE_REF, m, self.cur_mod_id), context) - else: - self.add_unknown_symbol(as_id, context) - - def visit_import_from(self, i: ImportFrom) -> None: - if i.id in self.modules: - m = self.modules[i.id] - for id, as_id in i.names: - node = m.names.get(id, None) - if node: - node = self.normalize_type_alias(node, i) - if not node: - return - self.add_symbol(as_id, SymbolTableNode(node.kind, node.node, - self.cur_mod_id), i) - else: - self.fail("Module has no attribute '{}'".format(id), i) - else: - for id, as_id in i.names: - self.add_unknown_symbol(as_id, i) - - def normalize_type_alias(self, node: SymbolTableNode, - ctx: Context) -> SymbolTableNode: - if node.fullname in type_aliases: - # Node refers to an aliased type such as typing.List; normalize. - node = self.lookup_qualified(type_aliases[node.fullname], ctx) - return node - - def visit_import_all(self, i: ImportAll) -> None: - if i.id in self.modules: - m = self.modules[i.id] - for name, node in m.names.items(): - node = self.normalize_type_alias(node, i) - if not name.startswith('_'): - self.add_symbol(name, SymbolTableNode(node.kind, node.node, - self.cur_mod_id), i) - else: - # Don't add any dummy symbols for 'from x import *' if 'x' is unknown. - pass - - def add_unknown_symbol(self, name: str, context: Context) -> None: - var = Var(name) - var._fullname = self.qualified_name(name) - var.is_ready = True - var.type = AnyType() - self.add_symbol(name, SymbolTableNode(GDEF, var, self.cur_mod_id), context) - - # - # Statements - # - - def visit_block(self, b: Block) -> None: - if b.is_unreachable: - return - self.block_depth[-1] += 1 - for s in b.body: - s.accept(self) - self.block_depth[-1] -= 1 - - def visit_block_maybe(self, b: Block) -> None: - if b: - self.visit_block(b) - - def visit_var_def(self, defn: VarDef) -> None: - for i in range(len(defn.items)): - defn.items[i].type = self.anal_type(defn.items[i].type) - - for v in defn.items: - if self.is_func_scope(): - defn.kind = LDEF - self.add_local(v, defn) - elif self.type: - v.info = self.type - v.is_initialized_in_class = defn.init is not None - self.type.names[v.name()] = SymbolTableNode(MDEF, v, - typ=v.type) - elif v.name not in self.globals: - defn.kind = GDEF - self.add_var(v, defn) - - if defn.init: - defn.init.accept(self) - - def anal_type(self, t: Type) -> Type: - if t: - a = TypeAnalyser(self.lookup_qualified, self.stored_vars, self.fail) - return t.accept(a) - else: - return None - - def visit_assignment_stmt(self, s: AssignmentStmt) -> None: - for lval in s.lvalues: - self.analyse_lvalue(lval, explicit_type=s.type is not None) - s.rvalue.accept(self) - if s.type: - s.type = self.anal_type(s.type) - else: - s.type = self.infer_type_from_undefined(s.rvalue) - # For simple assignments, allow binding type aliases - if (s.type is None and len(s.lvalues) == 1 and - isinstance(s.lvalues[0], NameExpr)): - res = analyse_node(self.lookup_qualified, s.rvalue, s) - if res: - # XXX Need to remove this later if reassigned - x = cast(NameExpr, s.lvalues[0]) - self.stored_vars[x.node] = res - - if s.type: - # Store type into nodes. - for lvalue in s.lvalues: - self.store_declared_types(lvalue, s.type) - self.check_and_set_up_type_alias(s) - self.process_typevar_declaration(s) - - def check_and_set_up_type_alias(self, s: AssignmentStmt) -> None: - """Check if assignment creates a type alias and set it up as needed.""" - # For now, type aliases only work at the top level of a module. - if (len(s.lvalues) == 1 and not self.is_func_scope() and not self.type - and not s.type): - lvalue = s.lvalues[0] - if isinstance(lvalue, NameExpr): - if not lvalue.is_def: - # Only a definition can create a type alias, not regular assignment. - return - rvalue = s.rvalue - if isinstance(rvalue, RefExpr): - node = rvalue.node - if isinstance(node, TypeInfo): - # TODO: We should record the fact that this is a variable - # that refers to a type, rather than making this - # just an alias for the type. - self.globals[lvalue.name].node = node - - def analyse_lvalue(self, lval: Node, nested: bool = False, - add_global: bool = False, - explicit_type: bool = False) -> None: - """Analyze an lvalue or assignment target. - - Only if add_global is True, add name to globals table. If nested - is true, the lvalue is within a tuple or list lvalue expression. - """ - if isinstance(lval, NameExpr): - nested_global = (not self.is_func_scope() and - self.block_depth[-1] > 0 and - not self.type) - if (add_global or nested_global) and lval.name not in self.globals: - # Define new global name. - v = Var(lval.name) - v._fullname = self.qualified_name(lval.name) - v.is_ready = False # Type not inferred yet - lval.node = v - lval.is_def = True - lval.kind = GDEF - lval.fullname = v._fullname - self.globals[lval.name] = SymbolTableNode(GDEF, v, - self.cur_mod_id) - elif isinstance(lval.node, Var) and lval.is_def: - # Since the is_def flag is set, this must have been analyzed - # already in the first pass and added to the symbol table. - v = cast(Var, lval.node) - assert v.name() in self.globals - elif (self.is_func_scope() and lval.name not in self.locals[-1] and - lval.name not in self.global_decls[-1]): - # Define new local name. - v = Var(lval.name) - lval.node = v - lval.is_def = True - lval.kind = LDEF - lval.fullname = lval.name - self.add_local(v, lval) - elif not self.is_func_scope() and (self.type and - lval.name not in self.type.names): - # Define a new attribute within class body. - v = Var(lval.name) - v.info = self.type - v.is_initialized_in_class = True - lval.node = v - lval.is_def = True - lval.kind = MDEF - lval.fullname = lval.name - self.type.names[lval.name] = SymbolTableNode(MDEF, v) - else: - # Bind to an existing name. - if explicit_type: - self.name_already_defined(lval.name, lval) - lval.accept(self) - self.check_lvalue_validity(lval.node, lval) - elif isinstance(lval, MemberExpr): - if not add_global: - self.analyse_member_lvalue(lval) - if explicit_type and not self.is_self_member_ref(lval): - self.fail('Type cannot be declared in assignment to non-self ' - 'attribute', lval) - elif isinstance(lval, IndexExpr): - if explicit_type: - self.fail('Unexpected type declaration', lval) - if not add_global: - lval.accept(self) - elif isinstance(lval, ParenExpr): - self.analyse_lvalue(lval.expr, nested, add_global, explicit_type) - elif (isinstance(lval, TupleExpr) or - isinstance(lval, ListExpr)) and not nested: - items = (Any(lval)).items - for i in items: - self.analyse_lvalue(i, nested=True, add_global=add_global, - explicit_type = explicit_type) - else: - self.fail('Invalid assignment target', lval) - - def analyse_member_lvalue(self, lval: MemberExpr) -> None: - lval.accept(self) - if (self.is_self_member_ref(lval) and - self.type.get(lval.name) is None): - # Implicit attribute definition in __init__. - lval.is_def = True - v = Var(lval.name) - v.info = self.type - v.is_ready = False - lval.def_var = v - lval.node = v - self.type.names[lval.name] = SymbolTableNode(MDEF, v) - self.check_lvalue_validity(lval.node, lval) - - def is_self_member_ref(self, memberexpr: MemberExpr) -> bool: - """Does memberexpr to refer to an attribute of self?""" - if not isinstance(memberexpr.expr, NameExpr): - return False - node = (cast(NameExpr, memberexpr.expr)).node - return isinstance(node, Var) and (cast(Var, node)).is_self - - def check_lvalue_validity(self, node: Node, ctx: Context) -> None: - if isinstance(node, (FuncDef, TypeInfo, TypeVarExpr)): - self.fail('Invalid assignment target', ctx) - - def infer_type_from_undefined(self, rvalue: Node) -> Type: - if isinstance(rvalue, CallExpr): - if isinstance(rvalue.analyzed, UndefinedExpr): - undef = cast(UndefinedExpr, rvalue.analyzed) - return undef.type - return None - - def store_declared_types(self, lvalue: Node, typ: Type) -> None: - if isinstance(lvalue, RefExpr): - lvalue.is_def = False - if isinstance(lvalue.node, Var): - var = cast(Var, lvalue.node) - var.type = typ - var.is_ready = True - # If node is not a variable, we'll catch it elsewhere. - elif isinstance(lvalue, TupleExpr): - if isinstance(typ, TupleType): - if len(lvalue.items) != len(typ.items): - self.fail('Incompatible number of tuple items', lvalue) - return - for item, itemtype in zip(lvalue.items, typ.items): - self.store_declared_types(item, itemtype) - else: - self.fail('Tuple type expected for multiple variables', - lvalue) - elif isinstance(lvalue, ParenExpr): - self.store_declared_types(lvalue.expr, typ) - else: - raise RuntimeError('Internal error (%s)' % type(lvalue)) - - def process_typevar_declaration(self, s: AssignmentStmt) -> None: - """Check if s declares a typevar; it yes, store it in symbol table.""" - if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): - return - if not isinstance(s.rvalue, CallExpr): - return - call = cast(CallExpr, s.rvalue) - if not isinstance(call.callee, RefExpr): - return - callee = cast(RefExpr, call.callee) - if callee.fullname != 'typing.typevar': - return - # TODO Share code with check_argument_count in checkexpr.py? - if len(call.args) < 1: - self.fail("Too few arguments for typevar()", s) - return - if len(call.args) > 2: - self.fail("Too many arguments for typevar()", s) - return - if call.arg_kinds not in ([ARG_POS], [ARG_POS, ARG_NAMED]): - self.fail("Unexpected arguments to typevar()", s) - return - if not isinstance(call.args[0], StrExpr): - self.fail("typevar() expects a string literal argument", s) - return - lvalue = cast(NameExpr, s.lvalues[0]) - name = lvalue.name - if cast(StrExpr, call.args[0]).value != name: - self.fail("Unexpected typevar() argument value", s) - return - if not lvalue.is_def: - if s.type: - self.fail("Cannot declare the type of a type variable", s) - else: - self.fail("Cannot redefine '%s' as a type variable" % name, s) - return - if len(call.args) == 2: - # Analyze values=(...) argument. - if call.arg_names[1] != 'values': - self.fail("Unexpected keyword argument '{}' to typevar()". - format(call.arg_names[1]), s) - return - if isinstance(call.args[1], ParenExpr): - expr = cast(ParenExpr, call.args[1]).expr - if isinstance(expr, TupleExpr): - values = self.analyze_types(expr.items) - else: - self.fail('The values argument must be a tuple literal', s) - return - else: - self.fail('The values argument must be in parentheses (...)', - s) - return - else: - values = [] - # Yes, it's a valid type variable definition! - node = self.lookup(name, s) - node.kind = UNBOUND_TVAR - typevar = TypeVarExpr(name, node.fullname, values) - typevar.line = call.line - call.analyzed = typevar - node.node = typevar - - def analyze_types(self, items: List[Node]) -> List[Type]: - result = List[Type]() - for node in items: - try: - result.append(self.anal_type(expr_to_unanalyzed_type(node))) - except TypeTranslationError: - self.fail('Type expected', node) - result.append(AnyType()) - return result - - def visit_decorator(self, dec: Decorator) -> None: - if not dec.is_overload: - if self.is_func_scope(): - self.add_symbol(dec.var.name(), SymbolTableNode(LDEF, dec), - dec) - elif self.type: - dec.var.info = self.type - dec.var.is_initialized_in_class = True - self.add_symbol(dec.var.name(), SymbolTableNode(MDEF, dec), - dec) - for d in dec.decorators: - d.accept(self) - removed = List[int]() - for i, d in enumerate(dec.decorators): - if refers_to_fullname(d, 'abc.abstractmethod'): - removed.append(i) - dec.func.is_abstract = True - self.check_decorated_function_is_method('abstractmethod', dec) - elif refers_to_fullname(d, 'asyncio.tasks.coroutine'): - removed.append(i) - dec.func.is_coroutine = True - elif refers_to_fullname(d, 'builtins.staticmethod'): - removed.append(i) - dec.func.is_static = True - dec.var.is_staticmethod = True - self.check_decorated_function_is_method('staticmethod', dec) - elif refers_to_fullname(d, 'builtins.classmethod'): - removed.append(i) - dec.func.is_class = True - dec.var.is_classmethod = True - self.check_decorated_function_is_method('classmethod', dec) - elif refers_to_fullname(d, 'builtins.property'): - removed.append(i) - dec.func.is_property = True - dec.var.is_property = True - if dec.is_overload: - self.fail('A property cannot be overloaded', dec) - self.check_decorated_function_is_method('property', dec) - if len(dec.func.args) > 1: - self.fail('Too many arguments', dec.func) - for i in reversed(removed): - del dec.decorators[i] - dec.func.accept(self) - if not dec.decorators and not dec.var.is_property: - # No non-special decorators left. We can trivially infer the type - # of the function here. - dec.var.type = dec.func.type - - def check_decorated_function_is_method(self, decorator: str, - context: Context) -> None: - if not self.type or self.is_func_scope(): - self.fail("'%s' used with a non-method" % decorator, context) - - def visit_expression_stmt(self, s: ExpressionStmt) -> None: - s.expr.accept(self) - - def visit_return_stmt(self, s: ReturnStmt) -> None: - if not self.is_func_scope(): - self.fail("'return' outside function", s) - if s.expr: - s.expr.accept(self) - - def visit_raise_stmt(self, s: RaiseStmt) -> None: - if s.expr: - s.expr.accept(self) - - def visit_yield_stmt(self, s: YieldStmt) -> None: - if not self.is_func_scope(): - self.fail("'yield' outside function", s) - else: - self.function_stack[-1].is_generator = True - if s.expr: - s.expr.accept(self) - -<<<<<<< HEAD - def visit_yield_from_stmt(self, s: YieldFromStmt) -> None: - if not self.is_func_scope(): - self.fail("'yield from' outside function", s) - if s.expr: - s.expr.accept(self) - -======= ->>>>>>> master - def visit_assert_stmt(self, s: AssertStmt) -> None: - if s.expr: - s.expr.accept(self) - - def visit_operator_assignment_stmt(self, - s: OperatorAssignmentStmt) -> None: - s.lvalue.accept(self) - s.rvalue.accept(self) - - def visit_while_stmt(self, s: WhileStmt) -> None: - s.expr.accept(self) - self.loop_depth += 1 - s.body.accept(self) - self.loop_depth -= 1 - self.visit_block_maybe(s.else_body) - - def visit_for_stmt(self, s: ForStmt) -> None: - s.expr.accept(self) - - # Bind index variables and check if they define new names. - for n in s.index: - self.analyse_lvalue(n) - - # Analyze index variable types. - for i in range(len(s.types)): - t = s.types[i] - if t: - s.types[i] = self.anal_type(t) - v = cast(Var, s.index[i].node) - # TODO check if redefinition - v.type = s.types[i] - - # Report error if only some of the loop variables have annotations. - if s.types != [None] * len(s.types) and None in s.types: - self.fail('Cannot mix unannotated and annotated loop variables', s) - - self.loop_depth += 1 - self.visit_block(s.body) - self.loop_depth -= 1 - - self.visit_block_maybe(s.else_body) - - def visit_break_stmt(self, s: BreakStmt) -> None: - if self.loop_depth == 0: - self.fail("'break' outside loop", s) - - def visit_continue_stmt(self, s: ContinueStmt) -> None: - if self.loop_depth == 0: - self.fail("'continue' outside loop", s) - - def visit_if_stmt(self, s: IfStmt) -> None: - infer_reachability_of_if_statement(s, pyversion=self.pyversion) - for i in range(len(s.expr)): - s.expr[i].accept(self) - self.visit_block(s.body[i]) - self.visit_block_maybe(s.else_body) - - def visit_try_stmt(self, s: TryStmt) -> None: - self.analyze_try_stmt(s, self) - - def analyze_try_stmt(self, s: TryStmt, visitor: NodeVisitor, - add_global: bool = False) -> None: - s.body.accept(visitor) - for type, var, handler in zip(s.types, s.vars, s.handlers): - if type: - type.accept(visitor) - if var: - self.analyse_lvalue(var, add_global=add_global) - handler.accept(visitor) - if s.else_body: - s.else_body.accept(visitor) - if s.finally_body: - s.finally_body.accept(visitor) - - def visit_with_stmt(self, s: WithStmt) -> None: - for e in s.expr: - e.accept(self) - for n in s.name: - if n: - self.analyse_lvalue(n) - self.visit_block(s.body) - - def visit_del_stmt(self, s: DelStmt) -> None: - s.expr.accept(self) - if not isinstance(s.expr, (IndexExpr, NameExpr, MemberExpr)): - self.fail('Invalid delete target', s) - - def visit_global_decl(self, g: GlobalDecl) -> None: - for n in g.names: - self.global_decls[-1].add(n) - - def visit_print_stmt(self, s: PrintStmt) -> None: - for arg in s.args: - arg.accept(self) - - # - # Expressions - # - - def visit_name_expr(self, expr: NameExpr) -> None: - n = self.lookup(expr.name, expr) - if n: - if n.kind == TVAR: - self.fail("'{}' is a type variable and only valid in type " - "context".format(expr.name), expr) - else: - expr.kind = n.kind - expr.node = (cast(Node, n.node)) - expr.fullname = n.fullname - - def visit_super_expr(self, expr: SuperExpr) -> None: - if not self.type: - self.fail('"super" used outside class', expr) - return - expr.info = self.type - - def visit_tuple_expr(self, expr: TupleExpr) -> None: - for item in expr.items: - item.accept(self) - - def visit_list_expr(self, expr: ListExpr) -> None: - for item in expr.items: - item.accept(self) - - def visit_set_expr(self, expr: SetExpr) -> None: - for item in expr.items: - item.accept(self) - - def visit_dict_expr(self, expr: DictExpr) -> None: - for key, value in expr.items: - key.accept(self) - value.accept(self) - - def visit_paren_expr(self, expr: ParenExpr) -> None: - expr.expr.accept(self) - -<<<<<<< HEAD - def visit_yield_from_expr(self, e: YieldFromExpr) -> None: - if not self.is_func_scope(): # not sure - self.fail("'yield from' outside function", s) - if e.expr: - e.expr.accept(self) - -======= ->>>>>>> master - def visit_call_expr(self, expr: CallExpr) -> None: - """Analyze a call expression. - - Some call expressions are recognized as special forms, including - cast(...), Undefined(...) and Any(...). - """ - expr.callee.accept(self) - if refers_to_fullname(expr.callee, 'typing.cast'): - # Special form cast(...). - if not self.check_fixed_args(expr, 2, 'cast'): - return - # Translate first argument to an unanalyzed type. - try: - target = expr_to_unanalyzed_type(expr.args[0]) - except TypeTranslationError: - self.fail('Cast target is not a type', expr) - return - # Pigguback CastExpr object to the CallExpr object; it takes - # precedence over the CallExpr semantics. - expr.analyzed = CastExpr(expr.args[1], target) - expr.analyzed.line = expr.line - expr.analyzed.accept(self) - elif refers_to_fullname(expr.callee, 'typing.Any'): - # Special form Any(...). - if not self.check_fixed_args(expr, 1, 'Any'): - return - expr.analyzed = CastExpr(expr.args[0], AnyType()) - expr.analyzed.line = expr.line - expr.analyzed.accept(self) - elif refers_to_fullname(expr.callee, 'typing.Undefined'): - # Special form Undefined(...). - if not self.check_fixed_args(expr, 1, 'Undefined'): - return - try: - type = expr_to_unanalyzed_type(expr.args[0]) - except TypeTranslationError: - self.fail('Argument to Undefined is not a type', expr) - return - expr.analyzed = UndefinedExpr(type) - expr.analyzed.line = expr.line - expr.analyzed.accept(self) - elif refers_to_fullname(expr.callee, 'typing.ducktype'): - # Special form ducktype(...). - if not self.check_fixed_args(expr, 1, 'ducktype'): - return - # Translate first argument to an unanalyzed type. - try: - target = expr_to_unanalyzed_type(expr.args[0]) - except TypeTranslationError: - self.fail('Argument 1 to ducktype is not a type', expr) - return - expr.analyzed = DucktypeExpr(target) - expr.analyzed.line = expr.line - expr.analyzed.accept(self) - elif refers_to_fullname(expr.callee, 'typing.disjointclass'): - # Special form disjointclass(...). - if not self.check_fixed_args(expr, 1, 'disjointclass'): - return - arg = expr.args[0] - if isinstance(arg, RefExpr): - expr.analyzed = DisjointclassExpr(arg) - expr.analyzed.line = expr.line - expr.analyzed.accept(self) - else: - self.fail('Argument 1 to disjointclass is not a class', expr) - else: - # Normal call expression. - for a in expr.args: - a.accept(self) - - def check_fixed_args(self, expr: CallExpr, numargs: int, - name: str) -> bool: - """Verify that expr has specified number of positional args. - - Return True if the arguments are valid. - """ - s = 's' - if numargs == 1: - s = '' - if len(expr.args) != numargs: - self.fail("'%s' expects %d argument%s" % (name, numargs, s), - expr) - return False - if expr.arg_kinds != [ARG_POS] * numargs: - self.fail("'%s' must be called with %s positional argument%s" % - (name, numargs, s), expr) - return False - return True - - def visit_member_expr(self, expr: MemberExpr) -> None: - base = expr.expr - base.accept(self) - # Bind references to module attributes. - if isinstance(base, RefExpr) and cast(RefExpr, - base).kind == MODULE_REF: - names = (cast(MypyFile, (cast(RefExpr, base)).node)).names - n = names.get(expr.name, None) - if n: - n = self.normalize_type_alias(n, expr) - if not n: - return - expr.kind = n.kind - expr.fullname = n.fullname - expr.node = n.node - - def visit_op_expr(self, expr: OpExpr) -> None: - expr.left.accept(self) - expr.right.accept(self) - -<<<<<<< HEAD -======= - def visit_comparison_expr(self, expr: ComparisonExpr) -> None: - for operand in expr.operands: - operand.accept(self) - ->>>>>>> master - def visit_unary_expr(self, expr: UnaryExpr) -> None: - expr.expr.accept(self) - - def visit_index_expr(self, expr: IndexExpr) -> None: - expr.base.accept(self) - if refers_to_class_or_function(expr.base): - # Special form -- type application. - # Translate index to an unanalyzed type. - types = List[Type]() - if isinstance(expr.index, TupleExpr): - items = (cast(TupleExpr, expr.index)).items - else: - items = [expr.index] - for item in items: - try: - typearg = expr_to_unanalyzed_type(item) - except TypeTranslationError: - self.fail('Type expected within [...]', expr) - return - typearg = self.anal_type(typearg) - types.append(typearg) - expr.analyzed = TypeApplication(expr.base, types) - expr.analyzed.line = expr.line - else: - expr.index.accept(self) - - def visit_slice_expr(self, expr: SliceExpr) -> None: - if expr.begin_index: - expr.begin_index.accept(self) - if expr.end_index: - expr.end_index.accept(self) - if expr.stride: - expr.stride.accept(self) - - def visit_cast_expr(self, expr: CastExpr) -> None: - expr.expr.accept(self) - expr.type = self.anal_type(expr.type) - - def visit_undefined_expr(self, expr: UndefinedExpr) -> None: - expr.type = self.anal_type(expr.type) - - def visit_type_application(self, expr: TypeApplication) -> None: - expr.expr.accept(self) - for i in range(len(expr.types)): - expr.types[i] = self.anal_type(expr.types[i]) - - def visit_list_comprehension(self, expr: ListComprehension) -> None: - expr.generator.accept(self) - - def visit_generator_expr(self, expr: GeneratorExpr) -> None: - self.enter() - - for index, sequence, conditions in zip(expr.indices, expr.sequences, - expr.condlists): - sequence.accept(self) - # Bind index variables. - for n in index: - self.analyse_lvalue(n) - for cond in conditions: - cond.accept(self) - - # TODO analyze variable types (see visit_for_stmt) - - expr.left_expr.accept(self) - self.leave() - - def visit_func_expr(self, expr: FuncExpr) -> None: - self.analyse_function(expr) - - def visit_conditional_expr(self, expr: ConditionalExpr) -> None: - expr.if_expr.accept(self) - expr.cond.accept(self) - expr.else_expr.accept(self) - - def visit_ducktype_expr(self, expr: DucktypeExpr) -> None: - expr.type = self.anal_type(expr.type) - - def visit_disjointclass_expr(self, expr: DisjointclassExpr) -> None: - expr.cls.accept(self) - - # - # Helpers - # - - def lookup(self, name: str, ctx: Context) -> SymbolTableNode: - """Look up an unqualified name in all active namespaces.""" - # 1. Name declared using 'global x' takes precedence - if name in self.global_decls[-1]: - if name in self.globals: - return self.globals[name] - else: - self.name_not_defined(name, ctx) - return None - # 2. Class attributes (if within class definition) - if self.is_class_scope() and name in self.type.names: - return self.type[name] - # 3. Local (function) scopes - for table in reversed(self.locals): - if table is not None and name in table: - return table[name] - # 4. Current file global scope - if name in self.globals: - return self.globals[name] - # 5. Builtins - b = self.globals.get('__builtins__', None) - if b: - table = cast(MypyFile, b.node).names - if name in table: - if name[0] == "_" and name[1] != "_": - self.name_not_defined(name, ctx) - return None - node = table[name] - # Only succeed if we are not using a type alias such List -- these must be - # be accessed via the typing module. - if node.node.name() == name: - return node - # Give up. - self.name_not_defined(name, ctx) - return None - - def lookup_qualified(self, name: str, ctx: Context) -> SymbolTableNode: - if '.' not in name: - return self.lookup(name, ctx) - else: - parts = name.split('.') - n = self.lookup(parts[0], ctx) # type: SymbolTableNode - if n: - for i in range(1, len(parts)): - if isinstance(n.node, TypeInfo): - n = (cast(TypeInfo, n.node)).get(parts[i]) - elif isinstance(n.node, MypyFile): - n = (cast(MypyFile, n.node)).names.get(parts[i], None) - if not n: - self.name_not_defined(name, ctx) - if n: - n = self.normalize_type_alias(n, ctx) - return n - - def qualified_name(self, n: str) -> str: - return self.cur_mod_id + '.' + n - - def enter(self) -> None: - self.locals.append(SymbolTable()) - self.global_decls.append(set()) - - def leave(self) -> None: - self.locals.pop() - self.global_decls.pop() - - def is_func_scope(self) -> bool: - return self.locals[-1] is not None - - def is_class_scope(self) -> bool: - return self.type is not None and not self.is_func_scope() - - def add_symbol(self, name: str, node: SymbolTableNode, - context: Context) -> None: - if self.is_func_scope(): - if name in self.locals[-1]: - # Flag redefinition unless this is a reimport of a module. - if not (node.kind == MODULE_REF and - self.locals[-1][name].node == node.node): - self.name_already_defined(name, context) - self.locals[-1][name] = node - elif self.type: - self.type.names[name] = node - else: - if name in self.globals and (not isinstance(node.node, MypyFile) or - self.globals[name].node != node.node): - # Modules can be imported multiple times to support import - # of multiple submodules of a package (e.g. a.x and a.y). - self.name_already_defined(name, context) - self.globals[name] = node - - def add_var(self, v: Var, ctx: Context) -> None: - if self.is_func_scope(): - self.add_local(v, ctx) - else: - self.globals[v.name()] = SymbolTableNode(GDEF, v, self.cur_mod_id) - v._fullname = self.qualified_name(v.name()) - - def add_local(self, v: Var, ctx: Context) -> None: - if v.name() in self.locals[-1]: - self.name_already_defined(v.name(), ctx) - v._fullname = v.name() - self.locals[-1][v.name()] = SymbolTableNode(LDEF, v) - - def add_local_func(self, defn: FuncBase, ctx: Context) -> None: - # TODO combine with above - if defn.name() in self.locals[-1]: - self.name_already_defined(defn.name(), ctx) - self.locals[-1][defn.name()] = SymbolTableNode(LDEF, defn) - - def check_no_global(self, n: str, ctx: Context, - is_func: bool = False) -> None: - if n in self.globals: - if is_func and isinstance(self.globals[n].node, FuncDef): - self.fail(("Name '{}' already defined (overload variants " - "must be next to each other)").format(n), ctx) - else: - self.name_already_defined(n, ctx) - - def name_not_defined(self, name: str, ctx: Context) -> None: - self.fail("Name '{}' is not defined".format(name), ctx) - - def name_already_defined(self, name: str, ctx: Context) -> None: - self.fail("Name '{}' already defined".format(name), ctx) - - def fail(self, msg: str, ctx: Context) -> None: - self.errors.report(ctx.get_line(), msg) - - -class FirstPass(NodeVisitor): - """First phase of semantic analysis""" - - def __init__(self, sem: SemanticAnalyzer) -> None: - self.sem = sem - self.pyversion = sem.pyversion - - def analyze(self, file: MypyFile, fnam: str, mod_id: str) -> None: - """Perform the first analysis pass. - - Resolve the full names of definitions not nested within functions and - construct type info structures, but do not resolve inter-definition - references such as base classes. - - Also add implicit definitions such as __name__. - """ - sem = self.sem - sem.cur_mod_id = mod_id - sem.errors.set_file(fnam) - sem.globals = SymbolTable() - sem.global_decls = [set()] - sem.block_depth = [0] - - defs = file.defs - - # Add implicit definitions of module '__name__' etc. - for n in implicit_module_attrs: - name_def = VarDef([Var(n, AnyType())], True) - defs.insert(0, name_def) - - for d in defs: - d.accept(self) - - # Add implicit definition of 'None' to builtins, as we cannot define a - # variable with a None type explicitly. - if mod_id == 'builtins': - none_def = VarDef([Var('None', NoneTyp())], True) - defs.append(none_def) - none_def.accept(self) - - def visit_block(self, b: Block) -> None: - if b.is_unreachable: - return - self.sem.block_depth[-1] += 1 - for node in b.body: - node.accept(self) - self.sem.block_depth[-1] -= 1 - - def visit_assignment_stmt(self, s: AssignmentStmt) -> None: - for lval in s.lvalues: - self.sem.analyse_lvalue(lval, add_global=True, - explicit_type=s.type is not None) - - def visit_func_def(self, d: FuncDef) -> None: - sem = self.sem - d.is_conditional = sem.block_depth[-1] > 0 - if d.name() in sem.globals: - n = sem.globals[d.name()].node - if sem.is_conditional_func(n, d): - # Conditional function definition -- multiple defs are ok. - d.original_def = cast(FuncDef, n) - else: - sem.check_no_global(d.name(), d, True) - d._fullname = sem.qualified_name(d.name()) - sem.globals[d.name()] = SymbolTableNode(GDEF, d, sem.cur_mod_id) - - def visit_overloaded_func_def(self, d: OverloadedFuncDef) -> None: - self.sem.check_no_global(d.name(), d) - d._fullname = self.sem.qualified_name(d.name()) - self.sem.globals[d.name()] = SymbolTableNode(GDEF, d, - self.sem.cur_mod_id) - - def visit_class_def(self, d: ClassDef) -> None: - self.sem.check_no_global(d.name, d) - d.fullname = self.sem.qualified_name(d.name) - info = TypeInfo(SymbolTable(), d) - info.set_line(d.line) - d.info = info - self.sem.globals[d.name] = SymbolTableNode(GDEF, info, - self.sem.cur_mod_id) - - def visit_var_def(self, d: VarDef) -> None: - for v in d.items: - self.sem.check_no_global(v.name(), d) - v._fullname = self.sem.qualified_name(v.name()) - self.sem.globals[v.name()] = SymbolTableNode(GDEF, v, - self.sem.cur_mod_id) - - def visit_for_stmt(self, s: ForStmt) -> None: - for n in s.index: - self.sem.analyse_lvalue(n, add_global=True) - - def visit_with_stmt(self, s: WithStmt) -> None: - for n in s.name: - if n: - self.sem.analyse_lvalue(n, add_global=True) - - def visit_decorator(self, d: Decorator) -> None: - d.var._fullname = self.sem.qualified_name(d.var.name()) - self.sem.add_symbol(d.var.name(), SymbolTableNode(GDEF, d.var), d) - - def visit_if_stmt(self, s: IfStmt) -> None: - infer_reachability_of_if_statement(s, pyversion=self.pyversion) - for node in s.body: - node.accept(self) - if s.else_body: - s.else_body.accept(self) - - def visit_try_stmt(self, s: TryStmt) -> None: - self.sem.analyze_try_stmt(s, self, add_global=True) - - -class ThirdPass(TraverserVisitor[None]): - """The third and final pass of semantic analysis. - - Check type argument counts and values of generic types. Also update - TypeInfo disjointclass information. - """ - - def __init__(self, errors: Errors) -> None: - self.errors = errors - - def visit_file(self, file_node: MypyFile, fnam: str) -> None: - self.errors.set_file(fnam) - file_node.accept(self) - - def visit_func_def(self, fdef: FuncDef) -> None: - self.errors.push_function(fdef.name()) - self.analyze(fdef.type) - super().visit_func_def(fdef) - self.errors.pop_function() - - def visit_class_def(self, tdef: ClassDef) -> None: - for type in tdef.info.bases: - self.analyze(type) - info = tdef.info - # Collect declared disjoint classes from all base classes. - for base in info.mro: - for disjoint in base.disjoint_classes: - if disjoint not in info.disjoint_classes: - info.disjoint_classes.append(disjoint) - for subtype in disjoint.all_subtypes(): - if info not in subtype.disjoint_classes: - subtype.disjoint_classes.append(info) - super().visit_class_def(tdef) - - def visit_assignment_stmt(self, s: AssignmentStmt) -> None: - self.analyze(s.type) - super().visit_assignment_stmt(s) - - def visit_undefined_expr(self, e: UndefinedExpr) -> None: - self.analyze(e.type) - - def visit_cast_expr(self, e: CastExpr) -> None: - self.analyze(e.type) - super().visit_cast_expr(e) - - def visit_type_application(self, e: TypeApplication) -> None: - for type in e.types: - self.analyze(type) - super().visit_type_application(e) - - def analyze(self, type: Type) -> None: - if type: - analyzer = TypeAnalyserPass3(self.fail) - type.accept(analyzer) - - def fail(self, msg: str, ctx: Context) -> None: - self.errors.report(ctx.get_line(), msg) - - -def self_type(typ: TypeInfo) -> Instance: - """For a non-generic type, return instance type representing the type. - For a generic G type with parameters T1, .., Tn, return G[T1, ..., Tn]. - """ - tv = List[Type]() - for i in range(len(typ.type_vars)): - tv.append(TypeVar(typ.type_vars[i], i + 1, - typ.defn.type_vars[i].values)) - return Instance(typ, tv) - - -@overload -def replace_implicit_first_type(sig: Callable, new: Type) -> Callable: - # We can detect implicit self type by it having no representation. - if not sig.arg_types[0].repr: - return replace_leading_arg_type(sig, new) - else: - return sig - - -@overload -def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike: - osig = cast(Overloaded, sig) - return Overloaded([replace_implicit_first_type(i, new) - for i in osig.items()]) - - -def set_callable_name(sig: Type, fdef: FuncDef) -> Type: - if isinstance(sig, FunctionLike): - if fdef.info: - return sig.with_name( - '"{}" of "{}"'.format(fdef.name(), fdef.info.name())) - else: - return sig.with_name('"{}"'.format(fdef.name())) - else: - return sig - - -def refers_to_fullname(node: Node, fullname: str) -> bool: - """Is node a name or member expression with the given full name?""" - return isinstance(node, - RefExpr) and cast(RefExpr, node).fullname == fullname - - -def refers_to_class_or_function(node: Node) -> bool: - """Does semantically analyzed node refer to a class?""" - return (isinstance(node, RefExpr) and - isinstance(cast(RefExpr, node).node, (TypeInfo, FuncDef, - OverloadedFuncDef))) - - -def expr_to_unanalyzed_type(expr: Node) -> Type: - """Translate an expression to the corresonding type. - - The result is not semantically analyzed. It can be UnboundType or ListType. - Raise TypeTranslationError if the expression cannot represent a type. - """ - if isinstance(expr, NameExpr): - name = expr.name - return UnboundType(name, line=expr.line) - elif isinstance(expr, MemberExpr): - fullname = get_member_expr_fullname(expr) - if fullname: - return UnboundType(fullname, line=expr.line) - else: - raise TypeTranslationError() - elif isinstance(expr, IndexExpr): - base = expr_to_unanalyzed_type(expr.base) - if isinstance(base, UnboundType): - if base.args: - raise TypeTranslationError() - if isinstance(expr.index, TupleExpr): - args = cast(TupleExpr, expr.index).items - else: - args = [expr.index] - base.args = [expr_to_unanalyzed_type(arg) for arg in args] - return base - else: - raise TypeTranslationError() - elif isinstance(expr, ListExpr): - return TypeList([expr_to_unanalyzed_type(t) for t in expr.items], - line=expr.line) - elif isinstance(expr, StrExpr): - # Parse string literal type. - try: - result = parse_str_as_type(expr.value, expr.line) - except TypeParseError: - raise TypeTranslationError() - return result - else: - raise TypeTranslationError() - - -def get_member_expr_fullname(expr: MemberExpr) -> str: - """Return the qualified name represention of a member expression. - - Return a string of form foo.bar, foo.bar.baz, or similar, or None if the - argument cannot be represented in this form. - """ - if isinstance(expr.expr, NameExpr): - initial = cast(NameExpr, expr.expr).name - elif isinstance(expr.expr, MemberExpr): - initial = get_member_expr_fullname(cast(MemberExpr, expr.expr)) - else: - return None - return '{}.{}'.format(initial, expr.name) - - -def find_duplicate(list: List[T]) -> T: - """If the list has duplicates, return one of the duplicates. - - Otherwise, return None. - """ - for i in range(1, len(list)): - if list[i] in list[:i]: - return list[i] - return None - - -def disable_typevars(nodes: List[SymbolTableNode]) -> None: - for node in nodes: - assert node.kind in (TVAR, UNBOUND_TVAR) - node.kind = UNBOUND_TVAR - - -def enable_typevars(nodes: List[SymbolTableNode]) -> None: - for node in nodes: - assert node.kind in (TVAR, UNBOUND_TVAR) - node.kind = TVAR - - -def remove_imported_names_from_symtable(names: SymbolTable, - module: str) -> None: - """Remove all imported names from the symbol table of a module.""" - removed = List[str]() - for name, node in names.items(): - fullname = node.node.fullname() - prefix = fullname[:fullname.rfind('.')] - if prefix != module: - removed.append(name) - for name in removed: - del names[name] - - -def infer_reachability_of_if_statement(s: IfStmt, pyversion: int) -> None: - always_true = False - for i in range(len(s.expr)): - result = infer_if_condition_value(s.expr[i], pyversion) - if result == ALWAYS_FALSE: - # The condition is always false, so we skip the if/elif body. - mark_block_unreachable(s.body[i]) - elif result == ALWAYS_TRUE: - # This condition is always true, so all of the remaining - # elif/else bodies will never be executed. - always_true = True - for body in s.body[i + 1:]: - mark_block_unreachable(s.body[i]) - if s.else_body: - mark_block_unreachable(s.else_body) - break - - -def infer_if_condition_value(expr: Node, pyversion: int) -> int: - """Infer whether if condition is always true/false. - - Return ALWAYS_TRUE if always true, ALWAYS_FALSE if always false, - and TRUTH_VALUE_UNKNOWN otherwise. - """ - name = '' - negated = False - alias = expr - if isinstance(alias, UnaryExpr): - if alias.op == 'not': - expr = alias.expr - negated = True - if isinstance(expr, NameExpr): - name = expr.name - elif isinstance(expr, MemberExpr): - name = expr.name - result = TRUTH_VALUE_UNKNOWN - if name == 'PY2': - result = ALWAYS_TRUE if pyversion == 2 else ALWAYS_FALSE - elif name == 'PY3': - result = ALWAYS_TRUE if pyversion == 3 else ALWAYS_FALSE - elif name == 'MYPY': - result = ALWAYS_TRUE - if negated: - if result == ALWAYS_TRUE: - result = ALWAYS_FALSE - elif result == ALWAYS_FALSE: - result = ALWAYS_TRUE - return result - - -def mark_block_unreachable(block: Block) -> None: - block.is_unreachable = True - block.accept(MarkImportsUnreachableVisitor()) - - -class MarkImportsUnreachableVisitor(TraverserVisitor): - """Visitor that flags all imports nested within a node as unreachable.""" - - def visit_import(self, node: Import) -> None: - node.is_unreachable = True - - def visit_import_from(self, node: ImportFrom) -> None: - node.is_unreachable = True - - def visit_import_all(self, node: ImportAll) -> None: - node.is_unreachable = True diff --git a/mypy/stats.py.orig b/mypy/stats.py.orig deleted file mode 100644 index 3ad237cc7979..000000000000 --- a/mypy/stats.py.orig +++ /dev/null @@ -1,358 +0,0 @@ -"""Utilities for calculating and reporting statistics about types.""" - -import cgi -import os.path -import re - -from typing import Any, Dict, List, cast, Tuple - -from mypy.traverser import TraverserVisitor -from mypy.types import ( - Type, AnyType, Instance, FunctionLike, TupleType, Void, TypeVar, - TypeQuery, ANY_TYPE_STRATEGY, Callable -) -from mypy import nodes -from mypy.nodes import ( - Node, FuncDef, TypeApplication, AssignmentStmt, NameExpr, CallExpr, -<<<<<<< HEAD - MemberExpr, OpExpr, IndexExpr, UnaryExpr, YieldFromExpr -======= - MemberExpr, OpExpr, ComparisonExpr, IndexExpr, UnaryExpr ->>>>>>> master -) - - -TYPE_PRECISE = 0 -TYPE_IMPRECISE = 1 -TYPE_ANY = 2 - - -class StatisticsVisitor(TraverserVisitor): - def __init__(self, inferred: bool, typemap: Dict[Node, Type] = None, - all_nodes: bool = False) -> None: - self.inferred = inferred - self.typemap = typemap - self.all_nodes = all_nodes - - self.num_precise = 0 - self.num_imprecise = 0 - self.num_any = 0 - - self.num_simple = 0 - self.num_generic = 0 - self.num_tuple = 0 - self.num_function = 0 - self.num_typevar = 0 - self.num_complex = 0 - - self.line = -1 - - self.line_map = Dict[int, int]() - - self.output = List[str]() - - TraverserVisitor.__init__(self) - - def visit_func_def(self, o: FuncDef) -> None: - self.line = o.line - if len(o.expanded) > 1: - for defn in o.expanded: - self.visit_func_def(cast(FuncDef, defn)) - else: - if o.type: - sig = cast(Callable, o.type) - arg_types = sig.arg_types - if (sig.arg_names and sig.arg_names[0] == 'self' and - not self.inferred): - arg_types = arg_types[1:] - for arg in arg_types: - self.type(arg) - self.type(sig.ret_type) - elif self.all_nodes: - self.record_line(self.line, TYPE_ANY) - super().visit_func_def(o) - - def visit_type_application(self, o: TypeApplication) -> None: - self.line = o.line - for t in o.types: - self.type(t) - super().visit_type_application(o) - - def visit_assignment_stmt(self, o: AssignmentStmt) -> None: - self.line = o.line - if (isinstance(o.rvalue, nodes.CallExpr) and - isinstance(cast(nodes.CallExpr, o.rvalue).analyzed, - nodes.TypeVarExpr)): - # Type variable definition -- not a real assignment. - return - if o.type: - self.type(o.type) - elif self.inferred: - for lvalue in o.lvalues: - lvalue_ref = lvalue - if isinstance(lvalue_ref, nodes.ParenExpr): - lvalue = lvalue_ref.expr - if isinstance(lvalue, nodes.TupleExpr): - items = lvalue.items - elif isinstance(lvalue, nodes.ListExpr): - items = lvalue.items - else: - items = [lvalue] - for item in items: - if hasattr(item, 'is_def') and Any(item).is_def: - t = self.typemap.get(item) - if t: - self.type(t) - else: - self.log(' !! No inferred type on line %d' % - self.line) - self.record_line(self.line, TYPE_ANY) - super().visit_assignment_stmt(o) - - def visit_name_expr(self, o: NameExpr) -> None: - self.process_node(o) - super().visit_name_expr(o) - - def visit_yield_from_expr(self, o: YieldFromExpr) -> None: - if o.expr: - o.expr.accept(self) - - def visit_call_expr(self, o: CallExpr) -> None: - self.process_node(o) - if o.analyzed: - o.analyzed.accept(self) - else: - o.callee.accept(self) - for a in o.args: - a.accept(self) - - def visit_member_expr(self, o: MemberExpr) -> None: - self.process_node(o) - super().visit_member_expr(o) - - def visit_op_expr(self, o: OpExpr) -> None: - self.process_node(o) - super().visit_op_expr(o) - - def visit_comparison_expr(self, o: ComparisonExpr) -> None: - self.process_node(o) - super().visit_comparison_expr(o) - - def visit_index_expr(self, o: IndexExpr) -> None: - self.process_node(o) - super().visit_index_expr(o) - - def visit_unary_expr(self, o: UnaryExpr) -> None: - self.process_node(o) - super().visit_unary_expr(o) - - def process_node(self, node: Node) -> None: - if self.all_nodes: - typ = self.typemap.get(node) - if typ: - self.line = node.line - self.type(typ) - - def type(self, t: Type) -> None: - if isinstance(t, AnyType): - self.log(' !! Any type around line %d' % self.line) - self.num_any += 1 - self.record_line(self.line, TYPE_ANY) - elif ((not self.all_nodes and is_imprecise(t)) or - (self.all_nodes and is_imprecise2(t))): - self.log(' !! Imprecise type around line %d' % self.line) - self.num_imprecise += 1 - self.record_line(self.line, TYPE_IMPRECISE) - else: - self.num_precise += 1 - self.record_line(self.line, TYPE_PRECISE) - - if isinstance(t, Instance): - if t.args: - if any(is_complex(arg) for arg in t.args): - self.num_complex += 1 - else: - self.num_generic += 1 - else: - self.num_simple += 1 - elif isinstance(t, Void): - self.num_simple += 1 - elif isinstance(t, FunctionLike): - self.num_function += 1 - elif isinstance(t, TupleType): - if any(is_complex(item) for item in t.items): - self.num_complex += 1 - else: - self.num_tuple += 1 - elif isinstance(t, TypeVar): - self.num_typevar += 1 - - def log(self, string: str) -> None: - self.output.append(string) - - def record_line(self, line: int, precision: int) -> None: - self.line_map[line] = max(precision, - self.line_map.get(line, TYPE_PRECISE)) - - -def dump_type_stats(tree: Node, path: str, inferred: bool = False, - typemap: Dict[Node, Type] = None) -> None: - if is_special_module(path): - return - print(path) - visitor = StatisticsVisitor(inferred, typemap) - tree.accept(visitor) - for line in visitor.output: - print(line) - print(' ** precision **') - print(' precise ', visitor.num_precise) - print(' imprecise', visitor.num_imprecise) - print(' any ', visitor.num_any) - print(' ** kinds **') - print(' simple ', visitor.num_simple) - print(' generic ', visitor.num_generic) - print(' function ', visitor.num_function) - print(' tuple ', visitor.num_tuple) - print(' typevar ', visitor.num_typevar) - print(' complex ', visitor.num_complex) - print(' any ', visitor.num_any) - - -def is_special_module(path: str) -> bool: - return os.path.basename(path) in ('abc.py', 'typing.py', 'builtins.py') - - -def is_imprecise(t: Type) -> bool: - return t.accept(HasAnyQuery()) - - -class HasAnyQuery(TypeQuery): - def __init__(self) -> None: - super().__init__(False, ANY_TYPE_STRATEGY) - - def visit_any(self, t: AnyType) -> bool: - return True - - def visit_instance(self, t: Instance) -> bool: - if t.type.fullname() == 'builtins.tuple': - return True - else: - return super().visit_instance(t) - - -def is_imprecise2(t: Type) -> bool: - return t.accept(HasAnyQuery2()) - - -class HasAnyQuery2(HasAnyQuery): - def visit_callable(self, t: Callable) -> bool: - # We don't want to flag references to functions with some Any - # argument types (etc.) since they generally don't mean trouble. - return False - - -def is_generic(t: Type) -> bool: - return isinstance(t, Instance) and bool(cast(Instance, t).args) - - -def is_complex(t: Type) -> bool: - return is_generic(t) or isinstance(t, (FunctionLike, TupleType, - TypeVar)) - - -html_files = [] # type: List[Tuple[str, str, int, int]] - - -def generate_html_report(tree: Node, path: str, type_map: Dict[Node, Type], - output_dir: str) -> None: - if is_special_module(path): - return - visitor = StatisticsVisitor(inferred=True, typemap=type_map, all_nodes=True) - tree.accept(visitor) - target_path = os.path.join(output_dir, 'html', path) - target_path = re.sub(r'\.py$', '.html', target_path) - ensure_dir_exists(os.path.dirname(target_path)) - output = [] # type: List[str] - append = output.append - append('''\ - - - - - -
''')
-    num_imprecise_lines = 0
-    num_lines = 0
-    with open(path) as input_file:
-        for i, line in enumerate(input_file):
-            lineno = i + 1
-            status = visitor.line_map.get(lineno, TYPE_PRECISE)
-            style_map = { TYPE_PRECISE: 'white',
-                          TYPE_IMPRECISE: 'yellow',
-                          TYPE_ANY: 'red' }
-            style = style_map[status]
-            append('%4d   ' % lineno +
-                   '%s' % (style,
-                                                   cgi.escape(line)))
-            if status != TYPE_PRECISE:
-                num_imprecise_lines += 1
-            if line.strip():
-                num_lines += 1
-    append('
') - append('') - with open(target_path, 'w') as output_file: - output_file.writelines(output) - target_path = target_path[len(output_dir) + 1:] - html_files.append((path, target_path, num_lines, num_imprecise_lines)) - - -def generate_html_index(output_dir: str) -> None: - path = os.path.join(output_dir, 'index.html') - output = [] # type: List[str] - append = output.append - append('''\ - - - - -''') - append('

Mypy Type Check Coverage Report

\n') - append('\n') - for source_path, target_path, num_lines, num_imprecise in sorted(html_files): - if num_lines == 0: - continue - source_path = os.path.normpath(source_path) - # TODO: Windows paths. - if (source_path.startswith('stubs/') or - '/stubs/' in source_path): - continue - percent = 100.0 * num_imprecise / num_lines - style = '' - if percent >= 20: - style = 'class="red"' - elif percent >= 5: - style = 'class="yellow"' - append('
%s%.1f%% imprecise%d LOC' % ( - style, target_path, source_path, percent, num_lines)) - append('
') - append('') - with open(path, 'w') as file: - file.writelines(output) - print('Generated HTML report: %s' % os.path.abspath(path)) - - -def ensure_dir_exists(dir: str) -> None: - if not os.path.exists(dir): - os.makedirs(dir) diff --git a/mypy/strconv.py.orig b/mypy/strconv.py.orig deleted file mode 100644 index 723c54842c62..000000000000 --- a/mypy/strconv.py.orig +++ /dev/null @@ -1,444 +0,0 @@ -"""Conversion of parse tree nodes to strings.""" - -import re -import os - -import typing - -from mypy.util import dump_tagged, short_type -import mypy.nodes -from mypy.visitor import NodeVisitor - - -class StrConv(NodeVisitor[str]): - """Visitor for converting a Node to a human-readable string. - - For example, an MypyFile node from program '1' is converted into - something like this: - - MypyFile:1( - fnam - ExpressionStmt:1( - IntExpr(1))) - """ - def dump(self, nodes, obj): - """Convert a list of items to a multiline pretty-printed string. - - The tag is produced from the type name of obj and its line - number. See mypy.util.dump_tagged for a description of the nodes - argument. - """ - return dump_tagged(nodes, short_type(obj) + ':' + str(obj.line)) - - def func_helper(self, o): - """Return a list in a format suitable for dump() that represents the - arguments and the body of a function. The caller can then decorate the - array with information specific to methods, global functions or - anonymous functions. - """ - args = [] - init = [] - extra = [] - for i, kind in enumerate(o.arg_kinds): - if kind == mypy.nodes.ARG_POS: - args.append(o.args[i]) - elif kind in (mypy.nodes.ARG_OPT, mypy.nodes.ARG_NAMED): - args.append(o.args[i]) - init.append(o.init[i]) - elif kind == mypy.nodes.ARG_STAR: - extra.append(('VarArg', [o.args[i]])) - elif kind == mypy.nodes.ARG_STAR2: - extra.append(('DictVarArg', [o.args[i]])) - a = [] - if args: - a.append(('Args', args)) - if o.type: - a.append(o.type) - if init: - a.append(('Init', init)) - if o.is_generator: - a.append('Generator') - a.extend(extra) - a.append(o.body) - return a - - # Top-level structures - - def visit_mypy_file(self, o): - # Skip implicit definitions. - defs = o.defs - while (defs and isinstance(defs[0], mypy.nodes.VarDef) and - not defs[0].repr): - defs = defs[1:] - a = [defs] - if o.is_bom: - a.insert(0, 'BOM') - # Omit path to special file with name "main". This is used to simplify - # test case descriptions; the file "main" is used by default in many - # test cases. - if o.path is not None and o.path != 'main': - # Insert path. Normalize directory separators to / to unify test - # case# output in all platforms. - a.insert(0, o.path.replace(os.sep, '/')) - return self.dump(a, o) - - def visit_import(self, o): - a = [] - for id, as_id in o.ids: - a.append('{} : {}'.format(id, as_id)) - return 'Import:{}({})'.format(o.line, ', '.join(a)) - - def visit_import_from(self, o): - a = [] - for name, as_name in o.names: - a.append('{} : {}'.format(name, as_name)) - return 'ImportFrom:{}({}, [{}])'.format(o.line, o.id, ', '.join(a)) - - def visit_import_all(self, o): - return 'ImportAll:{}({})'.format(o.line, o.id) - - # Definitions - - def visit_func_def(self, o): - a = self.func_helper(o) - a.insert(0, o.name()) - if mypy.nodes.ARG_NAMED in o.arg_kinds: - a.insert(1, 'MaxPos({})'.format(o.max_pos)) - if o.is_abstract: - a.insert(-1, 'Abstract') - if o.is_static: - a.insert(-1, 'Static') - if o.is_class: - a.insert(-1, 'Class') - if o.is_property: - a.insert(-1, 'Property') - return self.dump(a, o) - - def visit_overloaded_func_def(self, o): - a = o.items[:] - if o.type: - a.insert(0, o.type) - return self.dump(a, o) - - def visit_class_def(self, o): - a = [o.name, o.defs.body] - # Display base types unless they are implicitly just builtins.object - # (in this case there is no representation). - if len(o.base_types) > 1 or (len(o.base_types) == 1 - and o.base_types[0].repr): - a.insert(1, ('BaseType', o.base_types)) - if o.type_vars: - a.insert(1, ('TypeVars', o.type_vars)) - if o.metaclass: - a.insert(1, 'Metaclass({})'.format(o.metaclass)) - if o.decorators: - a.insert(1, ('Decorators', o.decorators)) - if o.is_builtinclass: - a.insert(1, 'Builtinclass') - if o.info and o.info.ducktype: - a.insert(1, 'Ducktype({})'.format(o.info.ducktype)) - if o.info and o.info.disjoint_classes: - a.insert(1, ('Disjointclasses', [info.fullname() for - info in o.info.disjoint_classes])) - return self.dump(a, o) - - def visit_var_def(self, o): - a = [] - for n in o.items: - a.append('Var({})'.format(n.name())) - a.append('Type({})'.format(n.type)) - if o.init: - a.append(o.init) - return self.dump(a, o) - - def visit_var(self, o): - l = '' - # Add :nil line number tag if no line number is specified to remain - # compatible with old test case descriptions that assume this. - if o.line < 0: - l = ':nil' - return 'Var' + l + '(' + o.name() + ')' - - def visit_global_decl(self, o): - return self.dump([o.names], o) - - def visit_decorator(self, o): - return self.dump([o.var, o.decorators, o.func], o) - - def visit_annotation(self, o): - return 'Type:{}({})'.format(o.line, o.type) - - # Statements - - def visit_block(self, o): - return self.dump(o.body, o) - - def visit_expression_stmt(self, o): - return self.dump([o.expr], o) - - def visit_assignment_stmt(self, o): - if len(o.lvalues) > 1: - a = [('Lvalues', o.lvalues)] - else: - a = [o.lvalues[0]] - a.append(o.rvalue) - if o.type: - a.append(o.type) - return self.dump(a, o) - - def visit_operator_assignment_stmt(self, o): - return self.dump([o.op, o.lvalue, o.rvalue], o) - - def visit_while_stmt(self, o): - a = [o.expr, o.body] - if o.else_body: - a.append(('Else', o.else_body.body)) - return self.dump(a, o) - - def visit_for_stmt(self, o): - a = [o.index] - if o.types != [None] * len(o.types): - a += o.types - a.extend([o.expr, o.body]) - if o.else_body: - a.append(('Else', o.else_body.body)) - return self.dump(a, o) - - def visit_return_stmt(self, o): - return self.dump([o.expr], o) - - def visit_if_stmt(self, o): - a = [] - for i in range(len(o.expr)): - a.append(('If', [o.expr[i]])) - a.append(('Then', o.body[i].body)) - - if not o.else_body: - return self.dump(a, o) - else: - return self.dump([a, ('Else', o.else_body.body)], o) - - def visit_break_stmt(self, o): - return self.dump([], o) - - def visit_continue_stmt(self, o): - return self.dump([], o) - - def visit_pass_stmt(self, o): - return self.dump([], o) - - def visit_raise_stmt(self, o): - return self.dump([o.expr, o.from_expr], o) - - def visit_assert_stmt(self, o): - return self.dump([o.expr], o) - - def visit_yield_stmt(self, o): - return self.dump([o.expr], o) - -<<<<<<< HEAD - def visit_yield_from_stmt(self, o): - return self.dump([o.expr], o) - -======= ->>>>>>> master - def visit_del_stmt(self, o): - return self.dump([o.expr], o) - - def visit_try_stmt(self, o): - a = [o.body] - - for i in range(len(o.vars)): - a.append(o.types[i]) - if o.vars[i]: - a.append(o.vars[i]) - a.append(o.handlers[i]) - - if o.else_body: - a.append(('Else', o.else_body.body)) - if o.finally_body: - a.append(('Finally', o.finally_body.body)) - - return self.dump(a, o) - - def visit_with_stmt(self, o): - a = [] - for i in range(len(o.expr)): - a.append(('Expr', [o.expr[i]])) - if o.name[i]: - a.append(('Name', [o.name[i]])) - return self.dump(a + [o.body], o) - - def visit_print_stmt(self, o): - a = o.args[:] - if o.newline: - a.append('Newline') - return self.dump(a, o) - - # Expressions - - # Simple expressions - - def visit_int_expr(self, o): - return 'IntExpr({})'.format(o.value) - - def visit_str_expr(self, o): - return 'StrExpr({})'.format(self.str_repr(o.value)) - - def visit_bytes_expr(self, o): - return 'BytesExpr({})'.format(self.str_repr(o.value)) - - def visit_unicode_expr(self, o): - return 'UnicodeExpr({})'.format(self.str_repr(o.value)) - - def str_repr(self, s): - s = re.sub(r'\\u[0-9a-fA-F]{4}', lambda m: '\\' + m.group(0), s) - return re.sub('[^\\x20-\\x7e]', - lambda m: r'\u%.4x' % ord(m.group(0)), s) - - def visit_float_expr(self, o): - return 'FloatExpr({})'.format(o.value) - - def visit_paren_expr(self, o): - return self.dump([o.expr], o) - - def visit_name_expr(self, o): - return (short_type(o) + '(' + self.pretty_name(o.name, o.kind, - o.fullname, o.is_def) - + ')') - - def pretty_name(self, name, kind, fullname, is_def): - n = name - if is_def: - n += '*' - if kind == mypy.nodes.GDEF or (fullname != name and - fullname is not None): - # Append fully qualified name for global references. - n += ' [{}]'.format(fullname) - elif kind == mypy.nodes.LDEF: - # Add tag to signify a local reference. - n += ' [l]' - elif kind == mypy.nodes.MDEF: - # Add tag to signify a member reference. - n += ' [m]' - return n - - def visit_member_expr(self, o): - return self.dump([o.expr, self.pretty_name(o.name, o.kind, o.fullname, - o.is_def)], o) - -<<<<<<< HEAD - def visit_yield_from_expr(self, o): - if o.expr: - return self.dump([o.expr.accept(self)], o) - else: - return self.dump([], o) - -======= ->>>>>>> master - def visit_call_expr(self, o): - if o.analyzed: - return o.analyzed.accept(self) - args = [] - extra = [] - for i, kind in enumerate(o.arg_kinds): - if kind in [mypy.nodes.ARG_POS, mypy.nodes.ARG_STAR]: - args.append(o.args[i]) - if kind == mypy.nodes.ARG_STAR: - extra.append('VarArg') - elif kind == mypy.nodes.ARG_NAMED: - extra.append(('KwArgs', [o.arg_names[i], o.args[i]])) - elif kind == mypy.nodes.ARG_STAR2: - extra.append(('DictVarArg', [o.args[i]])) - else: - raise RuntimeError('unknown kind %d' % kind) - - return self.dump([o.callee, ('Args', args)] + extra, o) - - def visit_op_expr(self, o): - return self.dump([o.op, o.left, o.right], o) - -<<<<<<< HEAD -======= - def visit_comparison_expr(self, o): - return self.dump([o.operators, o.operands], o) - ->>>>>>> master - def visit_cast_expr(self, o): - return self.dump([o.expr, o.type], o) - - def visit_unary_expr(self, o): - return self.dump([o.op, o.expr], o) - - def visit_list_expr(self, o): - return self.dump(o.items, o) - - def visit_dict_expr(self, o): - return self.dump([[k, v] for k, v in o.items], o) - - def visit_set_expr(self, o): - return self.dump(o.items, o) - - def visit_tuple_expr(self, o): - return self.dump(o.items, o) - - def visit_index_expr(self, o): - if o.analyzed: - return o.analyzed.accept(self) - return self.dump([o.base, o.index], o) - - def visit_super_expr(self, o): - return self.dump([o.name], o) - - def visit_undefined_expr(self, o): - return 'UndefinedExpr:{}({})'.format(o.line, o.type) - - def visit_type_application(self, o): - return self.dump([o.expr, ('Types', o.types)], o) - - def visit_type_var_expr(self, o): - if o.values: - return self.dump([('Values', o.values)], o) - else: - return 'TypeVarExpr:{}()'.format(o.line) - - def visit_ducktype_expr(self, o): - return 'DucktypeExpr:{}({})'.format(o.line, o.type) - - def visit_disjointclass_expr(self, o): - return 'DisjointclassExpr:{}({})'.format(o.line, o.cls.fullname) - - def visit_func_expr(self, o): - a = self.func_helper(o) - return self.dump(a, o) - - def visit_generator_expr(self, o): - # FIX types - condlists = o.condlists if any(o.condlists) else None - return self.dump([o.left_expr, o.indices, o.sequences, condlists], o) - - def visit_list_comprehension(self, o): - return self.dump([o.generator], o) - - def visit_conditional_expr(self, o): - return self.dump([('Condition', [o.cond]), o.if_expr, o.else_expr], o) - - def visit_slice_expr(self, o): - a = [o.begin_index, o.end_index, o.stride] - if not a[0]: - a[0] = '' - if not a[1]: - a[1] = '' - return self.dump(a, o) - - def visit_coerce_expr(self, o): - return self.dump([o.expr, ('Types', [o.target_type, o.source_type])], - o) - - def visit_type_expr(self, o): - return self.dump([str(o.type)], o) - - def visit_filter_node(self, o): - # These are for convenience. These node types are not defined in the - # parser module. - pass diff --git a/mypy/transform.py.orig b/mypy/transform.py.orig deleted file mode 100644 index fb45c6160488..000000000000 --- a/mypy/transform.py.orig +++ /dev/null @@ -1,449 +0,0 @@ -"""Transform program to include explicit coercions and wrappers. - -The transform performs these main changes: - - - add explicit coercions to/from any (or more generally, between different - levels of typing precision) - - add wrapper methods and functions for calling statically typed functions - in dynamically typed code - - add wrapper methods for overrides with a different signature - - add generic wrapper classes for coercions between generic types (e.g. - from List[Any] to List[str]) -""" - -from typing import Undefined, Dict, List, Tuple, cast - -from mypy.nodes import ( - Node, MypyFile, TypeInfo, ClassDef, VarDef, FuncDef, Var, - ReturnStmt, AssignmentStmt, IfStmt, WhileStmt, MemberExpr, NameExpr, MDEF, -<<<<<<< HEAD - CallExpr, SuperExpr, TypeExpr, CastExpr, OpExpr, CoerceExpr, GDEF, - SymbolTableNode, IndexExpr, function_type, YieldFromExpr -======= - CallExpr, SuperExpr, TypeExpr, CastExpr, OpExpr, CoerceExpr, ComparisonExpr, - GDEF, SymbolTableNode, IndexExpr, function_type ->>>>>>> master -) -from mypy.traverser import TraverserVisitor -from mypy.types import Type, AnyType, Callable, TypeVarDef, Instance -from mypy.lex import Token -from mypy.transformtype import TypeTransformer -from mypy.transutil import ( - prepend_arg_type, is_simple_override, tvar_arg_name, dynamic_suffix, - add_arg_type_after_self -) -from mypy.coerce import coerce -from mypy.rttypevars import translate_runtime_type_vars_in_context - - -class DyncheckTransformVisitor(TraverserVisitor): - """Translate a parse tree to use runtime representation of generics. - - Translate generic type variables to ordinary variables and all make - all non-trivial coercions explicit. Also generate generic wrapper classes - for coercions between generic types and wrapper methods for overrides - and for more efficient access from dynamically typed code. - - This visitor modifies the parse tree in-place. - """ - - type_map = Undefined(Dict[Node, Type]) - modules = Undefined(Dict[str, MypyFile]) - is_pretty = False - type_tf = Undefined(TypeTransformer) - - # Stack of function return types - return_types = Undefined(List[Type]) - # Stack of dynamically typed function flags - dynamic_funcs = Undefined(List[bool]) - - # Associate a Node with its start end line numbers. - line_map = Undefined(Dict[Node, Tuple[int, int]]) - - is_java = False - - # The current type context (or None if not within a type). -<<<<<<< HEAD - _type_context = None # type: TypeInfo -======= - _type_context = None # type: TypeInfo ->>>>>>> master - - def type_context(self) -> TypeInfo: - return self._type_context - - def __init__(self, type_map: Dict[Node, Type], - modules: Dict[str, MypyFile], is_pretty: bool, - is_java: bool = False) -> None: - self.type_tf = TypeTransformer(self) - self.return_types = [] - self.dynamic_funcs = [False] - self.line_map = {} - self.type_map = type_map - self.modules = modules - self.is_pretty = is_pretty - self.is_java = is_java - - # - # Transform definitions - # - - def visit_mypy_file(self, o: MypyFile) -> None: - """Transform an file.""" - res = [] # type: List[Node] - for d in o.defs: - if isinstance(d, ClassDef): - self._type_context = d.info - res.extend(self.type_tf.transform_class_def(d)) - self._type_context = None - else: - d.accept(self) - res.append(d) - o.defs = res - - def visit_var_def(self, o: VarDef) -> None: - """Transform a variable definition in-place. - - This is not suitable for member variable definitions; they are - transformed in TypeTransformer. - """ - super().visit_var_def(o) - - if o.init is not None: - if o.items[0].type: - t = o.items[0].type - else: - t = AnyType() - o.init = self.coerce(o.init, t, self.get_type(o.init), - self.type_context()) - - def visit_func_def(self, fdef: FuncDef) -> None: - """Transform a global function definition in-place. - - This is not suitable for methods; they are transformed in - FuncTransformer. - """ - self.prepend_generic_function_tvar_args(fdef) - self.transform_function_body(fdef) - - def transform_function_body(self, fdef: FuncDef) -> None: - """Transform the body of a function.""" - self.dynamic_funcs.append(fdef.is_implicit) - # FIX overloads - self.return_types.append(cast(Callable, function_type(fdef)).ret_type) - super().visit_func_def(fdef) - self.return_types.pop() - self.dynamic_funcs.pop() - - def prepend_generic_function_tvar_args(self, fdef: FuncDef) -> None: - """Add implicit function type variable arguments if fdef is generic.""" - sig = cast(Callable, function_type(fdef)) - tvars = sig.variables - if not fdef.type: - fdef.type = sig - -<<<<<<< HEAD - tv = [] # type: List[Var] -======= - tv = [] # type: List[Var] ->>>>>>> master - ntvars = len(tvars) - if fdef.is_method(): - # For methods, add type variable arguments after the self arg. - for n in range(ntvars): - tv.append(Var(tvar_arg_name(-1 - n))) - fdef.type = add_arg_type_after_self(cast(Callable, fdef.type), - AnyType()) - fdef.args = [fdef.args[0]] + tv + fdef.args[1:] - else: - # For ordinary functions, prepend type variable arguments. - for n in range(ntvars): - tv.append(Var(tvar_arg_name(-1 - n))) - fdef.type = prepend_arg_type(cast(Callable, fdef.type), - AnyType()) - fdef.args = tv + fdef.args - fdef.init = List[AssignmentStmt]([None]) * ntvars + fdef.init - - # - # Transform statements - # - - def transform_block(self, block: List[Node]) -> None: - for stmt in block: - stmt.accept(self) - - def visit_return_stmt(self, s: ReturnStmt) -> None: - super().visit_return_stmt(s) - s.expr = self.coerce(s.expr, self.return_types[-1], - self.get_type(s.expr), self.type_context()) - - def visit_assignment_stmt(self, s: AssignmentStmt) -> None: - super().visit_assignment_stmt(s) - if isinstance(s.lvalues[0], IndexExpr): - index = cast(IndexExpr, s.lvalues[0]) - method_type = index.method_type - if self.dynamic_funcs[-1] or isinstance(method_type, AnyType): - lvalue_type = AnyType() # type: Type - else: - method_callable = cast(Callable, method_type) - # TODO arg_types[1] may not be reliable - lvalue_type = method_callable.arg_types[1] - else: - lvalue_type = self.get_type(s.lvalues[0]) - - s.rvalue = self.coerce2(s.rvalue, lvalue_type, self.get_type(s.rvalue), - self.type_context()) - - # - # Transform expressions - # - - def visit_member_expr(self, e: MemberExpr) -> None: - super().visit_member_expr(e) - - typ = self.get_type(e.expr) - - if self.dynamic_funcs[-1]: - e.expr = self.coerce_to_dynamic(e.expr, typ, self.type_context()) - typ = AnyType() - - if isinstance(typ, Instance): - # Reference to a statically-typed method variant with the suffix - # derived from the base object type. - suffix = self.get_member_reference_suffix(e.name, typ.type) - else: - # Reference to a dynamically-typed method variant. - suffix = self.dynamic_suffix() - e.name += suffix - - def visit_name_expr(self, e: NameExpr) -> None: - super().visit_name_expr(e) - if e.kind == MDEF and isinstance(e.node, FuncDef): - # Translate reference to a method. - suffix = self.get_member_reference_suffix(e.name, e.info) - e.name += suffix - # Update representation to have the correct name. - prefix = e.repr.components[0].pre - - def get_member_reference_suffix(self, name: str, info: TypeInfo) -> str: - if info.has_method(name): - fdef = cast(FuncDef, info.get_method(name)) - return self.type_suffix(fdef) - else: - return '' - -<<<<<<< HEAD - def visit_yield_from_expr(self, e: YieldFromExpr) -> None: - if e.expr: - e.expr.accept(self) - -======= ->>>>>>> master - def visit_call_expr(self, e: CallExpr) -> None: - if e.analyzed: - # This is not an ordinary call. - e.analyzed.accept(self) - return - - super().visit_call_expr(e) - - # Do no coercions if this is a call to debugging facilities. - if self.is_debugging_call_expr(e): - return - - # Get the type of the callable (type variables in the context of the - # enclosing class). - ctype = self.get_type(e.callee) - - # Add coercions for the arguments. - for i in range(len(e.args)): - arg_type = AnyType() # type: Type - if isinstance(ctype, Callable): - arg_type = ctype.arg_types[i] - e.args[i] = self.coerce2(e.args[i], arg_type, - self.get_type(e.args[i]), - self.type_context()) - - # Prepend type argument values to the call as needed. - if isinstance(ctype, Callable) and cast(Callable, - ctype).bound_vars != []: - bound_vars = (cast(Callable, ctype)).bound_vars - - # If this is a constructor call (target is the constructor - # of a generic type or superclass __init__), include also - # instance type variables. Otherwise filter them away -- - # include only generic function type variables. - if (not (cast(Callable, ctype)).is_type_obj() and - not (isinstance(e.callee, SuperExpr) and - (cast(SuperExpr, e.callee)).name == '__init__')): - # Filter instance type variables; only include function tvars. - bound_vars = [(id, t) for id, t in bound_vars if id < 0] - -<<<<<<< HEAD - args = [] # type: List[Node] -======= - args = [] # type: List[Node] ->>>>>>> master - for i in range(len(bound_vars)): - # Compile type variables to runtime type variable expressions. - tv = translate_runtime_type_vars_in_context( - bound_vars[i][1], - self.type_context(), - self.is_java) - args.append(TypeExpr(tv)) - e.args = args + e.args - - def is_debugging_call_expr(self, e): - return isinstance(e.callee, NameExpr) and e.callee.name in ['__print'] - - def visit_cast_expr(self, e: CastExpr) -> None: - super().visit_cast_expr(e) - if isinstance(self.get_type(e), AnyType): - e.expr = self.coerce(e.expr, AnyType(), self.get_type(e.expr), - self.type_context()) - - def visit_op_expr(self, e: OpExpr) -> None: - super().visit_op_expr(e) - if e.op in ['and', 'or']: - target = self.get_type(e) - e.left = self.coerce(e.left, target, - self.get_type(e.left), self.type_context()) - e.right = self.coerce(e.right, target, - self.get_type(e.right), self.type_context()) - else: - method_type = e.method_type - if self.dynamic_funcs[-1] or isinstance(method_type, AnyType): - e.left = self.coerce_to_dynamic(e.left, self.get_type(e.left), - self.type_context()) - e.right = self.coerce(e.right, AnyType(), - self.get_type(e.right), - self.type_context()) - elif method_type: - method_callable = cast(Callable, method_type) - operand = e.right - # TODO arg_types[0] may not be reliable - operand = self.coerce(operand, method_callable.arg_types[0], - self.get_type(operand), - self.type_context()) - e.right = operand - - def visit_comparison_expr(self, e: ComparisonExpr) -> None: - super().visit_comparison_expr(e) - # Dummy - - - def visit_index_expr(self, e: IndexExpr) -> None: - if e.analyzed: - # Actually a type application, not indexing. - e.analyzed.accept(self) - return - super().visit_index_expr(e) - method_type = e.method_type - if self.dynamic_funcs[-1] or isinstance(method_type, AnyType): - e.base = self.coerce_to_dynamic(e.base, self.get_type(e.base), - self.type_context()) - e.index = self.coerce_to_dynamic(e.index, self.get_type(e.index), - self.type_context()) - else: - method_callable = cast(Callable, method_type) - e.index = self.coerce(e.index, method_callable.arg_types[0], - self.get_type(e.index), self.type_context()) - - # - # Helpers - # - - def get_type(self, node: Node) -> Type: - """Return the type of a node as reported by the type checker.""" - return self.type_map[node] - - def set_type(self, node: Node, typ: Type) -> None: - self.type_map[node] = typ - - def type_suffix(self, fdef: FuncDef, info: TypeInfo = None) -> str: - """Return the suffix for a mangled name. - - This includes an optional type suffix for a function or method. - """ - if not info: - info = fdef.info - # If info is None, we have a global function => no suffix. Also if the - # method is not an override, we need no suffix. - if not info or (not info.bases or - not info.bases[0].type.has_method(fdef.name())): - return '' - elif is_simple_override(fdef, info): - return self.type_suffix(fdef, info.bases[0].type) - elif self.is_pretty: - return '`' + info.name() - else: - return '__' + info.name() - - def dynamic_suffix(self) -> str: - """Return the suffix of the dynamic wrapper of a method or class.""" - return dynamic_suffix(self.is_pretty) - - def wrapper_class_suffix(self) -> str: - """Return the suffix of a generic wrapper class.""" - return '**' - - def coerce(self, expr: Node, target_type: Type, source_type: Type, - context: TypeInfo, is_wrapper_class: bool = False) -> Node: - return coerce(expr, target_type, source_type, context, - is_wrapper_class, self.is_java) - - def coerce2(self, expr: Node, target_type: Type, source_type: Type, - context: TypeInfo, is_wrapper_class: bool = False) -> Node: - """Create coercion from source_type to target_type. - - Also include middle coercion do 'Any' if transforming a dynamically - typed function. - """ - if self.dynamic_funcs[-1]: - return self.coerce(self.coerce(expr, AnyType(), source_type, - context, is_wrapper_class), - target_type, AnyType(), context, - is_wrapper_class) - else: - return self.coerce(expr, target_type, source_type, context, - is_wrapper_class) - - def coerce_to_dynamic(self, expr: Node, source_type: Type, - context: TypeInfo) -> Node: - if isinstance(source_type, AnyType): - return expr - source_type = translate_runtime_type_vars_in_context( - source_type, context, self.is_java) - return CoerceExpr(expr, AnyType(), source_type, False) - - def add_line_mapping(self, orig_node: Node, new_node: Node) -> None: - """Add a line mapping for a wrapper. - - The node new_node has logically the same line numbers as - orig_node. The nodes should be FuncDef/ClassDef nodes. - """ - if orig_node.repr: - start_line = orig_node.line - end_line = start_line # TODO use real end line - self.line_map[new_node] = (start_line, end_line) - - def named_type(self, name: str) -> Instance: - # TODO combine with checker - # Assume that the name refers to a type. - sym = self.lookup(name, GDEF) - return Instance(cast(TypeInfo, sym.node), []) - - def lookup(self, fullname: str, kind: int) -> SymbolTableNode: - # TODO combine with checker - # TODO remove kind argument - parts = fullname.split('.') - n = self.modules[parts[0]] - for i in range(1, len(parts) - 1): - n = cast(MypyFile, ((n.names.get(parts[i], None).node))) - return n.names[parts[-1]] - - def object_member_name(self) -> str: - if self.is_java: - return '__o_{}'.format(self.type_context().name()) - else: - return '__o' diff --git a/mypy/traverser.py.orig b/mypy/traverser.py.orig deleted file mode 100644 index 4596385411cb..000000000000 --- a/mypy/traverser.py.orig +++ /dev/null @@ -1,243 +0,0 @@ -"""Generic node traverser visitor""" - -from typing import typevar, Generic - -from mypy.visitor import NodeVisitor -from mypy.nodes import ( - Block, MypyFile, VarDef, FuncItem, CallExpr, ClassDef, Decorator, FuncDef, - ExpressionStmt, AssignmentStmt, OperatorAssignmentStmt, WhileStmt, - ForStmt, ReturnStmt, AssertStmt, YieldStmt, DelStmt, IfStmt, RaiseStmt, - TryStmt, WithStmt, ParenExpr, MemberExpr, OpExpr, SliceExpr, CastExpr, - UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr, - GeneratorExpr, ListComprehension, ConditionalExpr, TypeApplication, -<<<<<<< HEAD - FuncExpr, OverloadedFuncDef, YieldFromStmt, YieldFromExpr -======= - FuncExpr, ComparisonExpr, OverloadedFuncDef ->>>>>>> master -) - - -T = typevar('T') - - -class TraverserVisitor(NodeVisitor[T], Generic[T]): - """A parse tree visitor that traverses the parse tree during visiting. - - It does not peform any actions outside the travelsal. Subclasses - should override visit methods to perform actions during - travelsal. Calling the superclass method allows reusing the - travelsal implementation. - """ - - # Visit methods - - def visit_mypy_file(self, o: MypyFile) -> T: - for d in o.defs: - d.accept(self) - - def visit_block(self, block: Block) -> T: - for s in block.body: - s.accept(self) - - def visit_func(self, o: FuncItem) -> T: - for i in o.init: - if i is not None: - i.accept(self) - for v in o.args: - self.visit_var(v) - o.body.accept(self) - - def visit_func_def(self, o: FuncDef) -> T: - self.visit_func(o) - - def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> T: - for item in o.items: - item.accept(self) - - def visit_class_def(self, o: ClassDef) -> T: - o.defs.accept(self) - - def visit_decorator(self, o: Decorator) -> T: - o.func.accept(self) - o.var.accept(self) - for decorator in o.decorators: - decorator.accept(self) - - def visit_var_def(self, o: VarDef) -> T: - if o.init is not None: - o.init.accept(self) - for v in o.items: - self.visit_var(v) - - def visit_expression_stmt(self, o: ExpressionStmt) -> T: - o.expr.accept(self) - - def visit_assignment_stmt(self, o: AssignmentStmt) -> T: - o.rvalue.accept(self) - for l in o.lvalues: - l.accept(self) - - def visit_operator_assignment_stmt(self, o: OperatorAssignmentStmt) -> T: - o.rvalue.accept(self) - o.lvalue.accept(self) - - def visit_while_stmt(self, o: WhileStmt) -> T: - o.expr.accept(self) - o.body.accept(self) - if o.else_body: - o.else_body.accept(self) - - def visit_for_stmt(self, o: ForStmt) -> T: - for ind in o.index: - ind.accept(self) - o.expr.accept(self) - o.body.accept(self) - if o.else_body: - o.else_body.accept(self) - - def visit_return_stmt(self, o: ReturnStmt) -> T: - if o.expr is not None: - o.expr.accept(self) - - def visit_assert_stmt(self, o: AssertStmt) -> T: - if o.expr is not None: - o.expr.accept(self) - - def visit_yield_stmt(self, o: YieldStmt) -> T: - if o.expr is not None: - o.expr.accept(self) - -<<<<<<< HEAD - def visit_yield_from_stmt(self, o: YieldFromStmt) -> T: - if o.expr is not None: - o.expr.accept(self) - -======= ->>>>>>> master - def visit_del_stmt(self, o: DelStmt) -> T: - if o.expr is not None: - o.expr.accept(self) - - def visit_if_stmt(self, o: IfStmt) -> T: - for e in o.expr: - e.accept(self) - for b in o.body: - b.accept(self) - if o.else_body: - o.else_body.accept(self) - - def visit_raise_stmt(self, o: RaiseStmt) -> T: - if o.expr is not None: - o.expr.accept(self) - if o.from_expr is not None: - o.from_expr.accept(self) - - def visit_try_stmt(self, o: TryStmt) -> T: - o.body.accept(self) - for i in range(len(o.types)): - if o.types[i]: - o.types[i].accept(self) - o.handlers[i].accept(self) - if o.else_body is not None: - o.else_body.accept(self) - if o.finally_body is not None: - o.finally_body.accept(self) - - def visit_with_stmt(self, o: WithStmt) -> T: - for i in range(len(o.expr)): - o.expr[i].accept(self) - if o.name[i] is not None: - o.name[i].accept(self) - o.body.accept(self) - - def visit_paren_expr(self, o: ParenExpr) -> T: - o.expr.accept(self) - - def visit_member_expr(self, o: MemberExpr) -> T: - o.expr.accept(self) - -<<<<<<< HEAD - def visit_yield_from_expr(self, o: YieldFromExpr) -> T: - o.expr.accept(self) - -======= ->>>>>>> master - def visit_call_expr(self, o: CallExpr) -> T: - for a in o.args: - a.accept(self) - o.callee.accept(self) - if o.analyzed: - o.analyzed.accept(self) - - def visit_op_expr(self, o: OpExpr) -> T: - o.left.accept(self) - o.right.accept(self) - -<<<<<<< HEAD -======= - def visit_comparison_expr(self, o: ComparisonExpr) -> T: - for operand in o.operands: - operand.accept(self) - ->>>>>>> master - def visit_slice_expr(self, o: SliceExpr) -> T: - if o.begin_index is not None: - o.begin_index.accept(self) - if o.end_index is not None: - o.end_index.accept(self) - if o.stride is not None: - o.stride.accept(self) - - def visit_cast_expr(self, o: CastExpr) -> T: - o.expr.accept(self) - - def visit_unary_expr(self, o: UnaryExpr) -> T: - o.expr.accept(self) - - def visit_list_expr(self, o: ListExpr) -> T: - for item in o.items: - item.accept(self) - - def visit_tuple_expr(self, o: TupleExpr) -> T: - for item in o.items: - item.accept(self) - - def visit_dict_expr(self, o: DictExpr) -> T: - for k, v in o.items: - k.accept(self) - v.accept(self) - - def visit_set_expr(self, o: SetExpr) -> T: - for item in o.items: - item.accept(self) - - def visit_index_expr(self, o: IndexExpr) -> T: - o.base.accept(self) - o.index.accept(self) - if o.analyzed: - o.analyzed.accept(self) - - def visit_generator_expr(self, o: GeneratorExpr) -> T: - for index, sequence, conditions in zip(o.indices, o.sequences, - o.condlists): - sequence.accept(self) - for ind in index: - ind.accept(self) - for cond in conditions: - cond.accept(self) - o.left_expr.accept(self) - - def visit_list_comprehension(self, o: ListComprehension) -> T: - o.generator.accept(self) - - def visit_conditional_expr(self, o: ConditionalExpr) -> T: - o.cond.accept(self) - o.if_expr.accept(self) - o.else_expr.accept(self) - - def visit_type_application(self, o: TypeApplication) -> T: - o.expr.accept(self) - - def visit_func_expr(self, o: FuncExpr) -> T: - self.visit_func(o) diff --git a/mypy/treetransform.py.orig b/mypy/treetransform.py.orig deleted file mode 100644 index d435f0909a29..000000000000 --- a/mypy/treetransform.py.orig +++ /dev/null @@ -1,499 +0,0 @@ -"""Base visitor that implements an identity AST transform. - -Subclass TransformVisitor to perform non-trivial transformations. -""" - -from typing import List, Dict - -from mypy.nodes import ( - MypyFile, Import, Node, ImportAll, ImportFrom, FuncItem, FuncDef, - OverloadedFuncDef, ClassDef, Decorator, Block, Var, VarDef, - OperatorAssignmentStmt, ExpressionStmt, AssignmentStmt, ReturnStmt, - RaiseStmt, AssertStmt, YieldStmt, DelStmt, BreakStmt, ContinueStmt, - PassStmt, GlobalDecl, WhileStmt, ForStmt, IfStmt, TryStmt, WithStmt, - CastExpr, ParenExpr, TupleExpr, GeneratorExpr, ListComprehension, ListExpr, - ConditionalExpr, DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, - UnicodeExpr, FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, - SliceExpr, OpExpr, UnaryExpr, FuncExpr, TypeApplication, PrintStmt, - SymbolTable, RefExpr, UndefinedExpr, TypeVarExpr, DucktypeExpr, -<<<<<<< HEAD - DisjointclassExpr, CoerceExpr, TypeExpr, JavaCast, TempNode, YieldFromStmt, - YieldFromExpr -======= - DisjointclassExpr, CoerceExpr, TypeExpr, ComparisonExpr, - JavaCast, TempNode ->>>>>>> master -) -from mypy.types import Type -from mypy.visitor import NodeVisitor - - -class TransformVisitor(NodeVisitor[Node]): - """Transform a semantically analyzed AST (or subtree) to an identical copy. - - Use the node() method to transform an AST node. - - Subclass to perform a non-identity transform. - - Notes: - - * Do not duplicate TypeInfo nodes. This would generally not be desirable. - * Only update some name binding cross-references, but only those that - refer to Var nodes, not those targeting ClassDef, TypeInfo or FuncDef - nodes. - * Types are not transformed, but you can override type() to also perform - type transformation. - - TODO nested classes and functions have not been tested well enough - """ - - def __init__(self) -> None: - # There may be multiple references to a Var node. Keep track of - # Var translations using a dictionary. - self.var_map = Dict[Var, Var]() - - def visit_mypy_file(self, node: MypyFile) -> Node: - # NOTE: The 'names' and 'imports' instance variables will be empty! - new = MypyFile(self.nodes(node.defs), [], node.is_bom) - new._name = node._name - new._fullname = node._fullname - new.path = node.path - new.names = SymbolTable() - return new - - def visit_import(self, node: Import) -> Node: - return Import(node.ids[:]) - - def visit_import_from(self, node: ImportFrom) -> Node: - return ImportFrom(node.id, node.names[:]) - - def visit_import_all(self, node: ImportAll) -> Node: - return ImportAll(node.id) - - def visit_func_def(self, node: FuncDef) -> FuncDef: - # Note that a FuncDef must be transformed to a FuncDef. - new = FuncDef(node.name(), - [self.visit_var(var) for var in node.args], - node.arg_kinds[:], - [None] * len(node.init), - self.block(node.body), - self.optional_type(node.type)) - - self.copy_function_attributes(new, node) - - new._fullname = node._fullname - new.is_decorated = node.is_decorated - new.is_conditional = node.is_conditional - new.is_abstract = node.is_abstract - new.is_static = node.is_static - new.is_class = node.is_class - new.is_property = node.is_property - new.original_def = node.original_def - return new - - def visit_func_expr(self, node: FuncExpr) -> Node: - new = FuncExpr([self.visit_var(var) for var in node.args], - node.arg_kinds[:], - [None] * len(node.init), - self.block(node.body), - self.optional_type(node.type)) - self.copy_function_attributes(new, node) - return new - - def copy_function_attributes(self, new: FuncItem, - original: FuncItem) -> None: - new.info = original.info - new.min_args = original.min_args - new.max_pos = original.max_pos - new.is_implicit = original.is_implicit - new.is_overload = original.is_overload - new.is_generator = original.is_generator - new.init = self.duplicate_inits(original.init) - - def duplicate_inits(self, - inits: List[AssignmentStmt]) -> List[AssignmentStmt]: - result = List[AssignmentStmt]() - for init in inits: - if init: - result.append(self.duplicate_assignment(init)) - else: - result.append(None) - return result - - def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> Node: - items = [self.visit_decorator(decorator) - for decorator in node.items] - for newitem, olditem in zip(items, node.items): - newitem.line = olditem.line - new = OverloadedFuncDef(items) - new._fullname = node._fullname - new.type = self.type(node.type) - new.info = node.info - return new - - def visit_class_def(self, node: ClassDef) -> Node: - new = ClassDef(node.name, - self.block(node.defs), - node.type_vars, - self.types(node.base_types), - node.metaclass) - new.fullname = node.fullname - new.info = node.info - new.decorators = [decorator.accept(self) - for decorator in node.decorators] - new.is_builtinclass = node.is_builtinclass - return new - - def visit_var_def(self, node: VarDef) -> Node: - new = VarDef([self.visit_var(var) for var in node.items], - node.is_top_level, - self.optional_node(node.init)) - new.kind = node.kind - return new - - def visit_global_decl(self, node: GlobalDecl) -> Node: - return GlobalDecl(node.names[:]) - - def visit_block(self, node: Block) -> Block: - return Block(self.nodes(node.body)) - - def visit_decorator(self, node: Decorator) -> Decorator: - # Note that a Decorator must be transformed to a Decorator. - func = self.visit_func_def(node.func) - func.line = node.func.line - new = Decorator(func, self.nodes(node.decorators), - self.visit_var(node.var)) - new.is_overload = node.is_overload - return new - - def visit_var(self, node: Var) -> Var: - # Note that a Var must be transformed to a Var. - if node in self.var_map: - return self.var_map[node] - new = Var(node.name(), self.optional_type(node.type)) - new.line = node.line - new._fullname = node._fullname - new.info = node.info - new.is_self = node.is_self - new.is_ready = node.is_ready - new.is_initialized_in_class = node.is_initialized_in_class - new.is_staticmethod = node.is_staticmethod - new.is_classmethod = node.is_classmethod - new.is_property = node.is_property - new.set_line(node.line) - self.var_map[node] = new - return new - - def visit_expression_stmt(self, node: ExpressionStmt) -> Node: - return ExpressionStmt(self.node(node.expr)) - - def visit_assignment_stmt(self, node: AssignmentStmt) -> Node: - return self.duplicate_assignment(node) - - def duplicate_assignment(self, node: AssignmentStmt) -> AssignmentStmt: - new = AssignmentStmt(self.nodes(node.lvalues), - self.node(node.rvalue), - self.optional_type(node.type)) - new.line = node.line - return new - - def visit_operator_assignment_stmt(self, - node: OperatorAssignmentStmt) -> Node: - return OperatorAssignmentStmt(node.op, - self.node(node.lvalue), - self.node(node.rvalue)) - - def visit_while_stmt(self, node: WhileStmt) -> Node: - return WhileStmt(self.node(node.expr), - self.block(node.body), - self.optional_block(node.else_body)) - - def visit_for_stmt(self, node: ForStmt) -> Node: - return ForStmt(self.names(node.index), - self.node(node.expr), - self.block(node.body), - self.optional_block(node.else_body), - self.optional_types(node.types)) - - def visit_return_stmt(self, node: ReturnStmt) -> Node: - return ReturnStmt(self.optional_node(node.expr)) - - def visit_assert_stmt(self, node: AssertStmt) -> Node: - return AssertStmt(self.node(node.expr)) - - def visit_yield_stmt(self, node: YieldStmt) -> Node: - return YieldStmt(self.node(node.expr)) - -<<<<<<< HEAD - def visit_yield_from_stmt(self, node: YieldFromStmt) -> Node: - return YieldFromStmt(self.node(node.expr)) - -======= ->>>>>>> master - def visit_del_stmt(self, node: DelStmt) -> Node: - return DelStmt(self.node(node.expr)) - - def visit_if_stmt(self, node: IfStmt) -> Node: - return IfStmt(self.nodes(node.expr), - self.blocks(node.body), - self.optional_block(node.else_body)) - - def visit_break_stmt(self, node: BreakStmt) -> Node: - return BreakStmt() - - def visit_continue_stmt(self, node: ContinueStmt) -> Node: - return ContinueStmt() - - def visit_pass_stmt(self, node: PassStmt) -> Node: - return PassStmt() - - def visit_raise_stmt(self, node: RaiseStmt) -> Node: - return RaiseStmt(self.optional_node(node.expr), - self.optional_node(node.from_expr)) - - def visit_try_stmt(self, node: TryStmt) -> Node: - return TryStmt(self.block(node.body), - self.optional_names(node.vars), - self.optional_nodes(node.types), - self.blocks(node.handlers), - self.optional_block(node.else_body), - self.optional_block(node.finally_body)) - - def visit_with_stmt(self, node: WithStmt) -> Node: - return WithStmt(self.nodes(node.expr), - self.optional_names(node.name), - self.block(node.body)) - - def visit_print_stmt(self, node: PrintStmt) -> Node: - return PrintStmt(self.nodes(node.args), - node.newline) - - def visit_int_expr(self, node: IntExpr) -> Node: - return IntExpr(node.value) - - def visit_str_expr(self, node: StrExpr) -> Node: - return StrExpr(node.value) - - def visit_bytes_expr(self, node: BytesExpr) -> Node: - return BytesExpr(node.value) - - def visit_unicode_expr(self, node: UnicodeExpr) -> Node: - return UnicodeExpr(node.value) - - def visit_float_expr(self, node: FloatExpr) -> Node: - return FloatExpr(node.value) - - def visit_paren_expr(self, node: ParenExpr) -> Node: - return ParenExpr(self.node(node.expr)) - - def visit_name_expr(self, node: NameExpr) -> Node: - return self.duplicate_name(node) - - def duplicate_name(self, node: NameExpr) -> NameExpr: - # This method is used when the transform result must be a NameExpr. - # visit_name_expr() is used when there is no such restriction. - new = NameExpr(node.name) - new.info = node.info - self.copy_ref(new, node) - return new - - def visit_member_expr(self, node: MemberExpr) -> Node: - member = MemberExpr(self.node(node.expr), - node.name) - if node.def_var: - member.def_var = self.visit_var(node.def_var) - self.copy_ref(member, node) - return member - - def copy_ref(self, new: RefExpr, original: RefExpr) -> None: - new.kind = original.kind - new.fullname = original.fullname - target = original.node - if isinstance(target, Var): - target = self.visit_var(target) - new.node = target - new.is_def = original.is_def - -<<<<<<< HEAD - def visit_yield_from_expr(self, node: YieldFromExpr) -> Node: - return YieldFromExpr(self.node(node.expr)) - -======= ->>>>>>> master - def visit_call_expr(self, node: CallExpr) -> Node: - return CallExpr(self.node(node.callee), - self.nodes(node.args), - node.arg_kinds[:], - node.arg_names[:], - self.optional_node(node.analyzed)) - - def visit_op_expr(self, node: OpExpr) -> Node: - new = OpExpr(node.op, self.node(node.left), self.node(node.right)) - new.method_type = self.optional_type(node.method_type) - return new - -<<<<<<< HEAD -======= - def visit_comparison_expr(self, node: ComparisonExpr) -> Node: - new = ComparisonExpr(node.operators, self.nodes(node.operands)) - new.method_types = [self.optional_type(t) for t in node.method_types] - return new - ->>>>>>> master - def visit_cast_expr(self, node: CastExpr) -> Node: - return CastExpr(self.node(node.expr), - self.type(node.type)) - - def visit_super_expr(self, node: SuperExpr) -> Node: - new = SuperExpr(node.name) - new.info = node.info - return new - - def visit_unary_expr(self, node: UnaryExpr) -> Node: - new = UnaryExpr(node.op, self.node(node.expr)) - new.method_type = self.optional_type(node.method_type) - return new - - def visit_list_expr(self, node: ListExpr) -> Node: - return ListExpr(self.nodes(node.items)) - - def visit_dict_expr(self, node: DictExpr) -> Node: - return DictExpr([(self.node(key), self.node(value)) - for key, value in node.items]) - - def visit_tuple_expr(self, node: TupleExpr) -> Node: - return TupleExpr(self.nodes(node.items)) - - def visit_set_expr(self, node: SetExpr) -> Node: - return SetExpr(self.nodes(node.items)) - - def visit_index_expr(self, node: IndexExpr) -> Node: - new = IndexExpr(self.node(node.base), self.node(node.index)) - if node.method_type: - new.method_type = self.type(node.method_type) - if node.analyzed: - new.analyzed = self.visit_type_application(node.analyzed) - new.analyzed.set_line(node.analyzed.line) - return new - - def visit_undefined_expr(self, node: UndefinedExpr) -> Node: - return UndefinedExpr(self.type(node.type)) - - def visit_type_application(self, node: TypeApplication) -> TypeApplication: - return TypeApplication(self.node(node.expr), - self.types(node.types)) - - def visit_list_comprehension(self, node: ListComprehension) -> Node: - generator = self.duplicate_generator(node.generator) - generator.set_line(node.generator.line) - return ListComprehension(generator) - - def visit_generator_expr(self, node: GeneratorExpr) -> Node: - return self.duplicate_generator(node) - - def duplicate_generator(self, node: GeneratorExpr) -> GeneratorExpr: - return GeneratorExpr(self.node(node.left_expr), - [self.names(index) for index in node.indices], - [self.optional_types(t) for t in node.types], - [self.node(s) for s in node.sequences], - [[self.node(cond) for cond in conditions] -<<<<<<< HEAD - for conditions in node.condlists]) -======= - for conditions in node.condlists]) ->>>>>>> master - - def visit_slice_expr(self, node: SliceExpr) -> Node: - return SliceExpr(self.optional_node(node.begin_index), - self.optional_node(node.end_index), - self.optional_node(node.stride)) - - def visit_conditional_expr(self, node: ConditionalExpr) -> Node: - return ConditionalExpr(self.node(node.cond), - self.node(node.if_expr), - self.node(node.else_expr)) - - def visit_type_var_expr(self, node: TypeVarExpr) -> Node: - return TypeVarExpr(node.name(), node.fullname(), - self.types(node.values)) - - def visit_ducktype_expr(self, node: DucktypeExpr) -> Node: - return DucktypeExpr(node.type) - - def visit_disjointclass_expr(self, node: DisjointclassExpr) -> Node: - return DisjointclassExpr(node.cls) - - def visit_coerce_expr(self, node: CoerceExpr) -> Node: - raise RuntimeError('Not supported') - - def visit_type_expr(self, node: TypeExpr) -> Node: - raise RuntimeError('Not supported') - - def visit_java_cast(self, node: JavaCast) -> Node: - raise RuntimeError('Not supported') - - def visit_temp_node(self, node: TempNode) -> Node: - return TempNode(self.type(node.type)) - - def node(self, node: Node) -> Node: - new = node.accept(self) - new.set_line(node.line) - return new - - # Helpers - # - # All the node helpers also propagate line numbers. - - def optional_node(self, node: Node) -> Node: - if node: - return self.node(node) - else: - return None - - def block(self, block: Block) -> Block: - new = self.visit_block(block) - new.line = block.line - return new - - def optional_block(self, block: Block) -> Block: - if block: - return self.block(block) - else: - return None - - def nodes(self, nodes: List[Node]) -> List[Node]: - return [self.node(node) for node in nodes] - - def optional_nodes(self, nodes: List[Node]) -> List[Node]: - return [self.optional_node(node) for node in nodes] - - def blocks(self, blocks: List[Block]) -> List[Block]: - return [self.block(block) for block in blocks] - - def names(self, names: List[NameExpr]) -> List[NameExpr]: - return [self.duplicate_name(name) for name in names] - - def optional_names(self, names: List[NameExpr]) -> List[NameExpr]: - result = List[NameExpr]() - for name in names: - if name: - result.append(self.duplicate_name(name)) - else: - result.append(None) - return result - - def type(self, type: Type) -> Type: - # Override this method to transform types. - return type - - def optional_type(self, type: Type) -> Type: - if type: - return self.type(type) - else: - return None - - def types(self, types: List[Type]) -> List[Type]: - return [self.type(type) for type in types] - - def optional_types(self, types: List[Type]) -> List[Type]: - return [self.optional_type(type) for type in types] diff --git a/mypy/visitor.py.orig b/mypy/visitor.py.orig deleted file mode 100644 index e1091f549281..000000000000 --- a/mypy/visitor.py.orig +++ /dev/null @@ -1,225 +0,0 @@ -"""Generic abstract syntax tree node visitor""" - -from typing import typevar, Generic - -import mypy.nodes - - -T = typevar('T') - - -class NodeVisitor(Generic[T]): - """Empty base class for parse tree node visitors. - - The T type argument specifies the return type of the visit - methods. As all methods defined here return None by default, - subclasses do not always need to override all the methods. - - TODO make the default return value explicit - """ - - # Module structure - - def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T: - pass - - def visit_import(self, o: 'mypy.nodes.Import') -> T: - pass - - def visit_import_from(self, o: 'mypy.nodes.ImportFrom') -> T: - pass - - def visit_import_all(self, o: 'mypy.nodes.ImportAll') -> T: - pass - - # Definitions - - def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T: - pass - - def visit_overloaded_func_def(self, - o: 'mypy.nodes.OverloadedFuncDef') -> T: - pass - - def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T: - pass - - def visit_var_def(self, o: 'mypy.nodes.VarDef') -> T: - pass - - def visit_global_decl(self, o: 'mypy.nodes.GlobalDecl') -> T: - pass - - def visit_decorator(self, o: 'mypy.nodes.Decorator') -> T: - pass - - def visit_var(self, o: 'mypy.nodes.Var') -> T: - pass - - # Statements - - def visit_block(self, o: 'mypy.nodes.Block') -> T: - pass - - def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> T: - pass - - def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T: - pass - - def visit_operator_assignment_stmt(self, - o: 'mypy.nodes.OperatorAssignmentStmt') -> T: - pass - - def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T: - pass - - def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T: - pass - - def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T: - pass - - def visit_assert_stmt(self, o: 'mypy.nodes.AssertStmt') -> T: - pass - - def visit_yield_stmt(self, o: 'mypy.nodes.YieldStmt') -> T: - pass -<<<<<<< HEAD - def visit_yield_from_stmt(self, o: 'mypy.nodes.YieldFromStmt') -> T: - pass -======= - ->>>>>>> master - def visit_del_stmt(self, o: 'mypy.nodes.DelStmt') -> T: - pass - - def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T: - pass - - def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T: - pass - - def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T: - pass - - def visit_pass_stmt(self, o: 'mypy.nodes.PassStmt') -> T: - pass - - def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T: - pass - - def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T: - pass - - def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T: - pass - - def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> T: - pass - - # Expressions - - def visit_int_expr(self, o: 'mypy.nodes.IntExpr') -> T: - pass - - def visit_str_expr(self, o: 'mypy.nodes.StrExpr') -> T: - pass - - def visit_bytes_expr(self, o: 'mypy.nodes.BytesExpr') -> T: - pass - - def visit_unicode_expr(self, o: 'mypy.nodes.UnicodeExpr') -> T: - pass - - def visit_float_expr(self, o: 'mypy.nodes.FloatExpr') -> T: - pass - - def visit_paren_expr(self, o: 'mypy.nodes.ParenExpr') -> T: - pass - - def visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> T: - pass - - def visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> T: - pass -<<<<<<< HEAD - def visit_yield_from_expr(self, o: 'mypy.nodes.YieldFromExpr') -> T: - pass -======= - ->>>>>>> master - def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T: - pass - - def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> T: - pass - - def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T: - pass - - def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> T: - pass - - def visit_super_expr(self, o: 'mypy.nodes.SuperExpr') -> T: - pass - - def visit_unary_expr(self, o: 'mypy.nodes.UnaryExpr') -> T: - pass - - def visit_list_expr(self, o: 'mypy.nodes.ListExpr') -> T: - pass - - def visit_dict_expr(self, o: 'mypy.nodes.DictExpr') -> T: - pass - - def visit_tuple_expr(self, o: 'mypy.nodes.TupleExpr') -> T: - pass - - def visit_set_expr(self, o: 'mypy.nodes.SetExpr') -> T: - pass - - def visit_index_expr(self, o: 'mypy.nodes.IndexExpr') -> T: - pass - - def visit_undefined_expr(self, o: 'mypy.nodes.UndefinedExpr') -> T: - pass - - def visit_type_application(self, o: 'mypy.nodes.TypeApplication') -> T: - pass - - def visit_func_expr(self, o: 'mypy.nodes.FuncExpr') -> T: - pass - - def visit_list_comprehension(self, o: 'mypy.nodes.ListComprehension') -> T: - pass - - def visit_generator_expr(self, o: 'mypy.nodes.GeneratorExpr') -> T: - pass - - def visit_slice_expr(self, o: 'mypy.nodes.SliceExpr') -> T: - pass - - def visit_conditional_expr(self, o: 'mypy.nodes.ConditionalExpr') -> T: - pass - - def visit_type_var_expr(self, o: 'mypy.nodes.TypeVarExpr') -> T: - pass - - def visit_ducktype_expr(self, o: 'mypy.nodes.DucktypeExpr') -> T: - pass - - def visit_disjointclass_expr(self, o: 'mypy.nodes.DisjointclassExpr') -> T: - pass - - def visit_coerce_expr(self, o: 'mypy.nodes.CoerceExpr') -> T: - pass - - def visit_type_expr(self, o: 'mypy.nodes.TypeExpr') -> T: - pass - - def visit_java_cast(self, o: 'mypy.nodes.JavaCast') -> T: - pass - - def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> T: - pass diff --git a/stubs/3.4/asyncio/futures.py.orig b/stubs/3.4/asyncio/futures.py.orig deleted file mode 100644 index 138f93feda35..000000000000 --- a/stubs/3.4/asyncio/futures.py.orig +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Any, Function, typevar, List, Generic, Iterable, Iterator -from asyncio.events import AbstractEventLoop -# __all__ = ['CancelledError', 'TimeoutError', -# 'InvalidStateError', -# 'wrap_future', -# ] -__all__ = ['Future'] - -_T = typevar('_T') - -class _TracebackLogger: - __slots__ = [] # type: List[str] - exc = Any # Exception - tb = [] # type: List[str] - def __init__(self, exc: Any, loop: AbstractEventLoop) -> None: pass - def activate(self) -> None: pass - def clear(self) -> None: pass - def __del__(self) -> None: pass - -<<<<<<< HEAD -class Future(Iterator[T], Generic[T]): # (Iterable[T], Generic[T]) -======= -class Future(Generic[_T]): ->>>>>>> master - _state = '' - _exception = Any #Exception - _blocking = False - _log_traceback = False - _tb_logger = _TracebackLogger - def __init__(self, *, loop: AbstractEventLoop = None) -> None: pass - def __repr__(self) -> str: pass - def __del__(self) -> None: pass - def cancel(self) -> bool: pass - def _schedule_callbacks(self) -> None: pass - def cancelled(self) -> bool: pass - def done(self) -> bool: pass - def result(self) -> _T: pass - def exception(self) -> Any: pass -<<<<<<< HEAD - def add_done_callback(self, fn: Function[[Future[T]],Any]) -> None: pass - def remove_done_callback(self, fn: Function[[Future[T]], Any]) -> int: pass - def set_result(self, result: T) -> None: pass - def set_exception(self, exception: Any) -> None: pass - def _copy_state(self, other: Any) -> None: pass - def __iter__(self) -> 'Iterator[T]': pass - def __next__(self) -> 'T': pass -======= - def add_done_callback(self, fn: Function[[],Any]) -> None: pass - def remove_done_callback(self, fn: Function[[], Any]) -> int: pass - def set_result(self, result: _T) -> None: pass - def set_exception(self, exception: Any) -> None: pass - def _copy_state(self, other: Any) -> None: pass - def __iter__(self) -> Any: pass ->>>>>>> master diff --git a/stubs/3.4/asyncio/tasks.py.orig b/stubs/3.4/asyncio/tasks.py.orig deleted file mode 100644 index c32902f6c059..000000000000 --- a/stubs/3.4/asyncio/tasks.py.orig +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Any, Iterable, typevar, Set, Dict, List, TextIO, Union, Tuple, Generic, Function -from asyncio.events import AbstractEventLoop -from asyncio.futures import Future -# __all__ = ['iscoroutinefunction', 'iscoroutine', -# 'as_completed', 'async', -# 'gather', 'shield', -# ] - -__all__ = ['coroutine', 'Task', 'sleep', - 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', - 'wait', 'wait_for'] - -<<<<<<< HEAD -FIRST_EXCEPTION = 'FIRST_EXCEPTION' -FIRST_COMPLETED = 'FIRST_COMPLETED' -ALL_COMPLETED = 'ALL_COMPLETED' -T = typevar('T') -def coroutine(f: Any) -> Any: pass # Here comes and go a function -def sleep(delay: float, result: T=None, loop: AbstractEventLoop=None) -> Future[T]: pass -def wait(fs: List[Any], *, loop: AbstractEventLoop=None, - timeout: float=None, return_when: str=ALL_COMPLETED) -> Future[Tuple[Set[Future[T]], Set[Future[T]]]]: pass -def wait_for(fut: Future[T], timeout: float, *, loop: AbstractEventLoop=None) -> Future[T]: pass -# def wait(fs: Union[List[Iterable], List[Future[T]]], *, loop: AbstractEventLoop=None, -# timeout: int=None, return_when: str=ALL_COMPLETED) -> Future[Tuple[Set[Future[T]], Set[Future[T]]]]: pass - -class Task(Future[T], Generic[T]): - _all_tasks = None # type: Set[Task] - _current_tasks = {} # type: Dict[AbstractEventLoop, Task] - @classmethod - def current_task(cls, loop: AbstractEventLoop=None) -> Task: pass - @classmethod - def all_tasks(cls, loop: AbstractEventLoop=None) -> Set[Task]: pass - # def __init__(self, coro: Union[Iterable[T], Future[T]], *, loop: AbstractEventLoop=None) -> None: pass - def __init__(self, coro: Future[T], *, loop: AbstractEventLoop=None) -> None: pass - def __repr__(self) -> str: pass - def get_stack(self, *, limit: int=None) -> List[Any]: pass # return List[stackframe] - def print_stack(self, *, limit: int=None, file: TextIO=None) -> None: pass - def cancel(self) -> bool: pass - def _step(self, value: Any=None, exc: Exception=None) -> None: pass - def _wakeup(self, future: Future[Any]) -> None: pass - -======= -_T = typevar('_T') -def coroutine(f: Any) -> Any: pass -def sleep(delay: float, result: _T=None, loop: AbstractEventLoop=None) -> _T: pass ->>>>>>> master From 27abde5524c3c50de5f93d67f92f1ecd80d6eeff Mon Sep 17 00:00:00 2001 From: Rock Neurotiko Date: Mon, 15 Sep 2014 11:26:38 +0200 Subject: [PATCH 08/12] more merge --- mypy/checker.py | 315 ++++++++++++++++++++++------------------------ mypy/checkexpr.py | 309 +++++++++++++++++++++++++-------------------- mypy/messages.py | 139 ++++++++++---------- mypy/nodes.py | 28 ++--- 4 files changed, 409 insertions(+), 382 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 56b273319f04..e7e57d442cfa 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -15,8 +15,8 @@ BytesExpr, UnicodeExpr, FloatExpr, OpExpr, UnaryExpr, CastExpr, SuperExpr, TypeApplication, DictExpr, SliceExpr, FuncExpr, TempNode, SymbolTableNode, Context, ListComprehension, ConditionalExpr, GeneratorExpr, - Decorator, SetExpr, TypeVarExpr, UndefinedExpr, PrintStmt, - LITERAL_TYPE, BreakStmt, ContinueStmt, YieldFromExpr, YieldFromStmt + Decorator, SetExpr, PassStmt, TypeVarExpr, UndefinedExpr, PrintStmt, + LITERAL_TYPE, BreakStmt, ContinueStmt, ComparisonExpr, YieldFromExpr, YieldFromStmt ) from mypy.nodes import function_type, method_type from mypy import nodes @@ -180,7 +180,7 @@ def pop_frame(self, canskip=True, fallthrough=False) -> Tuple[bool, Frame]: """ result = self.frames.pop() - options = self.frames_on_escape.get(len(self.frames)-1, []) + options = self.frames_on_escape.get(len(self.frames) - 1, []) if canskip: options.append(self.frames[-1]) if fallthrough: @@ -251,20 +251,20 @@ def most_recent_enclosing_type(self, expr: Node, type: Type) -> Type: return self.get_declaration(expr) key = expr.literal_hash enclosers = ([self.get_declaration(expr)] + - [f[key] for f in self.frames - if key in f and is_subtype(type, f[key])]) + [f[key] for f in self.frames + if key in f and is_subtype(type, f[key])]) return enclosers[-1] def allow_jump(self, index: int) -> None: new_frame = Frame() - for f in self.frames[index+1:]: + for f in self.frames[index + 1:]: for k in f: new_frame[k] = f[k] self.frames_on_escape.setdefault(index, []).append(new_frame) def push_loop_frame(self): - self.loop_frames.append(len(self.frames)-1) + self.loop_frames.append(len(self.frames) - 1) def pop_loop_frame(self): self.loop_frames.pop() @@ -302,7 +302,7 @@ class TypeChecker(NodeVisitor[Type]): binder = Undefined(ConditionalTypeBinder) # Helper for type checking expressions expr_checker = Undefined('mypy.checkexpr.ExpressionChecker') - + # Stack of function return types return_types = Undefined(List[Type]) # Type context for type inference @@ -313,11 +313,11 @@ class TypeChecker(NodeVisitor[Type]): function_stack = Undefined(List[FuncItem]) # Set to True on return/break/raise, False on blocks that can block any of them breaking_out = False - + globals = Undefined(SymbolTable) locals = Undefined(SymbolTable) modules = Undefined(Dict[str, MypyFile]) - + def __init__(self, errors: Errors, modules: Dict[str, MypyFile], pyversion: int = 3) -> None: """Construct a type checker. @@ -338,16 +338,16 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], self.type_context = [] self.dynamic_funcs = [] self.function_stack = [] - - def visit_file(self, file_node: MypyFile, path: str) -> None: + + def visit_file(self, file_node: MypyFile, path: str) -> None: """Type check a mypy file with the given path.""" self.errors.set_file(path) self.globals = file_node.names self.locals = None - + for d in file_node.defs: self.accept(d) - + def accept(self, node: Node, type_context: Type = None) -> Type: """Type check a node in the given type context.""" self.type_context.append(type_context) @@ -375,7 +375,7 @@ def accept_in_frame(self, node: Node, type_context: Type = None, # # Definitions # - + def visit_var_def(self, defn: VarDef) -> Type: """Type check a variable definition. @@ -409,11 +409,11 @@ def visit_var_def(self, defn: VarDef) -> Type: if (defn.kind == LDEF and not defn.items[0].type and not defn.is_top_level and not self.is_dynamic_function()): self.fail(messages.NEED_ANNOTATION_FOR_VAR, defn) - + def infer_local_variable_type(self, x, y, z): # TODO raise RuntimeError('Not implemented') - + def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> Type: num_abstract = 0 for fdef in defn.items: @@ -436,7 +436,7 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: if is_unsafe_overlapping_signatures(sig1, sig2): self.msg.overloaded_signatures_overlap(i + 1, j + 2, item.func) - + def visit_func_def(self, defn: FuncDef) -> Type: """Type check a function definition.""" self.check_func_item(defn, name=defn.name()) @@ -447,7 +447,7 @@ def visit_func_def(self, defn: FuncDef) -> Type: if not is_same_type(function_type(defn), function_type(defn.original_def)): self.msg.incompatible_conditional_function_def(defn) - + def check_func_item(self, defn: FuncItem, type_override: Callable = None, name: str = None) -> Type: @@ -457,16 +457,16 @@ def check_func_item(self, defn: FuncItem, """ # We may be checking a function definition or an anonymous function. In # the first case, set up another reference with the precise type. - fdef = None # type: FuncDef + fdef = None # type: FuncDef if isinstance(defn, FuncDef): fdef = defn self.function_stack.append(defn) self.dynamic_funcs.append(defn.type is None and not type_override) - + if fdef: self.errors.push_function(fdef.name()) - + typ = function_type(defn) if type_override: typ = type_override @@ -474,13 +474,13 @@ def check_func_item(self, defn: FuncItem, self.check_func_def(defn, typ, name) else: raise RuntimeError('Not supported') - + if fdef: self.errors.pop_function() - + self.dynamic_funcs.pop() self.function_stack.pop() - + def check_func_def(self, defn: FuncItem, typ: Callable, name: str) -> None: """Type check a function definition.""" # Expand type variables with value restrictions to ordinary types. @@ -489,7 +489,7 @@ def check_func_def(self, defn: FuncItem, typ: Callable, name: str) -> None: self.binder = ConditionalTypeBinder(self.basic_types) self.binder.push_frame() defn.expanded.append(item) - + # We may be checking a function definition or an anonymous # function. In the first case, set up another reference with the # precise type. @@ -546,7 +546,7 @@ def check_func_def(self, defn: FuncItem, typ: Callable, name: str) -> None: def check_reverse_op_method(self, defn: FuncItem, typ: Callable, method: str) -> None: """Check a reverse operator method such as __radd__.""" - + # If the argument of a reverse operator method such as __radd__ # does not define the corresponding non-reverse method such as __add__ # the return type of __radd__ may not reliably represent the value of @@ -567,8 +567,8 @@ def check_reverse_op_method(self, defn: FuncItem, typ: Callable, if method in ('__eq__', '__ne__'): # These are defined for all objects => can't cause trouble. - return - + return + # With 'Any' or 'object' return type we are happy, since any possible # return value is valid. ret_type = typ.ret_type @@ -581,7 +581,7 @@ def check_reverse_op_method(self, defn: FuncItem, typ: Callable, # in an error elsewhere. if len(typ.arg_types) <= 2: # TODO check self argument kind - + # Check for the issue described above. arg_type = typ.arg_types[1] other_method = nodes.normal_from_reverse_op[method] @@ -663,7 +663,7 @@ def check_overlapping_op_methods(self, [None] * 2, forward_type.ret_type, is_type_obj=False, - name=forward_type.name) + name=forward_type.name) reverse_args = reverse_type.arg_types reverse_tweaked = Callable([reverse_args[1], reverse_args[0]], [nodes.ARG_POS] * 2, @@ -671,7 +671,7 @@ def check_overlapping_op_methods(self, reverse_type.ret_type, is_type_obj=False, name=reverse_type.name) - + if is_unsafe_overlapping_signatures(forward_tweaked, reverse_tweaked): self.msg.operator_method_signatures_overlap( @@ -733,13 +733,13 @@ def expand_typevars(self, defn: FuncItem, return result else: return [(defn, typ)] - + def check_method_override(self, defn: FuncBase) -> None: """Check if function definition is compatible with base classes.""" # Check against definitions in base classes. for base in defn.info.mro[1:]: self.check_method_or_accessor_override_for_base(defn, base) - + def check_method_or_accessor_override_for_base(self, defn: FuncBase, base: TypeInfo) -> None: """Check if method definition is compatible with a base class.""" @@ -789,7 +789,7 @@ def check_method_override_for_base_with_name( assert original_type is not None self.msg.signature_incompatible_with_supertype( defn.name(), name, base.name(), defn) - + def check_override(self, override: FunctionLike, original: FunctionLike, name: str, name_in_super: str, supertype: str, node: Context) -> None: @@ -805,9 +805,9 @@ def check_override(self, override: FunctionLike, original: FunctionLike, if (isinstance(override, Overloaded) or isinstance(original, Overloaded) or len(cast(Callable, override).arg_types) != - len(cast(Callable, original).arg_types) or + len(cast(Callable, original).arg_types) or cast(Callable, override).min_args != - cast(Callable, original).min_args): + cast(Callable, original).min_args): # Use boolean variable to clarify code. fail = False if not is_subtype(override, original): @@ -826,20 +826,20 @@ def check_override(self, override: FunctionLike, original: FunctionLike, # Give more detailed messages for the common case of both # signatures having the same number of arguments and no # overloads. - + coverride = cast(Callable, override) coriginal = cast(Callable, original) - + for i in range(len(coverride.arg_types)): - if not is_equivalent(coriginal.arg_types[i], - coverride.arg_types[i]): + if not is_subtype(coriginal.arg_types[i], + coverride.arg_types[i]): self.msg.argument_incompatible_with_supertype( i + 1, name, name_in_super, supertype, node) - + if not is_subtype(coverride.ret_type, coriginal.ret_type): self.msg.return_type_incompatible_with_supertype( name, name_in_super, supertype, node) - + def visit_class_def(self, defn: ClassDef) -> Type: """Type check a class definition.""" typ = defn.info @@ -902,11 +902,11 @@ def check_compatibility(self, name: str, base1: TypeInfo, if not ok: self.msg.base_class_definitions_incompatible(name, base1, base2, ctx) - + # # Statements # - + def visit_block(self, b: Block) -> Type: if b.is_unreachable: return None @@ -914,7 +914,7 @@ def visit_block(self, b: Block) -> Type: self.accept(s) if self.breaking_out: break - + def visit_assignment_stmt(self, s: AssignmentStmt) -> Type: """Type check an assignment statement. @@ -936,11 +936,11 @@ def check_assignments(self, lvalues: List[Node], # since we cannot typecheck them until we know the rvalue type. # For each lvalue, one of lvalue_types[i] or index_lvalues[i] is not # None. - lvalue_types = [] # type: List[Type] # Each may be None - index_lvalues = [] # type: List[IndexExpr] # Each may be None - inferred = [] # type: List[Var] + lvalue_types = [] # type: List[Type] # Each may be None + index_lvalues = [] # type: List[IndexExpr] # Each may be None + inferred = [] # type: List[Var] is_inferred = False - + for lv in lvalues: if self.is_definition(lv): is_inferred = True @@ -988,7 +988,7 @@ def check_assignments(self, lvalues: List[Node], if is_inferred: self.infer_variable_type(inferred, lvalues, self.accept(rvalue), rvalue) - + def is_definition(self, s: Node) -> bool: if isinstance(s, NameExpr): if s.is_def: @@ -1004,7 +1004,7 @@ def is_definition(self, s: Node) -> bool: elif isinstance(s, MemberExpr): return s.is_def return False - + def expand_lvalues(self, n: Node) -> List[Node]: if isinstance(n, TupleExpr): return self.expr_checker.unwrap_list(n.items) @@ -1014,7 +1014,7 @@ def expand_lvalues(self, n: Node) -> List[Node]: return self.expand_lvalues(n.expr) else: return [n] - + def infer_variable_type(self, names: List[Var], lvalues: List[Node], init_type: Type, context: Context) -> None: """Infer the type of initialized variables from initializer type.""" @@ -1026,10 +1026,10 @@ def infer_variable_type(self, names: List[Var], lvalues: List[Node], self.fail(messages.NEED_ANNOTATION_FOR_VAR, context) else: # Infer type of the target. - + # Make the type more general (strip away function names etc.). init_type = strip_type(init_type) - + if len(names) > 1: if isinstance(init_type, TupleType): # Initializer with a tuple type. @@ -1068,7 +1068,7 @@ def set_inferred_type(self, var: Var, lvalue: Node, type: Type) -> None: if var: var.type = type self.store_type(lvalue, type) - + def is_valid_inferred_type(self, typ: Type) -> bool: """Is an inferred type invalid? @@ -1106,7 +1106,7 @@ def check_multi_assignment(self, lvalue_types: List[Type], undefined_rvalue = True if not rvalue_type: # Infer the type of an ordinary rvalue expression. - rvalue_type = self.accept(rvalue) # TODO maybe elsewhere; redundant + rvalue_type = self.accept(rvalue) # TODO maybe elsewhere; redundant undefined_rvalue = False # Try to expand rvalue to lvalue(s). rvalue_types = None # type: List[Type] @@ -1114,7 +1114,7 @@ def check_multi_assignment(self, lvalue_types: List[Type], pass elif isinstance(rvalue_type, TupleType): # Rvalue with tuple type. - items = [] # type: List[Type] + items = [] # type: List[Type] for i in range(len(lvalue_types)): if lvalue_types[i]: items.append(lvalue_types[i]) @@ -1154,9 +1154,9 @@ def check_multi_assignment(self, lvalue_types: List[Type], return rvalue_types def check_single_assignment(self, - lvalue_type: Type, index_lvalue: IndexExpr, - rvalue: Node, context: Context, - msg: str = messages.INCOMPATIBLE_TYPES_IN_ASSIGNMENT) -> Type: + lvalue_type: Type, index_lvalue: IndexExpr, + rvalue: Node, context: Context, + msg: str = messages.INCOMPATIBLE_TYPES_IN_ASSIGNMENT) -> Type: """Type check an assignment. If lvalue_type is None, the index_lvalue argument must be the @@ -1175,7 +1175,7 @@ def check_single_assignment(self, return rvalue_type elif index_lvalue: self.check_indexed_assignment(index_lvalue, rvalue, context) - + def check_indexed_assignment(self, lvalue: IndexExpr, rvalue: Node, context: Context) -> None: """Type check indexed assignment base[index] = rvalue. @@ -1189,10 +1189,10 @@ def check_indexed_assignment(self, lvalue: IndexExpr, self.expr_checker.check_call(method_type, [lvalue.index, rvalue], [nodes.ARG_POS, nodes.ARG_POS], context) - + def visit_expression_stmt(self, s: ExpressionStmt) -> Type: self.accept(s.expr) - + def visit_return_stmt(self, s: ReturnStmt) -> Type: """Type check a return statement.""" self.breaking_out = True @@ -1204,7 +1204,9 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type: if not isinstance(typ, AnyType): if isinstance(self.return_types[-1], Void): # FuncExpr (lambda) may have a Void return. - if not isinstance(self.function_stack[-1], FuncExpr): + # Function returning a value of type None may have a Void return. + if (not isinstance(self.function_stack[-1], FuncExpr) and + not isinstance(typ, NoneTyp)): self.fail(messages.NO_RETURN_VALUE_EXPECTED, s) else: if self.function_stack[-1].is_coroutine: # Something similar will be needed to mix return and yield @@ -1213,34 +1215,14 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type: self.check_subtype( typ, self.return_types[-1], s, messages.INCOMPATIBLE_RETURN_VALUE_TYPE - + ": expected {}, got {}".format(self.return_types[-1], typ) - ) + + ": expected {}, got {}".format(self.return_types[-1], typ) + ) else: # Return without a value. It's valid in a generator and coroutine function. if not self.function_stack[-1].is_generator and not self.function_stack[-1].is_coroutine: if (not isinstance(self.return_types[-1], Void) and - not self.is_dynamic_function()): - self.fail(messages.RETURN_VALUE_EXPECTED, s) - - def wrap_generic_type(self, typ: Type, rtyp: Type, check_type: str, context: Context) -> Type: - n_diff = self.count_concatenated_types(rtyp, check_type) - self.count_concatenated_types(typ, check_type) - if n_diff >= 1: - return self.named_generic_type(check_type, [typ]) - elif n_diff == 0: - self.fail(messages.INCOMPATIBLE_RETURN_VALUE_TYPE - + ": expected {}, got {}".format(rtyp, typ), context) - return typ - return typ - - def count_concatenated_types(self, typ: Type, check_type: str) -> int: - c = 0 - while is_subtype(typ, self.named_type(check_type)): - c += 1 - if hasattr(typ, 'args') and typ.args: - typ = typ.args[0] - else: - return c - return c + not self.is_dynamic_function()): + self.fail(messages.RETURN_VALUE_EXPECTED, s) def visit_yield_stmt(self, s: YieldStmt) -> Type: return_type = self.return_types[-1] @@ -1255,7 +1237,7 @@ def visit_yield_stmt(self, s: YieldStmt) -> Type: self.fail(messages.INVALID_RETURN_TYPE_FOR_YIELD, s) return None if s.expr is None: - actual_item_type = Void() # type: Type + actual_item_type = Void() # type: Type else: actual_item_type = self.accept(s.expr, expected_item_type) self.check_subtype(actual_item_type, expected_item_type, s, @@ -1330,7 +1312,7 @@ def visit_if_stmt(self, s: IfStmt) -> Type: self.binder.push(var, type) self.accept(b) _, frame = self.binder.pop_frame() - self.binder.allow_jump(len(self.binder.frames)-1) + self.binder.allow_jump(len(self.binder.frames) - 1) if not self.breaking_out: broken = False ending_frames.append(meet_frames(self.basic_types(), clauses_frame, frame)) @@ -1386,17 +1368,17 @@ def visit_operator_assignment_stmt(self, method = infer_operator_assignment_method(lvalue_type, s.op) rvalue_type, method_type = self.expr_checker.check_op( method, lvalue_type, s.rvalue, s) - + if isinstance(s.lvalue, IndexExpr): lv = cast(IndexExpr, s.lvalue) self.check_single_assignment(None, lv, s.rvalue, s.rvalue) else: if not is_subtype(rvalue_type, lvalue_type): self.msg.incompatible_operator_assignment(s.op, s) - + def visit_assert_stmt(self, s: AssertStmt) -> Type: self.accept(s.expr) - + def visit_raise_stmt(self, s: RaiseStmt) -> Type: """Type check a raise statement.""" self.breaking_out = True @@ -1414,14 +1396,14 @@ def visit_raise_stmt(self, s: RaiseStmt) -> Type: self.check_subtype(typ, self.named_type('builtins.BaseException'), s, messages.INVALID_EXCEPTION) - + def visit_try_stmt(self, s: TryStmt) -> Type: """Type check a try statement.""" completed_frames = List[Frame]() self.binder.push_frame() - self.binder.try_frames.add(len(self.binder.frames)-2) + self.binder.try_frames.add(len(self.binder.frames) - 2) self.accept(s.body) - self.binder.try_frames.remove(len(self.binder.frames)-2) + self.binder.try_frames.remove(len(self.binder.frames) - 2) if s.else_body: self.accept(s.else_body) changed, frame_on_completion = self.binder.pop_frame() @@ -1453,7 +1435,7 @@ def exception_type(self, n: Node) -> Type: # Multiple exception types (...). unwrapped = self.expr_checker.unwrap(n) if isinstance(unwrapped, TupleExpr): - t = None # type: Type + t = None # type: Type for item in unwrapped.items: tt = self.exception_type(item) if t: @@ -1504,10 +1486,10 @@ def visit_for_stmt(self, s: ForStmt) -> Type: def analyse_iterable_item_type(self, expr: Node) -> Type: """Analyse iterable expression and return iterator item type.""" iterable = self.accept(expr) - + self.check_not_void(iterable, expr) if isinstance(iterable, TupleType): - joined = NoneTyp() # type: Type + joined = NoneTyp() # type: Type for item in iterable.items: joined = join_types(joined, item, self.basic_types()) if isinstance(joined, ErrorType): @@ -1540,7 +1522,7 @@ def analyse_index_variables(self, index: List[NameExpr], if not is_annotated: # Create a temporary copy of variables with Node item type. # TODO this is ugly - node_index = [] # type: List[Node] + node_index = [] # type: List[Node] for i in index: node_index.append(i) self.check_assignments(node_index, @@ -1549,10 +1531,10 @@ def analyse_index_variables(self, index: List[NameExpr], v = cast(Var, index[0].node) if v.type: self.check_single_assignment(v.type, None, - self.temp_node(item_type), context, - messages.INCOMPATIBLE_TYPES_IN_FOR) + self.temp_node(item_type), context, + messages.INCOMPATIBLE_TYPES_IN_FOR) else: - t = [] # type: List[Type] + t = [] # type: List[Type] for ii in index: v = cast(Var, ii.node) if v.type: @@ -1562,7 +1544,7 @@ def analyse_index_variables(self, index: List[NameExpr], self.check_multi_assignment(t, [None] * len(index), self.temp_node(item_type), context, messages.INCOMPATIBLE_TYPES_IN_FOR) - + def visit_del_stmt(self, s: DelStmt) -> Type: if isinstance(s.expr, IndexExpr): e = cast(IndexExpr, s.expr) # Cast @@ -1574,10 +1556,10 @@ def visit_del_stmt(self, s: DelStmt) -> Type: else: s.expr.accept(self) return None - + def visit_decorator(self, e: Decorator) -> Type: e.func.accept(self) - sig = function_type(e.func) # type: Type + sig = function_type(e.func) # type: Type # Process decorators from the inside out. for i in range(len(e.decorators)): n = len(e.decorators) - 1 - i @@ -1604,18 +1586,18 @@ def visit_with_stmt(self, s: WithStmt) -> Type: def visit_print_stmt(self, s: PrintStmt) -> Type: for arg in s.args: - self.accept(arg) - + self.accept(arg) + # # Expressions # - + def visit_name_expr(self, e: NameExpr) -> Type: return self.expr_checker.visit_name_expr(e) - + def visit_paren_expr(self, e: ParenExpr) -> Type: return self.expr_checker.visit_paren_expr(e) - + def visit_call_expr(self, e: CallExpr) -> Type: result = self.expr_checker.visit_call_expr(e) self.breaking_out = False @@ -1638,10 +1620,10 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: def visit_member_expr(self, e: MemberExpr) -> Type: return self.expr_checker.visit_member_expr(e) - + def visit_break_stmt(self, s: BreakStmt) -> Type: self.breaking_out = True - self.binder.allow_jump(self.binder.loop_frames[-1]-1) + self.binder.allow_jump(self.binder.loop_frames[-1] - 1) return None def visit_continue_stmt(self, s: ContinueStmt) -> Type: @@ -1651,59 +1633,62 @@ def visit_continue_stmt(self, s: ContinueStmt) -> Type: def visit_int_expr(self, e: IntExpr) -> Type: return self.expr_checker.visit_int_expr(e) - + def visit_str_expr(self, e: StrExpr) -> Type: return self.expr_checker.visit_str_expr(e) - + def visit_bytes_expr(self, e: BytesExpr) -> Type: return self.expr_checker.visit_bytes_expr(e) - + def visit_unicode_expr(self, e: UnicodeExpr) -> Type: return self.expr_checker.visit_unicode_expr(e) - + def visit_float_expr(self, e: FloatExpr) -> Type: return self.expr_checker.visit_float_expr(e) - + def visit_op_expr(self, e: OpExpr) -> Type: return self.expr_checker.visit_op_expr(e) - + + def visit_comparison_expr(self, e: ComparisonExpr) -> Type: + return self.expr_checker.visit_comparison_expr(e) + def visit_unary_expr(self, e: UnaryExpr) -> Type: return self.expr_checker.visit_unary_expr(e) - + def visit_index_expr(self, e: IndexExpr) -> Type: return self.expr_checker.visit_index_expr(e) - + def visit_cast_expr(self, e: CastExpr) -> Type: return self.expr_checker.visit_cast_expr(e) - + def visit_super_expr(self, e: SuperExpr) -> Type: return self.expr_checker.visit_super_expr(e) - + def visit_type_application(self, e: TypeApplication) -> Type: return self.expr_checker.visit_type_application(e) def visit_type_var_expr(self, e: TypeVarExpr) -> Type: # TODO Perhaps return a special type used for type variables only? return AnyType() - + def visit_list_expr(self, e: ListExpr) -> Type: return self.expr_checker.visit_list_expr(e) - + def visit_set_expr(self, e: SetExpr) -> Type: return self.expr_checker.visit_set_expr(e) - + def visit_tuple_expr(self, e: TupleExpr) -> Type: return self.expr_checker.visit_tuple_expr(e) - + def visit_dict_expr(self, e: DictExpr) -> Type: return self.expr_checker.visit_dict_expr(e) - + def visit_slice_expr(self, e: SliceExpr) -> Type: return self.expr_checker.visit_slice_expr(e) - + def visit_func_expr(self, e: FuncExpr) -> Type: return self.expr_checker.visit_func_expr(e) - + def visit_list_comprehension(self, e: ListComprehension) -> Type: return self.expr_checker.visit_list_comprehension(e) @@ -1718,22 +1703,22 @@ def visit_temp_node(self, e: TempNode) -> Type: def visit_conditional_expr(self, e: ConditionalExpr) -> Type: return self.expr_checker.visit_conditional_expr(e) - + # # Helpers # - + def check_subtype(self, subtype: Type, supertype: Type, context: Context, - msg: str = messages.INCOMPATIBLE_TYPES, - subtype_label: str = None, - supertype_label: str = None) -> None: + msg: str = messages.INCOMPATIBLE_TYPES, + subtype_label: str = None, + supertype_label: str = None) -> None: """Generate an error if the subtype is not compatible with supertype.""" if not is_subtype(subtype, supertype): if isinstance(subtype, Void): self.msg.does_not_return_value(subtype, context) else: - extra_info = [] # type: List[str] + extra_info = [] # type: List[str] if subtype_label is not None: extra_info.append(subtype_label + ' ' + self.msg.format_simple(subtype)) if supertype_label is not None: @@ -1741,7 +1726,7 @@ def check_subtype(self, subtype: Type, supertype: Type, context: Context, if extra_info: msg += ' (' + ', '.join(extra_info) + ')' self.fail(msg, context) - + def named_type(self, name: str) -> Instance: """Return an instance type with type given by the name and no type arguments. For example, named_type('builtins.object') @@ -1750,11 +1735,11 @@ def named_type(self, name: str) -> Instance: # Assume that the name refers to a type. sym = self.lookup_qualified(name) return Instance(cast(TypeInfo, sym.node), []) - + def named_type_if_exists(self, name: str) -> Type: """Return named instance type, or UnboundType if the type was not defined. - + This is used to simplify test cases by avoiding the need to define basic types not needed in specific test cases (tuple etc.). @@ -1765,7 +1750,7 @@ def named_type_if_exists(self, name: str) -> Type: return Instance(cast(TypeInfo, sym.node), []) except KeyError: return UnboundType(name) - + def named_generic_type(self, name: str, args: List[Type]) -> Instance: """Return an instance with the given name and type arguments. @@ -1778,30 +1763,30 @@ def lookup_typeinfo(self, fullname: str) -> TypeInfo: # Assume that the name refers to a class. sym = self.lookup_qualified(fullname) return cast(TypeInfo, sym.node) - + def type_type(self) -> Instance: """Return instance type 'type'.""" return self.named_type('builtins.type') - + def object_type(self) -> Instance: """Return instance type 'object'.""" return self.named_type('builtins.object') - + def bool_type(self) -> Instance: """Return instance type 'bool'.""" return self.named_type('builtins.bool') - + def str_type(self) -> Instance: """Return instance type 'str'.""" return self.named_type('builtins.str') - + def tuple_type(self) -> Type: """Return instance type 'tuple'.""" # We need the tuple for analysing member access. We want to be able to # do this even if tuple type is not available (useful in test cases), # so we return an unbound type if there is no tuple type. return self.named_type_if_exists('builtins.tuple') - + def check_type_equivalency(self, t1: Type, t2: Type, node: Context, msg: str = messages.INCOMPATIBLE_TYPES) -> None: """Generate an error if the types are not equivalent. The @@ -1809,14 +1794,14 @@ def check_type_equivalency(self, t1: Type, t2: Type, node: Context, """ if not is_equivalent(t1, t2): self.fail(msg, node) - + def store_type(self, node: Node, typ: Type) -> None: """Store the type of a node in the type map.""" self.type_map[node] = typ - + def is_dynamic_function(self) -> bool: return len(self.dynamic_funcs) > 0 and self.dynamic_funcs[-1] - + def lookup(self, name: str, kind: int) -> SymbolTableNode: """Look up a definition from the symbol table with the given name. TODO remove kind argument @@ -1832,23 +1817,23 @@ def lookup(self, name: str, kind: int) -> SymbolTableNode: if name in table: return table[name] raise KeyError('Failed lookup: {}'.format(name)) - + def lookup_qualified(self, name: str) -> SymbolTableNode: if '.' not in name: - return self.lookup(name, GDEF) # FIX kind + return self.lookup(name, GDEF) # FIX kind else: parts = name.split('.') n = self.modules[parts[0]] for i in range(1, len(parts) - 1): n = cast(MypyFile, ((n.names.get(parts[i], None).node))) return n.names[parts[-1]] - + def enter(self) -> None: self.locals = SymbolTable() - + def leave(self) -> None: self.locals = None - + def basic_types(self) -> BasicTypes: """Return a BasicTypes instance that contains primitive types that are needed for certain type operations (joins, for example). @@ -1856,26 +1841,26 @@ def basic_types(self) -> BasicTypes: return BasicTypes(self.object_type(), self.named_type('builtins.type'), self.named_type_if_exists('builtins.tuple'), self.named_type_if_exists('builtins.function')) - + def is_within_function(self) -> bool: """Are we currently type checking within a function? I.e. not at class body or at the top level. """ return self.return_types != [] - + def check_not_void(self, typ: Type, context: Context) -> None: """Generate an error if the type is Void.""" if isinstance(typ, Void): self.msg.does_not_return_value(typ, context) - + def temp_node(self, t: Type, context: Context = None) -> Node: """Create a temporary node with the given, fixed type.""" temp = TempNode(t) if context: temp.set_line(context.get_line()) return temp - + def fail(self, msg: str, context: Context) -> None: """Produce an error message.""" self.msg.fail(msg, context) @@ -1892,12 +1877,12 @@ def map_type_from_supertype(typ: Type, sub_info: TypeInfo, """Map type variables in a type defined in a supertype context to be valid in the subtype context. Assume that the result is unique; if more than one type is possible, return one of the alternatives. - + For example, assume - + class D(Generic[S]) ... class C(D[E[T]], Generic[T]) ... - + Now S in the context of D would be mapped to E[T] in the context of C. """ # Create the type of self in subtype, of form t[a1, ...]. @@ -1991,7 +1976,7 @@ class TypeTransformVisitor(TransformVisitor): def __init__(self, map: Dict[int, Type]) -> None: super().__init__() self.map = map - + def type(self, type: Type) -> Type: return expand_type(type, self.map) @@ -2034,10 +2019,10 @@ def is_unsafe_overlapping_signatures(signature: Type, other: Type) -> bool: if is_same_type(signature.ret_type, other.ret_type): return False # If the first signature has more general argument types, the - # latter will never be called + # latter will never be called if is_more_general_arg_prefix(signature, other): return False - return not is_more_precise_signature(signature, other) + return not is_more_precise_signature(signature, other) return True diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d20e3126dcc4..f4d15c6ae6c7 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -12,7 +12,7 @@ OpExpr, UnaryExpr, IndexExpr, CastExpr, TypeApplication, ListExpr, TupleExpr, DictExpr, FuncExpr, SuperExpr, ParenExpr, SliceExpr, Context, ListComprehension, GeneratorExpr, SetExpr, MypyFile, Decorator, - UndefinedExpr, ConditionalExpr, TempNode, LITERAL_TYPE, YieldFromExpr + UndefinedExpr, ConditionalExpr, ComparisonExpr, TempNode, LITERAL_TYPE, YieldFromExpr ) from mypy.errors import Errors from mypy.nodes import function_type, method_type @@ -38,19 +38,19 @@ class ExpressionChecker: This class works closely together with checker.TypeChecker. """ - + # Some services are provided by a TypeChecker instance. chk = Undefined('mypy.checker.TypeChecker') # This is shared with TypeChecker, but stored also here for convenience. msg = Undefined(MessageBuilder) - + def __init__(self, - chk: 'mypy.checker.TypeChecker', - msg: MessageBuilder) -> None: + chk: 'mypy.checker.TypeChecker', + msg: MessageBuilder) -> None: """Construct an expression type checker.""" self.chk = chk self.msg = msg - + def visit_name_expr(self, e: NameExpr) -> Type: """Type check a name expression. @@ -58,7 +58,7 @@ def visit_name_expr(self, e: NameExpr) -> Type: """ result = self.analyse_ref_expr(e) return self.chk.narrow_type_from_binder(e, result) - + def analyse_ref_expr(self, e: RefExpr) -> Type: result = Undefined(Type) node = e.node @@ -112,7 +112,7 @@ def visit_call_expr(self, e: CallExpr) -> Type: # way we get a more precise callee in dynamically typed functions. callee_type = self.chk.type_map[e.callee] return self.check_call_expr_with_callee_type(callee_type, e) - + def check_call_expr_with_callee_type(self, callee_type: Type, e: CallExpr) -> Type: """Type check call expression. @@ -122,7 +122,7 @@ def check_call_expr_with_callee_type(self, callee_type: Type, """ return self.check_call(callee_type, e.args, e.arg_kinds, e, e.arg_names, callable_node=e.callee)[0] - + def check_call(self, callee: Type, args: List[Node], arg_kinds: List[int], context: Context, arg_names: List[str] = None, @@ -154,24 +154,24 @@ def check_call(self, callee: Type, args: List[Node], self.msg.cannot_instantiate_abstract_class( callee.type_object().name(), type.abstract_attributes, context) - + formal_to_actual = map_actuals_to_formals( arg_kinds, arg_names, callee.arg_kinds, callee.arg_names, lambda i: self.accept(args[i])) - + if callee.is_generic(): callee = self.infer_function_type_arguments_using_context( callee, context) callee = self.infer_function_type_arguments( callee, args, arg_kinds, formal_to_actual, context) - + arg_types = self.infer_arg_types_in_context2( callee, args, arg_kinds, formal_to_actual) self.check_argument_count(callee, arg_types, arg_kinds, arg_names, formal_to_actual, context) - + self.check_argument_types(arg_types, arg_kinds, callee, formal_to_actual, context, messages=arg_messages) @@ -186,7 +186,7 @@ def check_call(self, callee: Type, args: List[Node], self.msg.disable_errors() arg_types = self.infer_arg_types_in_context(None, args) self.msg.enable_errors() - + target = self.overload_call_target(arg_types, is_var_arg, callee, context, messages=arg_messages) @@ -198,14 +198,14 @@ def check_call(self, callee: Type, args: List[Node], elif isinstance(callee, UnionType): self.msg.disable_type_names += 1 results = [self.check_call(subtype, args, arg_kinds, context, arg_names, - arg_messages=arg_messages) + arg_messages=arg_messages) for subtype in callee.items] self.msg.disable_type_names -= 1 return (UnionType.make_simplified_union([res[0] for res in results]), callee) else: return self.msg.not_callable(callee, context), AnyType() - + def infer_arg_types_in_context(self, callee: Callable, args: List[Node]) -> List[Type]: """Infer argument expression types using a callable type as context. @@ -214,14 +214,14 @@ def infer_arg_types_in_context(self, callee: Callable, argument expression with List[int] type context. """ # TODO Always called with callee as None, i.e. empty context. - res = [] # type: List[Type] - + res = [] # type: List[Type] + fixed = len(args) if callee: fixed = min(fixed, callee.max_fixed_args()) - arg_type = None # type: Type - ctx = None # type: Type + arg_type = None # type: Type + ctx = None # type: Type for i, arg in enumerate(args): if i < fixed: if callee and i < len(callee.arg_types): @@ -237,7 +237,7 @@ def infer_arg_types_in_context(self, callee: Callable, else: res.append(arg_type) return res - + def infer_arg_types_in_context2( self, callee: Callable, args: List[Node], arg_kinds: List[int], formal_to_actual: List[List[int]]) -> List[Type]: @@ -248,7 +248,7 @@ def infer_arg_types_in_context2( Returns the inferred types of *actual arguments*. """ - res = [None] * len(args) # type: List[Type] + res = [None] * len(args) # type: List[Type] for i, actuals in enumerate(formal_to_actual): for ai in actuals: @@ -260,7 +260,7 @@ def infer_arg_types_in_context2( if not t: res[i] = self.accept(args[i]) return res - + def infer_function_type_arguments_using_context( self, callable: Callable, error_context: Context) -> Callable: """Unify callable return type to type context to infer type vars. @@ -275,12 +275,20 @@ def infer_function_type_arguments_using_context( # The return type may have references to function type variables that # we are inferring right now. We must consider them as indeterminate # and they are not potential results; thus we replace them with the - # None type. On the other hand, class type variables are valid results. + # special ErasedType type. On the other hand, class type variables are + # valid results. erased_ctx = replace_func_type_vars(ctx, ErasedType()) - args = infer_type_arguments(callable.type_var_ids(), callable.ret_type, + ret_type = callable.ret_type + if isinstance(ret_type, TypeVar): + if ret_type.values: + # The return type is a type variable with values, but we can't easily restrict + # type inference to conform to the valid values. Give up and just use function + # arguments for type inference. + ret_type = NoneTyp() + args = infer_type_arguments(callable.type_var_ids(), ret_type, erased_ctx, self.chk.basic_types()) # Only substite non-None and non-erased types. - new_args = [] # type: List[Type] + new_args = [] # type: List[Type] for arg in args: if isinstance(arg, NoneTyp) or has_erased_component(arg): new_args.append(None) @@ -288,7 +296,7 @@ def infer_function_type_arguments_using_context( new_args.append(arg) return cast(Callable, self.apply_generic_arguments(callable, new_args, error_context)) - + def infer_function_type_arguments(self, callee_type: Callable, args: List[Node], arg_kinds: List[int], @@ -307,25 +315,25 @@ def infer_function_type_arguments(self, callee_type: Callable, # these errors can be safely ignored as the arguments will be # inferred again later. self.msg.disable_errors() - + arg_types = self.infer_arg_types_in_context2( callee_type, args, arg_kinds, formal_to_actual) - + self.msg.enable_errors() arg_pass_nums = self.get_arg_infer_passes( callee_type.arg_types, formal_to_actual, len(args)) - pass1_args = [] # type: List[Type] + pass1_args = [] # type: List[Type] for i, arg in enumerate(arg_types): if arg_pass_nums[i] > 1: pass1_args.append(None) else: pass1_args.append(arg) - + inferred_args = infer_function_type_arguments( callee_type, pass1_args, arg_kinds, formal_to_actual, - self.chk.basic_types()) # type: List[Type] + self.chk.basic_types()) # type: List[Type] if 2 in arg_pass_nums: # Second pass of type inference. @@ -341,12 +349,12 @@ def infer_function_type_arguments(self, callee_type: Callable, context) def infer_function_type_arguments_pass2( - self, callee_type: Callable, - args: List[Node], - arg_kinds: List[int], - formal_to_actual: List[List[int]], - inferred_args: List[Type], - context: Context) -> Tuple[Callable, List[Type]]: + self, callee_type: Callable, + args: List[Node], + arg_kinds: List[int], + formal_to_actual: List[List[int]], + inferred_args: List[Type], + context: Context) -> Tuple[Callable, List[Type]]: """Perform second pass of generic function type argument inference. The second pass is needed for arguments with types such as func, @@ -393,7 +401,7 @@ def get_arg_infer_passes(self, arg_types: List[Type], for j in formal_to_actual[i]: res[j] = 2 return res - + def apply_inferred_arguments(self, callee_type: Callable, inferred_args: List[Type], context: Context) -> Callable: @@ -414,10 +422,10 @@ def apply_inferred_arguments(self, callee_type: Callable, # return type must be Callable, since we give the right number of type # arguments. return cast(Callable, self.apply_generic_arguments(callee_type, - inferred_args, context)) + inferred_args, context)) def check_argument_count(self, callee: Callable, actual_types: List[Type], - actual_kinds: List[int], actual_names: List[str], + actual_kinds: List[int], actual_names: List[str], formal_to_actual: List[List[int]], context: Context) -> None: """Check that the number of arguments to a function are valid. @@ -427,11 +435,11 @@ def check_argument_count(self, callee: Callable, actual_types: List[Type], formal_kinds = callee.arg_kinds # Collect list of all actual arguments matched to formal arguments. - all_actuals = [] # type: List[int] + all_actuals = [] # type: List[int] for actuals in formal_to_actual: all_actuals.extend(actuals) - is_error = False # Keep track of errors to avoid duplicate errors. + is_error = False # Keep track of errors to avoid duplicate errors. for i, kind in enumerate(actual_kinds): if i not in all_actuals and ( kind != nodes.ARG_STAR or @@ -460,14 +468,13 @@ def check_argument_count(self, callee: Callable, actual_types: List[Type], self.msg.too_few_arguments(callee, context) elif kind in [nodes.ARG_POS, nodes.ARG_OPT, nodes.ARG_NAMED] and is_duplicate_mapping( - formal_to_actual[i], - actual_kinds): + formal_to_actual[i], actual_kinds): self.msg.duplicate_argument_value(callee, i, context) elif (kind == nodes.ARG_NAMED and formal_to_actual[i] and actual_kinds[formal_to_actual[i][0]] != nodes.ARG_NAMED): # Positional argument when expecting a keyword argument. self.msg.too_many_positional_arguments(callee, context) - + def check_argument_types(self, arg_types: List[Type], arg_kinds: List[int], callee: Callable, formal_to_actual: List[List[int]], @@ -497,7 +504,7 @@ def check_argument_types(self, arg_types: List[Type], arg_kinds: List[int], self.check_arg(actual_type, arg_type, callee.arg_types[i], actual + 1, callee, context, messages) - + # There may be some remaining tuple varargs items that haven't # been checked yet. Handle them. if (callee.arg_kinds[i] == nodes.ARG_STAR and @@ -511,7 +518,7 @@ def check_argument_types(self, arg_types: List[Type], arg_kinds: List[int], self.check_arg(actual_type, arg_type, callee.arg_types[i], actual + 1, callee, context, messages) - + def check_arg(self, caller_type: Type, original_caller_type: Type, callee_type: Type, n: int, callee: Callable, context: Context, messages: MessageBuilder) -> None: @@ -521,7 +528,7 @@ def check_arg(self, caller_type: Type, original_caller_type: Type, elif not is_subtype(caller_type, callee_type): messages.incompatible_argument(n, callee, original_caller_type, context) - + def overload_call_target(self, arg_types: List[Type], is_var_arg: bool, overload: Overloaded, context: Context, messages: MessageBuilder = None) -> Type: @@ -535,13 +542,13 @@ def overload_call_target(self, arg_types: List[Type], is_var_arg: bool, # TODO also consider argument names and kinds # TODO for overlapping signatures we should try to get a more precise # result than 'Any' - match = [] # type: List[Callable] + match = [] # type: List[Callable] for typ in overload.items(): if self.matches_signature_erased(arg_types, is_var_arg, typ): if (match and not is_same_type(match[-1].ret_type, typ.ret_type) and not mypy.checker.is_more_precise_signature( - match[-1], typ)): + match[-1], typ)): # Ambiguous return type. Either the function overload is # overlapping (which results in an error elsewhere) or the # caller has provided some Any argument types; in @@ -568,7 +575,7 @@ def overload_call_target(self, arg_types: List[Type], is_var_arg: bool, if self.match_signature_types(arg_types, is_var_arg, m): return m return match[0] - + def matches_signature_erased(self, arg_types: List[Type], is_var_arg: bool, callee: Callable) -> bool: """Determine whether arguments could match the signature at runtime. @@ -578,7 +585,7 @@ def matches_signature_erased(self, arg_types: List[Type], is_var_arg: bool, """ if not is_valid_argc(len(arg_types), False, callee): return False - + if is_var_arg: if not self.is_valid_var_arg(arg_types[-1]): return False @@ -599,7 +606,7 @@ def matches_signature_erased(self, arg_types: List[Type], is_var_arg: bool, self.erase(callee.arg_types[func_fixed])): return False return True - + def match_signature_types(self, arg_types: List[Type], is_var_arg: bool, callee: Callable) -> bool: """Determine whether arguments types match the signature. @@ -623,7 +630,7 @@ def match_signature_types(self, arg_types: List[Type], is_var_arg: bool, callee.arg_types[func_fixed]): return False return True - + def apply_generic_arguments(self, callable: Callable, types: List[Type], context: Context) -> Type: """Apply generic type arguments to a callable type. @@ -631,7 +638,7 @@ def apply_generic_arguments(self, callable: Callable, types: List[Type], For example, applying [int] to 'def [T] (T) -> T' results in 'def [-1:int] (int) -> int'. Here '[-1:int]' is an implicit bound type variable. - + Note that each type can be None; in this case, it will not be applied. """ tvars = callable.variables @@ -639,7 +646,7 @@ def apply_generic_arguments(self, callable: Callable, types: List[Type], self.msg.incompatible_type_application(len(tvars), len(types), context) return AnyType() - + # Check that inferred type variable values are compatible with allowed # values. Also, promote subtype values to allowed values. types = types[:] @@ -657,21 +664,21 @@ def apply_generic_arguments(self, callable: Callable, types: List[Type], callable, i + 1, type, context) # Create a map from type variable id to target type. - id_to_type = {} # type: Dict[int, Type] + id_to_type = {} # type: Dict[int, Type] for i, tv in enumerate(tvars): if types[i]: id_to_type[tv.id] = types[i] # Apply arguments to argument types. arg_types = [expand_type(at, id_to_type) for at in callable.arg_types] - + bound_vars = [(tv.id, id_to_type[tv.id]) for tv in tvars if tv.id in id_to_type] # The callable may retain some type vars if only some were applied. remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type] - + return Callable(arg_types, callable.arg_kinds, callable.arg_names, @@ -681,10 +688,10 @@ def apply_generic_arguments(self, callable: Callable, types: List[Type], remaining_tvars, callable.bound_vars + bound_vars, callable.line, callable.repr) - + def apply_generic_arguments2(self, overload: Overloaded, types: List[Type], - context: Context) -> Type: - items = [] # type: List[Callable] + context: Context) -> Type: + items = [] # type: List[Callable] for item in overload.items(): applied = self.apply_generic_arguments(item, types, context) if isinstance(applied, Callable): @@ -693,7 +700,7 @@ def apply_generic_arguments2(self, overload: Overloaded, types: List[Type], # There was an error. return AnyType() return Overloaded(items) - + def visit_member_expr(self, e: MemberExpr) -> Type: """Visit member expression (of form e.id).""" result = self.analyse_ordinary_member_access(e, False) @@ -710,7 +717,7 @@ def analyse_ordinary_member_access(self, e: MemberExpr, return analyse_member_access(e.name, self.accept(e.expr), e, is_lvalue, False, self.chk.basic_types(), self.msg) - + def analyse_external_member_access(self, member: str, base_type: Type, context: Context) -> Type: """Analyse member access that is external, i.e. it cannot @@ -719,27 +726,27 @@ def analyse_external_member_access(self, member: str, base_type: Type, # TODO remove; no private definitions in mypy return analyse_member_access(member, base_type, context, False, False, self.chk.basic_types(), self.msg) - + def visit_int_expr(self, e: IntExpr) -> Type: """Type check an integer literal (trivial).""" return self.named_type('builtins.int') - + def visit_str_expr(self, e: StrExpr) -> Type: """Type check a string literal (trivial).""" return self.named_type('builtins.str') - + def visit_bytes_expr(self, e: BytesExpr) -> Type: """Type check a bytes literal (trivial).""" return self.named_type('builtins.bytes') - + def visit_unicode_expr(self, e: UnicodeExpr) -> Type: """Type check a unicode literal (trivial).""" return self.named_type('builtins.unicode') - + def visit_float_expr(self, e: FloatExpr) -> Type: """Type check a float literal (trivial).""" return self.named_type('builtins.float') - + def visit_op_expr(self, e: OpExpr) -> Type: """Type check a binary operator expression.""" if e.op == 'and' or e.op == 'or': @@ -748,47 +755,81 @@ def visit_op_expr(self, e: OpExpr) -> Type: # Expressions of form [...] * e get special type inference. return self.check_list_multiply(e) left_type = self.accept(e.left) - right_type = self.accept(e.right) # TODO only evaluate if needed - if e.op == 'in' or e.op == 'not in': - local_errors = self.msg.copy() - result, method_type = self.check_op_local('__contains__', right_type, - e.left, e, local_errors) - if (local_errors.is_errors() and - # is_valid_var_arg is True for any Iterable - self.is_valid_var_arg(right_type)): - itertype = self.chk.analyse_iterable_item_type(e.right) - method_type = Callable([left_type], [nodes.ARG_POS], [None], - self.chk.bool_type(), False) - result = self.chk.bool_type() - if not is_subtype(left_type, itertype): - self.msg.unsupported_operand_types('in', left_type, right_type, e) - else: - self.msg.add_errors(local_errors) - e.method_type = method_type - if e.op == 'in': - return result - else: - return self.chk.bool_type() - elif e.op in nodes.op_methods: + + if e.op in nodes.op_methods: method = self.get_operator_method(e.op) result, method_type = self.check_op(method, left_type, e.right, e, allow_reverse=True) e.method_type = method_type return result - elif e.op == 'is' or e.op == 'is not': - return self.chk.bool_type() else: raise RuntimeError('Unknown operator {}'.format(e.op)) + def visit_comparison_expr(self, e: ComparisonExpr) -> Type: + """Type check a comparison expression. + + Comparison expressions are type checked consecutive-pair-wise + That is, 'a < b > c == d' is check as 'a < b and b > c and c == d' + """ + result = None # type: mypy.types.Type + + # Check each consecutive operand pair and their operator + for left, right, operator in zip(e.operands, e.operands[1:], e.operators): + left_type = self.accept(left) + + method_type = None # type: mypy.types.Type + + if operator == 'in' or operator == 'not in': + right_type = self.accept(right) # TODO only evaluate if needed + + local_errors = self.msg.copy() + sub_result, method_type = self.check_op_local('__contains__', right_type, + left, e, local_errors) + if (local_errors.is_errors() and + # is_valid_var_arg is True for any Iterable + self.is_valid_var_arg(right_type)): + itertype = self.chk.analyse_iterable_item_type(right) + method_type = Callable([left_type], [nodes.ARG_POS], [None], + self.chk.bool_type(), False) + sub_result = self.chk.bool_type() + if not is_subtype(left_type, itertype): + self.msg.unsupported_operand_types('in', left_type, right_type, e) + else: + self.msg.add_errors(local_errors) + if operator == 'not in': + sub_result = self.chk.bool_type() + elif operator in nodes.op_methods: + method = self.get_operator_method(operator) + sub_result, method_type = self.check_op(method, left_type, right, e, + allow_reverse=True) + + elif operator == 'is' or operator == 'is not': + sub_result = self.chk.bool_type() + method_type = None + else: + raise RuntimeError('Unknown comparison operator {}'.format(operator)) + + e.method_types.append(method_type) + + # Determine type of boolean-and of result and sub_result + if result == None: + result = sub_result + else: + # TODO: check on void needed? + self.check_not_void(sub_result, e) + result = join.join_types(result, sub_result, self.chk.basic_types()) + + return result + def get_operator_method(self, op: str) -> str: if op == '/' and self.chk.pyversion == 2: # TODO also check for "from __future__ import division" return '__div__' else: return nodes.op_methods[op] - + def check_op_local(self, method: str, base_type: Type, arg: Node, - context: Context, local_errors: MessageBuilder) -> Tuple[Type, Type]: + context: Context, local_errors: MessageBuilder) -> Tuple[Type, Type]: """Type check a binary operation which maps to a method call. Return tuple (result type, inferred operator method type). @@ -853,12 +894,12 @@ def get_reverse_op_method(self, method: str) -> str: return '__rdiv__' else: return nodes.reverse_op_methods[method] - + def check_boolean_op(self, e: OpExpr, context: Context) -> Type: """Type check a boolean operation ('and' or 'or').""" # A boolean operation can evaluate to either of the operands. - + # We use the current type context to guide the type inference of of # the left operand. We also use the left operand type to guide the type # inference of the right operand so that expressions such as @@ -866,7 +907,7 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: ctx = self.chk.type_context[-1] left_type = self.accept(e.left, ctx) right_type = self.accept(e.right, left_type) - + self.check_not_void(left_type, context) self.check_not_void(right_type, context) @@ -888,14 +929,14 @@ def check_list_multiply(self, e: OpExpr) -> Type: result, method_type = self.check_op('__mul__', left_type, e.right, e) e.method_type = method_type return result - + def visit_unary_expr(self, e: UnaryExpr) -> Type: """Type check an unary operation ('not', '-', '+' or '~').""" operand_type = self.accept(e.expr) op = e.op if op == 'not': self.check_not_void(operand_type, e) - result = self.chk.bool_type() # type: Type + result = self.chk.bool_type() # type: Type elif op == '-': method_type = self.analyse_external_member_access('__neg__', operand_type, e) @@ -963,13 +1004,13 @@ def visit_cast_expr(self, expr: CastExpr) -> Type: if not self.is_valid_cast(source_type, target_type): self.msg.invalid_cast(target_type, source_type, expr) return target_type - + def is_valid_cast(self, source_type: Type, target_type: Type) -> bool: """Is a cast from source_type to target_type meaningful?""" return (isinstance(target_type, AnyType) or (not isinstance(source_type, Void) and not isinstance(target_type, Void))) - + def visit_type_application(self, tapp: TypeApplication) -> Type: """Type check a type application (expr[type, ...]).""" expr_type = self.accept(tapp.expr) @@ -988,7 +1029,7 @@ def visit_type_application(self, tapp: TypeApplication) -> Type: new_type = AnyType() self.chk.type_map[tapp.expr] = new_type return new_type - + def visit_list_expr(self, e: ListExpr) -> Type: """Type check a list expression [...].""" return self.check_list_or_set_expr(e.items, 'builtins.list', '', @@ -1015,17 +1056,17 @@ def check_list_or_set_expr(self, items: List[Node], fullname: str, def visit_tuple_expr(self, e: TupleExpr) -> Type: """Type check a tuple expression.""" - ctx = None # type: TupleType + ctx = None # type: TupleType # Try to determine type context for type inference. if isinstance(self.chk.type_context[-1], TupleType): t = cast(TupleType, self.chk.type_context[-1]) if len(t.items) == len(e.items): ctx = t # Infer item types. - items = [] # type: List[Type] + items = [] # type: List[Type] for i in range(len(e.items)): item = e.items[i] - tt = Undefined # type: Type + tt = Undefined # type: Type if not ctx: tt = self.accept(item) else: @@ -1033,7 +1074,7 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: self.check_not_void(tt, e) items.append(tt) return TupleType(items) - + def visit_dict_expr(self, e: DictExpr) -> Type: # Translate into type checking a generic function call. tv1 = TypeVar('KT', -1, []) @@ -1058,7 +1099,7 @@ def visit_dict_expr(self, e: DictExpr) -> Type: return self.check_call(constructor, args, [nodes.ARG_POS] * len(args), e)[0] - + def visit_func_expr(self, e: FuncExpr) -> Type: """Type check lambda expression.""" inferred_type = self.infer_lambda_type_using_context(e) @@ -1088,28 +1129,28 @@ def infer_lambda_type_using_context(self, e: FuncExpr) -> Callable: ctx = self.chk.type_context[-1] if not ctx or not isinstance(ctx, Callable): return None - + # The context may have function type variables in it. We replace them # since these are the type variables we are ultimately trying to infer; # they must be considered as indeterminate. We use ErasedType since it # does not affect type inference results (it is for purposes like this # only). ctx = replace_func_type_vars(ctx, ErasedType()) - + callable_ctx = cast(Callable, ctx) - + if callable_ctx.arg_kinds != e.arg_kinds: # Incompatible context; cannot use it to infer types. self.chk.fail(messages.CANNOT_INFER_LAMBDA_TYPE, e) return None - + return callable_ctx - + def visit_super_expr(self, e: SuperExpr) -> Type: """Type check a super expression (non-lvalue).""" t = self.analyse_super(e, False) return t - + def analyse_super(self, e: SuperExpr, is_lvalue: bool) -> Type: """Type check a super expression.""" if e.info and e.info.bases: @@ -1121,11 +1162,11 @@ def analyse_super(self, e: SuperExpr, is_lvalue: bool) -> Type: else: # Invalid super. This has been reported by the semantic analyser. return AnyType() - + def visit_paren_expr(self, e: ParenExpr) -> Type: """Type check a parenthesised expression.""" return self.accept(e.expr, self.chk.type_context[-1]) - + def visit_slice_expr(self, e: SliceExpr) -> Type: for index in [e.begin_index, e.end_index, e.stride]: if index: @@ -1141,12 +1182,12 @@ def visit_list_comprehension(self, e: ListComprehension) -> Type: def visit_generator_expr(self, e: GeneratorExpr) -> Type: return self.check_generator_or_comprehension(e, 'typing.Iterator', '') - + def check_generator_or_comprehension(self, gen: GeneratorExpr, type_name: str, id_for_messages: str) -> Type: """Type check a generator expression or a list comprehension.""" - + self.chk.binder.push_frame() for index, sequence, conditions in zip(gen.indices, gen.sequences, gen.condlists): @@ -1178,41 +1219,41 @@ def visit_conditional_expr(self, e: ConditionalExpr) -> Type: if_type = self.accept(e.if_expr) else_type = self.accept(e.else_expr, context=if_type) return join.join_types(if_type, else_type, self.chk.basic_types()) - + # # Helpers # - + def accept(self, node: Node, context: Type = None) -> Type: """Type check a node. Alias for TypeChecker.accept.""" return self.chk.accept(node, context) - + def check_not_void(self, typ: Type, context: Context) -> None: """Generate an error if type is Void.""" self.chk.check_not_void(typ, context) - + def is_boolean(self, typ: Type) -> bool: """Is type compatible with bool?""" return is_subtype(typ, self.chk.bool_type()) - + def named_type(self, name: str) -> Instance: """Return an instance type with type given by the name and no type arguments. Alias for TypeChecker.named_type. """ return self.chk.named_type(name) - + def is_valid_var_arg(self, typ: Type) -> bool: """Is a type valid as a *args argument?""" return (isinstance(typ, TupleType) or is_subtype(typ, self.chk.named_generic_type('typing.Iterable', [AnyType()])) or isinstance(typ, AnyType)) - - def is_valid_keyword_var_arg(self, typ: Type) -> bool: + + def is_valid_keyword_var_arg(self, typ: Type) -> bool: """Is a type valid as a **kwargs argument?""" return is_subtype(typ, self.chk.named_generic_type( 'builtins.dict', [self.named_type('builtins.str'), AnyType()])) - + def has_non_method(self, typ: Type, member: str) -> bool: """Does type have a member variable / property with the given name?""" if isinstance(typ, Instance): @@ -1220,7 +1261,7 @@ def has_non_method(self, typ: Type, member: str) -> bool: typ.type.has_readable_member(member)) else: return False - + def has_member(self, typ: Type, member: str) -> bool: """Does type have member with the given name?""" # TODO TupleType => also consider tuple attributes @@ -1233,14 +1274,14 @@ def has_member(self, typ: Type, member: str) -> bool: return result else: return False - + def unwrap(self, e: Node) -> Node: """Unwrap parentheses from an expression node.""" if isinstance(e, ParenExpr): return self.unwrap(e.expr) else: return e - + def unwrap_list(self, a: List[Node]) -> List[Node]: """Unwrap parentheses from a list of expression nodes.""" r = List[Node]() @@ -1285,7 +1326,7 @@ def map_actuals_to_formals(caller_kinds: List[int], argument type with the given index. """ ncallee = len(callee_kinds) - map = [None] * ncallee # type: List[List[int]] + map = [None] * ncallee # type: List[List[int]] for i in range(ncallee): map[i] = [] j = 0 @@ -1371,7 +1412,7 @@ class ArgInferSecondPassQuery(types.TypeQuery): The result is True if the type has a type variable in a callable return type anywhere. For example, the result for Function[[], T] is True if t is a type variable. - """ + """ def __init__(self) -> None: super().__init__(False, types.ANY_TYPE_STRATEGY) diff --git a/mypy/messages.py b/mypy/messages.py index 75c5fba89924..7e86461e177a 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -40,7 +40,7 @@ INCOMPATIBLE_TYPES_IN_YIELD_FROM = 'Incompatible types in "yield from"' INIT_MUST_NOT_HAVE_RETURN_TYPE = 'Cannot define return type for "__init__"' GETTER_TYPE_INCOMPATIBLE_WITH_SETTER = \ - 'Type of getter incompatible with setter' + 'Type of getter incompatible with setter' TUPLE_INDEX_MUST_BE_AN_INT_LITERAL = 'Tuple index must an integer literal' TUPLE_INDEX_OUT_OF_RANGE = 'Tuple index out of range' TYPE_CONSTANT_EXPECTED = 'Type "Constant" or initializer expected' @@ -60,31 +60,31 @@ CANNOT_ASSIGN_TO_METHOD = 'Cannot assign to a method' CANNOT_ASSIGN_TO_TYPE = 'Cannot assign to a type' INCONSISTENT_ABSTRACT_OVERLOAD = \ - 'Overloaded method has both abstract and non-abstract variants' + 'Overloaded method has both abstract and non-abstract variants' INSTANCE_LAYOUT_CONFLICT = 'Instance layout conflict in multiple inheritance' class MessageBuilder: """Helper class for reporting type checker error messages with parameters. - + The methods of this class need to be provided with the context within a file; the errors member manages the wider context. - + IDEA: Support a 'verbose mode' that includes full information about types in error messages and that may otherwise produce more detailed error messages. """ - + # Report errors using this instance. It knows about the current file and # import context. errors = Undefined(Errors) - + # Number of times errors have been disabled. disable_count = 0 # Hack to deduplicate error messages from union types disable_type_names = 0 - + def __init__(self, errors: Errors) -> None: self.errors = errors self.disable_count = 0 @@ -95,7 +95,9 @@ def __init__(self, errors: Errors) -> None: # def copy(self) -> 'MessageBuilder': - return MessageBuilder(self.errors.copy()) + new = MessageBuilder(self.errors.copy()) + new.disable_count = self.disable_count + return new def add_errors(self, messages: 'MessageBuilder') -> None: """Add errors in messages to this builder.""" @@ -109,12 +111,12 @@ def enable_errors(self) -> None: def is_errors(self) -> bool: return self.errors.is_errors() - + def fail(self, msg: str, context: Context) -> None: """Report an error message (unless disabled).""" if self.disable_count <= 0: self.errors.report(context.get_line(), msg.strip()) - + def format(self, typ: Type) -> str: """Convert a type to a relatively short string that is suitable for error messages. Mostly behave like format_simple @@ -134,8 +136,7 @@ def format(self, typ: Type) -> str: elif isinstance(func, Callable): arg_types = map(self.format, func.arg_types) return_type = self.format(func.ret_type) - return 'Function[[{}] -> {}]'.format(", ".join(arg_types), - return_type) + return 'Function[[{}] -> {}]'.format(", ".join(arg_types), return_type) else: # Use a simple representation for function types; proper # function types may result in long and difficult-to-read @@ -144,13 +145,13 @@ def format(self, typ: Type) -> str: else: # Default case; we simply have to return something meaningful here. return 'object' - + def format_simple(self, typ: Type) -> str: """Convert simple types to string that is suitable for error messages. - + Return "" for complex types. Try to keep the length of the result relatively short to avoid overly long error messages. - + Examples: builtins.int -> 'int' Any type -> 'Any' @@ -176,7 +177,7 @@ def format_simple(self, typ: Type) -> str: # (using format() instead of format_simple() to avoid empty # strings). If the result is too long, replace arguments # with [...]. - a = [] # type: List[str] + a = [] # type: List[str] for arg in itype.args: a.append(strip_quotes(self.format(arg))) s = ', '.join(a) @@ -222,11 +223,11 @@ def format_simple(self, typ: Type) -> str: # # Specific operations # - + # The following operations are for genering specific error messages. They # get some information as arguments, and they build an error message based # on them. - + def has_no_attr(self, typ: Type, member: str, context: Context) -> Type: """Report a missing or non-accessible member. @@ -272,38 +273,38 @@ def has_no_attr(self, typ: Type, member: str, context: Context) -> Type: member), context) else: self.fail('Some element of union has no attribute "{}"'.format( - member), context) + member), context) return AnyType() - + def unsupported_operand_types(self, op: str, left_type: Any, right_type: Any, context: Context) -> None: """Report unsupported operand types for a binary operation. - + Types can be Type objects or strings. """ if isinstance(left_type, Void) or isinstance(right_type, Void): self.check_void(left_type, context) self.check_void(right_type, context) - return + return left_str = '' if isinstance(left_type, str): left_str = left_type else: left_str = self.format(left_type) - + right_str = '' if isinstance(right_type, str): right_str = right_type else: right_str = self.format(right_type) - + if self.disable_type_names: msg = 'Unsupported operand types for {} (likely involving Union)'.format(op) else: msg = 'Unsupported operand types for {} ({} and {})'.format( op, left_str, right_str) self.fail(msg, context) - + def unsupported_left_operand(self, op: str, typ: Type, context: Context) -> None: if not self.check_void(typ, context): @@ -313,14 +314,14 @@ def unsupported_left_operand(self, op: str, typ: Type, msg = 'Unsupported left operand type for {} ({})'.format( op, self.format(typ)) self.fail(msg, context) - + def type_expected_as_right_operand_of_is(self, context: Context) -> None: self.fail('Type expected as right operand of "is"', context) - + def not_callable(self, typ: Type, context: Context) -> Type: self.fail('{} not callable'.format(self.format(typ)), context) return AnyType() - + def incompatible_argument(self, n: int, callee: Callable, arg_type: Type, context: Context) -> None: """Report an error about an incompatible argument type. @@ -334,7 +335,7 @@ def incompatible_argument(self, n: int, callee: Callable, arg_type: Type, if callee.name: name = callee.name base = extract_type(name) - + for op, method in op_methods.items(): for variant in method, '__r' + method[2:]: if name.startswith('"{}" of'.format(variant)): @@ -345,21 +346,21 @@ def incompatible_argument(self, n: int, callee: Callable, arg_type: Type, else: self.unsupported_operand_types(op, base, arg_type, context) - return - + return + if name.startswith('"__getitem__" of'): self.invalid_index_type(arg_type, base, context) - return - + return + if name.startswith('"__setitem__" of'): if n == 1: self.invalid_index_type(arg_type, base, context) else: self.fail(INCOMPATIBLE_TYPES_IN_ASSIGNMENT, context) - return - + return + target = 'to {} '.format(name) - + msg = '' if callee.name == '': name = callee.name[1:-1] @@ -367,43 +368,43 @@ def incompatible_argument(self, n: int, callee: Callable, arg_type: Type, name[0].upper() + name[1:], n, self.format_simple(arg_type)) elif callee.name == '': msg = 'List comprehension has incompatible type List[{}]'.format( - strip_quotes(self.format(arg_type))) + strip_quotes(self.format(arg_type))) elif callee.name == '': msg = 'Generator has incompatible item type {}'.format( - self.format_simple(arg_type)) + self.format_simple(arg_type)) else: try: - expected_type = callee.arg_types[n-1] + expected_type = callee.arg_types[n - 1] except IndexError: # Varargs callees expected_type = callee.arg_types[-1] msg = 'Argument {} {}has incompatible type {}; expected {}'.format( n, target, self.format(arg_type), self.format(expected_type)) self.fail(msg, context) - + def invalid_index_type(self, index_type: Type, base_str: str, context: Context) -> None: self.fail('Invalid index type {} for {}'.format( self.format(index_type), base_str), context) - + def invalid_argument_count(self, callee: Callable, num_args: int, context: Context) -> None: if num_args < len(callee.arg_types): self.too_few_arguments(callee, context) else: self.too_many_arguments(callee, context) - + def too_few_arguments(self, callee: Callable, context: Context) -> None: msg = 'Too few arguments' if callee.name: msg += ' for {}'.format(callee.name) self.fail(msg, context) - + def too_many_arguments(self, callee: Callable, context: Context) -> None: msg = 'Too many arguments' if callee.name: msg += ' for {}'.format(callee.name) self.fail(msg, context) - + def too_many_positional_arguments(self, callee: Callable, context: Context) -> None: msg = 'Too many positional arguments' @@ -416,14 +417,14 @@ def unexpected_keyword_argument(self, callee: Callable, name: str, msg = 'Unexpected keyword argument "{}"'.format(name) if callee.name: msg += ' for {}'.format(callee.name) - self.fail(msg, context) + self.fail(msg, context) def duplicate_argument_value(self, callee: Callable, index: int, context: Context) -> None: self.fail('{} gets multiple values for keyword argument "{}"'. format(capitalize(callable_name(callee)), callee.arg_names[index]), context) - + def does_not_return_value(self, void_type: Type, context: Context) -> None: """Report an error about a void type in a non-void context. @@ -436,31 +437,31 @@ def does_not_return_value(self, void_type: Type, context: Context) -> None: else: self.fail('{} does not return a value'.format( capitalize((cast(Void, void_type)).source)), context) - + def no_variant_matches_arguments(self, overload: Overloaded, - context: Context) -> None: + context: Context) -> None: if overload.name(): self.fail('No overload variant of {} matches argument types' .format(overload.name()), context) else: self.fail('No overload variant matches argument types', context) - + def function_variants_overlap(self, n1: int, n2: int, context: Context) -> None: self.fail('Function signature variants {} and {} overlap'.format( n1 + 1, n2 + 1), context) - + def invalid_cast(self, target_type: Type, source_type: Type, - context: Context) -> None: + context: Context) -> None: if not self.check_void(source_type, context): self.fail('Cannot cast from {} to {}'.format( self.format(source_type), self.format(target_type)), context) - + def incompatible_operator_assignment(self, op: str, context: Context) -> None: self.fail('Result type of {} incompatible in assignment'.format(op), context) - + def incompatible_value_count_in_assignment(self, lvalue_count: int, rvalue_count: int, context: Context) -> None: @@ -468,26 +469,26 @@ def incompatible_value_count_in_assignment(self, lvalue_count: int, self.fail('Need {} values to assign'.format(lvalue_count), context) elif rvalue_count > lvalue_count: self.fail('Too many values to assign', context) - + def type_incompatible_with_supertype(self, name: str, supertype: TypeInfo, - context: Context) -> None: + context: Context) -> None: self.fail('Type of "{}" incompatible with supertype "{}"'.format( name, supertype.name), context) - + def signature_incompatible_with_supertype( self, name: str, name_in_super: str, supertype: str, context: Context) -> None: target = self.override_target(name, name_in_super, supertype) self.fail('Signature of "{}" incompatible with {}'.format( name, target), context) - + def argument_incompatible_with_supertype( self, arg_num: int, name: str, name_in_supertype: str, supertype: str, context: Context) -> None: target = self.override_target(name, name_in_supertype, supertype) self.fail('Argument {} of "{}" incompatible with {}' .format(arg_num, name, target), context) - + def return_type_incompatible_with_supertype( self, name: str, name_in_supertype: str, supertype: str, context: Context) -> None: @@ -500,13 +501,13 @@ def override_target(self, name: str, name_in_super: str, target = 'supertype "{}"'.format(supertype) if name_in_super != name: target = '"{}" of {}'.format(name_in_super, target) - return target - + return target + def boolean_return_value_expected(self, method: str, context: Context) -> None: self.fail('Boolean return value expected for method "{}"'.format( method), context) - + def incompatible_type_application(self, expected_arg_count: int, actual_arg_count: int, context: Context) -> None: @@ -519,12 +520,12 @@ def incompatible_type_application(self, expected_arg_count: int, else: self.fail('Type application has too few types ({} expected)' .format(expected_arg_count), context) - + def incompatible_array_item_type(self, typ: Type, index: int, context: Context) -> None: self.fail('Array item {} has incompatible type {}'.format( index, self.format(typ)), context) - + def could_not_infer_type_arguments(self, callee_type: Callable, n: int, context: Context) -> None: if callee_type.name and n > 0: @@ -532,10 +533,10 @@ def could_not_infer_type_arguments(self, callee_type: Callable, n: int, n, callee_type.name), context) else: self.fail('Cannot infer function type argument', context) - + def invalid_var_arg(self, typ: Type, context: Context) -> None: self.fail('List or tuple expected as variable arguments', context) - + def invalid_keyword_var_arg(self, typ: Type, context: Context) -> None: if isinstance(typ, Instance) and ( (cast(Instance, typ)).type.fullname() == 'builtins.dict'): @@ -543,18 +544,18 @@ def invalid_keyword_var_arg(self, typ: Type, context: Context) -> None: else: self.fail('Argument after ** must be a dictionary', context) - + def incomplete_type_var_match(self, member: str, context: Context) -> None: self.fail('"{}" has incomplete match to supertype type variable' .format(member), context) - + def not_implemented(self, msg: str, context: Context) -> Type: self.fail('Feature not implemented yet ({})'.format(msg), context) return AnyType() - + def undefined_in_superclass(self, member: str, context: Context) -> None: self.fail('"{}" undefined in superclass'.format(member), context) - + def check_void(self, typ: Type, context: Context) -> bool: """If type is void, report an error such as '.. does not return a value' and return True. Otherwise, return False. diff --git a/mypy/nodes.py b/mypy/nodes.py index e121c7163b90..bce4a7674de8 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -28,16 +28,16 @@ def get_line(self) -> int: pass # Variable kind constants # TODO rename to use more descriptive names -LDEF = 0 # type: int -GDEF = 1 # type: int -MDEF = 2 # type: int -MODULE_REF = 3 # type: int +LDEF = 0 # type: int +GDEF = 1 # type: int +MDEF = 2 # type: int +MODULE_REF = 3 # type: int # Type variable declared using typevar(...) has kind UNBOUND_TVAR. It's not -# valid as a type. A type variable is valid as a type (kind TVAR) within +# valid as a type. A type variable is valid as a type (kind TVAR) within # (1) a generic class that uses the type variable as a type argument or # (2) a generic function that refers to the type variable in its signature. -UNBOUND_TVAR = 4 # type: 'int' -TVAR = 5 # type: int +UNBOUND_TVAR = 4 # type: 'int' +TVAR = 5 # type: int LITERAL_YES = 2 @@ -69,13 +69,13 @@ def get_line(self) -> int: pass class Node(Context): """Common base class for all non-type parse tree nodes.""" - + line = -1 # Textual representation - repr = None # type: Any + repr = None # type: Any literal = LITERAL_NO - literal_hash = None # type: Any + literal_hash = None # type: Any def __str__(self) -> str: ans = self.accept(mypy.strconv.StrConv()) @@ -87,7 +87,7 @@ def __str__(self) -> str: def set_line(self, tok: Token) -> 'Node': self.line = tok.line return self - + @overload def set_line(self, line: int) -> 'Node': self.line = line @@ -96,16 +96,16 @@ def set_line(self, line: int) -> 'Node': def get_line(self) -> int: # TODO this should be just 'line' return self.line - + def accept(self, visitor: NodeVisitor[T]) -> T: raise RuntimeError('Not implemented') class SymbolNode(Node): # Nodes that can be stored in a symbol table. - + # TODO do not use methods for these - + @abstractmethod def name(self) -> str: pass From 5a86acd19c5833867792b95b89775be8d8b56fca Mon Sep 17 00:00:00 2001 From: Rock Neurotiko Date: Mon, 15 Sep 2014 12:41:13 +0200 Subject: [PATCH 09/12] merge --- mypy/checker.py | 2 +- mypy/nodes.py | 571 ++++++++++++++++++----------------- mypy/output.py | 185 ++++++------ mypy/parse.py | 434 +++++++++++++------------- mypy/pprinter.py | 72 +++-- mypy/semanal.py | 228 +++++++------- mypy/stats.py | 24 +- mypy/strconv.py | 139 ++++----- mypy/transform.py | 131 ++++---- mypy/traverser.py | 72 ++--- mypy/treetransform.py | 141 ++++----- mypy/visitor.py | 85 +++++- stubs/3.4/asyncio/futures.py | 16 +- stubs/3.4/asyncio/tasks.py | 18 +- 14 files changed, 1130 insertions(+), 988 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index e7e57d442cfa..5eb9d4b08581 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1,7 +1,7 @@ """Mypy type checker.""" import itertools - + from typing import Undefined, Any, Dict, Set, List, cast, overload, Tuple, Function, typevar from mypy.errors import Errors diff --git a/mypy/nodes.py b/mypy/nodes.py index bce4a7674de8..e6e9bb25b453 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -115,15 +115,15 @@ def fullname(self) -> str: pass class MypyFile(SymbolNode): """The abstract syntax tree of a single source file.""" - - _name = None # type: str # Module name ('__main__' for initial file) - _fullname = None # type: str # Qualified module name - path = '' # Path to the file (None if not known) - defs = Undefined # type: List[Node] # Global definitions and statements - is_bom = False # Is there a UTF-8 BOM at the start? + + _name = None # type: str # Module name ('__main__' for initial file) + _fullname = None # type: str # Qualified module name + path = '' # Path to the file (None if not known) + defs = Undefined # type: List[Node] # Global definitions and statements + is_bom = False # Is there a UTF-8 BOM at the start? names = Undefined('SymbolTable') imports = Undefined(List['ImportBase']) # All import nodes within the file - + def __init__(self, defs: List[Node], imports: List['ImportBase'], is_bom: bool = False) -> None: self.defs = defs @@ -136,7 +136,7 @@ def name(self) -> str: def fullname(self) -> str: return self._fullname - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_mypy_file(self) @@ -144,57 +144,57 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class ImportBase(Node): """Base class for all import statements.""" is_unreachable = False - + class Import(ImportBase): """import m [as n]""" - + ids = Undefined(List[Tuple[str, str]]) # (module id, as id) - + def __init__(self, ids: List[Tuple[str, str]]) -> None: self.ids = ids - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_import(self) class ImportFrom(ImportBase): """from m import x, ...""" - - names = Undefined(List[Tuple[str, str]]) # Tuples (name, as name) - + + names = Undefined(List[Tuple[str, str]]) # Tuples (name, as name) + def __init__(self, id: str, names: List[Tuple[str, str]]) -> None: self.id = id self.names = names - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_import_from(self) class ImportAll(ImportBase): """from m import *""" - + def __init__(self, id: str) -> None: self.id = id - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_import_all(self) class FuncBase(SymbolNode): """Abstract base class for function-like nodes""" - + # Type signature (Callable or Overloaded) - type = None # type: mypy.types.Type + type = None # type: mypy.types.Type # If method, reference to TypeInfo - info = None # type: TypeInfo + info = None # type: TypeInfo @abstractmethod def name(self) -> str: pass - + def fullname(self) -> str: return self.name() - + def is_method(self) -> bool: return bool(self.info) @@ -205,28 +205,28 @@ class OverloadedFuncDef(FuncBase): This node has no explicit representation in the source program. Overloaded variants must be consecutive in the source file. """ - + items = Undefined(List['Decorator']) - _fullname = None # type: str - + _fullname = None # type: str + def __init__(self, items: List['Decorator']) -> None: self.items = items self.set_line(items[0].line) - + def name(self) -> str: return self.items[1].func.name() def fullname(self) -> str: return self._fullname - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_overloaded_func_def(self) class FuncItem(FuncBase): - args = Undefined(List['Var']) # Argument names - arg_kinds = Undefined(List[int]) # Kinds of arguments (ARG_*) - + args = Undefined(List['Var']) # Argument names + arg_kinds = Undefined(List[int]) # Kinds of arguments (ARG_*) + # Initialization expessions for fixed args; None if no initialiser init = Undefined(List['AssignmentStmt']) min_args = 0 # Minimum number of arguments @@ -242,7 +242,7 @@ class FuncItem(FuncBase): is_class = False # Uses @classmethod? expanded = Undefined(List['FuncItem']) # Variants of function with type # variables with values expanded - + def __init__(self, args: List['Var'], arg_kinds: List[int], init: List[Node], body: 'Block', typ: 'mypy.types.Type' = None) -> None: @@ -252,7 +252,7 @@ def __init__(self, args: List['Var'], arg_kinds: List[int], self.body = body self.type = typ self.expanded = [] - + i2 = List[AssignmentStmt]() self.min_args = 0 for i in range(len(init)): @@ -267,24 +267,24 @@ def __init__(self, args: List['Var'], arg_kinds: List[int], if i < self.max_fixed_argc(): self.min_args = i + 1 self.init = i2 - + def max_fixed_argc(self) -> int: return self.max_pos - + @overload def set_line(self, tok: Token) -> Node: super().set_line(tok) for n in self.args: n.line = self.line return self - + @overload def set_line(self, tok: int) -> Node: super().set_line(tok) for n in self.args: n.line = self.line return self - + def init_expressions(self) -> List[Node]: res = List[Node]() for i in self.init: @@ -298,16 +298,16 @@ def init_expressions(self) -> List[Node]: class FuncDef(FuncItem): """Function definition. - This is a non-lambda function defined using 'def'. + This is a non-lambda function defined using 'def'. """ - - _fullname = None # type: str # Name with module prefix + + _fullname = None # type: str # Name with module prefix is_decorated = False is_conditional = False # Defined conditionally (within block)? is_abstract = False is_property = False - original_def = None # type: FuncDef # Original conditional definition - + original_def = None # type: FuncDef # Original conditional definition + def __init__(self, name: str, # Function name args: List['Var'], # Argument names @@ -320,13 +320,13 @@ def __init__(self, def name(self) -> str: return self._name - + def fullname(self) -> str: return self._fullname def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_func_def(self) - + def is_constructor(self) -> bool: return self.info is not None and self._name == '__init__' @@ -340,12 +340,12 @@ class Decorator(SymbolNode): A single Decorator object can include any number of function decorators. """ - - func = Undefined(FuncDef) # Decorated function - decorators = Undefined(List[Node]) # Decorators, at least one - var = Undefined('Var') # Represents the decorated function obj + + func = Undefined(FuncDef) # Decorated function + decorators = Undefined(List[Node]) # Decorators, at least one + var = Undefined('Var') # Represents the decorated function obj is_overload = False - + def __init__(self, func: FuncDef, decorators: List[Node], var: 'Var') -> None: self.func = func @@ -368,20 +368,20 @@ class Var(SymbolNode): It can refer to global/local variable or a data attribute. """ - - _name = None # type: str # Name without module prefix - _fullname = None # type: str # Name with module prefix + + _name = None # type: str # Name without module prefix + _fullname = None # type: str # Name with module prefix info = Undefined('TypeInfo') # Defining class (for member variables) - type = None # type: mypy.types.Type # Declared or inferred type, or None - is_self = False # Is this the first argument to an ordinary method - # (usually "self")? - is_ready = False # If inferred, is the inferred type available? + type = None # type: mypy.types.Type # Declared or inferred type, or None + is_self = False # Is this the first argument to an ordinary method + # (usually "self")? + is_ready = False # If inferred, is the inferred type available? # Is this initialized explicitly to a non-None value in class body? is_initialized_in_class = False is_staticmethod = False is_classmethod = False is_property = False - + def __init__(self, name: str, type: 'mypy.types.Type' = None) -> None: self._name = name self.type = type @@ -394,26 +394,26 @@ def name(self) -> str: def fullname(self) -> str: return self._fullname - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_var(self) class ClassDef(Node): """Class definition""" - - name = Undefined(str) # Name of the class without module prefix - fullname = None # type: str # Fully qualified name of the class + + name = Undefined(str) # Name of the class without module prefix + fullname = None # type: str # Fully qualified name of the class defs = Undefined('Block') type_vars = Undefined(List['mypy.types.TypeVarDef']) # Base classes (Instance or UnboundType). base_types = Undefined(List['mypy.types.Type']) - info = None # type: TypeInfo # Related TypeInfo + info = None # type: TypeInfo # Related TypeInfo metaclass = '' decorators = Undefined(List[Node]) # Built-in/extension class? (single implementation inheritance only) is_builtinclass = False - + def __init__(self, name: str, defs: 'Block', type_vars: List['mypy.types.TypeVarDef'] = None, base_types: List['mypy.types.Type'] = None, @@ -426,58 +426,58 @@ def __init__(self, name: str, defs: 'Block', self.base_types = base_types self.metaclass = metaclass self.decorators = [] - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_class_def(self) - + def is_generic(self) -> bool: return self.info.is_generic() class VarDef(Node): """Variable definition with explicit types""" - + items = Undefined(List[Var]) - kind = None # type: int # LDEF/GDEF/MDEF/... - init = Undefined(Node) # Expression or None - is_top_level = False # Is the definition at the top level (not within - # a function or a type)? - + kind = None # type: int # LDEF/GDEF/MDEF/... + init = Undefined(Node) # Expression or None + is_top_level = False # Is the definition at the top level (not within + # a function or a type)? + def __init__(self, items: List[Var], is_top_level: bool, init: Node = None) -> None: self.items = items self.is_top_level = is_top_level self.init = init - + def info(self) -> 'TypeInfo': return self.items[0].info - + @overload def set_line(self, tok: Token) -> Node: super().set_line(tok) for n in self.items: n.line = self.line return self - + @overload def set_line(self, tok: int) -> Node: super().set_line(tok) for n in self.items: n.line = self.line return self - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_var_def(self) class GlobalDecl(Node): """Declaration global x, y, ...""" - + names = Undefined(List[str]) - + def __init__(self, names: List[str]) -> None: self.names = names - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_global_decl(self) @@ -488,10 +488,10 @@ class Block(Node): # this applies to blocks that are protected by something like "if PY3:" # when using Python 2. is_unreachable = False - + def __init__(self, body: List[Node]) -> None: self.body = body - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_block(self) @@ -502,10 +502,10 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class ExpressionStmt(Node): """An expression as a statament, such as print(s).""" expr = Undefined(Node) - + def __init__(self, expr: Node) -> None: self.expr = expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_expression_stmt(self) @@ -520,34 +520,34 @@ class AssignmentStmt(Node): An lvalue can be NameExpr, TupleExpr, ListExpr, MemberExpr, IndexExpr or ParenExpr. """ - + lvalues = Undefined(List[Node]) rvalue = Undefined(Node) - type = None # type: mypy.types.Type # Declared type in a comment, + type = None # type: mypy.types.Type # Declared type in a comment, # may be None. - + def __init__(self, lvalues: List[Node], rvalue: Node, type: 'mypy.types.Type' = None) -> None: self.lvalues = lvalues self.rvalue = rvalue self.type = type - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_assignment_stmt(self) class OperatorAssignmentStmt(Node): """Operator assignment statement such as x += 1""" - + op = '' lvalue = Undefined(Node) rvalue = Undefined(Node) - + def __init__(self, op: str, lvalue: Node, rvalue: Node) -> None: self.op = op self.lvalue = lvalue self.rvalue = rvalue - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_operator_assignment_stmt(self) @@ -556,12 +556,12 @@ class WhileStmt(Node): expr = Undefined(Node) body = Undefined(Block) else_body = Undefined(Block) - + def __init__(self, expr: Node, body: Block, else_body: Block) -> None: self.expr = expr self.body = body self.else_body = else_body - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_while_stmt(self) @@ -575,7 +575,7 @@ class ForStmt(Node): expr = Undefined(Node) body = Undefined(Block) else_body = Undefined(Block) - + def __init__(self, index: List['NameExpr'], expr: Node, body: Block, else_body: Block, types: List['mypy.types.Type'] = None) -> None: @@ -584,10 +584,10 @@ def __init__(self, index: List['NameExpr'], expr: Node, body: Block, self.body = body self.else_body = else_body self.types = types - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_for_stmt(self) - + def is_annotated(self) -> bool: ann = False for t in self.types: @@ -598,30 +598,30 @@ def is_annotated(self) -> bool: class ReturnStmt(Node): expr = Undefined(Node) # Expression or None - + def __init__(self, expr: Node) -> None: self.expr = expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_return_stmt(self) class AssertStmt(Node): expr = Undefined(Node) - + def __init__(self, expr: Node) -> None: self.expr = expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_assert_stmt(self) class YieldStmt(Node): expr = Undefined(Node) - + def __init__(self, expr: Node) -> None: self.expr = expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_yield_stmt(self) @@ -665,13 +665,13 @@ class IfStmt(Node): expr = Undefined(List[Node]) body = Undefined(List[Block]) else_body = Undefined(Block) - + def __init__(self, expr: List[Node], body: List[Block], else_body: Block) -> None: self.expr = expr self.body = body self.else_body = else_body - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_if_stmt(self) @@ -679,11 +679,11 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class RaiseStmt(Node): expr = Undefined(Node) from_expr = Undefined(Node) - + def __init__(self, expr: Node, from_expr: Node = None) -> None: self.expr = expr self.from_expr = from_expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_raise_stmt(self) @@ -695,7 +695,7 @@ class TryStmt(Node): handlers = Undefined(List[Block]) # Except bodies else_body = Undefined(Block) finally_body = Undefined(Block) - + def __init__(self, body: Block, vars: List['NameExpr'], types: List[Node], handlers: List[Block], else_body: Block, finally_body: Block) -> None: @@ -705,7 +705,7 @@ def __init__(self, body: Block, vars: List['NameExpr'], types: List[Node], self.handlers = handlers self.else_body = else_body self.finally_body = finally_body - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_try_stmt(self) @@ -714,27 +714,27 @@ class WithStmt(Node): expr = Undefined(List[Node]) name = Undefined(List['NameExpr']) body = Undefined(Block) - + def __init__(self, expr: List[Node], name: List['NameExpr'], body: Block) -> None: self.expr = expr self.name = name self.body = body - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_with_stmt(self) class PrintStmt(Node): """Python 2 print statement""" - + args = Undefined(List[Node]) newline = False def __init__(self, args: List[Node], newline: bool) -> None: self.args = args self.newline = newline - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_print_stmt(self) @@ -744,95 +744,95 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class IntExpr(Node): """Integer literal""" - + value = 0 literal = LITERAL_YES - + def __init__(self, value: int) -> None: self.value = value self.literal_hash = value - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_int_expr(self) class StrExpr(Node): """String literal""" - + value = '' literal = LITERAL_YES def __init__(self, value: str) -> None: self.value = value self.literal_hash = value - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_str_expr(self) class BytesExpr(Node): """Bytes literal""" - - value = '' # TODO use bytes + + value = '' # TODO use bytes literal = LITERAL_YES def __init__(self, value: str) -> None: self.value = value self.literal_hash = value - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_bytes_expr(self) class UnicodeExpr(Node): """Unicode literal (Python 2.x)""" - - value = '' # TODO use bytes + + value = '' # TODO use bytes literal = LITERAL_YES def __init__(self, value: str) -> None: self.value = value self.literal_hash = value - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_unicode_expr(self) class FloatExpr(Node): """Float literal""" - + value = 0.0 literal = LITERAL_YES - + def __init__(self, value: float) -> None: self.value = value self.literal_hash = value - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_float_expr(self) class ParenExpr(Node): """Parenthesised expression""" - + expr = Undefined(Node) - + def __init__(self, expr: Node) -> None: self.expr = expr self.literal = self.expr.literal self.literal_hash = ('Paren', expr.literal_hash,) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_paren_expr(self) class RefExpr(Node): """Abstract base class for name-like constructs""" - - kind = None # type: int # LDEF/GDEF/MDEF/... (None if not available) - node = Undefined(Node) # Var, FuncDef or TypeInfo that describes this - fullname = None # type: str # Fully qualified name (or name if not global) - + + kind = None # type: int # LDEF/GDEF/MDEF/... (None if not available) + node = Undefined(Node) # Var, FuncDef or TypeInfo that describes this + fullname = None # type: str # Fully qualified name (or name if not global) + # Does this define a new name with inferred type? # # For members, after semantic analysis, this does not take base @@ -845,33 +845,33 @@ class NameExpr(RefExpr): This refers to a local name, global name or a module. """ - - name = None # type: str # Name referred to (may be qualified) - info = Undefined('TypeInfo') # TypeInfo of class surrounding expression - # (may be None) + + name = None # type: str # Name referred to (may be qualified) + info = Undefined('TypeInfo') # TypeInfo of class surrounding expression + # (may be None) literal = LITERAL_TYPE - + def __init__(self, name: str) -> None: self.name = name self.literal_hash = ('Var', name,) - + def type_node(self): return cast('TypeInfo', self.node) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_name_expr(self) class MemberExpr(RefExpr): """Member access expression x.y""" - + expr = Undefined(Node) - name = None # type: str + name = None # type: str # The variable node related to a definition. - def_var = None # type: Var + def_var = None # type: Var # Is this direct assignment to a data member (bypassing accessors)? direct = False - + def __init__(self, expr: Node, name: str, direct: bool = False) -> None: self.expr = expr self.name = name @@ -886,15 +886,15 @@ def accept(self, visitor: NodeVisitor[T]) -> T: # Kinds of arguments # Positional argument -ARG_POS = 0 # type: int +ARG_POS = 0 # type: int # Positional, optional argument (functions only, not calls) -ARG_OPT = 1 # type: int +ARG_OPT = 1 # type: int # *arg argument -ARG_STAR = 2 # type: int +ARG_STAR = 2 # type: int # Keyword argument x=y in call, or keyword-only function arg -ARG_NAMED = 3 # type: int +ARG_NAMED = 3 # type: int # **arg argument -ARG_STAR2 = 4 # type: int +ARG_STAR2 = 4 # type: int class CallExpr(Node): @@ -903,16 +903,16 @@ class CallExpr(Node): This can also represent several special forms that are syntactically calls such as cast(...) and Undefined(...). """ - + callee = Undefined(Node) args = Undefined(List[Node]) - arg_kinds = Undefined(List[int]) # ARG_ constants - arg_names = Undefined(List[str]) # Each name can be None if not a keyword - # argument. - analyzed = Undefined(Node) # If not None, the node that represents - # the meaning of the CallExpr. For - # cast(...) this is a CastExpr. - + arg_kinds = Undefined(List[int]) # ARG_ constants + arg_names = Undefined(List[str]) # Each name can be None if not a keyword + # argument. + analyzed = Undefined(Node) # If not None, the node that represents + # the meaning of the CallExpr. For + # cast(...) this is a CastExpr. + def __init__(self, callee: Node, args: List[Node], arg_kinds: List[int], arg_names: List[str] = None, analyzed: Node = None) -> None: if not arg_names: @@ -922,7 +922,7 @@ def __init__(self, callee: Node, args: List[Node], arg_kinds: List[int], self.arg_kinds = arg_kinds self.arg_names = arg_names self.analyzed = analyzed - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_call_expr(self) @@ -941,15 +941,15 @@ class IndexExpr(Node): Also wraps type application as a special form. """ - + base = Undefined(Node) index = Undefined(Node) # Inferred __getitem__ method type - method_type = None # type: mypy.types.Type + method_type = None # type: mypy.types.Type # If not None, this is actually semantically a type application # Class[type, ...]. analyzed = Undefined('TypeApplication') - + def __init__(self, base: Node, index: Node) -> None: self.base = base self.index = index @@ -958,25 +958,25 @@ def __init__(self, base: Node, index: Node) -> None: self.literal = self.base.literal self.literal_hash = ('Member', base.literal_hash, index.literal_hash) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_index_expr(self) class UnaryExpr(Node): """Unary operation""" - + op = '' expr = Undefined(Node) # Inferred operator method type - method_type = None # type: mypy.types.Type - + method_type = None # type: mypy.types.Type + def __init__(self, op: str, expr: Node) -> None: self.op = op self.expr = expr self.literal = self.expr.literal self.literal_hash = ('Unary', op, expr.literal_hash) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_unary_expr(self) @@ -1036,69 +1036,89 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class OpExpr(Node): - """Binary operation (other than . or [], which have specific nodes).""" - + """Binary operation (other than . or [] or comparison operators, + which have specific nodes).""" + op = '' left = Undefined(Node) right = Undefined(Node) - # Inferred type for the operator method type (when relevant; None for - # 'is'). - method_type = None # type: mypy.types.Type - + # Inferred type for the operator method type (when relevant). + method_type = None # type: mypy.types.Type + def __init__(self, op: str, left: Node, right: Node) -> None: self.op = op self.left = left self.right = right self.literal = min(self.left.literal, self.right.literal) self.literal_hash = ('Binary', op, left.literal_hash, right.literal_hash) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_op_expr(self) +class ComparisonExpr(Node): + """Comparison expression (e.g. a < b > c < d).""" + + operators = Undefined(List[str]) + operands = Undefined(List[Node]) + # Inferred type for the operator methods (when relevant; None for 'is'). + method_types = Undefined(List["mypy.types.Type"]) + + def __init__(self, operators: List[str], operands: List[Node]) -> None: + self.operators = operators + self.operands = operands + self.method_types = [] + self.literal = min(o.literal for o in self.operands) + self.literal_hash = ( ('Comparison',) + tuple(operators) + + tuple(o.literal_hash for o in operands) ) + + def accept(self, visitor: NodeVisitor[T]) -> T: + return visitor.visit_comparison_expr(self) + + class SliceExpr(Node): """Slice expression (e.g. 'x:y', 'x:', '::2' or ':'). This is only valid as index in index expressions. """ - + begin_index = Undefined(Node) # May be None end_index = Undefined(Node) # May be None stride = Undefined(Node) # May be None - + def __init__(self, begin_index: Node, end_index: Node, stride: Node) -> None: self.begin_index = begin_index self.end_index = end_index self.stride = stride - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_slice_expr(self) class CastExpr(Node): """Cast expression cast(type, expr).""" - + expr = Undefined(Node) type = Undefined('mypy.types.Type') - + def __init__(self, expr: Node, typ: 'mypy.types.Type') -> None: self.expr = expr self.type = typ - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_cast_expr(self) class SuperExpr(Node): """Expression super().name""" - + name = '' - info = Undefined('TypeInfo') # Type that contains this super expression - + info = Undefined('TypeInfo') # Type that contains this super expression + def __init__(self, name: str) -> None: self.name = name - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_super_expr(self) @@ -1108,122 +1128,123 @@ class FuncExpr(FuncItem): def name(self) -> str: return '' - + def expr(self) -> Node: """Return the expression (the body) of the lambda.""" ret = cast(ReturnStmt, self.body.body[0]) return ret.expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_func_expr(self) + class ListExpr(Node): """List literal expression [...].""" - - items = Undefined(List[Node] ) - + + items = Undefined(List[Node]) + def __init__(self, items: List[Node]) -> None: self.items = items if all(x.literal == LITERAL_YES for x in items): self.literal = LITERAL_YES self.literal_hash = ('List',) + tuple(x.literal_hash for x in items) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_list_expr(self) class DictExpr(Node): """Dictionary literal expression {key: value, ...}.""" - + items = Undefined(List[Tuple[Node, Node]]) - + def __init__(self, items: List[Tuple[Node, Node]]) -> None: self.items = items if all(x[0].literal == LITERAL_YES and x[1].literal == LITERAL_YES for x in items): self.literal = LITERAL_YES self.literal_hash = ('Dict',) + tuple((x[0].literal_hash, x[1].literal_hash) for x in items) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_dict_expr(self) class TupleExpr(Node): """Tuple literal expression (..., ...)""" - + items = Undefined(List[Node]) - + def __init__(self, items: List[Node]) -> None: self.items = items if all(x.literal == LITERAL_YES for x in items): self.literal = LITERAL_YES self.literal_hash = ('Tuple',) + tuple(x.literal_hash for x in items) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_tuple_expr(self) class SetExpr(Node): """Set literal expression {value, ...}.""" - + items = Undefined(List[Node]) - + def __init__(self, items: List[Node]) -> None: self.items = items if all(x.literal == LITERAL_YES for x in items): self.literal = LITERAL_YES self.literal_hash = ('Set',) + tuple(x.literal_hash for x in items) - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_set_expr(self) class GeneratorExpr(Node): """Generator expression ... for ... in ... [ for ... in ... ] [ if ... ].""" - + left_expr = Undefined(Node) sequences_expr = Undefined(List[Node]) condlists = Undefined(List[List[Node]]) indices = Undefined(List[List[NameExpr]]) types = Undefined(List[List['mypy.types.Type']]) - + def __init__(self, left_expr: Node, indices: List[List[NameExpr]], - types: List[List['mypy.types.Type']], sequences: List[Node], + types: List[List['mypy.types.Type']], sequences: List[Node], condlists: List[List[Node]]) -> None: self.left_expr = left_expr self.sequences = sequences self.condlists = condlists self.indices = indices self.types = types - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_generator_expr(self) class ListComprehension(Node): """List comprehension (e.g. [x + 1 for x in a])""" - + generator = Undefined(GeneratorExpr) - + def __init__(self, generator: GeneratorExpr) -> None: self.generator = generator - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_list_comprehension(self) class ConditionalExpr(Node): """Conditional expression (e.g. x if y else z)""" - + cond = Undefined(Node) if_expr = Undefined(Node) else_expr = Undefined(Node) - + def __init__(self, cond: Node, if_expr: Node, else_expr: Node) -> None: self.cond = cond self.if_expr = if_expr self.else_expr = else_expr - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_conditional_expr(self) @@ -1236,7 +1257,7 @@ class UndefinedExpr(Node): x = Undefined(List[int]) """ - + def __init__(self, type: 'mypy.types.Type') -> None: self.type = type @@ -1246,14 +1267,14 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class TypeApplication(Node): """Type application expr[type, ...]""" - + expr = Undefined(Node) types = Undefined(List['mypy.types.Type']) - + def __init__(self, expr: Node, types: List['mypy.types.Type']) -> None: self.expr = expr self.types = types - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_type_application(self) @@ -1278,7 +1299,7 @@ def name(self) -> str: def fullname(self) -> str: return self._fullname - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_type_var_expr(self) @@ -1313,12 +1334,12 @@ class CoerceExpr(Node): This is used only when compiling/transforming. These are inserted after type checking. """ - + expr = Undefined(Node) target_type = Undefined('mypy.types.Type') source_type = Undefined('mypy.types.Type') is_wrapper_class = False - + def __init__(self, expr: Node, target_type: 'mypy.types.Type', source_type: 'mypy.types.Type', is_wrapper_class: bool) -> None: @@ -1326,7 +1347,7 @@ def __init__(self, expr: Node, target_type: 'mypy.types.Type', self.target_type = target_type self.source_type = source_type self.is_wrapper_class = is_wrapper_class - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_coerce_expr(self) @@ -1334,12 +1355,12 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class JavaCast(Node): # TODO obsolete; remove expr = Undefined(Node) - target = Undefined('mypy.types.Type') - + target = Undefined('mypy.types.Type') + def __init__(self, expr: Node, target: 'mypy.types.Type') -> None: self.expr = expr self.target = target - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_java_cast(self) @@ -1350,12 +1371,12 @@ class TypeExpr(Node): This is used only for runtime type checking. This node is always generated only after type checking. """ - - type = Undefined('mypy.types.Type') - + + type = Undefined('mypy.types.Type') + def __init__(self, typ: 'mypy.types.Type') -> None: self.type = typ - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_type_expr(self) @@ -1367,12 +1388,12 @@ class TempNode(Node): of the type checker implementation. It only represents an opaque node with some fixed type. """ - + type = Undefined('mypy.types.Type') - + def __init__(self, typ: 'mypy.types.Type') -> None: self.type = typ - + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_temp_node(self) @@ -1383,14 +1404,14 @@ class TypeInfo(SymbolNode): The corresponding ClassDef instance represents the parse tree of the class. """ - - _fullname = None # type: str # Fully qualified name - defn = Undefined(ClassDef) # Corresponding ClassDef + + _fullname = None # type: str # Fully qualified name + defn = Undefined(ClassDef) # Corresponding ClassDef # Method Resolution Order: the order of looking up attributes. The first # value always to refers to self. mro = Undefined(List['TypeInfo']) - subtypes = Undefined(Set['TypeInfo']) # Direct subclasses - names = Undefined('SymbolTable') # Names defined directly in this type + subtypes = Undefined(Set['TypeInfo']) # Direct subclasses + names = Undefined('SymbolTable') # Names defined directly in this type is_abstract = False # Does the class have any abstract attributes? abstract_attributes = Undefined(List[str]) # All classes in this build unit that are disjoint with this class. @@ -1398,18 +1419,18 @@ class TypeInfo(SymbolNode): # Targets of disjointclass declarations present in this class only (for # generating error messages). disjointclass_decls = Undefined(List['TypeInfo']) - + # Information related to type annotations. - + # Generic type variable names type_vars = Undefined(List[str]) - + # Direct base classes. bases = Undefined(List['mypy.types.Instance']) # Duck type compatibility (ducktype decorator) - ducktype = None # type: mypy.types.Type - + ducktype = None # type: mypy.types.Type + def __init__(self, names: 'SymbolTable', defn: ClassDef) -> None: """Initialize a TypeInfo.""" self.names = names @@ -1426,18 +1447,18 @@ def __init__(self, names: 'SymbolTable', defn: ClassDef) -> None: if defn.type_vars: for vd in defn.type_vars: self.type_vars.append(vd.name) - + def name(self) -> str: """Short name.""" return self.defn.name def fullname(self) -> str: return self._fullname - + def is_generic(self) -> bool: """Is the type generic (i.e. does it have type variables)?""" return self.type_vars is not None and len(self.type_vars) > 0 - + def get(self, name: str) -> 'SymbolTableNode': for cls in self.mro: n = cls.names.get(name) @@ -1454,23 +1475,22 @@ def __getitem__(self, name: str) -> 'SymbolTableNode': def __repr__(self) -> str: return '' % self.fullname() - - + # IDEA: Refactor the has* methods to be more consistent and document # them. - + def has_readable_member(self, name: str) -> bool: return self.get(name) is not None - + def has_writable_member(self, name: str) -> bool: return self.has_var(name) - + def has_var(self, name: str) -> bool: return self.get_var(name) is not None - + def has_method(self, name: str) -> bool: return self.get_method(name) is not None - + def get_var(self, name: str) -> Var: for cls in self.mro: if name in cls.names: @@ -1480,15 +1500,15 @@ def get_var(self, name: str) -> Var: else: return None return None - + def get_var_or_getter(self, name: str) -> SymbolNode: # TODO getter return self.get_var(name) - + def get_var_or_setter(self, name: str) -> SymbolNode: # TODO setter return self.get_var(name) - + def get_method(self, name: str) -> FuncBase: for cls in self.mro: if name in cls.names: @@ -1505,7 +1525,7 @@ def calculate_mro(self) -> None: Raise MroError if cannot determine mro. """ self.mro = linearize_hierarchy(self) - + def has_base(self, fullname: str) -> bool: """Return True if type has a base type with the specified name. @@ -1515,7 +1535,7 @@ def has_base(self, fullname: str) -> bool: if cls.fullname() == fullname: return True return False - + def all_subtypes(self) -> 'Set[TypeInfo]': """Return TypeInfos of all subtypes, including this type, as a set.""" subtypes = set([self]) @@ -1523,24 +1543,24 @@ def all_subtypes(self) -> 'Set[TypeInfo]': for t in subt.all_subtypes(): subtypes.add(t) return subtypes - + def all_base_classes(self) -> 'List[TypeInfo]': """Return a list of base classes, including indirect bases.""" assert False - + def direct_base_classes(self) -> 'List[TypeInfo]': """Return a direct base classes. Omit base classes of other base classes. """ return [base.type for base in self.bases] - + def __str__(self) -> str: """Return a string representation of the type. This includes the most important information about the type. """ - base = None # type: str + base = None # type: str if self.bases: base = 'Bases({})'.format(', '.join(str(base) for base in self.bases)) @@ -1552,7 +1572,7 @@ def __str__(self) -> str: class SymbolTableNode: # LDEF/GDEF/MDEF/UNBOUND_TVAR/TVAR/... - kind = None # type: int + kind = None # type: int # AST node of definition (FuncDef/Var/TypeInfo/Decorator/TypeVarExpr, # or None for a bound type variable). node = Undefined(SymbolNode) @@ -1560,9 +1580,9 @@ class SymbolTableNode: tvar_id = 0 # Module id (e.g. "foo.bar") or None mod_id = '' - # If None, fall back to type of node + # If None, fall back to type of node type_override = Undefined('mypy.types.Type') - + def __init__(self, kind: int, node: SymbolNode, mod_id: str = None, typ: 'mypy.types.Type' = None, tvar_id: int = 0) -> None: self.kind = kind @@ -1581,7 +1601,7 @@ def fullname(self) -> str: @property def type(self) -> 'mypy.types.Type': # IDEA: Get rid of the Any type. - node = self.node # type: Any + node = self.node # type: Any if self.type_override is not None: return self.type_override elif ((isinstance(node, Var) or isinstance(node, FuncDef)) @@ -1591,7 +1611,7 @@ def type(self) -> 'mypy.types.Type': return (cast(Decorator, node)).var.type else: return None - + def __str__(self) -> str: s = '{}/{}'.format(node_kinds[self.kind], short_type(self.node)) if self.mod_id is not None: @@ -1610,7 +1630,7 @@ def __str__(self) -> str: if isinstance(value, SymbolTableNode): if (value.fullname != 'builtins' and value.fullname.split('.')[-1] not in - implicit_module_attrs): + implicit_module_attrs): a.append(' ' + str(key) + ' : ' + str(value)) else: a.append(' ') @@ -1623,7 +1643,7 @@ def __str__(self) -> str: def clean_up(s: str) -> str: # TODO remove return re.sub('.*::', '', s) - + def function_type(func: FuncBase) -> 'mypy.types.FunctionLike': if func.type: @@ -1632,19 +1652,19 @@ def function_type(func: FuncBase) -> 'mypy.types.FunctionLike': # Implicit type signature with dynamic types. # Overloaded functions always have a signature, so func must be an # ordinary function. - fdef = cast(FuncDef, func) + fdef = cast(FuncDef, func) name = func.name() if name: name = '"{}"'.format(name) - names = [] # type: List[str] + names = [] # type: List[str] for arg in fdef.args: names.append(arg.name()) return mypy.types.Callable([mypy.types.AnyType()] * len(fdef.args), - fdef.arg_kinds, - names, - mypy.types.AnyType(), - False, - name) + fdef.arg_kinds, + names, + mypy.types.AnyType(), + False, + name) @overload @@ -1652,6 +1672,7 @@ def method_type(func: FuncBase) -> 'mypy.types.FunctionLike': """Return the signature of a method (omit self).""" return method_type(function_type(func)) + @overload def method_type(sig: 'mypy.types.FunctionLike') -> 'mypy.types.FunctionLike': if isinstance(sig, mypy.types.Callable): @@ -1667,13 +1688,13 @@ def method_type(sig: 'mypy.types.FunctionLike') -> 'mypy.types.FunctionLike': def method_callable(c: 'mypy.types.Callable') -> 'mypy.types.Callable': return mypy.types.Callable(c.arg_types[1:], - c.arg_kinds[1:], - c.arg_names[1:], - c.ret_type, - c.is_type_obj(), - c.name, - c.variables, - c.bound_vars) + c.arg_kinds[1:], + c.arg_names[1:], + c.ret_type, + c.is_type_obj(), + c.name, + c.variables, + c.bound_vars) class MroError(Exception): diff --git a/mypy/output.py b/mypy/output.py index 43ced2ebebda..5a0820f194e0 100644 --- a/mypy/output.py +++ b/mypy/output.py @@ -26,15 +26,15 @@ def __init__(self): # break self.extra_indent = 0 self.block_depth = 0 - + def output(self): """Return a string representation of the output.""" return ''.join(self.result) - + def visit_mypy_file(self, o): self.nodes(o.defs) self.token(o.repr.eof) - + def visit_import(self, o): r = o.repr self.token(r.import_tok) @@ -45,13 +45,13 @@ def visit_import(self, o): if i < len(r.commas): self.token(r.commas[i]) self.token(r.br) - + def visit_import_from(self, o): self.output_import_from_or_all(o) - + def visit_import_all(self, o): self.output_import_from_or_all(o) - + def output_import_from_or_all(self, o): r = o.repr self.token(r.from_tok) @@ -63,7 +63,7 @@ def output_import_from_or_all(self, o): self.token(comma) self.token(r.rparen) self.token(r.br) - + def visit_class_def(self, o): r = o.repr self.tokens([r.class_tok, r.name]) @@ -76,7 +76,7 @@ def visit_class_def(self, o): self.token(r.commas[i]) self.token(r.rparen) self.node(o.defs) - + def type_vars(self, v): # IDEA: Combine this with type_vars in TypeOutputVisitor. if v and v.repr: @@ -91,38 +91,38 @@ def type_vars(self, v): if i < len(r.commas): self.token(r.commas[i]) self.token(r.rangle) - + def visit_func_def(self, o): r = o.repr - + if r.def_tok: self.token(r.def_tok) else: self.type(o.type.items()[0].ret_type) - + self.token(r.name) - + self.function_header(o, r.args, o.arg_kinds) - + self.node(o.body) - + def visit_overloaded_func_def(self, o): for f in o.items: f.accept(self) - + def function_header(self, o, arg_repr, arg_kinds, pre_args_func=None, erase_type=False, strip_space_before_first_arg=False): r = o.repr - + t = None if o.type and not erase_type: t = o.type - + init = o.init - + if t: self.type_vars(t.variables) - + self.token(arg_repr.lseparator) if pre_args_func: pre_args_func() @@ -151,7 +151,7 @@ def function_header(self, o, arg_repr, arg_kinds, pre_args_func=None, if i < len(arg_repr.commas): self.token(arg_repr.commas[i]) self.token(arg_repr.rseparator) - + def visit_var_def(self, o): r = o.repr if r: @@ -161,21 +161,21 @@ def visit_var_def(self, o): self.token(r.assign) self.node(o.init) self.token(r.br) - + def visit_var(self, o): r = o.repr self.token(r.name) self.token(r.comma) - + def visit_decorator(self, o): for at, br, dec in zip(o.repr.ats, o.repr.brs, o.decorators): self.token(at) self.node(dec) self.token(br) self.node(o.func) - + # Statements - + def visit_block(self, o): r = o.repr self.tokens([r.colon, r.br, r.indent]) @@ -186,7 +186,7 @@ def visit_block(self, o): self.token(r.dedent) self.indent = old_indent self.block_depth -= 1 - + def visit_global_decl(self, o): r = o.repr self.token(r.global_tok) @@ -195,11 +195,11 @@ def visit_global_decl(self, o): if i < len(r.commas): self.token(r.commas[i]) self.token(r.br) - + def visit_expression_stmt(self, o): self.node(o.expr) self.token(o.repr.br) - + def visit_assignment_stmt(self, o): r = o.repr i = 0 @@ -209,20 +209,20 @@ def visit_assignment_stmt(self, o): i += 1 self.node(o.rvalue) self.token(r.br) - + def visit_operator_assignment_stmt(self, o): r = o.repr self.node(o.lvalue) self.token(r.assign) self.node(o.rvalue) self.token(r.br) - + def visit_return_stmt(self, o): self.simple_stmt(o, o.expr) - + def visit_assert_stmt(self, o): self.simple_stmt(o, o.expr) - + def visit_yield_stmt(self, o): self.simple_stmt(o, o.expr) @@ -234,18 +234,18 @@ def visit_del_stmt(self, o): def visit_break_stmt(self, o): self.simple_stmt(o) - + def visit_continue_stmt(self, o): self.simple_stmt(o) - + def visit_pass_stmt(self, o): self.simple_stmt(o) - + def simple_stmt(self, o, expr=None): self.token(o.repr.keyword) self.node(expr) self.token(o.repr.br) - + def visit_raise_stmt(self, o): self.token(o.repr.raise_tok) self.node(o.expr) @@ -253,7 +253,7 @@ def visit_raise_stmt(self, o): self.token(o.repr.from_tok) self.node(o.from_expr) self.token(o.repr.br) - + def visit_while_stmt(self, o): self.token(o.repr.while_tok) self.node(o.expr) @@ -261,7 +261,7 @@ def visit_while_stmt(self, o): if o.else_body: self.token(o.repr.else_tok) self.node(o.else_body) - + def visit_for_stmt(self, o): r = o.repr self.token(r.for_tok) @@ -271,12 +271,12 @@ def visit_for_stmt(self, o): self.token(r.commas[i]) self.token(r.in_tok) self.node(o.expr) - + self.node(o.body) if o.else_body: self.token(r.else_tok) self.node(o.else_body) - + def visit_if_stmt(self, o): r = o.repr self.token(r.if_tok) @@ -289,7 +289,7 @@ def visit_if_stmt(self, o): self.token(r.else_tok) if o.else_body: self.node(o.else_body) - + def visit_try_stmt(self, o): r = o.repr self.token(r.try_tok) @@ -306,7 +306,7 @@ def visit_try_stmt(self, o): if o.finally_body: self.token(r.finally_tok) self.node(o.finally_body) - + def visit_with_stmt(self, o): self.token(o.repr.with_tok) for i in range(len(o.expr)): @@ -316,42 +316,42 @@ def visit_with_stmt(self, o): if i < len(o.repr.commas): self.token(o.repr.commas[i]) self.node(o.body) - + # Expressions - + def visit_int_expr(self, o): self.token(o.repr.int) - + def visit_str_expr(self, o): self.tokens(o.repr.string) - + def visit_bytes_expr(self, o): self.tokens(o.repr.string) - + def visit_float_expr(self, o): self.token(o.repr.float) - + def visit_paren_expr(self, o): self.token(o.repr.lparen) self.node(o.expr) self.token(o.repr.rparen) - + def visit_name_expr(self, o): # Supertype references may not have a representation. if o.repr: self.token(o.repr.id) - + def visit_member_expr(self, o): self.node(o.expr) self.token(o.repr.dot) self.token(o.repr.name) - + def visit_index_expr(self, o): self.node(o.base) self.token(o.repr.lbracket) self.node(o.index) self.token(o.repr.rbracket) - + def visit_slice_expr(self, o): self.node(o.begin_index) self.token(o.repr.colon) @@ -380,41 +380,48 @@ def visit_call_expr(self, o): if i < len(r.commas): self.token(r.commas[i]) self.token(r.rparen) - + def visit_op_expr(self, o): self.node(o.left) - self.tokens([o.repr.op, o.repr.op2]) + self.tokens([o.repr.op]) self.node(o.right) - + + def visit_comparison_expr(self, o): + self.node(o.operands[0]) + for ops, operand in zip(o.repr.operators, o.operands[1:]): + # ops = op, op2 + self.tokens(list(ops)) + self.node(operand) + def visit_cast_expr(self, o): self.token(o.repr.lparen) self.type(o.type) self.token(o.repr.rparen) self.node(o.expr) - + def visit_super_expr(self, o): r = o.repr self.tokens([r.super_tok, r.lparen, r.rparen, r.dot, r.name]) - + def visit_unary_expr(self, o): self.token(o.repr.op) self.node(o.expr) - + def visit_list_expr(self, o): r = o.repr self.token(r.lbracket) self.comma_list(o.items, r.commas) self.token(r.rbracket) - + def visit_set_expr(self, o): self.visit_list_expr(o) - + def visit_tuple_expr(self, o): r = o.repr self.token(r.lparen) self.comma_list(o.items, r.commas) self.token(r.rparen) - + def visit_dict_expr(self, o): r = o.repr self.token(r.lbrace) @@ -427,14 +434,14 @@ def visit_dict_expr(self, o): self.token(r.commas[i]) i += 1 self.token(r.rbrace) - + def visit_func_expr(self, o): r = o.repr self.token(r.lambda_tok) self.function_header(o, r.args, o.arg_kinds) self.token(r.colon) self.node(o.body.body[0].expr) - + def visit_type_application(self, o): self.node(o.expr) self.token(o.repr.langle) @@ -449,7 +456,7 @@ def visit_generator_expr(self, o): for j in range(len(o.indices[i])): self.node(o.types[i][j]) self.node(o.indices[i][j]) - if j < len(o.indices[i])-1: + if j < len(o.indices[i]) - 1: self.token(r.commas[0]) self.token(r.in_toks[i]) self.node(o.sequences[i]) @@ -461,12 +468,12 @@ def visit_list_comprehension(self, o): self.token(o.repr.lbracket) self.node(o.generator) self.token(o.repr.rbracket) - + # Helpers - + def line(self): return self.line_number - + def string(self, s): """Output a string.""" if self.omit_next_space: @@ -477,44 +484,44 @@ def string(self, s): if s != '': s = s.replace('\n', '\n' + ' ' * self.extra_indent) self.result.append(s) - + def token(self, t): """Output a token.""" self.string(t.rep()) - + def tokens(self, a): """Output an array of tokens.""" for t in a: self.token(t) - + def node(self, n): """Output a node.""" if n: n.accept(self) - + def nodes(self, a): """Output an array of nodes.""" for n in a: self.node(n) - + def comma_list(self, items, commas): for i in range(len(items)): self.node(items[i]) if i < len(commas): self.token(commas[i]) - + def type_list(self, items, commas): for i in range(len(items)): self.type(items[i]) if i < len(commas): self.token(commas[i]) - + def type(self, t): """Output a type.""" if t: v = TypeOutputVisitor() t.accept(v) self.string(v.output()) - + def last_output_char(self): if self.result and self.result[-1]: return self.result[-1][-1] @@ -527,22 +534,22 @@ class TypeOutputVisitor: """Type visitor that outputs source code.""" def __init__(self): self.result = [] # strings - + def output(self): """Return a string representation of the output.""" return ''.join(self.result) - + def visit_unbound_type(self, t): self.visit_instance(t) - + def visit_any(self, t): if t.repr: self.token(t.repr.any_tok) - + def visit_void(self, t): if t.repr: self.token(t.repr.void) - + def visit_instance(self, t): r = t.repr if isinstance(r, CommonTypeRepr): @@ -555,17 +562,17 @@ def visit_instance(self, t): assert len(t.args) == 1 self.comma_list(t.args, []) self.tokens([r.lbracket, r.rbracket]) - + def visit_type_var(self, t): self.token(t.repr.name) - + def visit_tuple_type(self, t): r = t.repr self.tokens(r.components) self.token(r.langle) self.comma_list(t.items, r.commas) self.token(r.rangle) - + def visit_callable(self, t): r = t.repr self.tokens([r.func, r.langle]) @@ -573,7 +580,7 @@ def visit_callable(self, t): self.token(r.lparen) self.comma_list(t.arg_types, r.commas) self.tokens([r.rparen, r.rangle]) - + def type_vars(self, v): if v and v.repr: r = v.repr @@ -587,26 +594,26 @@ def type_vars(self, v): if i < len(r.commas): self.token(r.commas[i]) self.token(r.rangle) - + # Helpers - + def string(self, s): """Output a string.""" self.result.append(s) - + def token(self, t): """Output a token.""" self.result.append(t.rep()) - + def tokens(self, a): """Output an array of tokens.""" for t in a: self.token(t) - + def type(self, n): """Output a type.""" if n: n.accept(self) - + def comma_list(self, items, commas): for i in range(len(items)): self.type(items[i]) diff --git a/mypy/parse.py b/mypy/parse.py index 137dba91c872..e086eca9b715 100755 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -23,7 +23,7 @@ TupleExpr, GeneratorExpr, ListComprehension, ListExpr, ConditionalExpr, DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, - UnaryExpr, FuncExpr, TypeApplication, PrintStmt, ImportBase, YieldFromStmt, + UnaryExpr, FuncExpr, TypeApplication, PrintStmt, ImportBase, ComparisonExpr, YieldFromStmt, YieldFromExpr ) from mypy import nodes @@ -49,8 +49,8 @@ 'not': 6, 'and': 5, 'or': 4, - '': 3, # conditional expression - '': 2, # list comprehension + '': 3, # conditional expression + '': 2, # list comprehension ',': 1} @@ -58,8 +58,10 @@ '+=', '-=', '*=', '/=', '//=', '%=', '**=', '|=', '&=', '^=', '>>=', '<<=']) +op_comp = set([ + '>', '<', '==', '>=', '<=', '<>', '!=', 'is', 'is', 'in', 'not']) -none = Token('') # Empty token +none = Token('') # Empty token def parse(s: str, fnam: str = None, errors: Errors = None, @@ -83,16 +85,16 @@ class Parser: ind = 0 errors = Undefined(Errors) raise_on_error = False - + # Are we currently parsing the body of a class definition? is_class_body = False # All import nodes encountered so far in this parse unit. imports = Undefined(List[ImportBase]) # Names imported from __future__. future_options = Undefined(List[str]) - + def __init__(self, fnam: str, errors: Errors, pyversion: int, - custom_typing_module : str = None) -> None: + custom_typing_module: str = None) -> None: self.raise_on_error = errors is None self.pyversion = pyversion self.custom_typing_module = custom_typing_module @@ -104,7 +106,7 @@ def __init__(self, fnam: str, errors: Errors, pyversion: int, self.errors.set_file(fnam) else: self.errors.set_file('') - + def parse(self, s: str) -> MypyFile: self.tok = lex.lex(s) self.ind = 0 @@ -114,7 +116,7 @@ def parse(self, s: str) -> MypyFile: if self.raise_on_error and self.errors.is_errors(): self.errors.raise_error() return file - + def parse_file(self) -> MypyFile: """Parse a mypy source file.""" is_bom = self.parse_bom() @@ -123,9 +125,9 @@ def parse_file(self) -> MypyFile: node = MypyFile(defs, self.imports, is_bom) self.set_repr(node, noderepr.MypyFileRepr(eof)) return node - + # Parse the initial part - + def parse_bom(self) -> bool: """Parse the optional byte order mark at the beginning of a file.""" if isinstance(self.current(), Bom): @@ -135,7 +137,7 @@ def parse_bom(self) -> bool: return True else: return False - + def parse_import(self) -> Import: import_tok = self.expect('import') ids = List[Tuple[str, str]]() @@ -165,7 +167,7 @@ def parse_import(self) -> Import: self.set_repr(node, noderepr.ImportRepr(import_tok, id_toks, as_names, commas, br)) return node - + def parse_import_from(self) -> Node: from_tok = self.expect('from') name, components = self.parse_qualified_name() @@ -175,7 +177,7 @@ def parse_import_from(self) -> Node: name_toks = List[Tuple[List[Token], Token]]() lparen = none rparen = none - node = None # type: ImportBase + node = None # type: ImportBase if self.current_str() == '*': name_toks.append(([self.skip()], none)) node = ImportAll(name) @@ -209,11 +211,11 @@ def parse_import_from(self) -> Node: self.imports.append(node) # TODO: Fix representation if there is a custom typing module import. self.set_repr(node, noderepr.ImportFromRepr( - from_tok, components,import_tok, lparen, name_toks, rparen, br)) + from_tok, components, import_tok, lparen, name_toks, rparen, br)) if name == '__future__': self.future_options.extend(target[0] for target in targets) return node - + def parse_import_name(self) -> Tuple[str, str, List[Token]]: tok = self.expect_type(Name) name = tok.string @@ -225,7 +227,7 @@ def parse_import_name(self) -> Tuple[str, str, List[Token]]: return name, as_name.string, tokens else: return name, name, tokens - + def parse_qualified_name(self) -> Tuple[str, List[Token]]: """Parse a name with an optional module qualifier. @@ -242,9 +244,9 @@ def parse_qualified_name(self) -> Tuple[str, List[Token]]: n += '.' + tok.string components.append(tok) return n, components - + # Parsing global definitions - + def parse_defs(self) -> List[Node]: defs = List[Node]() while not self.eof(): @@ -256,24 +258,24 @@ def parse_defs(self) -> List[Node]: except ParseError: pass return defs - + def parse_class_def(self) -> ClassDef: old_is_class_body = self.is_class_body self.is_class_body = True - + type_tok = self.expect('class') lparen = none rparen = none - metaclass = None # type: str - + metaclass = None # type: str + try: commas, base_types = List[Token](), List[Type]() try: name_tok = self.expect_type(Name) name = name_tok.string - + self.errors.push_type(name) - + if self.current_str() == '(': lparen = self.skip() while True: @@ -287,9 +289,9 @@ def parse_class_def(self) -> ClassDef: rparen = self.expect(')') except ParseError: pass - + defs, _ = self.parse_block() - + node = ClassDef(name, defs, None, base_types, metaclass=metaclass) self.set_repr(node, noderepr.TypeDefRepr(type_tok, name_tok, lparen, commas, rparen)) @@ -297,7 +299,7 @@ def parse_class_def(self) -> ClassDef: finally: self.errors.pop_type() self.is_class_body = old_is_class_body - + def parse_super_type(self) -> Type: if (isinstance(self.current(), Name) and self.current_str() != 'void'): return self.parse_type() @@ -308,7 +310,7 @@ def parse_metaclass(self) -> str: self.expect('metaclass') self.expect('=') return self.parse_qualified_name()[0] - + def parse_decorated_function_or_class(self) -> Node: ats = List[Token]() brs = List[Token]() @@ -331,7 +333,7 @@ def parse_decorated_function_or_class(self) -> Node: cls = self.parse_class_def() cls.decorators = decorators return cls - + def parse_function(self) -> FuncDef: def_tok = self.expect('def') is_method = self.is_class_body @@ -339,7 +341,7 @@ def parse_function(self) -> FuncDef: try: (name, args, init, kinds, typ, is_error, toks) = self.parse_function_header() - + body, comment_type = self.parse_block(allow_type=True) if comment_type: # The function has a # type: ... signature. @@ -365,12 +367,12 @@ def parse_function(self) -> FuncDef: [arg.name() for arg in args], sig.ret_type, False) - + # If there was a serious error, we really cannot build a parse tree # node. if is_error: return None - + node = FuncDef(name, args, kinds, init, body, typ) name_tok, arg_reprs = toks node.set_line(name_tok) @@ -393,11 +395,11 @@ def check_argument_kinds(self, funckinds: List[int], sigkinds: List[int], (nodes.ARG_STAR2, '**')]: if ((kind in funckinds and sigkinds[funckinds.index(kind)] != kind) or - (funckinds.count(kind) != sigkinds.count(kind))): + (funckinds.count(kind) != sigkinds.count(kind))): self.fail( "Inconsistent use of '{}' in function " "signature".format(token), line) - + def parse_function_header(self) -> Tuple[str, List[Var], List[Node], List[int], Type, bool, Tuple[Token, Any]]: @@ -411,61 +413,60 @@ def parse_function_header(self) -> Tuple[str, List[Var], List[Node], signature (annotation) error flag (True if error) (name token, representation of arguments) - """ + """ name_tok = none - + try: name_tok = self.expect_type(Name) name = name_tok.string - + self.errors.push_function(name) - + (args, init, kinds, typ, arg_repr) = self.parse_args() except ParseError: if not isinstance(self.current(), Break): - self.ind -= 1 # Kludge: go back to the Break token + self.ind -= 1 # Kludge: go back to the Break token # Resynchronise parsing by going back over :, if present. if isinstance(self.tok[self.ind - 1], Colon): self.ind -= 1 return (name, [], [], [], None, True, (name_tok, None)) - + return (name, args, init, kinds, typ, False, (name_tok, arg_repr)) - + def parse_args(self) -> Tuple[List[Var], List[Node], List[int], Type, noderepr.FuncArgsRepr]: """Parse a function signature (...) [-> t].""" lparen = self.expect('(') - + # Parse the argument list (everything within '(' and ')'). (args, init, kinds, has_inits, arg_names, commas, asterisk, assigns, arg_types) = self.parse_arg_list() - + rparen = self.expect(')') - if self.current_str() == '-': + if self.current_str() == '->': self.skip() - self.expect('>') ret_type = self.parse_type() else: ret_type = None self.verify_argument_kinds(kinds, lparen.line) - - names = [] # type: List[str] + + names = [] # type: List[str] for arg in args: names.append(arg.name()) - + annotation = self.build_func_annotation( ret_type, arg_types, kinds, names, lparen.line) - + return (args, init, kinds, annotation, noderepr.FuncArgsRepr(lparen, rparen, arg_names, commas, assigns, asterisk)) - + def build_func_annotation(self, ret_type: Type, arg_types: List[Type], - kinds: List[int], names: List[str], + kinds: List[int], names: List[str], line: int, is_default_ret: bool = False) -> Type: # Are there any type annotations? if ((ret_type and not is_default_ret) @@ -475,7 +476,7 @@ def build_func_annotation(self, ret_type: Type, arg_types: List[Type], ret_type, line) else: return None - + def parse_arg_list( self, allow_signature: bool = True) -> Tuple[List[Var], List[Node], List[int], bool, @@ -490,21 +491,21 @@ def parse_arg_list( arguments, initializers, kinds, has inits, arg name tokens, comma tokens, asterisk tokens, assignment tokens, argument types """ - args = [] # type: List[Var] - kinds = [] # type: List[int] - names = [] # type: List[str] - init = [] # type: List[Node] + args = [] # type: List[Var] + kinds = [] # type: List[int] + names = [] # type: List[str] + init = [] # type: List[Node] has_inits = False - arg_types = [] # type: List[Type] - - arg_names = [] # type: List[Token] - commas = [] # type: List[Token] - asterisk = [] # type: List[Token] - assigns = [] # type: List[Token] - + arg_types = [] # type: List[Type] + + arg_names = [] # type: List[Token] + commas = [] # type: List[Token] + asterisk = [] # type: List[Token] + assigns = [] # type: List[Token] + require_named = False bare_asterisk_before = -1 - + if self.current_str() != ')' and self.current_str() != ':': while self.current_str() != ')': if self.current_str() == '*' and self.peek().string == ',': @@ -540,7 +541,7 @@ def parse_arg_list( arg_names.append(name) args.append(Var(name.string)) arg_types.append(self.parse_arg_type(allow_signature)) - + if self.current_str() == '=': assigns.append(self.expect('=')) init.append(self.parse_expression(precedence[','])) @@ -557,11 +558,11 @@ def parse_arg_list( kinds.append(nodes.ARG_NAMED) else: kinds.append(nodes.ARG_POS) - + if self.current().string != ',': break commas.append(self.expect(',')) - + return (args, init, kinds, has_inits, arg_names, commas, asterisk, assigns, arg_types) @@ -584,7 +585,7 @@ def verify_argument_kinds(self, kinds: List[int], line: int) -> None: elif kind == nodes.ARG_STAR2 and i != len(kinds) - 1: self.fail('Invalid argument list', line) found.add(kind) - + def construct_function_type(self, arg_types: List[Type], kinds: List[int], names: List[str], ret_type: Type, line: int) -> Callable: @@ -597,9 +598,9 @@ def construct_function_type(self, arg_types: List[Type], kinds: List[int], ret_type = AnyType() return Callable(arg_types, kinds, names, ret_type, False, None, None, [], line, None) - + # Parsing statements - + def parse_block(self, allow_type: bool = False) -> Tuple[Block, Type]: colon = self.expect(':') if not isinstance(self.current(), Break): @@ -612,7 +613,7 @@ def parse_block(self, allow_type: bool = False) -> Tuple[Block, Type]: br = self.expect_break() type = self.parse_type_comment(br, signature=True) indent = self.expect_indent() - stmt = [] # type: List[Node] + stmt = [] # type: List[Node] while (not isinstance(self.current(), Dedent) and not isinstance(self.current(), Eof)): try: @@ -625,7 +626,7 @@ def parse_block(self, allow_type: bool = False) -> Tuple[Block, Type]: dedent = none if isinstance(self.current(), Dedent): dedent = self.skip() - node = Block(stmt).set_line(colon) + node = Block(stmt).set_line(colon) self.set_repr(node, noderepr.BlockRepr(colon, br, indent, dedent)) return cast(Block, node), type @@ -638,13 +639,13 @@ def try_combine_overloads(self, s: Node, stmt: List[Node]) -> bool: stmt[-1] = OverloadedFuncDef([cast(Decorator, stmt[-1]), fdef]) return True elif (isinstance(stmt[-1], OverloadedFuncDef) and - (cast(OverloadedFuncDef, stmt[-1])).name() == n): + (cast(OverloadedFuncDef, stmt[-1])).name() == n): (cast(OverloadedFuncDef, stmt[-1])).items.append(fdef) return True return False - + def parse_statement(self) -> Node: - stmt = Undefined # type: Node + stmt = Undefined # type: Node t = self.current() ts = self.current_str() if ts == 'if': @@ -685,7 +686,7 @@ def parse_statement(self) -> Node: stmt = self.parse_with_stmt() elif ts == '@': stmt = self.parse_decorated_function_or_class() - elif ts == 'print' and (self.pyversion == 2 and + elif ts == 'print' and (self.pyversion == 2 and 'print_function' not in self.future_options): stmt = self.parse_print_stmt() else: @@ -693,7 +694,7 @@ def parse_statement(self) -> Node: if stmt is not None: stmt.set_line(t) return stmt - + def parse_expression_or_assignment(self) -> Node: e = self.parse_expression() if self.current_str() == '=': @@ -714,7 +715,7 @@ def parse_expression_or_assignment(self) -> Node: expr = ExpressionStmt(e) self.set_repr(expr, noderepr.ExpressionStmtRepr(br)) return expr - + def parse_assignment(self, lv: Any) -> Node: """Parse an assignment statement. @@ -723,7 +724,7 @@ def parse_assignment(self, lv: Any) -> Node: """ assigns = [self.expect('=')] lvalues = [lv] - + e = self.parse_expression() while self.current_str() == '=': lvalues.append(e) @@ -735,10 +736,10 @@ def parse_assignment(self, lv: Any) -> Node: assignment = AssignmentStmt(lvalues, e, type) self.set_repr(assignment, noderepr.AssignmentStmtRepr(assigns, br)) return assignment - + def parse_return_stmt(self) -> ReturnStmt: return_tok = self.expect('return') - expr = None # type: Node + expr = None # type: Node if not isinstance(self.current(), Break): expr = self.parse_expression() if isinstance(expr, YieldFromExpr): #cant go a yield from expr @@ -747,11 +748,11 @@ def parse_return_stmt(self) -> ReturnStmt: node = ReturnStmt(expr) self.set_repr(node, noderepr.SimpleStmtRepr(return_tok, br)) return node - + def parse_raise_stmt(self) -> RaiseStmt: raise_tok = self.expect('raise') - expr = None # type: Node - from_expr = None # type: Node + expr = None # type: Node + from_expr = None # type: Node from_tok = none if not isinstance(self.current(), Break): expr = self.parse_expression() @@ -762,7 +763,7 @@ def parse_raise_stmt(self) -> RaiseStmt: node = RaiseStmt(expr, from_expr) self.set_repr(node, noderepr.RaiseStmtRepr(raise_tok, from_tok, br)) return node - + def parse_assert_stmt(self) -> AssertStmt: assert_tok = self.expect('assert') expr = self.parse_expression() @@ -770,11 +771,10 @@ def parse_assert_stmt(self) -> AssertStmt: node = AssertStmt(expr) self.set_repr(node, noderepr.SimpleStmtRepr(assert_tok, br)) return node - + def parse_yield_stmt(self) -> YieldStmt: yield_tok = self.expect('yield') - expr = None # type: Node - node = YieldStmt(expr) + expr = None # type: Node if not isinstance(self.current(), Break): if isinstance(self.current(), Keyword) and self.current_str() == "from": # Not go if it's not from from_tok = self.expect("from") @@ -811,28 +811,28 @@ def parse_del_stmt(self) -> DelStmt: node = DelStmt(expr) self.set_repr(node, noderepr.SimpleStmtRepr(del_tok, br)) return node - + def parse_break_stmt(self) -> BreakStmt: break_tok = self.expect('break') br = self.expect_break() node = BreakStmt() self.set_repr(node, noderepr.SimpleStmtRepr(break_tok, br)) return node - + def parse_continue_stmt(self) -> ContinueStmt: continue_tok = self.expect('continue') br = self.expect_break() node = ContinueStmt() self.set_repr(node, noderepr.SimpleStmtRepr(continue_tok, br)) return node - + def parse_pass_stmt(self) -> PassStmt: pass_tok = self.expect('pass') br = self.expect_break() node = PassStmt() self.set_repr(node, noderepr.SimpleStmtRepr(pass_tok, br)) return node - + def parse_global_decl(self) -> GlobalDecl: global_tok = self.expect('global') names = List[str]() @@ -850,7 +850,7 @@ def parse_global_decl(self) -> GlobalDecl: self.set_repr(node, noderepr.GlobalDeclRepr(global_tok, name_toks, commas, br)) return node - + def parse_while_stmt(self) -> WhileStmt: is_error = False while_tok = self.expect('while') @@ -871,38 +871,38 @@ def parse_while_stmt(self) -> WhileStmt: return node else: return None - + def parse_for_stmt(self) -> ForStmt: for_tok = self.expect('for') index, types, commas = self.parse_for_index_variables() in_tok = self.expect('in') expr = self.parse_expression() - + body, _ = self.parse_block() - + if self.current_str() == 'else': else_tok = self.expect('else') else_body, _ = self.parse_block() else: else_body = None else_tok = none - + node = ForStmt(index, expr, body, else_body, types) self.set_repr(node, noderepr.ForStmtRepr(for_tok, commas, in_tok, else_tok)) return node - + def parse_for_index_variables(self) -> Tuple[List[NameExpr], List[Type], List[Token]]: # Parse index variables of a 'for' statement. index = List[NameExpr]() types = List[Type]() commas = List[Token]() - + is_paren = self.current_str() == '(' if is_paren: self.skip() - + while True: v = self.parse_name_expr() index.append(v) @@ -911,24 +911,24 @@ def parse_for_index_variables(self) -> Tuple[List[NameExpr], List[Type], commas.append(none) break commas.append(self.skip()) - + if is_paren: self.expect(')') - + return index, types, commas - + def parse_if_stmt(self) -> IfStmt: is_error = False - + if_tok = self.expect('if') expr = List[Node]() try: expr.append(self.parse_expression()) except ParseError: is_error = True - + body = [self.parse_block()[0]] - + elif_toks = List[Token]() while self.current_str() == 'elif': elif_toks.append(self.expect('elif')) @@ -937,14 +937,14 @@ def parse_if_stmt(self) -> IfStmt: except ParseError: is_error = True body.append(self.parse_block()[0]) - + if self.current_str() == 'else': else_tok = self.expect('else') else_body, _ = self.parse_block() else: else_tok = none else_body = None - + if not is_error: node = IfStmt(expr, body, else_body) self.set_repr(node, noderepr.IfStmtRepr(if_tok, elif_toks, @@ -952,7 +952,7 @@ def parse_if_stmt(self) -> IfStmt: return node else: return None - + def parse_try_stmt(self) -> Node: try_tok = self.expect('try') body, _ = self.parse_block() @@ -1005,7 +1005,7 @@ def parse_try_stmt(self) -> Node: return node else: return None - + def parse_with_stmt(self) -> WithStmt: with_tok = self.expect('with') as_toks = List[Token]() @@ -1043,14 +1043,14 @@ def parse_print_stmt(self) -> PrintStmt: break self.expect_break() return PrintStmt(args, newline=not comma) - + # Parsing expressions - + def parse_expression(self, prec: int = 0) -> Node: """Parse a subexpression within a specific precedence context.""" - expr = Undefined # type: Node - t = self.current() # Remember token for setting the line number. - + expr = Undefined # type: Node + t = self.current() # Remember token for setting the line number. + # Parse a "value" expression or unary operator expression and store # that in expr. s = self.current_str() @@ -1085,13 +1085,13 @@ def parse_expression(self, prec: int = 0) -> Node: else: # Invalid expression. self.parse_error() - + # Set the line of the expression node, if not specified. This # simplifies recording the line number as not every node type needs to # deal with it separately. if expr.line < 0: expr.set_line(t) - + # Parse operations that require a left argument (stored in expr). while True: t = self.current() @@ -1120,7 +1120,7 @@ def parse_expression(self, prec: int = 0) -> Node: # comprehension if needed elsewhere. expr = self.parse_generator_expr(expr) else: - break + break elif s == 'if': # Conditional expression. if precedence[''] > prec: @@ -1136,7 +1136,10 @@ def parse_expression(self, prec: int = 0) -> Node: # Either "not in" or an error. op_prec = precedence['in'] if op_prec > prec: - expr = self.parse_bin_op_expr(expr, op_prec) + if op in op_comp: + expr = self.parse_comparison_expr(expr, op_prec) + else: + expr = self.parse_bin_op_expr(expr, op_prec) else: # The operation cannot be associated with the # current left operand due to the precedence @@ -1146,20 +1149,20 @@ def parse_expression(self, prec: int = 0) -> Node: # Not an operation that accepts a left argument; let the # caller handle the rest. break - + # Set the line of the expression node, if not specified. This # simplifies recording the line number as not every node type # needs to deal with it separately. if expr.line < 0: expr.set_line(t) - + return expr - + def parse_parentheses(self) -> Node: lparen = self.skip() if self.current_str() == ')': # Empty tuple (). - expr = self.parse_empty_tuple_expr(lparen) # type: Node + expr = self.parse_empty_tuple_expr(lparen) # type: Node else: # Parenthesised expression. expr = self.parse_expression(0) @@ -1167,13 +1170,13 @@ def parse_parentheses(self) -> Node: expr = ParenExpr(expr) self.set_repr(expr, noderepr.ParenExprRepr(lparen, rparen)) return expr - + def parse_empty_tuple_expr(self, lparen: Any) -> TupleExpr: rparen = self.expect(')') node = TupleExpr([]) self.set_repr(node, noderepr.TupleExprRepr(lparen, [], rparen)) return node - + def parse_list_expr(self) -> Node: """Parse list literal or list comprehension.""" items = List[Node]() @@ -1185,19 +1188,19 @@ def parse_list_expr(self) -> Node: break commas.append(self.expect(',')) if self.current_str() == 'for' and len(items) == 1: - items[0] = self.parse_generator_expr(items[0]) + items[0] = self.parse_generator_expr(items[0]) rbracket = self.expect(']') if len(items) == 1 and isinstance(items[0], GeneratorExpr): - list_comp = ListComprehension(cast(GeneratorExpr, items[0])) - self.set_repr(list_comp, noderepr.ListComprehensionRepr(lbracket, - rbracket)) - return list_comp + list_comp = ListComprehension(cast(GeneratorExpr, items[0])) + self.set_repr(list_comp, noderepr.ListComprehensionRepr(lbracket, + rbracket)) + return list_comp else: expr = ListExpr(items) self.set_repr(expr, noderepr.ListSetExprRepr(lbracket, commas, rbracket, none, none)) return expr - + def parse_generator_expr(self, left_expr: Node) -> GeneratorExpr: indices = List[List[NameExpr]]() sequences = List[Node]() @@ -1227,7 +1230,7 @@ def parse_generator_expr(self, left_expr: Node) -> GeneratorExpr: self.set_repr(gen, noderepr.GeneratorExprRepr(for_toks, commas, in_toks, if_toklists)) return gen - + def parse_expression_list(self) -> Node: prec = precedence[''] expr = self.parse_expression(prec) @@ -1236,14 +1239,14 @@ def parse_expression_list(self) -> Node: else: t = self.current() return self.parse_tuple_expr(expr, prec).set_line(t) - + def parse_conditional_expr(self, left_expr: Node) -> ConditionalExpr: self.expect('if') cond = self.parse_expression(precedence['']) self.expect('else') else_expr = self.parse_expression(precedence['']) return ConditionalExpr(cond, left_expr, else_expr) - + def parse_dict_or_set_expr(self) -> Node: items = List[Tuple[Node, Node]]() lbrace = self.expect('{') @@ -1266,7 +1269,7 @@ def parse_dict_or_set_expr(self) -> Node: self.set_repr(node, noderepr.DictExprRepr(lbrace, colons, commas, rbrace, none, none, none)) return node - + def parse_set_expr(self, first: Node, lbrace: Token) -> SetExpr: items = [first] commas = List[Token]() @@ -1280,7 +1283,7 @@ def parse_set_expr(self, first: Node, lbrace: Token) -> SetExpr: self.set_repr(expr, noderepr.ListSetExprRepr(lbrace, commas, rbrace, none, none)) return expr - + def parse_tuple_expr(self, expr: Node, prec: int = precedence[',']) -> TupleExpr: items = [expr] @@ -1295,14 +1298,14 @@ def parse_tuple_expr(self, expr: Node, node = TupleExpr(items) self.set_repr(node, noderepr.TupleExprRepr(none, commas, none)) return node - + def parse_name_expr(self) -> NameExpr: tok = self.expect_type(Name) node = NameExpr(tok.string) node.set_line(tok) self.set_repr(node, noderepr.NameExprRepr(tok)) return node - + def parse_int_expr(self) -> IntExpr: tok = self.expect_type(IntLit) s = tok.string @@ -1316,7 +1319,7 @@ def parse_int_expr(self) -> IntExpr: node = IntExpr(v) self.set_repr(node, noderepr.IntExprRepr(tok)) return node - + def parse_str_expr(self) -> Node: # XXX \uxxxx literals tok = [self.expect_type(StrLit)] @@ -1332,7 +1335,7 @@ def parse_str_expr(self) -> Node: node = StrExpr(value) self.set_repr(node, noderepr.StrExprRepr(tok)) return node - + def parse_bytes_literal(self) -> Node: # XXX \uxxxx literals tok = [self.expect_type(BytesLit)] @@ -1342,12 +1345,12 @@ def parse_bytes_literal(self) -> Node: tok.append(t) value += t.parsed() if self.pyversion >= 3: - node = BytesExpr(value) # type: Node + node = BytesExpr(value) # type: Node else: node = StrExpr(value) self.set_repr(node, noderepr.StrExprRepr(tok)) return node - + def parse_unicode_literal(self) -> Node: # XXX \uxxxx literals tok = [self.expect_type(UnicodeLit)] @@ -1358,18 +1361,18 @@ def parse_unicode_literal(self) -> Node: value += t.parsed() if self.pyversion >= 3: # Python 3.3 supports u'...' as an alias of '...'. - node = StrExpr(value) # type: Node + node = StrExpr(value) # type: Node else: node = UnicodeExpr(value) self.set_repr(node, noderepr.StrExprRepr(tok)) return node - + def parse_float_expr(self) -> FloatExpr: tok = self.expect_type(FloatLit) node = FloatExpr(float(tok.string)) self.set_repr(node, noderepr.FloatExprRepr(tok)) return node - + def parse_call_expr(self, callee: Any) -> CallExpr: lparen = self.expect('(') (args, kinds, names, @@ -1379,7 +1382,7 @@ def parse_call_expr(self, callee: Any) -> CallExpr: self.set_repr(node, noderepr.CallExprRepr(lparen, commas, star, star2, assigns, rparen)) return node - + def parse_arg_expr(self) -> Tuple[List[Node], List[int], List[str], List[Token], Token, Token, List[List[Token]]]: @@ -1393,14 +1396,14 @@ def parse_arg_expr(self) -> Tuple[List[Node], List[int], List[str], * token (if any) ** token (if any) (assignment, name) tokens - """ - args = [] # type: List[Node] - kinds = [] # type: List[int] - names = [] # type: List[str] + """ + args = [] # type: List[Node] + kinds = [] # type: List[int] + names = [] # type: List[str] star = none star2 = none - commas = [] # type: List[Token] - keywords = [] # type: List[List[Token]] + commas = [] # type: List[Token] + keywords = [] # type: List[List[Token]] var_arg = False dict_arg = False named_args = False @@ -1437,7 +1440,7 @@ def parse_arg_expr(self) -> Tuple[List[Node], List[int], List[str], break commas.append(self.expect(',')) return args, kinds, names, commas, star, star2, keywords - + def parse_member_expr(self, expr: Any) -> Node: dot = self.expect('.') name = self.expect_type(Name) @@ -1454,7 +1457,7 @@ def parse_member_expr(self, expr: Any) -> Node: node = MemberExpr(expr, name.string) self.set_repr(node, noderepr.MemberExprRepr(dot, name)) return node - + def parse_index_expr(self, base: Any) -> IndexExpr: lbracket = self.expect('[') if self.current_str() != ':': @@ -1469,7 +1472,7 @@ def parse_index_expr(self, base: Any) -> IndexExpr: else: end_index = None colon2 = none - stride = None # type: Node + stride = None # type: Node if self.current_str() == ':': colon2 = self.expect(':') if self.current_str() != ']': @@ -1480,28 +1483,53 @@ def parse_index_expr(self, base: Any) -> IndexExpr: node = IndexExpr(base, index) self.set_repr(node, noderepr.IndexExprRepr(lbracket, rbracket)) return node - + def parse_bin_op_expr(self, left: Node, prec: int) -> OpExpr: op = self.expect_type(Op) - op2 = none op_str = op.string - if op_str == 'not': - if self.current_str() == 'in': - op_str = 'not in' - op2 = self.skip() - else: - self.parse_error() - elif op_str == 'is' and self.current_str() == 'not': - op_str = 'is not' - op2 = self.skip() - elif op_str == '~': + if op_str == '~': self.ind -= 1 self.parse_error() right = self.parse_expression(prec) node = OpExpr(op_str, left, right) - self.set_repr(node, noderepr.OpExprRepr(op, op2)) + self.set_repr(node, noderepr.OpExprRepr(op)) + return node + + def parse_comparison_expr(self, left: Node, prec: int) -> ComparisonExpr: + operators = [] # type: List[Tuple[Token, Token]] + operators_str = [] # type: List[str] + operands = [left] + + while True: + op = self.expect_type(Op) + op2 = none + op_str = op.string + if op_str == 'not': + if self.current_str() == 'in': + op_str = 'not in' + op2 = self.skip() + else: + self.parse_error() + elif op_str == 'is' and self.current_str() == 'not': + op_str = 'is not' + op2 = self.skip() + + operators_str.append(op_str) + operators.append( (op, op2) ) + operand = self.parse_expression(prec) + operands.append(operand) + + # Continue if next token is a comparison operator + t = self.current() + s = self.current_str() + if s not in op_comp: + break + + node = ComparisonExpr(operators_str, operands) + self.set_repr(node, noderepr.ComparisonExprRepr(operators)) return node - + + def parse_unary_expr(self) -> UnaryExpr: op_tok = self.skip() op = op_tok.string @@ -1513,11 +1541,11 @@ def parse_unary_expr(self) -> UnaryExpr: node = UnaryExpr(op, expr) self.set_repr(node, noderepr.UnaryExprRepr(op_tok)) return node - + def parse_lambda_expr(self) -> FuncExpr: is_error = False lambda_tok = self.expect('lambda') - + (args, init, kinds, has_inits, arg_names, commas, asterisk, assigns, arg_types) = self.parse_arg_list(allow_signature=False) @@ -1532,14 +1560,14 @@ def parse_lambda_expr(self) -> FuncExpr: ret_type = UnboundType('__builtins__.object') typ = self.build_func_annotation(ret_type, arg_types, kinds, names, lambda_tok.line, is_default_ret=True) - + colon = self.expect(':') - + expr = self.parse_expression(precedence[',']) - + body = Block([ReturnStmt(expr).set_line(lambda_tok)]) body.set_line(colon) - + node = FuncExpr(args, kinds, init, body, typ) self.set_repr(node, noderepr.FuncExprRepr( @@ -1547,59 +1575,59 @@ def parse_lambda_expr(self) -> FuncExpr: noderepr.FuncArgsRepr(none, none, arg_names, commas, assigns, asterisk))) return node - + # Helper methods - + def skip(self) -> Token: self.ind += 1 return self.tok[self.ind - 1] - + def expect(self, string: str) -> Token: if self.current_str() == string: self.ind += 1 return self.tok[self.ind - 1] else: self.parse_error() - + def expect_indent(self) -> Token: if isinstance(self.current(), Indent): return self.expect_type(Indent) else: self.fail('Expected an indented block', self.current().line) return none - + def fail(self, msg: str, line: int) -> None: self.errors.report(line, msg) - + def expect_type(self, typ: type) -> Token: if isinstance(self.current(), typ): self.ind += 1 return self.tok[self.ind - 1] else: self.parse_error() - + def expect_colon_and_break(self) -> Tuple[Token, Token]: return self.expect_type(Colon), self.expect_type(Break) - + def expect_break(self) -> Token: return self.expect_type(Break) - + def expect_end(self) -> Tuple[Token, Token]: return self.expect('end'), self.expect_type(Break) - + def current(self) -> Token: return self.tok[self.ind] - + def current_str(self) -> str: return self.current().string - + def peek(self) -> Token: return self.tok[self.ind + 1] - + def parse_error(self) -> None: self.parse_error_at(self.current()) raise ParseError() - + def parse_error_at(self, tok: Token, skip: bool = True) -> None: msg = '' if isinstance(tok, LexError): @@ -1609,12 +1637,12 @@ def parse_error_at(self, tok: Token, skip: bool = True) -> None: msg = 'Inconsistent indentation' else: msg = 'Parse error before {}'.format(token_repr(tok)) - + self.errors.report(tok.line, msg) - + if skip: self.skip_until_next_line() - + def skip_until_break(self) -> None: n = 0 while (not isinstance(self.current(), Break) @@ -1623,20 +1651,20 @@ def skip_until_break(self) -> None: n += 1 if isinstance(self.tok[self.ind - 1], Colon) and n > 1: self.ind -= 1 - + def skip_until_next_line(self) -> None: self.skip_until_break() if isinstance(self.current(), Break): self.skip() - + def eol(self) -> bool: return isinstance(self.current(), Break) or self.eof() - + def eof(self) -> bool: return isinstance(self.current(), Eof) - + # Type annotation related functionality - + def parse_type(self) -> Type: line = self.current().line try: @@ -1675,16 +1703,16 @@ def parse_type_comment(self, token: Token, signature: bool) -> Type: return None return type else: - return None - + return None + # Representation management - + def set_repr(self, node: Node, repr: Any) -> None: node.repr = repr - + def repr(self, node: Node) -> Any: return node.repr - + def paren_repr(self, e: Node) -> Tuple[List[Token], List[Token]]: """If e is a ParenExpr, return an array of left-paren tokens (more that one if nested parens) and an array of corresponding diff --git a/mypy/pprinter.py b/mypy/pprinter.py index 9bc3cb93655e..5e20a5d9fe74 100644 --- a/mypy/pprinter.py +++ b/mypy/pprinter.py @@ -21,12 +21,12 @@ class PrettyPrintVisitor(NodeVisitor): def __init__(self) -> None: super().__init__() - self.result = [] # type: List[str] + self.result = [] # type: List[str] self.indent = 0 def output(self) -> str: return ''.join(self.result) - + # # Definitions # @@ -34,12 +34,12 @@ def output(self) -> str: def visit_mypy_file(self, file: MypyFile) -> None: for d in file.defs: d.accept(self) - + def visit_class_def(self, tdef: ClassDef) -> None: self.string('class ') self.string(tdef.name) if tdef.base_types: - b = [] # type: List[str] + b = [] # type: List[str] for bt in tdef.base_types: if not bt: continue @@ -54,7 +54,7 @@ def visit_class_def(self, tdef: ClassDef) -> None: for d in tdef.defs.body: d.accept(self) self.dedent() - + def visit_func_def(self, fdef: FuncDef) -> None: # FIX varargs, default args, keyword args etc. ftyp = cast(Callable, fdef.type) @@ -74,7 +74,7 @@ def visit_func_def(self, fdef: FuncDef) -> None: self.string(') -> ') self.type(ftyp.ret_type) fdef.body.accept(self) - + def visit_var_def(self, vdef: VarDef) -> None: if vdef.items[0].name() not in nodes.implicit_module_attrs: self.string(vdef.items[0].name()) @@ -84,7 +84,7 @@ def visit_var_def(self, vdef: VarDef) -> None: self.string(' = ') self.node(vdef.init) self.string('\n') - + # # Statements # @@ -97,23 +97,23 @@ def visit_block(self, b): def visit_pass_stmt(self, o): self.string('pass\n') - + def visit_return_stmt(self, o): self.string('return ') if o.expr: self.node(o.expr) self.string('\n') - + def visit_expression_stmt(self, o): self.node(o.expr) self.string('\n') - + def visit_assignment_stmt(self, o): if isinstance(o.rvalue, CallExpr) and isinstance(o.rvalue.analyzed, TypeVarExpr): # Skip type variable definition 'x = typevar(...)'. return - self.node(o.lvalues[0]) # FIX multiple lvalues + self.node(o.lvalues[0]) # FIX multiple lvalues if o.type: self.string(': ') self.type(o.type) @@ -140,11 +140,11 @@ def visit_while_stmt(self, o): if o.else_body: self.string('else') self.node(o.else_body) - + # # Expressions # - + def visit_call_expr(self, o): if o.analyzed: o.analyzed.accept(self) @@ -167,10 +167,10 @@ def visit_member_expr(self, o): self.string('.' + o.name) if o.direct: self.string('!') - + def visit_name_expr(self, o): self.string(o.name) - + def visit_coerce_expr(self, o: CoerceExpr) -> None: self.string('{') self.full_type(o.target_type) @@ -180,14 +180,14 @@ def visit_coerce_expr(self, o: CoerceExpr) -> None: self.string(' ') self.node(o.expr) self.string('}') - + def visit_type_expr(self, o: TypeExpr) -> None: # Type expressions are only generated during transformation, so we must # use automatic formatting. self.string('<') self.full_type(o.type) self.string('>') - + def visit_index_expr(self, o): if o.analyzed: o.analyzed.accept(self) @@ -199,7 +199,7 @@ def visit_index_expr(self, o): def visit_int_expr(self, o): self.string(str(o.value)) - + def visit_str_expr(self, o): self.string(repr(o.value)) @@ -208,6 +208,12 @@ def visit_op_expr(self, o): self.string(' %s ' % o.op) self.node(o.right) + def visit_comparison_expr(self, o): + self.node(o.operands[0]) + for operator, operand in zip(o.operators, o.operands[1:]): + self.string(' %s ' % operator) + self.node(operand) + def visit_unary_expr(self, o): self.string(o.op) if o.op == 'not': @@ -218,7 +224,7 @@ def visit_paren_expr(self, o): self.string('(') self.node(o.expr) self.string(')') - + def visit_super_expr(self, o): self.string('super().') self.string(o.name) @@ -237,7 +243,7 @@ def visit_type_application(self, o): def visit_undefined_expr(self, o): # Omit declared type as redundant. self.string('Undefined') - + # # Helpers # @@ -261,13 +267,13 @@ def last_output_char(self) -> str: if self.result: return self.result[-1][-1] return '' - + def type(self, t): """Pretty-print a type with erased type arguments.""" if t: v = TypeErasedPrettyPrintVisitor() self.string(t.accept(v)) - + def full_type(self, t): """Pretty-print a type, includingn type arguments.""" if t: @@ -283,19 +289,19 @@ class TypeErasedPrettyPrintVisitor(TypeVisitor[str]): Note that the translation does not preserve all information about the types, but this is fine since this is only used in test case output. """ - + def visit_any(self, t): return 'Any' - + def visit_void(self, t): return 'None' - + def visit_instance(self, t): return t.type.name() - + def visit_type_var(self, t): return 'Any*' - + def visit_runtime_type_var(self, t): v = PrettyPrintVisitor() t.node.accept(v) @@ -306,27 +312,27 @@ class TypePrettyPrintVisitor(TypeVisitor[str]): """Pretty-print types. Include type variables. - + Note that the translation does not preserve all information about the types, but this is fine since this is only used in test case output. """ - + def visit_any(self, t): return 'Any' - + def visit_void(self, t): return 'None' - + def visit_instance(self, t): s = t.type.name() if t.args: argstr = ', '.join([a.accept(self) for a in t.args]) s += '[%s]' % argstr return s - + def visit_type_var(self, t): return 'Any*' - + def visit_runtime_type_var(self, t): v = PrettyPrintVisitor() t.node.accept(v) diff --git a/mypy/semanal.py b/mypy/semanal.py index a3ade28ddf61..2d1828349813 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -21,11 +21,11 @@ cyclic references between modules, such as module 'a' that imports module 'b' and used names defined in b *and* vice versa. The first pass can be performed before dependent modules have been processed. - + * SemanticAnalyzer is the second pass. It does the bulk of the work. It assumes that dependent modules have been semantically analyzed, up to the second pass, unless there is a import cycle. - + * ThirdPass checks that type argument counts are valid; for example, it will reject Dict[int]. We don't do this in the second pass, since we infer the type argument counts of classes during this @@ -56,7 +56,7 @@ SymbolTableNode, TVAR, UNBOUND_TVAR, ListComprehension, GeneratorExpr, FuncExpr, MDEF, FuncBase, Decorator, SetExpr, UndefinedExpr, TypeVarExpr, StrExpr, PrintStmt, ConditionalExpr, DucktypeExpr, DisjointclassExpr, - ARG_POS, ARG_NAMED, MroError, type_aliases, YieldFromStmt, YieldFromExpr + ComparisonExpr, ARG_POS, ARG_NAMED, MroError, type_aliases, YieldFromStmt, YieldFromExpr ) from mypy.visitor import NodeVisitor from mypy.traverser import TraverserVisitor @@ -93,7 +93,7 @@ class SemanticAnalyzer(NodeVisitor): This is the second phase of semantic analysis. """ - + # Library search paths lib_path = Undefined(List[str]) # Module name space @@ -117,7 +117,7 @@ class SemanticAnalyzer(NodeVisitor): cur_mod_id = '' # Current module id (or None) (phase 2) imports = Undefined(Set[str]) # Imported modules (during phase 2 analysis) errors = Undefined(Errors) # Keep track of generated errors - + def __init__(self, lib_path: List[str], errors: Errors, pyversion: int = 3) -> None: """Construct semantic analyzer. @@ -137,28 +137,28 @@ def __init__(self, lib_path: List[str], errors: Errors, self.modules = {} self.pyversion = pyversion self.stored_vars = Dict[Node, Type]() - + def visit_file(self, file_node: MypyFile, fnam: str) -> None: self.errors.set_file(fnam) self.globals = file_node.names self.cur_mod_id = file_node.fullname() - + if 'builtins' in self.modules: self.globals['__builtins__'] = SymbolTableNode( MODULE_REF, self.modules['builtins'], self.cur_mod_id) - + defs = file_node.defs for d in defs: d.accept(self) if self.cur_mod_id == 'builtins': remove_imported_names_from_symtable(self.globals, 'builtins') - + def visit_func_def(self, defn: FuncDef) -> None: self.errors.push_function(defn.name()) self.update_function_type_variables(defn) self.errors.pop_function() - + if self.is_class_scope(): # Method definition defn.is_conditional = self.block_depth[-1] > 0 @@ -186,7 +186,7 @@ def visit_func_def(self, defn: FuncDef) -> None: not defn.is_overload): self.add_local_func(defn, defn) defn._fullname = defn.name() - + self.errors.push_function(defn.name()) self.analyse_function(defn) self.errors.pop_function() @@ -249,7 +249,7 @@ def find_type_variables_in_type( def is_defined_type_var(self, tvar: str, context: Node) -> bool: return self.lookup_qualified(tvar, context).kind == TVAR - + def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: t = List[Callable]() for item in defn.items: @@ -261,17 +261,17 @@ def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: if not [dec for dec in item.decorators if refers_to_fullname(dec, 'typing.overload')]: self.fail("'overload' decorator expected", item) - + defn.type = Overloaded(t) defn.type.line = defn.line - + if self.is_class_scope(): self.type.names[defn.name()] = SymbolTableNode(MDEF, defn, typ=defn.type) defn.info = self.type elif self.is_func_scope(): self.add_local_func(defn, defn) - + def analyse_function(self, defn: FuncItem) -> None: is_method = self.is_class_scope() tvarnodes = self.add_func_type_variables_to_symbol_table(defn) @@ -293,18 +293,18 @@ def analyse_function(self, defn: FuncItem) -> None: for init_ in defn.init: if init_: init_.lvalues[0].accept(self) - + # The first argument of a non-static, non-class method is like 'self' # (though the name could be different), having the enclosing class's # instance type. if is_method and not defn.is_static and not defn.is_class and defn.args: defn.args[0].is_self = True - + defn.body.accept(self) disable_typevars(tvarnodes) self.leave() self.function_stack.pop() - + def add_func_type_variables_to_symbol_table( self, defn: FuncItem) -> List[SymbolTableNode]: nodes = List[SymbolTableNode]() @@ -320,13 +320,13 @@ def add_func_type_variables_to_symbol_table( nodes.append(node) names.add(name) return nodes - + def type_var_names(self) -> Set[str]: if not self.type: return set() else: return set(self.type.type_vars) - + def add_type_var(self, fullname: str, id: int, context: Context) -> SymbolTableNode: node = self.lookup_qualified(fullname, context) @@ -340,7 +340,7 @@ def check_function_signature(self, fdef: FuncItem) -> None: self.fail('Type signature has too few arguments', fdef) elif len(sig.arg_types) > len(fdef.args): self.fail('Type signature has too many arguments', fdef) - + def visit_class_def(self, defn: ClassDef) -> None: self.clean_up_bases_and_infer_type_variables(defn) self.setup_class_def_analysis(defn) @@ -355,7 +355,7 @@ def visit_class_def(self, defn: ClassDef) -> None: self.calculate_abstract_status(defn.info) self.setup_ducktyping(defn) - + # Restore analyzer state. self.block_depth.pop() self.locals.pop() @@ -369,7 +369,7 @@ def analyze_class_decorator(self, defn: ClassDef, decorator: Node) -> None: decorator.accept(self) if refers_to_fullname(decorator, 'typing.builtinclass'): defn.is_builtinclass = True - + def calculate_abstract_status(self, typ: TypeInfo) -> None: """Calculate abstract status of a class. @@ -386,7 +386,7 @@ def calculate_abstract_status(self, typ: TypeInfo) -> None: # check arbitrarily the first overload item. If the # different items have a different abstract status, there # should be an error reported elsewhere. - func = node.items[0] # type: Node + func = node.items[0] # type: Node else: func = node if isinstance(func, Decorator): @@ -419,7 +419,7 @@ def clean_up_bases_and_infer_type_variables(self, defn: ClassDef) -> None: For example, consider this class: . class Foo(Bar, Generic[t]): ... - + Now we will remove Generic[t] from bases of Foo and infer that the type variable 't' is a type argument of Foo. """ @@ -488,12 +488,12 @@ def setup_class_def_analysis(self, defn: ClassDef) -> None: tvarnodes = self.add_class_type_variables_to_symbol_table(defn.info) # Remember previous active class and type vars of *this* class. self.type_stack.append((self.type, tvarnodes)) - self.locals.append(None) # Add class scope - self.block_depth.append(-1) # The class body increments this to 0 + self.locals.append(None) # Add class scope + self.block_depth.append(-1) # The class body increments this to 0 self.type = defn.info def analyze_base_classes(self, defn: ClassDef) -> None: - """Analyze and set up base classes.""" + """Analyze and set up base classes.""" bases = List[Instance]() for i in range(len(defn.base_types)): base = self.anal_type(defn.base_types[i]) @@ -565,10 +565,10 @@ def object_type(self) -> Instance: def named_type(self, qualified_name: str) -> Instance: sym = self.lookup_qualified(qualified_name, None) return Instance(cast(TypeInfo, sym.node), []) - + def is_instance_type(self, t: Type) -> bool: return isinstance(t, Instance) - + def add_class_type_variables_to_symbol_table( self, info: TypeInfo) -> List[SymbolTableNode]: vars = info.type_vars @@ -578,7 +578,7 @@ def add_class_type_variables_to_symbol_table( node = self.add_type_var(vars[i], i + 1, info) nodes.append(node) return nodes - + def visit_import(self, i: Import) -> None: for id, as_id in i.ids: if as_id != id: @@ -593,7 +593,7 @@ def add_module_symbol(self, id: str, as_id: str, context: Context) -> None: self.add_symbol(as_id, SymbolTableNode(MODULE_REF, m, self.cur_mod_id), context) else: self.add_unknown_symbol(as_id, context) - + def visit_import_from(self, i: ImportFrom) -> None: if i.id in self.modules: m = self.modules[i.id] @@ -617,7 +617,7 @@ def normalize_type_alias(self, node: SymbolTableNode, # Node refers to an aliased type such as typing.List; normalize. node = self.lookup_qualified(type_aliases[node.fullname], ctx) return node - + def visit_import_all(self, i: ImportAll) -> None: if i.id in self.modules: m = self.modules[i.id] @@ -636,11 +636,11 @@ def add_unknown_symbol(self, name: str, context: Context) -> None: var.is_ready = True var.type = AnyType() self.add_symbol(name, SymbolTableNode(GDEF, var, self.cur_mod_id), context) - + # # Statements # - + def visit_block(self, b: Block) -> None: if b.is_unreachable: return @@ -648,15 +648,15 @@ def visit_block(self, b: Block) -> None: for s in b.body: s.accept(self) self.block_depth[-1] -= 1 - + def visit_block_maybe(self, b: Block) -> None: if b: self.visit_block(b) - + def visit_var_def(self, defn: VarDef) -> None: for i in range(len(defn.items)): defn.items[i].type = self.anal_type(defn.items[i].type) - + for v in defn.items: if self.is_func_scope(): defn.kind = LDEF @@ -669,17 +669,17 @@ def visit_var_def(self, defn: VarDef) -> None: elif v.name not in self.globals: defn.kind = GDEF self.add_var(v, defn) - + if defn.init: defn.init.accept(self) - + def anal_type(self, t: Type) -> Type: if t: a = TypeAnalyser(self.lookup_qualified, self.stored_vars, self.fail) return t.accept(a) else: return None - + def visit_assignment_stmt(self, s: AssignmentStmt) -> None: for lval in s.lvalues: self.analyse_lvalue(lval, explicit_type=s.type is not None) @@ -690,7 +690,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: s.type = self.infer_type_from_undefined(s.rvalue) # For simple assignments, allow binding type aliases if (s.type is None and len(s.lvalues) == 1 and - isinstance(s.lvalues[0], NameExpr)): + isinstance(s.lvalues[0], NameExpr)): res = analyse_node(self.lookup_qualified, s.rvalue, s) if res: # XXX Need to remove this later if reassigned @@ -722,7 +722,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> None: # that refers to a type, rather than making this # just an alias for the type. self.globals[lvalue.name].node = node - + def analyse_lvalue(self, lval: Node, nested: bool = False, add_global: bool = False, explicit_type: bool = False) -> None: @@ -739,7 +739,7 @@ def analyse_lvalue(self, lval: Node, nested: bool = False, # Define new global name. v = Var(lval.name) v._fullname = self.qualified_name(lval.name) - v.is_ready = False # Type not inferred yet + v.is_ready = False # Type not inferred yet lval.node = v lval.is_def = True lval.kind = GDEF @@ -761,7 +761,7 @@ def analyse_lvalue(self, lval: Node, nested: bool = False, lval.fullname = lval.name self.add_local(v, lval) elif not self.is_func_scope() and (self.type and - lval.name not in self.type.names): + lval.name not in self.type.names): # Define a new attribute within class body. v = Var(lval.name) v.info = self.type @@ -798,7 +798,7 @@ def analyse_lvalue(self, lval: Node, nested: bool = False, explicit_type = explicit_type) else: self.fail('Invalid assignment target', lval) - + def analyse_member_lvalue(self, lval: MemberExpr) -> None: lval.accept(self) if (self.is_self_member_ref(lval) and @@ -848,7 +848,7 @@ def store_declared_types(self, lvalue: Node, typ: Type) -> None: self.store_declared_types(item, itemtype) else: self.fail('Tuple type expected for multiple variables', - lvalue) + lvalue) elif isinstance(lvalue, ParenExpr): self.store_declared_types(lvalue.expr, typ) else: @@ -903,7 +903,7 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> None: else: self.fail('The values argument must be a tuple literal', s) return - else: + else: self.fail('The values argument must be in parentheses (...)', s) return @@ -979,20 +979,20 @@ def check_decorated_function_is_method(self, decorator: str, context: Context) -> None: if not self.type or self.is_func_scope(): self.fail("'%s' used with a non-method" % decorator, context) - + def visit_expression_stmt(self, s: ExpressionStmt) -> None: s.expr.accept(self) - + def visit_return_stmt(self, s: ReturnStmt) -> None: if not self.is_func_scope(): self.fail("'return' outside function", s) if s.expr: s.expr.accept(self) - + def visit_raise_stmt(self, s: RaiseStmt) -> None: if s.expr: s.expr.accept(self) - + def visit_yield_stmt(self, s: YieldStmt) -> None: if not self.is_func_scope(): self.fail("'yield' outside function", s) @@ -1015,21 +1015,21 @@ def visit_operator_assignment_stmt(self, s: OperatorAssignmentStmt) -> None: s.lvalue.accept(self) s.rvalue.accept(self) - + def visit_while_stmt(self, s: WhileStmt) -> None: s.expr.accept(self) self.loop_depth += 1 s.body.accept(self) self.loop_depth -= 1 self.visit_block_maybe(s.else_body) - + def visit_for_stmt(self, s: ForStmt) -> None: s.expr.accept(self) - + # Bind index variables and check if they define new names. for n in s.index: self.analyse_lvalue(n) - + # Analyze index variable types. for i in range(len(s.types)): t = s.types[i] @@ -1038,37 +1038,37 @@ def visit_for_stmt(self, s: ForStmt) -> None: v = cast(Var, s.index[i].node) # TODO check if redefinition v.type = s.types[i] - + # Report error if only some of the loop variables have annotations. if s.types != [None] * len(s.types) and None in s.types: self.fail('Cannot mix unannotated and annotated loop variables', s) - + self.loop_depth += 1 self.visit_block(s.body) self.loop_depth -= 1 - + self.visit_block_maybe(s.else_body) - + def visit_break_stmt(self, s: BreakStmt) -> None: if self.loop_depth == 0: self.fail("'break' outside loop", s) - + def visit_continue_stmt(self, s: ContinueStmt) -> None: if self.loop_depth == 0: self.fail("'continue' outside loop", s) - + def visit_if_stmt(self, s: IfStmt) -> None: infer_reachability_of_if_statement(s, pyversion=self.pyversion) for i in range(len(s.expr)): s.expr[i].accept(self) self.visit_block(s.body[i]) self.visit_block_maybe(s.else_body) - + def visit_try_stmt(self, s: TryStmt) -> None: self.analyze_try_stmt(s, self) def analyze_try_stmt(self, s: TryStmt, visitor: NodeVisitor, - add_global: bool = False) -> None: + add_global: bool = False) -> None: s.body.accept(visitor) for type, var, handler in zip(s.types, s.vars, s.handlers): if type: @@ -1080,7 +1080,7 @@ def analyze_try_stmt(self, s: TryStmt, visitor: NodeVisitor, s.else_body.accept(visitor) if s.finally_body: s.finally_body.accept(visitor) - + def visit_with_stmt(self, s: WithStmt) -> None: for e in s.expr: e.accept(self) @@ -1088,12 +1088,12 @@ def visit_with_stmt(self, s: WithStmt) -> None: if n: self.analyse_lvalue(n) self.visit_block(s.body) - + def visit_del_stmt(self, s: DelStmt) -> None: s.expr.accept(self) if not isinstance(s.expr, (IndexExpr, NameExpr, MemberExpr)): self.fail('Invalid delete target', s) - + def visit_global_decl(self, g: GlobalDecl) -> None: for n in g.names: self.global_decls[-1].add(n) @@ -1101,11 +1101,11 @@ def visit_global_decl(self, g: GlobalDecl) -> None: def visit_print_stmt(self, s: PrintStmt) -> None: for arg in s.args: arg.accept(self) - + # # Expressions # - + def visit_name_expr(self, expr: NameExpr) -> None: n = self.lookup(expr.name, expr) if n: @@ -1116,30 +1116,30 @@ def visit_name_expr(self, expr: NameExpr) -> None: expr.kind = n.kind expr.node = (cast(Node, n.node)) expr.fullname = n.fullname - + def visit_super_expr(self, expr: SuperExpr) -> None: if not self.type: self.fail('"super" used outside class', expr) - return + return expr.info = self.type - + def visit_tuple_expr(self, expr: TupleExpr) -> None: for item in expr.items: item.accept(self) - + def visit_list_expr(self, expr: ListExpr) -> None: for item in expr.items: item.accept(self) - + def visit_set_expr(self, expr: SetExpr) -> None: for item in expr.items: item.accept(self) - + def visit_dict_expr(self, expr: DictExpr) -> None: for key, value in expr.items: key.accept(self) value.accept(self) - + def visit_paren_expr(self, expr: ParenExpr) -> None: expr.expr.accept(self) @@ -1174,7 +1174,7 @@ def visit_call_expr(self, expr: CallExpr) -> None: elif refers_to_fullname(expr.callee, 'typing.Any'): # Special form Any(...). if not self.check_fixed_args(expr, 1, 'Any'): - return + return expr.analyzed = CastExpr(expr.args[0], AnyType()) expr.analyzed.line = expr.line expr.analyzed.accept(self) @@ -1237,7 +1237,7 @@ def check_fixed_args(self, expr: CallExpr, numargs: int, (name, numargs, s), expr) return False return True - + def visit_member_expr(self, expr: MemberExpr) -> None: base = expr.expr base.accept(self) @@ -1253,14 +1253,18 @@ def visit_member_expr(self, expr: MemberExpr) -> None: expr.kind = n.kind expr.fullname = n.fullname expr.node = n.node - + def visit_op_expr(self, expr: OpExpr) -> None: expr.left.accept(self) expr.right.accept(self) - + + def visit_comparison_expr(self, expr: ComparisonExpr) -> None: + for operand in expr.operands: + operand.accept(self) + def visit_unary_expr(self, expr: UnaryExpr) -> None: expr.expr.accept(self) - + def visit_index_expr(self, expr: IndexExpr) -> None: expr.base.accept(self) if refers_to_class_or_function(expr.base): @@ -1291,14 +1295,14 @@ def visit_slice_expr(self, expr: SliceExpr) -> None: expr.end_index.accept(self) if expr.stride: expr.stride.accept(self) - + def visit_cast_expr(self, expr: CastExpr) -> None: expr.expr.accept(self) expr.type = self.anal_type(expr.type) def visit_undefined_expr(self, expr: UndefinedExpr) -> None: expr.type = self.anal_type(expr.type) - + def visit_type_application(self, expr: TypeApplication) -> None: expr.expr.accept(self) for i in range(len(expr.types)): @@ -1337,11 +1341,11 @@ def visit_ducktype_expr(self, expr: DucktypeExpr) -> None: def visit_disjointclass_expr(self, expr: DisjointclassExpr) -> None: expr.cls.accept(self) - + # # Helpers # - + def lookup(self, name: str, ctx: Context) -> SymbolTableNode: """Look up an unqualified name in all active namespaces.""" # 1. Name declared using 'global x' takes precedence @@ -1366,6 +1370,9 @@ def lookup(self, name: str, ctx: Context) -> SymbolTableNode: if b: table = cast(MypyFile, b.node).names if name in table: + if name[0] == "_" and name[1] != "_": + self.name_not_defined(name, ctx) + return None node = table[name] # Only succeed if we are not using a type alias such List -- these must be # be accessed via the typing module. @@ -1374,13 +1381,13 @@ def lookup(self, name: str, ctx: Context) -> SymbolTableNode: # Give up. self.name_not_defined(name, ctx) return None - + def lookup_qualified(self, name: str, ctx: Context) -> SymbolTableNode: if '.' not in name: return self.lookup(name, ctx) else: parts = name.split('.') - n = self.lookup(parts[0], ctx) # type: SymbolTableNode + n = self.lookup(parts[0], ctx) # type: SymbolTableNode if n: for i in range(1, len(parts)): if isinstance(n.node, TypeInfo): @@ -1392,14 +1399,14 @@ def lookup_qualified(self, name: str, ctx: Context) -> SymbolTableNode: if n: n = self.normalize_type_alias(n, ctx) return n - + def qualified_name(self, n: str) -> str: return self.cur_mod_id + '.' + n - + def enter(self) -> None: self.locals.append(SymbolTable()) self.global_decls.append(set()) - + def leave(self) -> None: self.locals.pop() self.global_decls.pop() @@ -1428,14 +1435,14 @@ def add_symbol(self, name: str, node: SymbolTableNode, # of multiple submodules of a package (e.g. a.x and a.y). self.name_already_defined(name, context) self.globals[name] = node - + def add_var(self, v: Var, ctx: Context) -> None: if self.is_func_scope(): self.add_local(v, ctx) else: self.globals[v.name()] = SymbolTableNode(GDEF, v, self.cur_mod_id) v._fullname = self.qualified_name(v.name()) - + def add_local(self, v: Var, ctx: Context) -> None: if v.name() in self.locals[-1]: self.name_already_defined(v.name(), ctx) @@ -1447,7 +1454,7 @@ def add_local_func(self, defn: FuncBase, ctx: Context) -> None: if defn.name() in self.locals[-1]: self.name_already_defined(defn.name(), ctx) self.locals[-1][defn.name()] = SymbolTableNode(LDEF, defn) - + def check_no_global(self, n: str, ctx: Context, is_func: bool = False) -> None: if n in self.globals: @@ -1456,20 +1463,20 @@ def check_no_global(self, n: str, ctx: Context, "must be next to each other)").format(n), ctx) else: self.name_already_defined(n, ctx) - + def name_not_defined(self, name: str, ctx: Context) -> None: self.fail("Name '{}' is not defined".format(name), ctx) - + def name_already_defined(self, name: str, ctx: Context) -> None: self.fail("Name '{}' already defined".format(name), ctx) - + def fail(self, msg: str, ctx: Context) -> None: self.errors.report(ctx.get_line(), msg) class FirstPass(NodeVisitor): """First phase of semantic analysis""" - + def __init__(self, sem: SemanticAnalyzer) -> None: self.sem = sem self.pyversion = sem.pyversion @@ -1491,15 +1498,15 @@ def analyze(self, file: MypyFile, fnam: str, mod_id: str) -> None: sem.block_depth = [0] defs = file.defs - + # Add implicit definitions of module '__name__' etc. for n in implicit_module_attrs: name_def = VarDef([Var(n, AnyType())], True) defs.insert(0, name_def) - + for d in defs: d.accept(self) - + # Add implicit definition of 'None' to builtins, as we cannot define a # variable with a None type explicitly. if mod_id == 'builtins': @@ -1514,12 +1521,12 @@ def visit_block(self, b: Block) -> None: for node in b.body: node.accept(self) self.sem.block_depth[-1] -= 1 - + def visit_assignment_stmt(self, s: AssignmentStmt) -> None: for lval in s.lvalues: self.sem.analyse_lvalue(lval, add_global=True, explicit_type=s.type is not None) - + def visit_func_def(self, d: FuncDef) -> None: sem = self.sem d.is_conditional = sem.block_depth[-1] > 0 @@ -1532,13 +1539,13 @@ def visit_func_def(self, d: FuncDef) -> None: sem.check_no_global(d.name(), d, True) d._fullname = sem.qualified_name(d.name()) sem.globals[d.name()] = SymbolTableNode(GDEF, d, sem.cur_mod_id) - + def visit_overloaded_func_def(self, d: OverloadedFuncDef) -> None: self.sem.check_no_global(d.name(), d) d._fullname = self.sem.qualified_name(d.name()) self.sem.globals[d.name()] = SymbolTableNode(GDEF, d, self.sem.cur_mod_id) - + def visit_class_def(self, d: ClassDef) -> None: self.sem.check_no_global(d.name, d) d.fullname = self.sem.qualified_name(d.name) @@ -1547,7 +1554,7 @@ def visit_class_def(self, d: ClassDef) -> None: d.info = info self.sem.globals[d.name] = SymbolTableNode(GDEF, info, self.sem.cur_mod_id) - + def visit_var_def(self, d: VarDef) -> None: for v in d.items: self.sem.check_no_global(v.name(), d) @@ -1585,14 +1592,14 @@ class ThirdPass(TraverserVisitor[None]): Check type argument counts and values of generic types. Also update TypeInfo disjointclass information. """ - + def __init__(self, errors: Errors) -> None: self.errors = errors - + def visit_file(self, file_node: MypyFile, fnam: str) -> None: self.errors.set_file(fnam) file_node.accept(self) - + def visit_func_def(self, fdef: FuncDef) -> None: self.errors.push_function(fdef.name()) self.analyze(fdef.type) @@ -1633,7 +1640,7 @@ def analyze(self, type: Type) -> None: if type: analyzer = TypeAnalyserPass3(self.fail) type.accept(analyzer) - + def fail(self, msg: str, ctx: Context) -> None: self.errors.report(ctx.get_line(), msg) @@ -1657,6 +1664,7 @@ def replace_implicit_first_type(sig: Callable, new: Type) -> Callable: else: return sig + @overload def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike: osig = cast(Overloaded, sig) @@ -1838,7 +1846,7 @@ def mark_block_unreachable(block: Block) -> None: class MarkImportsUnreachableVisitor(TraverserVisitor): """Visitor that flags all imports nested within a node as unreachable.""" - + def visit_import(self, node: Import) -> None: node.is_unreachable = True diff --git a/mypy/stats.py b/mypy/stats.py index 282922a0c4ed..9fcbd2ae9a21 100644 --- a/mypy/stats.py +++ b/mypy/stats.py @@ -5,7 +5,7 @@ import re from typing import Any, Dict, List, cast, Tuple - + from mypy.traverser import TraverserVisitor from mypy.types import ( Type, AnyType, Instance, FunctionLike, TupleType, Void, TypeVar, @@ -14,7 +14,7 @@ from mypy import nodes from mypy.nodes import ( Node, FuncDef, TypeApplication, AssignmentStmt, NameExpr, CallExpr, - MemberExpr, OpExpr, IndexExpr, UnaryExpr, YieldFromExpr + MemberExpr, OpExpr, ComparisonExpr, IndexExpr, UnaryExpr, YieldFromExpr ) @@ -29,7 +29,7 @@ def __init__(self, inferred: bool, typemap: Dict[Node, Type] = None, self.inferred = inferred self.typemap = typemap self.all_nodes = all_nodes - + self.num_precise = 0 self.num_imprecise = 0 self.num_any = 0 @@ -46,9 +46,9 @@ def __init__(self, inferred: bool, typemap: Dict[Node, Type] = None, self.line_map = Dict[int, int]() self.output = List[str]() - + TraverserVisitor.__init__(self) - + def visit_func_def(self, o: FuncDef) -> None: self.line = o.line if len(o.expanded) > 1: @@ -59,7 +59,7 @@ def visit_func_def(self, o: FuncDef) -> None: sig = cast(Callable, o.type) arg_types = sig.arg_types if (sig.arg_names and sig.arg_names[0] == 'self' and - not self.inferred): + not self.inferred): arg_types = arg_types[1:] for arg in arg_types: self.type(arg) @@ -130,6 +130,10 @@ def visit_op_expr(self, o: OpExpr) -> None: self.process_node(o) super().visit_op_expr(o) + def visit_comparison_expr(self, o: ComparisonExpr) -> None: + self.process_node(o) + super().visit_comparison_expr(o) + def visit_index_expr(self, o: IndexExpr) -> None: self.process_node(o) super().visit_index_expr(o) @@ -252,7 +256,7 @@ def is_complex(t: Type) -> bool: TypeVar)) -html_files = [] # type: List[Tuple[str, str, int, int]] +html_files = [] # type: List[Tuple[str, str, int, int]] def generate_html_report(tree: Node, path: str, type_map: Dict[Node, Type], @@ -264,7 +268,7 @@ def generate_html_report(tree: Node, path: str, type_map: Dict[Node, Type], target_path = os.path.join(output_dir, 'html', path) target_path = re.sub(r'\.py$', '.html', target_path) ensure_dir_exists(os.path.dirname(target_path)) - output = [] # type: List[str] + output = [] # type: List[str] append = output.append append('''\ @@ -305,7 +309,7 @@ def generate_html_report(tree: Node, path: str, type_map: Dict[Node, Type], def generate_html_index(output_dir: str) -> None: path = os.path.join(output_dir, 'index.html') - output = [] # type: List[str] + output = [] # type: List[str] append = output.append append('''\ @@ -328,7 +332,7 @@ def generate_html_index(output_dir: str) -> None: source_path = os.path.normpath(source_path) # TODO: Windows paths. if (source_path.startswith('stubs/') or - '/stubs/' in source_path): + '/stubs/' in source_path): continue percent = 100.0 * num_imprecise / num_lines style = '' diff --git a/mypy/strconv.py b/mypy/strconv.py index 9536511e8660..937d3b57c5da 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -12,10 +12,10 @@ class StrConv(NodeVisitor[str]): """Visitor for converting a Node to a human-readable string. - + For example, an MypyFile node from program '1' is converted into something like this: - + MypyFile:1( fnam ExpressionStmt:1( @@ -29,7 +29,7 @@ def dump(self, nodes, obj): argument. """ return dump_tagged(nodes, short_type(obj) + ':' + str(obj.line)) - + def func_helper(self, o): """Return a list in a format suitable for dump() that represents the arguments and the body of a function. The caller can then decorate the @@ -60,10 +60,10 @@ def func_helper(self, o): a.append('Generator') a.extend(extra) a.append(o.body) - return a - + return a + # Top-level structures - + def visit_mypy_file(self, o): # Skip implicit definitions. defs = o.defs @@ -81,24 +81,24 @@ def visit_mypy_file(self, o): # case# output in all platforms. a.insert(0, o.path.replace(os.sep, '/')) return self.dump(a, o) - + def visit_import(self, o): a = [] for id, as_id in o.ids: a.append('{} : {}'.format(id, as_id)) return 'Import:{}({})'.format(o.line, ', '.join(a)) - + def visit_import_from(self, o): a = [] for name, as_name in o.names: a.append('{} : {}'.format(name, as_name)) return 'ImportFrom:{}({}, [{}])'.format(o.line, o.id, ', '.join(a)) - + def visit_import_all(self, o): return 'ImportAll:{}({})'.format(o.line, o.id) - + # Definitions - + def visit_func_def(self, o): a = self.func_helper(o) a.insert(0, o.name()) @@ -113,13 +113,13 @@ def visit_func_def(self, o): if o.is_property: a.insert(-1, 'Property') return self.dump(a, o) - + def visit_overloaded_func_def(self, o): a = o.items[:] if o.type: a.insert(0, o.type) return self.dump(a, o) - + def visit_class_def(self, o): a = [o.name, o.defs.body] # Display base types unless they are implicitly just builtins.object @@ -141,7 +141,7 @@ def visit_class_def(self, o): a.insert(1, ('Disjointclasses', [info.fullname() for info in o.info.disjoint_classes])) return self.dump(a, o) - + def visit_var_def(self, o): a = [] for n in o.items: @@ -150,7 +150,7 @@ def visit_var_def(self, o): if o.init: a.append(o.init) return self.dump(a, o) - + def visit_var(self, o): l = '' # Add :nil line number tag if no line number is specified to remain @@ -158,24 +158,24 @@ def visit_var(self, o): if o.line < 0: l = ':nil' return 'Var' + l + '(' + o.name() + ')' - + def visit_global_decl(self, o): return self.dump([o.names], o) - + def visit_decorator(self, o): return self.dump([o.var, o.decorators, o.func], o) - + def visit_annotation(self, o): return 'Type:{}({})'.format(o.line, o.type) - + # Statements - + def visit_block(self, o): return self.dump(o.body, o) - + def visit_expression_stmt(self, o): return self.dump([o.expr], o) - + def visit_assignment_stmt(self, o): if len(o.lvalues) > 1: a = [('Lvalues', o.lvalues)] @@ -185,16 +185,16 @@ def visit_assignment_stmt(self, o): if o.type: a.append(o.type) return self.dump(a, o) - + def visit_operator_assignment_stmt(self, o): return self.dump([o.op, o.lvalue, o.rvalue], o) - + def visit_while_stmt(self, o): a = [o.expr, o.body] if o.else_body: a.append(('Else', o.else_body.body)) return self.dump(a, o) - + def visit_for_stmt(self, o): a = [o.index] if o.types != [None] * len(o.types): @@ -203,36 +203,36 @@ def visit_for_stmt(self, o): if o.else_body: a.append(('Else', o.else_body.body)) return self.dump(a, o) - + def visit_return_stmt(self, o): return self.dump([o.expr], o) - + def visit_if_stmt(self, o): a = [] for i in range(len(o.expr)): a.append(('If', [o.expr[i]])) a.append(('Then', o.body[i].body)) - + if not o.else_body: return self.dump(a, o) else: return self.dump([a, ('Else', o.else_body.body)], o) - + def visit_break_stmt(self, o): return self.dump([], o) - + def visit_continue_stmt(self, o): return self.dump([], o) - + def visit_pass_stmt(self, o): return self.dump([], o) - + def visit_raise_stmt(self, o): return self.dump([o.expr, o.from_expr], o) - + def visit_assert_stmt(self, o): return self.dump([o.expr], o) - + def visit_yield_stmt(self, o): return self.dump([o.expr], o) @@ -244,20 +244,20 @@ def visit_del_stmt(self, o): def visit_try_stmt(self, o): a = [o.body] - + for i in range(len(o.vars)): a.append(o.types[i]) if o.vars[i]: a.append(o.vars[i]) a.append(o.handlers[i]) - + if o.else_body: a.append(('Else', o.else_body.body)) if o.finally_body: a.append(('Finally', o.finally_body.body)) - + return self.dump(a, o) - + def visit_with_stmt(self, o): a = [] for i in range(len(o.expr)): @@ -271,39 +271,39 @@ def visit_print_stmt(self, o): if o.newline: a.append('Newline') return self.dump(a, o) - + # Expressions - + # Simple expressions - + def visit_int_expr(self, o): return 'IntExpr({})'.format(o.value) - + def visit_str_expr(self, o): return 'StrExpr({})'.format(self.str_repr(o.value)) - + def visit_bytes_expr(self, o): return 'BytesExpr({})'.format(self.str_repr(o.value)) - + def visit_unicode_expr(self, o): return 'UnicodeExpr({})'.format(self.str_repr(o.value)) - + def str_repr(self, s): s = re.sub(r'\\u[0-9a-fA-F]{4}', lambda m: '\\' + m.group(0), s) return re.sub('[^\\x20-\\x7e]', lambda m: r'\u%.4x' % ord(m.group(0)), s) - + def visit_float_expr(self, o): return 'FloatExpr({})'.format(o.value) - + def visit_paren_expr(self, o): return self.dump([o.expr], o) - + def visit_name_expr(self, o): return (short_type(o) + '(' + self.pretty_name(o.name, o.kind, o.fullname, o.is_def) + ')') - + def pretty_name(self, name, kind, fullname, is_def): n = name if is_def: @@ -319,7 +319,7 @@ def pretty_name(self, name, kind, fullname, is_def): # Add tag to signify a member reference. n += ' [m]' return n - + def visit_member_expr(self, o): return self.dump([o.expr, self.pretty_name(o.name, o.kind, o.fullname, o.is_def)], o) @@ -348,39 +348,42 @@ def visit_call_expr(self, o): raise RuntimeError('unknown kind %d' % kind) return self.dump([o.callee, ('Args', args)] + extra, o) - + def visit_op_expr(self, o): return self.dump([o.op, o.left, o.right], o) - + + def visit_comparison_expr(self, o): + return self.dump([o.operators, o.operands], o) + def visit_cast_expr(self, o): return self.dump([o.expr, o.type], o) - + def visit_unary_expr(self, o): return self.dump([o.op, o.expr], o) - + def visit_list_expr(self, o): return self.dump(o.items, o) - + def visit_dict_expr(self, o): return self.dump([[k, v] for k, v in o.items], o) - + def visit_set_expr(self, o): return self.dump(o.items, o) - + def visit_tuple_expr(self, o): return self.dump(o.items, o) - + def visit_index_expr(self, o): if o.analyzed: return o.analyzed.accept(self) return self.dump([o.base, o.index], o) - + def visit_super_expr(self, o): return self.dump([o.name], o) def visit_undefined_expr(self, o): return 'UndefinedExpr:{}({})'.format(o.line, o.type) - + def visit_type_application(self, o): return self.dump([o.expr, ('Types', o.types)], o) @@ -395,22 +398,22 @@ def visit_ducktype_expr(self, o): def visit_disjointclass_expr(self, o): return 'DisjointclassExpr:{}({})'.format(o.line, o.cls.fullname) - + def visit_func_expr(self, o): a = self.func_helper(o) return self.dump(a, o) - + def visit_generator_expr(self, o): # FIX types condlists = o.condlists if any(o.condlists) else None return self.dump([o.left_expr, o.indices, o.sequences, condlists], o) - + def visit_list_comprehension(self, o): return self.dump([o.generator], o) - + def visit_conditional_expr(self, o): return self.dump([('Condition', [o.cond]), o.if_expr, o.else_expr], o) - + def visit_slice_expr(self, o): a = [o.begin_index, o.end_index, o.stride] if not a[0]: @@ -418,14 +421,14 @@ def visit_slice_expr(self, o): if not a[1]: a[1] = '' return self.dump(a, o) - + def visit_coerce_expr(self, o): return self.dump([o.expr, ('Types', [o.target_type, o.source_type])], o) - + def visit_type_expr(self, o): return self.dump([str(o.type)], o) - + def visit_filter_node(self, o): # These are for convenience. These node types are not defined in the # parser module. diff --git a/mypy/transform.py b/mypy/transform.py index 388607a6f10b..9a4c9a0cd96b 100644 --- a/mypy/transform.py +++ b/mypy/transform.py @@ -16,8 +16,8 @@ from mypy.nodes import ( Node, MypyFile, TypeInfo, ClassDef, VarDef, FuncDef, Var, ReturnStmt, AssignmentStmt, IfStmt, WhileStmt, MemberExpr, NameExpr, MDEF, - CallExpr, SuperExpr, TypeExpr, CastExpr, OpExpr, CoerceExpr, GDEF, - SymbolTableNode, IndexExpr, function_type, YieldFromExpr + CallExpr, SuperExpr, TypeExpr, CastExpr, OpExpr, CoerceExpr, ComparisonExpr, + GDEF, SymbolTableNode, IndexExpr, function_type, YieldFromExpr ) from mypy.traverser import TraverserVisitor from mypy.types import Type, AnyType, Callable, TypeVarDef, Instance @@ -38,7 +38,7 @@ class DyncheckTransformVisitor(TraverserVisitor): all non-trivial coercions explicit. Also generate generic wrapper classes for coercions between generic types and wrapper methods for overrides and for more efficient access from dynamically typed code. - + This visitor modifies the parse tree in-place. """ @@ -46,23 +46,23 @@ class DyncheckTransformVisitor(TraverserVisitor): modules = Undefined(Dict[str, MypyFile]) is_pretty = False type_tf = Undefined(TypeTransformer) - + # Stack of function return types return_types = Undefined(List[Type]) # Stack of dynamically typed function flags dynamic_funcs = Undefined(List[bool]) - + # Associate a Node with its start end line numbers. line_map = Undefined(Dict[Node, Tuple[int, int]]) - + is_java = False - + # The current type context (or None if not within a type). - _type_context = None # type: TypeInfo - + _type_context = None # type: TypeInfo + def type_context(self) -> TypeInfo: return self._type_context - + def __init__(self, type_map: Dict[Node, Type], modules: Dict[str, MypyFile], is_pretty: bool, is_java: bool = False) -> None: @@ -74,14 +74,14 @@ def __init__(self, type_map: Dict[Node, Type], self.modules = modules self.is_pretty = is_pretty self.is_java = is_java - + # # Transform definitions # - + def visit_mypy_file(self, o: MypyFile) -> None: """Transform an file.""" - res = [] # type: List[Node] + res = [] # type: List[Node] for d in o.defs: if isinstance(d, ClassDef): self._type_context = d.info @@ -91,7 +91,7 @@ def visit_mypy_file(self, o: MypyFile) -> None: d.accept(self) res.append(d) o.defs = res - + def visit_var_def(self, o: VarDef) -> None: """Transform a variable definition in-place. @@ -99,7 +99,7 @@ def visit_var_def(self, o: VarDef) -> None: transformed in TypeTransformer. """ super().visit_var_def(o) - + if o.init is not None: if o.items[0].type: t = o.items[0].type @@ -107,7 +107,7 @@ def visit_var_def(self, o: VarDef) -> None: t = AnyType() o.init = self.coerce(o.init, t, self.get_type(o.init), self.type_context()) - + def visit_func_def(self, fdef: FuncDef) -> None: """Transform a global function definition in-place. @@ -116,7 +116,7 @@ def visit_func_def(self, fdef: FuncDef) -> None: """ self.prepend_generic_function_tvar_args(fdef) self.transform_function_body(fdef) - + def transform_function_body(self, fdef: FuncDef) -> None: """Transform the body of a function.""" self.dynamic_funcs.append(fdef.is_implicit) @@ -125,15 +125,15 @@ def transform_function_body(self, fdef: FuncDef) -> None: super().visit_func_def(fdef) self.return_types.pop() self.dynamic_funcs.pop() - + def prepend_generic_function_tvar_args(self, fdef: FuncDef) -> None: """Add implicit function type variable arguments if fdef is generic.""" sig = cast(Callable, function_type(fdef)) tvars = sig.variables if not fdef.type: fdef.type = sig - - tv = [] # type: List[Var] + + tv = [] # type: List[Var] ntvars = len(tvars) if fdef.is_method(): # For methods, add type variable arguments after the self arg. @@ -150,50 +150,50 @@ def prepend_generic_function_tvar_args(self, fdef: FuncDef) -> None: AnyType()) fdef.args = tv + fdef.args fdef.init = List[AssignmentStmt]([None]) * ntvars + fdef.init - + # # Transform statements - # - + # + def transform_block(self, block: List[Node]) -> None: for stmt in block: stmt.accept(self) - + def visit_return_stmt(self, s: ReturnStmt) -> None: super().visit_return_stmt(s) s.expr = self.coerce(s.expr, self.return_types[-1], self.get_type(s.expr), self.type_context()) - + def visit_assignment_stmt(self, s: AssignmentStmt) -> None: super().visit_assignment_stmt(s) if isinstance(s.lvalues[0], IndexExpr): index = cast(IndexExpr, s.lvalues[0]) method_type = index.method_type if self.dynamic_funcs[-1] or isinstance(method_type, AnyType): - lvalue_type = AnyType() # type: Type + lvalue_type = AnyType() # type: Type else: method_callable = cast(Callable, method_type) # TODO arg_types[1] may not be reliable lvalue_type = method_callable.arg_types[1] else: lvalue_type = self.get_type(s.lvalues[0]) - + s.rvalue = self.coerce2(s.rvalue, lvalue_type, self.get_type(s.rvalue), self.type_context()) - + # # Transform expressions # - + def visit_member_expr(self, e: MemberExpr) -> None: super().visit_member_expr(e) - + typ = self.get_type(e.expr) - + if self.dynamic_funcs[-1]: e.expr = self.coerce_to_dynamic(e.expr, typ, self.type_context()) typ = AnyType() - + if isinstance(typ, Instance): # Reference to a statically-typed method variant with the suffix # derived from the base object type. @@ -202,7 +202,7 @@ def visit_member_expr(self, e: MemberExpr) -> None: # Reference to a dynamically-typed method variant. suffix = self.dynamic_suffix() e.name += suffix - + def visit_name_expr(self, e: NameExpr) -> None: super().visit_name_expr(e) if e.kind == MDEF and isinstance(e.node, FuncDef): @@ -211,7 +211,7 @@ def visit_name_expr(self, e: NameExpr) -> None: e.name += suffix # Update representation to have the correct name. prefix = e.repr.components[0].pre - + def get_member_reference_suffix(self, name: str, info: TypeInfo) -> str: if info.has_method(name): fdef = cast(FuncDef, info.get_method(name)) @@ -228,9 +228,9 @@ def visit_call_expr(self, e: CallExpr) -> None: # This is not an ordinary call. e.analyzed.accept(self) return - + super().visit_call_expr(e) - + # Do no coercions if this is a call to debugging facilities. if self.is_debugging_call_expr(e): return @@ -241,13 +241,13 @@ def visit_call_expr(self, e: CallExpr) -> None: # Add coercions for the arguments. for i in range(len(e.args)): - arg_type = AnyType() # type: Type + arg_type = AnyType() # type: Type if isinstance(ctype, Callable): arg_type = ctype.arg_types[i] e.args[i] = self.coerce2(e.args[i], arg_type, self.get_type(e.args[i]), self.type_context()) - + # Prepend type argument values to the call as needed. if isinstance(ctype, Callable) and cast(Callable, ctype).bound_vars != []: @@ -262,8 +262,8 @@ def visit_call_expr(self, e: CallExpr) -> None: (cast(SuperExpr, e.callee)).name == '__init__')): # Filter instance type variables; only include function tvars. bound_vars = [(id, t) for id, t in bound_vars if id < 0] - - args = [] # type: List[Node] + + args = [] # type: List[Node] for i in range(len(bound_vars)): # Compile type variables to runtime type variable expressions. tv = translate_runtime_type_vars_in_context( @@ -272,16 +272,16 @@ def visit_call_expr(self, e: CallExpr) -> None: self.is_java) args.append(TypeExpr(tv)) e.args = args + e.args - + def is_debugging_call_expr(self, e): return isinstance(e.callee, NameExpr) and e.callee.name in ['__print'] - + def visit_cast_expr(self, e: CastExpr) -> None: super().visit_cast_expr(e) if isinstance(self.get_type(e), AnyType): e.expr = self.coerce(e.expr, AnyType(), self.get_type(e.expr), self.type_context()) - + def visit_op_expr(self, e: OpExpr) -> None: super().visit_op_expr(e) if e.op in ['and', 'or']: @@ -301,17 +301,16 @@ def visit_op_expr(self, e: OpExpr) -> None: elif method_type: method_callable = cast(Callable, method_type) operand = e.right - # For 'in', the order of operands is reversed. - if e.op == 'in': - operand = e.left # TODO arg_types[0] may not be reliable operand = self.coerce(operand, method_callable.arg_types[0], self.get_type(operand), - self.type_context()) - if e.op == 'in': - e.left = operand - else: - e.right = operand + self.type_context()) + e.right = operand + + def visit_comparison_expr(self, e: ComparisonExpr) -> None: + super().visit_comparison_expr(e) + # Dummy + def visit_index_expr(self, e: IndexExpr) -> None: if e.analyzed: @@ -329,18 +328,18 @@ def visit_index_expr(self, e: IndexExpr) -> None: method_callable = cast(Callable, method_type) e.index = self.coerce(e.index, method_callable.arg_types[0], self.get_type(e.index), self.type_context()) - + # # Helpers - # - + # + def get_type(self, node: Node) -> Type: """Return the type of a node as reported by the type checker.""" return self.type_map[node] - + def set_type(self, node: Node, typ: Type) -> None: self.type_map[node] = typ - + def type_suffix(self, fdef: FuncDef, info: TypeInfo = None) -> str: """Return the suffix for a mangled name. @@ -359,20 +358,20 @@ def type_suffix(self, fdef: FuncDef, info: TypeInfo = None) -> str: return '`' + info.name() else: return '__' + info.name() - + def dynamic_suffix(self) -> str: """Return the suffix of the dynamic wrapper of a method or class.""" return dynamic_suffix(self.is_pretty) - + def wrapper_class_suffix(self) -> str: """Return the suffix of a generic wrapper class.""" return '**' - + def coerce(self, expr: Node, target_type: Type, source_type: Type, context: TypeInfo, is_wrapper_class: bool = False) -> Node: return coerce(expr, target_type, source_type, context, is_wrapper_class, self.is_java) - + def coerce2(self, expr: Node, target_type: Type, source_type: Type, context: TypeInfo, is_wrapper_class: bool = False) -> Node: """Create coercion from source_type to target_type. @@ -388,7 +387,7 @@ def coerce2(self, expr: Node, target_type: Type, source_type: Type, else: return self.coerce(expr, target_type, source_type, context, is_wrapper_class) - + def coerce_to_dynamic(self, expr: Node, source_type: Type, context: TypeInfo) -> Node: if isinstance(source_type, AnyType): @@ -396,7 +395,7 @@ def coerce_to_dynamic(self, expr: Node, source_type: Type, source_type = translate_runtime_type_vars_in_context( source_type, context, self.is_java) return CoerceExpr(expr, AnyType(), source_type, False) - + def add_line_mapping(self, orig_node: Node, new_node: Node) -> None: """Add a line mapping for a wrapper. @@ -405,15 +404,15 @@ def add_line_mapping(self, orig_node: Node, new_node: Node) -> None: """ if orig_node.repr: start_line = orig_node.line - end_line = start_line # TODO use real end line + end_line = start_line # TODO use real end line self.line_map[new_node] = (start_line, end_line) - + def named_type(self, name: str) -> Instance: # TODO combine with checker # Assume that the name refers to a type. sym = self.lookup(name, GDEF) return Instance(cast(TypeInfo, sym.node), []) - + def lookup(self, fullname: str, kind: int) -> SymbolTableNode: # TODO combine with checker # TODO remove kind argument @@ -422,7 +421,7 @@ def lookup(self, fullname: str, kind: int) -> SymbolTableNode: for i in range(1, len(parts) - 1): n = cast(MypyFile, ((n.names.get(parts[i], None).node))) return n.names[parts[-1]] - + def object_member_name(self) -> str: if self.is_java: return '__o_{}'.format(self.type_context().name()) diff --git a/mypy/traverser.py b/mypy/traverser.py index 4020f95e8513..c5ee2a690412 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -10,7 +10,7 @@ TryStmt, WithStmt, ParenExpr, MemberExpr, OpExpr, SliceExpr, CastExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr, GeneratorExpr, ListComprehension, ConditionalExpr, TypeApplication, - FuncExpr, OverloadedFuncDef, YieldFromStmt, YieldFromExpr + FuncExpr, ComparisonExpr, OverloadedFuncDef, YieldFromStmt, YieldFromExpr ) @@ -27,7 +27,7 @@ class TraverserVisitor(NodeVisitor[T], Generic[T]): """ # Visit methods - + def visit_mypy_file(self, o: MypyFile) -> T: for d in o.defs: d.accept(self) @@ -35,7 +35,7 @@ def visit_mypy_file(self, o: MypyFile) -> T: def visit_block(self, block: Block) -> T: for s in block.body: s.accept(self) - + def visit_func(self, o: FuncItem) -> T: for i in o.init: if i is not None: @@ -43,47 +43,47 @@ def visit_func(self, o: FuncItem) -> T: for v in o.args: self.visit_var(v) o.body.accept(self) - + def visit_func_def(self, o: FuncDef) -> T: self.visit_func(o) def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> T: for item in o.items: item.accept(self) - + def visit_class_def(self, o: ClassDef) -> T: o.defs.accept(self) - + def visit_decorator(self, o: Decorator) -> T: o.func.accept(self) o.var.accept(self) for decorator in o.decorators: decorator.accept(self) - + def visit_var_def(self, o: VarDef) -> T: if o.init is not None: o.init.accept(self) for v in o.items: self.visit_var(v) - + def visit_expression_stmt(self, o: ExpressionStmt) -> T: o.expr.accept(self) - + def visit_assignment_stmt(self, o: AssignmentStmt) -> T: o.rvalue.accept(self) for l in o.lvalues: l.accept(self) - + def visit_operator_assignment_stmt(self, o: OperatorAssignmentStmt) -> T: o.rvalue.accept(self) o.lvalue.accept(self) - + def visit_while_stmt(self, o: WhileStmt) -> T: o.expr.accept(self) o.body.accept(self) if o.else_body: o.else_body.accept(self) - + def visit_for_stmt(self, o: ForStmt) -> T: for ind in o.index: ind.accept(self) @@ -91,15 +91,15 @@ def visit_for_stmt(self, o: ForStmt) -> T: o.body.accept(self) if o.else_body: o.else_body.accept(self) - + def visit_return_stmt(self, o: ReturnStmt) -> T: if o.expr is not None: o.expr.accept(self) - + def visit_assert_stmt(self, o: AssertStmt) -> T: if o.expr is not None: o.expr.accept(self) - + def visit_yield_stmt(self, o: YieldStmt) -> T: if o.expr is not None: o.expr.accept(self) @@ -119,13 +119,13 @@ def visit_if_stmt(self, o: IfStmt) -> T: b.accept(self) if o.else_body: o.else_body.accept(self) - + def visit_raise_stmt(self, o: RaiseStmt) -> T: if o.expr is not None: o.expr.accept(self) if o.from_expr is not None: o.from_expr.accept(self) - + def visit_try_stmt(self, o: TryStmt) -> T: o.body.accept(self) for i in range(len(o.types)): @@ -136,17 +136,17 @@ def visit_try_stmt(self, o: TryStmt) -> T: o.else_body.accept(self) if o.finally_body is not None: o.finally_body.accept(self) - + def visit_with_stmt(self, o: WithStmt) -> T: for i in range(len(o.expr)): o.expr[i].accept(self) if o.name[i] is not None: o.name[i].accept(self) o.body.accept(self) - + def visit_paren_expr(self, o: ParenExpr) -> T: o.expr.accept(self) - + def visit_member_expr(self, o: MemberExpr) -> T: o.expr.accept(self) @@ -159,11 +159,15 @@ def visit_call_expr(self, o: CallExpr) -> T: o.callee.accept(self) if o.analyzed: o.analyzed.accept(self) - + def visit_op_expr(self, o: OpExpr) -> T: o.left.accept(self) o.right.accept(self) - + + def visit_comparison_expr(self, o: ComparisonExpr) -> T: + for operand in o.operands: + operand.accept(self) + def visit_slice_expr(self, o: SliceExpr) -> T: if o.begin_index is not None: o.begin_index.accept(self) @@ -171,36 +175,36 @@ def visit_slice_expr(self, o: SliceExpr) -> T: o.end_index.accept(self) if o.stride is not None: o.stride.accept(self) - + def visit_cast_expr(self, o: CastExpr) -> T: o.expr.accept(self) - + def visit_unary_expr(self, o: UnaryExpr) -> T: o.expr.accept(self) - + def visit_list_expr(self, o: ListExpr) -> T: for item in o.items: item.accept(self) - + def visit_tuple_expr(self, o: TupleExpr) -> T: for item in o.items: item.accept(self) - + def visit_dict_expr(self, o: DictExpr) -> T: for k, v in o.items: k.accept(self) v.accept(self) - + def visit_set_expr(self, o: SetExpr) -> T: for item in o.items: item.accept(self) - + def visit_index_expr(self, o: IndexExpr) -> T: o.base.accept(self) o.index.accept(self) if o.analyzed: o.analyzed.accept(self) - + def visit_generator_expr(self, o: GeneratorExpr) -> T: for index, sequence, conditions in zip(o.indices, o.sequences, o.condlists): @@ -210,17 +214,17 @@ def visit_generator_expr(self, o: GeneratorExpr) -> T: for cond in conditions: cond.accept(self) o.left_expr.accept(self) - + def visit_list_comprehension(self, o: ListComprehension) -> T: o.generator.accept(self) - + def visit_conditional_expr(self, o: ConditionalExpr) -> T: o.cond.accept(self) o.if_expr.accept(self) o.else_expr.accept(self) - + def visit_type_application(self, o: TypeApplication) -> T: o.expr.accept(self) - + def visit_func_expr(self, o: FuncExpr) -> T: self.visit_func(o) diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 97dccd8ef08d..43bb24e1c920 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -16,8 +16,8 @@ UnicodeExpr, FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, UnaryExpr, FuncExpr, TypeApplication, PrintStmt, SymbolTable, RefExpr, UndefinedExpr, TypeVarExpr, DucktypeExpr, - DisjointclassExpr, CoerceExpr, TypeExpr, JavaCast, TempNode, YieldFromStmt, - YieldFromExpr + DisjointclassExpr, CoerceExpr, TypeExpr, ComparisonExpr, + JavaCast, TempNode, YieldFromStmt, YieldFromExpr ) from mypy.types import Type from mypy.visitor import NodeVisitor @@ -46,7 +46,7 @@ def __init__(self) -> None: # There may be multiple references to a Var node. Keep track of # Var translations using a dictionary. self.var_map = Dict[Var, Var]() - + def visit_mypy_file(self, node: MypyFile) -> Node: # NOTE: The 'names' and 'imports' instance variables will be empty! new = MypyFile(self.nodes(node.defs), [], node.is_bom) @@ -55,16 +55,16 @@ def visit_mypy_file(self, node: MypyFile) -> Node: new.path = node.path new.names = SymbolTable() return new - + def visit_import(self, node: Import) -> Node: return Import(node.ids[:]) - + def visit_import_from(self, node: ImportFrom) -> Node: return ImportFrom(node.id, node.names[:]) - + def visit_import_all(self, node: ImportAll) -> Node: return ImportAll(node.id) - + def visit_func_def(self, node: FuncDef) -> FuncDef: # Note that a FuncDef must be transformed to a FuncDef. new = FuncDef(node.name(), @@ -75,7 +75,7 @@ def visit_func_def(self, node: FuncDef) -> FuncDef: self.optional_type(node.type)) self.copy_function_attributes(new, node) - + new._fullname = node._fullname new.is_decorated = node.is_decorated new.is_conditional = node.is_conditional @@ -85,7 +85,7 @@ def visit_func_def(self, node: FuncDef) -> FuncDef: new.is_property = node.is_property new.original_def = node.original_def return new - + def visit_func_expr(self, node: FuncExpr) -> Node: new = FuncExpr([self.visit_var(var) for var in node.args], node.arg_kinds[:], @@ -114,7 +114,7 @@ def duplicate_inits(self, else: result.append(None) return result - + def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> Node: items = [self.visit_decorator(decorator) for decorator in node.items] @@ -125,33 +125,33 @@ def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> Node: new.type = self.type(node.type) new.info = node.info return new - + def visit_class_def(self, node: ClassDef) -> Node: new = ClassDef(node.name, - self.block(node.defs), - node.type_vars, - self.types(node.base_types), - node.metaclass) + self.block(node.defs), + node.type_vars, + self.types(node.base_types), + node.metaclass) new.fullname = node.fullname new.info = node.info new.decorators = [decorator.accept(self) for decorator in node.decorators] new.is_builtinclass = node.is_builtinclass return new - + def visit_var_def(self, node: VarDef) -> Node: new = VarDef([self.visit_var(var) for var in node.items], node.is_top_level, self.optional_node(node.init)) new.kind = node.kind return new - + def visit_global_decl(self, node: GlobalDecl) -> Node: return GlobalDecl(node.names[:]) - + def visit_block(self, node: Block) -> Block: return Block(self.nodes(node.body)) - + def visit_decorator(self, node: Decorator) -> Decorator: # Note that a Decorator must be transformed to a Decorator. func = self.visit_func_def(node.func) @@ -160,7 +160,7 @@ def visit_decorator(self, node: Decorator) -> Decorator: self.visit_var(node.var)) new.is_overload = node.is_overload return new - + def visit_var(self, node: Var) -> Var: # Note that a Var must be transformed to a Var. if node in self.var_map: @@ -178,44 +178,44 @@ def visit_var(self, node: Var) -> Var: new.set_line(node.line) self.var_map[node] = new return new - + def visit_expression_stmt(self, node: ExpressionStmt) -> Node: return ExpressionStmt(self.node(node.expr)) - + def visit_assignment_stmt(self, node: AssignmentStmt) -> Node: return self.duplicate_assignment(node) - + def duplicate_assignment(self, node: AssignmentStmt) -> AssignmentStmt: new = AssignmentStmt(self.nodes(node.lvalues), self.node(node.rvalue), self.optional_type(node.type)) new.line = node.line return new - + def visit_operator_assignment_stmt(self, node: OperatorAssignmentStmt) -> Node: return OperatorAssignmentStmt(node.op, self.node(node.lvalue), self.node(node.rvalue)) - + def visit_while_stmt(self, node: WhileStmt) -> Node: return WhileStmt(self.node(node.expr), self.block(node.body), self.optional_block(node.else_body)) - + def visit_for_stmt(self, node: ForStmt) -> Node: return ForStmt(self.names(node.index), self.node(node.expr), self.block(node.body), self.optional_block(node.else_body), self.optional_types(node.types)) - + def visit_return_stmt(self, node: ReturnStmt) -> Node: return ReturnStmt(self.optional_node(node.expr)) - + def visit_assert_stmt(self, node: AssertStmt) -> Node: return AssertStmt(self.node(node.expr)) - + def visit_yield_stmt(self, node: YieldStmt) -> Node: return YieldStmt(self.node(node.expr)) @@ -224,25 +224,25 @@ def visit_yield_from_stmt(self, node: YieldFromStmt) -> Node: def visit_del_stmt(self, node: DelStmt) -> Node: return DelStmt(self.node(node.expr)) - + def visit_if_stmt(self, node: IfStmt) -> Node: return IfStmt(self.nodes(node.expr), self.blocks(node.body), self.optional_block(node.else_body)) - + def visit_break_stmt(self, node: BreakStmt) -> Node: return BreakStmt() - + def visit_continue_stmt(self, node: ContinueStmt) -> Node: return ContinueStmt() - + def visit_pass_stmt(self, node: PassStmt) -> Node: return PassStmt() - + def visit_raise_stmt(self, node: RaiseStmt) -> Node: return RaiseStmt(self.optional_node(node.expr), self.optional_node(node.from_expr)) - + def visit_try_stmt(self, node: TryStmt) -> Node: return TryStmt(self.block(node.body), self.optional_names(node.vars), @@ -250,34 +250,34 @@ def visit_try_stmt(self, node: TryStmt) -> Node: self.blocks(node.handlers), self.optional_block(node.else_body), self.optional_block(node.finally_body)) - + def visit_with_stmt(self, node: WithStmt) -> Node: return WithStmt(self.nodes(node.expr), self.optional_names(node.name), self.block(node.body)) - + def visit_print_stmt(self, node: PrintStmt) -> Node: return PrintStmt(self.nodes(node.args), node.newline) - + def visit_int_expr(self, node: IntExpr) -> Node: return IntExpr(node.value) - + def visit_str_expr(self, node: StrExpr) -> Node: return StrExpr(node.value) - + def visit_bytes_expr(self, node: BytesExpr) -> Node: return BytesExpr(node.value) - + def visit_unicode_expr(self, node: UnicodeExpr) -> Node: return UnicodeExpr(node.value) - + def visit_float_expr(self, node: FloatExpr) -> Node: return FloatExpr(node.value) - + def visit_paren_expr(self, node: ParenExpr) -> Node: return ParenExpr(self.node(node.expr)) - + def visit_name_expr(self, node: NameExpr) -> Node: return self.duplicate_name(node) @@ -288,7 +288,7 @@ def duplicate_name(self, node: NameExpr) -> NameExpr: new.info = node.info self.copy_ref(new, node) return new - + def visit_member_expr(self, node: MemberExpr) -> Node: member = MemberExpr(self.node(node.expr), node.name) @@ -315,39 +315,44 @@ def visit_call_expr(self, node: CallExpr) -> Node: node.arg_kinds[:], node.arg_names[:], self.optional_node(node.analyzed)) - + def visit_op_expr(self, node: OpExpr) -> Node: new = OpExpr(node.op, self.node(node.left), self.node(node.right)) new.method_type = self.optional_type(node.method_type) return new - + + def visit_comparison_expr(self, node: ComparisonExpr) -> Node: + new = ComparisonExpr(node.operators, self.nodes(node.operands)) + new.method_types = [self.optional_type(t) for t in node.method_types] + return new + def visit_cast_expr(self, node: CastExpr) -> Node: return CastExpr(self.node(node.expr), self.type(node.type)) - + def visit_super_expr(self, node: SuperExpr) -> Node: new = SuperExpr(node.name) new.info = node.info return new - + def visit_unary_expr(self, node: UnaryExpr) -> Node: new = UnaryExpr(node.op, self.node(node.expr)) new.method_type = self.optional_type(node.method_type) return new - + def visit_list_expr(self, node: ListExpr) -> Node: return ListExpr(self.nodes(node.items)) - + def visit_dict_expr(self, node: DictExpr) -> Node: return DictExpr([(self.node(key), self.node(value)) for key, value in node.items]) - + def visit_tuple_expr(self, node: TupleExpr) -> Node: return TupleExpr(self.nodes(node.items)) - + def visit_set_expr(self, node: SetExpr) -> Node: return SetExpr(self.nodes(node.items)) - + def visit_index_expr(self, node: IndexExpr) -> Node: new = IndexExpr(self.node(node.base), self.node(node.index)) if node.method_type: @@ -356,19 +361,19 @@ def visit_index_expr(self, node: IndexExpr) -> Node: new.analyzed = self.visit_type_application(node.analyzed) new.analyzed.set_line(node.analyzed.line) return new - + def visit_undefined_expr(self, node: UndefinedExpr) -> Node: return UndefinedExpr(self.type(node.type)) - + def visit_type_application(self, node: TypeApplication) -> TypeApplication: return TypeApplication(self.node(node.expr), self.types(node.types)) - + def visit_list_comprehension(self, node: ListComprehension) -> Node: generator = self.duplicate_generator(node.generator) generator.set_line(node.generator.line) return ListComprehension(generator) - + def visit_generator_expr(self, node: GeneratorExpr) -> Node: return self.duplicate_generator(node) @@ -378,18 +383,18 @@ def duplicate_generator(self, node: GeneratorExpr) -> GeneratorExpr: [self.optional_types(t) for t in node.types], [self.node(s) for s in node.sequences], [[self.node(cond) for cond in conditions] - for conditions in node.condlists]) - + for conditions in node.condlists]) + def visit_slice_expr(self, node: SliceExpr) -> Node: return SliceExpr(self.optional_node(node.begin_index), self.optional_node(node.end_index), self.optional_node(node.stride)) - + def visit_conditional_expr(self, node: ConditionalExpr) -> Node: return ConditionalExpr(self.node(node.cond), self.node(node.if_expr), self.node(node.else_expr)) - + def visit_type_var_expr(self, node: TypeVarExpr) -> Node: return TypeVarExpr(node.name(), node.fullname(), self.types(node.values)) @@ -399,23 +404,23 @@ def visit_ducktype_expr(self, node: DucktypeExpr) -> Node: def visit_disjointclass_expr(self, node: DisjointclassExpr) -> Node: return DisjointclassExpr(node.cls) - + def visit_coerce_expr(self, node: CoerceExpr) -> Node: raise RuntimeError('Not supported') - + def visit_type_expr(self, node: TypeExpr) -> Node: raise RuntimeError('Not supported') - + def visit_java_cast(self, node: JavaCast) -> Node: raise RuntimeError('Not supported') - + def visit_temp_node(self, node: TempNode) -> Node: return TempNode(self.type(node.type)) def node(self, node: Node) -> Node: new = node.accept(self) new.set_line(node.line) - return new + return new # Helpers # @@ -458,7 +463,7 @@ def optional_names(self, names: List[NameExpr]) -> List[NameExpr]: else: result.append(None) return result - + def type(self, type: Type) -> Type: # Override this method to transform types. return type diff --git a/mypy/visitor.py b/mypy/visitor.py index bd6ac45ebefb..f6e06dd67e5d 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -17,148 +17,205 @@ class NodeVisitor(Generic[T]): TODO make the default return value explicit """ - + # Module structure - + def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T: pass - + def visit_import(self, o: 'mypy.nodes.Import') -> T: pass + def visit_import_from(self, o: 'mypy.nodes.ImportFrom') -> T: pass + def visit_import_all(self, o: 'mypy.nodes.ImportAll') -> T: pass - + # Definitions - + def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T: pass + def visit_overloaded_func_def(self, o: 'mypy.nodes.OverloadedFuncDef') -> T: pass + def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T: pass + def visit_var_def(self, o: 'mypy.nodes.VarDef') -> T: pass + def visit_global_decl(self, o: 'mypy.nodes.GlobalDecl') -> T: pass + def visit_decorator(self, o: 'mypy.nodes.Decorator') -> T: pass - + def visit_var(self, o: 'mypy.nodes.Var') -> T: pass - + # Statements - + def visit_block(self, o: 'mypy.nodes.Block') -> T: pass - + def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> T: pass + def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T: pass + def visit_operator_assignment_stmt(self, - o: 'mypy.nodes.OperatorAssignmentStmt') -> T: + o: 'mypy.nodes.OperatorAssignmentStmt') -> T: pass + def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T: pass + def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T: pass + def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T: pass + def visit_assert_stmt(self, o: 'mypy.nodes.AssertStmt') -> T: pass + def visit_yield_stmt(self, o: 'mypy.nodes.YieldStmt') -> T: pass + def visit_yield_from_stmt(self, o: 'mypy.nodes.YieldFromStmt') -> T: pass + def visit_del_stmt(self, o: 'mypy.nodes.DelStmt') -> T: pass + def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T: pass + def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T: pass + def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T: pass + def visit_pass_stmt(self, o: 'mypy.nodes.PassStmt') -> T: pass + def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T: pass + def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T: pass + def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T: pass + def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> T: pass - + # Expressions - + def visit_int_expr(self, o: 'mypy.nodes.IntExpr') -> T: pass + def visit_str_expr(self, o: 'mypy.nodes.StrExpr') -> T: pass + def visit_bytes_expr(self, o: 'mypy.nodes.BytesExpr') -> T: pass + def visit_unicode_expr(self, o: 'mypy.nodes.UnicodeExpr') -> T: pass + def visit_float_expr(self, o: 'mypy.nodes.FloatExpr') -> T: pass + def visit_paren_expr(self, o: 'mypy.nodes.ParenExpr') -> T: pass + def visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> T: pass + def visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> T: pass + def visit_yield_from_expr(self, o: 'mypy.nodes.YieldFromExpr') -> T: pass + def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T: pass + def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> T: pass + + def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T: + pass + def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> T: pass + def visit_super_expr(self, o: 'mypy.nodes.SuperExpr') -> T: pass + def visit_unary_expr(self, o: 'mypy.nodes.UnaryExpr') -> T: pass + def visit_list_expr(self, o: 'mypy.nodes.ListExpr') -> T: pass + def visit_dict_expr(self, o: 'mypy.nodes.DictExpr') -> T: pass + def visit_tuple_expr(self, o: 'mypy.nodes.TupleExpr') -> T: pass + def visit_set_expr(self, o: 'mypy.nodes.SetExpr') -> T: pass + def visit_index_expr(self, o: 'mypy.nodes.IndexExpr') -> T: pass + def visit_undefined_expr(self, o: 'mypy.nodes.UndefinedExpr') -> T: pass + def visit_type_application(self, o: 'mypy.nodes.TypeApplication') -> T: pass + def visit_func_expr(self, o: 'mypy.nodes.FuncExpr') -> T: pass + def visit_list_comprehension(self, o: 'mypy.nodes.ListComprehension') -> T: pass + def visit_generator_expr(self, o: 'mypy.nodes.GeneratorExpr') -> T: pass + def visit_slice_expr(self, o: 'mypy.nodes.SliceExpr') -> T: pass + def visit_conditional_expr(self, o: 'mypy.nodes.ConditionalExpr') -> T: pass + def visit_type_var_expr(self, o: 'mypy.nodes.TypeVarExpr') -> T: pass + def visit_ducktype_expr(self, o: 'mypy.nodes.DucktypeExpr') -> T: pass + def visit_disjointclass_expr(self, o: 'mypy.nodes.DisjointclassExpr') -> T: pass - + def visit_coerce_expr(self, o: 'mypy.nodes.CoerceExpr') -> T: pass + def visit_type_expr(self, o: 'mypy.nodes.TypeExpr') -> T: pass + def visit_java_cast(self, o: 'mypy.nodes.JavaCast') -> T: pass - + def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> T: pass diff --git a/stubs/3.4/asyncio/futures.py b/stubs/3.4/asyncio/futures.py index b818102eb1b9..d14361788e0d 100644 --- a/stubs/3.4/asyncio/futures.py +++ b/stubs/3.4/asyncio/futures.py @@ -6,7 +6,7 @@ # ] __all__ = ['Future'] -T = typevar('T') +_T = typevar('_T') class _TracebackLogger: __slots__ = [] # type: List[str] @@ -17,7 +17,7 @@ def activate(self) -> None: pass def clear(self) -> None: pass def __del__(self) -> None: pass -class Future(Iterator[T], Generic[T]): # (Iterable[T], Generic[T]) +class Future(Iterator[_T], Generic[_T]): # (Iterable[_T], Generic[_T]) _state = '' _exception = Any #Exception _blocking = False @@ -30,12 +30,12 @@ def cancel(self) -> bool: pass def _schedule_callbacks(self) -> None: pass def cancelled(self) -> bool: pass def done(self) -> bool: pass - def result(self) -> T: pass + def result(self) -> _T: pass def exception(self) -> Any: pass - def add_done_callback(self, fn: Function[[Future[T]],Any]) -> None: pass - def remove_done_callback(self, fn: Function[[Future[T]], Any]) -> int: pass - def set_result(self, result: T) -> None: pass + def add_done_callback(self, fn: Function[[Future[_T]],Any]) -> None: pass + def remove_done_callback(self, fn: Function[[Future[_T]], Any]) -> int: pass + def set_result(self, result: _T) -> None: pass def set_exception(self, exception: Any) -> None: pass def _copy_state(self, other: Any) -> None: pass - def __iter__(self) -> 'Iterator[T]': pass - def __next__(self) -> 'T': pass + def __iter__(self) -> 'Iterator[_T]': pass + def __next__(self) -> '_T': pass diff --git a/stubs/3.4/asyncio/tasks.py b/stubs/3.4/asyncio/tasks.py index 94db71a814c5..ae39eb7aaf53 100644 --- a/stubs/3.4/asyncio/tasks.py +++ b/stubs/3.4/asyncio/tasks.py @@ -13,24 +13,24 @@ FIRST_EXCEPTION = 'FIRST_EXCEPTION' FIRST_COMPLETED = 'FIRST_COMPLETED' ALL_COMPLETED = 'ALL_COMPLETED' -T = typevar('T') +_T = typevar('_T') def coroutine(f: Any) -> Any: pass # Here comes and go a function -def sleep(delay: float, result: T=None, loop: AbstractEventLoop=None) -> Future[T]: pass +def sleep(delay: float, result: T=None, loop: AbstractEventLoop=None) -> Future[_T]: pass def wait(fs: List[Any], *, loop: AbstractEventLoop=None, - timeout: float=None, return_when: str=ALL_COMPLETED) -> Future[Tuple[Set[Future[T]], Set[Future[T]]]]: pass -def wait_for(fut: Future[T], timeout: float, *, loop: AbstractEventLoop=None) -> Future[T]: pass -# def wait(fs: Union[List[Iterable], List[Future[T]]], *, loop: AbstractEventLoop=None, -# timeout: int=None, return_when: str=ALL_COMPLETED) -> Future[Tuple[Set[Future[T]], Set[Future[T]]]]: pass + timeout: float=None, return_when: str=ALL_COMPLETED) -> Future[Tuple[Set[Future[_T]], Set[Future[_T]]]]: pass +def wait_for(fut: Future[_T], timeout: float, *, loop: AbstractEventLoop=None) -> Future[_T]: pass +# def wait(fs: Union[List[Iterable], List[Future[_T]]], *, loop: AbstractEventLoop=None, +# timeout: int=None, return_when: str=ALL_COMPLETED) -> Future[Tuple[Set[Future[_T]], Set[Future[_T]]]]: pass -class Task(Future[T], Generic[T]): +class Task(Future[_T], Generic[_T]): _all_tasks = None # type: Set[Task] _current_tasks = {} # type: Dict[AbstractEventLoop, Task] @classmethod def current_task(cls, loop: AbstractEventLoop=None) -> Task: pass @classmethod def all_tasks(cls, loop: AbstractEventLoop=None) -> Set[Task]: pass - # def __init__(self, coro: Union[Iterable[T], Future[T]], *, loop: AbstractEventLoop=None) -> None: pass - def __init__(self, coro: Future[T], *, loop: AbstractEventLoop=None) -> None: pass + # def __init__(self, coro: Union[Iterable[_T], Future[_T]], *, loop: AbstractEventLoop=None) -> None: pass + def __init__(self, coro: Future[_T], *, loop: AbstractEventLoop=None) -> None: pass def __repr__(self) -> str: pass def get_stack(self, *, limit: int=None) -> List[Any]: pass # return List[stackframe] def print_stack(self, *, limit: int=None, file: TextIO=None) -> None: pass From 08fb60b990c83b3e4b53d2b1b0c750c670538ecd Mon Sep 17 00:00:00 2001 From: Rock Neurotiko Date: Mon, 15 Sep 2014 12:56:47 +0200 Subject: [PATCH 10/12] merge --- mypy/checker.py | 20 ++++++++++++++++++++ mypy/parse.py | 2 +- stubs/3.4/asyncio/tasks.py | 2 +- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 5eb9d4b08581..5127087ea71f 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1224,6 +1224,26 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type: not self.is_dynamic_function()): self.fail(messages.RETURN_VALUE_EXPECTED, s) + def wrap_generic_type(self, typ: Type, rtyp: Type, check_type: str, context: Context) -> Type: + n_diff = self.count_concatenated_types(rtyp, check_type) - self.count_concatenated_types(typ, check_type) + if n_diff >= 1: + return self.named_generic_type(check_type, [typ]) + elif n_diff == 0: + self.fail(messages.INCOMPATIBLE_RETURN_VALUE_TYPE + + ": expected {}, got {}".format(rtyp, typ), context) + return typ + return typ + + def count_concatenated_types(self, typ: Type, check_type: str) -> int: + c = 0 + while is_subtype(typ, self.named_type(check_type)): + c += 1 + if hasattr(typ, 'args') and typ.args: + typ = typ.args[0] + else: + return c + return c + def visit_yield_stmt(self, s: YieldStmt) -> Type: return_type = self.return_types[-1] if isinstance(return_type, Instance): diff --git a/mypy/parse.py b/mypy/parse.py index e086eca9b715..b0df50efb4cc 100755 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -775,6 +775,7 @@ def parse_assert_stmt(self) -> AssertStmt: def parse_yield_stmt(self) -> YieldStmt: yield_tok = self.expect('yield') expr = None # type: Node + node = YieldStmt(expr) if not isinstance(self.current(), Break): if isinstance(self.current(), Keyword) and self.current_str() == "from": # Not go if it's not from from_tok = self.expect("from") @@ -784,7 +785,6 @@ def parse_yield_stmt(self) -> YieldStmt: expr = self.parse_expression() node = YieldStmt(expr) br = self.expect_break() - node = YieldStmt(expr) self.set_repr(node, noderepr.SimpleStmtRepr(yield_tok, br)) return node diff --git a/stubs/3.4/asyncio/tasks.py b/stubs/3.4/asyncio/tasks.py index ae39eb7aaf53..d5ff03d5516a 100644 --- a/stubs/3.4/asyncio/tasks.py +++ b/stubs/3.4/asyncio/tasks.py @@ -15,7 +15,7 @@ ALL_COMPLETED = 'ALL_COMPLETED' _T = typevar('_T') def coroutine(f: Any) -> Any: pass # Here comes and go a function -def sleep(delay: float, result: T=None, loop: AbstractEventLoop=None) -> Future[_T]: pass +def sleep(delay: float, result: _T=None, loop: AbstractEventLoop=None) -> Future[_T]: pass def wait(fs: List[Any], *, loop: AbstractEventLoop=None, timeout: float=None, return_when: str=ALL_COMPLETED) -> Future[Tuple[Set[Future[_T]], Set[Future[_T]]]]: pass def wait_for(fut: Future[_T], timeout: float, *, loop: AbstractEventLoop=None) -> Future[_T]: pass From 39b70304577db862372817a657a0e43a88546193 Mon Sep 17 00:00:00 2001 From: Miguel Garcia Lafuente Date: Tue, 11 Nov 2014 11:01:27 +0100 Subject: [PATCH 11/12] corrected some issues & typos --- mypy/checker.py | 35 ++++++++++++++-------------- mypy/messages.py | 2 +- mypy/nodes.py | 1 + mypy/parse.py | 2 +- mypy/test/data/check-statements.test | 16 ++++++------- mypy/test/data/parse-errors.test | 2 +- stubs/3.4/asyncio/events.py | 34 +++++++++++++-------------- stubs/3.4/asyncio/futures.py | 8 +++---- stubs/3.4/asyncio/tasks.py | 26 ++++++++++----------- 9 files changed, 63 insertions(+), 63 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 5127087ea71f..f746318467ac 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1210,7 +1210,7 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type: self.fail(messages.NO_RETURN_VALUE_EXPECTED, s) else: if self.function_stack[-1].is_coroutine: # Something similar will be needed to mix return and yield - #If the function is a coroutine, wrap the return type in a Future + # If the function is a coroutine, wrap the return type in a Future typ = self.wrap_generic_type(typ, self.return_types[-1], 'asyncio.futures.Future', s) self.check_subtype( typ, self.return_types[-1], s, @@ -1225,20 +1225,21 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type: self.fail(messages.RETURN_VALUE_EXPECTED, s) def wrap_generic_type(self, typ: Type, rtyp: Type, check_type: str, context: Context) -> Type: - n_diff = self.count_concatenated_types(rtyp, check_type) - self.count_concatenated_types(typ, check_type) - if n_diff >= 1: + n_diff = self.count_nested_types(rtyp, check_type) - self.count_nested_types(typ, check_type) + if n_diff == 1: return self.named_generic_type(check_type, [typ]) - elif n_diff == 0: + elif n_diff == 0 or n_diff > 1: self.fail(messages.INCOMPATIBLE_RETURN_VALUE_TYPE + ": expected {}, got {}".format(rtyp, typ), context) return typ return typ - def count_concatenated_types(self, typ: Type, check_type: str) -> int: + def count_nested_types(self, typ: Type, check_type: str) -> int: c = 0 while is_subtype(typ, self.named_type(check_type)): c += 1 - if hasattr(typ, 'args') and typ.args: + typ = map_instance_to_supertype(typ, self.lookup_typeinfo(check_type)) + if typ.args: typ = typ.args[0] else: return c @@ -1268,7 +1269,7 @@ def visit_yield_from_stmt(self, s: YieldFromStmt) -> Type: return_type = self.return_types[-1] type_func = self.accept(s.expr, return_type) if isinstance(type_func, Instance): - if hasattr(type_func, 'type') and hasattr(type_func.type, 'fullname') and type_func.type.fullname() == 'asyncio.futures.Future': + if type_func.type.fullname() == 'asyncio.futures.Future': # if is a Future, in stmt don't need to do nothing # because the type Future[Some] jus matters to the main loop # that python executes, in statement we shouldn't get the Future, @@ -1277,15 +1278,15 @@ def visit_yield_from_stmt(self, s: YieldFromStmt) -> Type: elif is_subtype(type_func, self.named_type('typing.Iterable')): # If it's and Iterable-Like, let's check the types. # Maybe just check if have __iter__? (like in analyse_iterable) - self.check_iterable_yf(s) + self.check_iterable_yield_from(s) else: - self.msg.yield_from_not_valid_applied(type_func, s) + self.msg.yield_from_invalid_operand_type(type_func, s) elif isinstance(type_func, AnyType): - self.check_iterable_yf(s) + self.check_iterable_yield_from(s) else: - self.msg.yield_from_not_valid_applied(type_func, s) + self.msg.yield_from_invalid_operand_type(type_func, s) - def check_iterable_yf(self, s: YieldFromStmt) -> Type: + def check_iterable_yield_from(self, s: YieldFromStmt) -> Type: """ Check that return type is super type of Iterable (Maybe just check if have __iter__?) and compare it with the type of the expression @@ -1295,9 +1296,9 @@ def check_iterable_yf(self, s: YieldFromStmt) -> Type: if not is_subtype(expected_item_type, self.named_type('typing.Iterable')): self.fail(messages.INVALID_RETURN_TYPE_FOR_YIELD_FROM, s) return None - elif hasattr(expected_item_type, 'args') and expected_item_type.args: + elif expected_item_type.args: + expected_item_type = map_instance_to_supertype(expected_item_type, self.lookup_typeinfo('typing.Iterable')) expected_item_type = expected_item_type.args[0] # Take the item inside the iterator - # expected_item_type = expected_item_type elif isinstance(expected_item_type, AnyType): expected_item_type = AnyType() else: @@ -1308,6 +1309,7 @@ def check_iterable_yf(self, s: YieldFromStmt) -> Type: else: actual_item_type = self.accept(s.expr, expected_item_type) if hasattr(actual_item_type, 'args') and actual_item_type.args: + actual_item_type = map_instance_to_supertype(actual_item_type, self.lookup_typeinfo('typing.Iterable')) actual_item_type = actual_item_type.args[0] # Take the item inside the iterator self.check_subtype(actual_item_type, expected_item_type, s, messages.INCOMPATIBLE_TYPES_IN_YIELD_FROM, @@ -1625,7 +1627,7 @@ def visit_call_expr(self, e: CallExpr) -> Type: def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: result = self.expr_checker.visit_yield_from_expr(e) - if hasattr(result, 'type') and result.type.fullname() == "asyncio.futures.Future": + if result.type.fullname() == "asyncio.futures.Future": self.function_stack[-1].is_coroutine = True # Set the function as coroutine result = result.args[0] # Set the return type as the type inside elif is_subtype(result, self.named_type('typing.Iterable')): @@ -1634,8 +1636,7 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: # Maybe set result like in the Future pass else: - self.msg.yield_from_not_valid_applied(e.expr, e) - self.breaking_out = False + self.msg.yield_from_invalid_operand_type(e.expr, e) return result def visit_member_expr(self, e: MemberExpr) -> Type: diff --git a/mypy/messages.py b/mypy/messages.py index 7e86461e177a..8825c5df1d7f 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -637,7 +637,7 @@ def signatures_incompatible(self, method: str, other_method: str, self.fail('Signatures of "{}" and "{}" are incompatible'.format( method, other_method), context) - def yield_from_not_valid_applied(self, expr: Type, context: Context) -> Type: + def yield_from_invalid_operand_type(self, expr: Type, context: Context) -> Type: text = self.format(expr) if self.format(expr) != 'object' else expr self.fail('"yield from" can\'t be applied to {}'.format(text), context) return AnyType() diff --git a/mypy/nodes.py b/mypy/nodes.py index e6e9bb25b453..3b8e80b30bd6 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -936,6 +936,7 @@ def __init__(self, expr: Node) -> None: def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_yield_from_expr(self) + class IndexExpr(Node): """Index expression x[y]. diff --git a/mypy/parse.py b/mypy/parse.py index b0df50efb4cc..63f392dc82ca 100755 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -742,7 +742,7 @@ def parse_return_stmt(self) -> ReturnStmt: expr = None # type: Node if not isinstance(self.current(), Break): expr = self.parse_expression() - if isinstance(expr, YieldFromExpr): #cant go a yield from expr + if isinstance(expr, YieldFromExpr): # "yield from" expressions can't be returned. return None br = self.expect_break() node = ReturnStmt(expr) diff --git a/mypy/test/data/check-statements.test b/mypy/test/data/check-statements.test index 97e8b7fbcc13..1ae73c697fb4 100644 --- a/mypy/test/data/check-statements.test +++ b/mypy/test/data/check-statements.test @@ -626,7 +626,7 @@ def f() -> Iterator[None]: -- Iterables -- ---------- -[case testSimpleYFIter] +[case testSimpleYieldFromWithIterator] from typing import Iterator def g() -> Iterator[str]: yield '42' @@ -638,7 +638,7 @@ def f() -> Iterator[str]: [out] main: In function "f": -[case testYFAppliedToAny] +[case testYieldFromAppliedToAny] from typing import Any def g() -> Any: yield object() @@ -646,7 +646,7 @@ def f() -> Any: yield from g() [out] -[case testYFInFunctionReturningFunction] +[case testYieldFromInFunctionReturningFunction] from typing import Iterator, Function def g() -> Iterator[int]: yield 42 @@ -655,7 +655,7 @@ def f() -> Function[[], None]: [out] main: In function "f": -[case testGoodYFNotIterableReturnType] +[case testYieldFromNotIterableReturnType] from typing import Iterator def g() -> Iterator[int]: yield 42 @@ -664,7 +664,7 @@ def f() -> int: [out] main: In function "f": -[case testYFNotAppliedIter] +[case testYieldFromNotAppliedIterator] from typing import Iterator def g() -> int: return 42 @@ -673,7 +673,7 @@ def f() -> Iterator[int]: [out] main: In function "f": -[case testYFCheckIncompatibleTypesTwoIterables] +[case testYieldFromCheckIncompatibleTypesTwoIterables] from typing import List, Iterator def g() -> Iterator[List[int]]: yield [2, 3, 4] @@ -684,13 +684,13 @@ def f() -> Iterator[List[int]]: [out] main: In function "f": -[case testYFNotAppliedToNothing] +[case testYieldFromNotAppliedToNothing] def h(): yield from # E: Parse error before end of line [out] main: In function "h": -[case testYFAndYieldTogether] +[case testYieldFromAndYieldTogether] from typing import Iterator def f() -> Iterator[str]: yield "g1 ham" diff --git a/mypy/test/data/parse-errors.test b/mypy/test/data/parse-errors.test index e24eea980545..2f0a8a431594 100644 --- a/mypy/test/data/parse-errors.test +++ b/mypy/test/data/parse-errors.test @@ -348,7 +348,7 @@ def f(): file: In function "f": file, line 2: Parse error before end of line -[case testYielFromAfterReturn] +[case testYieldFromAfterReturn] def f(): return yield from h() [out] diff --git a/stubs/3.4/asyncio/events.py b/stubs/3.4/asyncio/events.py index 650e986396ee..8a1be6cfdfac 100644 --- a/stubs/3.4/asyncio/events.py +++ b/stubs/3.4/asyncio/events.py @@ -61,36 +61,36 @@ def set_default_executor(self, executor: Any) -> None: pass # Network I/O methods returning Futures. @abstractmethod def getaddrinfo(self, host: str, port: int, *, - family: int=0, type: int=0, proto: int=0, flags: int=0) -> List[Tuple[int, int, int, str, tuple]]: pass + family: int = 0, type: int = 0, proto: int = 0, flags: int = 0) -> List[Tuple[int, int, int, str, tuple]]: pass @abstractmethod - def getnameinfo(self, sockaddr: tuple, flags: int=0) -> Tuple[str, int]: pass + def getnameinfo(self, sockaddr: tuple, flags: int = 0) -> Tuple[str, int]: pass @abstractmethod - def create_connection(self, protocol_factory: Any, host: str=None, port: int=None, *, - ssl: Any=None, family: int=0, proto: int=0, flags: int=0, sock: Any=None, - local_addr: str=None, server_hostname: str=None) -> tuple: pass + def create_connection(self, protocol_factory: Any, host: str = None, port: int = None, *, + ssl: Any = None, family: int = 0, proto: int = 0, flags: int = 0, sock: Any = None, + local_addr: str = None, server_hostname: str = None) -> tuple: pass # ?? check Any # return (Transport, Protocol) @abstractmethod - def create_server(self, protocol_factory: Any, host: str=None, port: int=None, *, - family: int=AF_UNSPEC, flags: int=AI_PASSIVE, - sock: Any=None, backlog: int=100, ssl: Any=None, reuse_address: Any=None) -> Any: pass + def create_server(self, protocol_factory: Any, host: str = None, port: int = None, *, + family: int = AF_UNSPEC, flags: int = AI_PASSIVE, + sock: Any = None, backlog: int = 100, ssl: Any = None, reuse_address: Any = None) -> Any: pass # ?? check Any # return Server @abstractmethod def create_unix_connection(self, protocol_factory: Any, path: str, *, - ssl: Any=None, sock: Any=None, - server_hostname: str=None) -> tuple: pass + ssl: Any = None, sock: Any = None, + server_hostname: str = None) -> tuple: pass # ?? check Any # return tuple(Transport, Protocol) @abstractmethod def create_unix_server(self, protocol_factory: Any, path: str, *, - sock: Any=None, backlog: int=100, ssl: Any=None) -> Any: pass + sock: Any = None, backlog: int = 100, ssl: Any = None) -> Any: pass # ?? check Any # return Server @abstractmethod def create_datagram_endpoint(self, protocol_factory: Any, - local_addr: str=None, remote_addr: str=None, *, - family: int=0, proto: int=0, flags: int=0) -> tuple: pass + local_addr: str = None, remote_addr: str = None, *, + family: int = 0, proto: int = 0, flags: int = 0) -> tuple: pass #?? check Any # return (Transport, Protocol) # Pipes and subprocesses. @@ -103,14 +103,14 @@ def connect_write_pipe(self, protocol_factory: Any, pipe: Any) -> tuple: pass #?? check Any # return (Transport, Protocol) @abstractmethod - def subprocess_shell(self, protocol_factory: Any, cmd: Union[bytes,str], *, stdin: Any=PIPE, - stdout: Any=PIPE, stderr: Any=PIPE, + def subprocess_shell(self, protocol_factory: Any, cmd: Union[bytes, str], *, stdin: Any = PIPE, + stdout: Any = PIPE, stderr: Any = PIPE, **kwargs: Dict[str, Any]) -> tuple: pass #?? check Any # return (Transport, Protocol) @abstractmethod - def subprocess_exec(self, protocol_factory: Any, *args: List[Any], stdin: Any=PIPE, - stdout: Any=PIPE, stderr: Any=PIPE, + def subprocess_exec(self, protocol_factory: Any, *args: List[Any], stdin: Any = PIPE, + stdout: Any = PIPE, stderr: Any = PIPE, **kwargs: Dict[str, Any]) -> tuple: pass #?? check Any # return (Transport, Protocol) diff --git a/stubs/3.4/asyncio/futures.py b/stubs/3.4/asyncio/futures.py index d14361788e0d..b08e3430bfd8 100644 --- a/stubs/3.4/asyncio/futures.py +++ b/stubs/3.4/asyncio/futures.py @@ -17,7 +17,7 @@ def activate(self) -> None: pass def clear(self) -> None: pass def __del__(self) -> None: pass -class Future(Iterator[_T], Generic[_T]): # (Iterable[_T], Generic[_T]) +class Future(Iterator[_T], Generic[_T]): _state = '' _exception = Any #Exception _blocking = False @@ -32,10 +32,10 @@ def cancelled(self) -> bool: pass def done(self) -> bool: pass def result(self) -> _T: pass def exception(self) -> Any: pass - def add_done_callback(self, fn: Function[[Future[_T]],Any]) -> None: pass + def add_done_callback(self, fn: Function[[Future[_T]], Any]) -> None: pass def remove_done_callback(self, fn: Function[[Future[_T]], Any]) -> int: pass def set_result(self, result: _T) -> None: pass def set_exception(self, exception: Any) -> None: pass def _copy_state(self, other: Any) -> None: pass - def __iter__(self) -> 'Iterator[_T]': pass - def __next__(self) -> '_T': pass + def __iter__(self) -> Iterator[_T]: pass + def __next__(self) -> _T: pass diff --git a/stubs/3.4/asyncio/tasks.py b/stubs/3.4/asyncio/tasks.py index d5ff03d5516a..dc69607e80e1 100644 --- a/stubs/3.4/asyncio/tasks.py +++ b/stubs/3.4/asyncio/tasks.py @@ -14,27 +14,25 @@ FIRST_COMPLETED = 'FIRST_COMPLETED' ALL_COMPLETED = 'ALL_COMPLETED' _T = typevar('_T') -def coroutine(f: Any) -> Any: pass # Here comes and go a function -def sleep(delay: float, result: _T=None, loop: AbstractEventLoop=None) -> Future[_T]: pass -def wait(fs: List[Any], *, loop: AbstractEventLoop=None, - timeout: float=None, return_when: str=ALL_COMPLETED) -> Future[Tuple[Set[Future[_T]], Set[Future[_T]]]]: pass -def wait_for(fut: Future[_T], timeout: float, *, loop: AbstractEventLoop=None) -> Future[_T]: pass -# def wait(fs: Union[List[Iterable], List[Future[_T]]], *, loop: AbstractEventLoop=None, -# timeout: int=None, return_when: str=ALL_COMPLETED) -> Future[Tuple[Set[Future[_T]], Set[Future[_T]]]]: pass +def coroutine(f: _T) -> _T: pass # Here comes and go a function +def sleep(delay: float, result: _T = None, loop: AbstractEventLoop = None) -> Future[_T]: pass +def wait(fs: List[Future[_T]], *, loop: AbstractEventLoop = None, + timeout: float = None, return_when: str = ALL_COMPLETED) -> Future[Tuple[Set[Future[_T]], Set[Future[_T]]]]: pass +def wait_for(fut: Future[_T], timeout: float, *, loop: AbstractEventLoop = None) -> Future[_T]: pass + class Task(Future[_T], Generic[_T]): _all_tasks = None # type: Set[Task] _current_tasks = {} # type: Dict[AbstractEventLoop, Task] @classmethod - def current_task(cls, loop: AbstractEventLoop=None) -> Task: pass + def current_task(cls, loop: AbstractEventLoop = None) -> Task: pass @classmethod - def all_tasks(cls, loop: AbstractEventLoop=None) -> Set[Task]: pass - # def __init__(self, coro: Union[Iterable[_T], Future[_T]], *, loop: AbstractEventLoop=None) -> None: pass - def __init__(self, coro: Future[_T], *, loop: AbstractEventLoop=None) -> None: pass + def all_tasks(cls, loop: AbstractEventLoop = None) -> Set[Task]: pass + def __init__(self, coro: Future[_T], *, loop: AbstractEventLoop = None) -> None: pass def __repr__(self) -> str: pass - def get_stack(self, *, limit: int=None) -> List[Any]: pass # return List[stackframe] - def print_stack(self, *, limit: int=None, file: TextIO=None) -> None: pass + def get_stack(self, *, limit: int = None) -> List[Any]: pass # return List[stackframe] + def print_stack(self, *, limit: int = None, file: TextIO = None) -> None: pass def cancel(self) -> bool: pass - def _step(self, value: Any=None, exc: Exception=None) -> None: pass + def _step(self, value: Any = None, exc: Exception = None) -> None: pass def _wakeup(self, future: Future[Any]) -> None: pass From 6e72952e92c4c751123ca3aa855bc0fe2e5f2cf2 Mon Sep 17 00:00:00 2001 From: Rock Neurotiko Date: Tue, 11 Nov 2014 22:04:35 +0100 Subject: [PATCH 12/12] Changed the asyncio full tests to pass only in python >= 3.4 --- mypy/test/testpythoneval.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mypy/test/testpythoneval.py b/mypy/test/testpythoneval.py index 227a67fc7dab..42cde428ab72 100644 --- a/mypy/test/testpythoneval.py +++ b/mypy/test/testpythoneval.py @@ -25,8 +25,9 @@ # Files which contain test case descriptions. python_eval_files = ['pythoneval.test', - 'python2eval.test', - 'pythoneval-asyncio.test'] + 'python2eval.test'] + +python_34_eval_files = ['pythoneval-asyncio.test'] # Path to Python 3 interpreter python3_path = 'python3' @@ -40,6 +41,10 @@ def cases(self): for f in python_eval_files: c += parse_test_cases(os.path.join(test_data_prefix, f), test_python_evaluation, test_temp_dir, True) + if sys.version_info.major == 3 and sys.version_info.minor >= 4: + for f in python_34_eval_files: + c += parse_test_cases(os.path.join(test_data_prefix, f), + test_python_evaluation, test_temp_dir, True) return c