diff --git a/mypy/checker.py b/mypy/checker.py index 5bacc9d2bb62..df622c47a4bd 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -11,7 +11,7 @@ from mypy.errors import Errors, report_internal_error from mypy.nodes import ( - SymbolTable, Node, MypyFile, Var, + SymbolTable, Node, MypyFile, Var, Expression, OverloadedFuncDef, FuncDef, FuncItem, FuncBase, TypeInfo, ClassDef, GDEF, Block, AssignmentStmt, NameExpr, MemberExpr, IndexExpr, TupleExpr, ListExpr, ExpressionStmt, ReturnStmt, IfStmt, @@ -25,6 +25,7 @@ YieldFromExpr, NamedTupleExpr, SetComprehension, DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr, RefExpr, YieldExpr, BackquoteExpr, ImportFrom, ImportAll, ImportBase, + AwaitExpr, CONTRAVARIANT, COVARIANT ) from mypy.nodes import function_type, method_type, method_type_with_fallback @@ -257,21 +258,61 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: self.msg.overloaded_signatures_overlap(i + 1, i + j + 2, item.func) - def is_generator_return_type(self, typ: Type) -> bool: - return is_subtype(self.named_generic_type('typing.Generator', - [AnyType(), AnyType(), AnyType()]), - typ) + # Here's the scoop about generators and coroutines. + # + # There are two kinds of generators: classic generators (functions + # with `yield` or `yield from` in the body) and coroutines + # (functions declared with `async def`). The latter are specified + # in PEP 492 and only available in Python >= 3.5. + # + # Classic generators can be parameterized with three types: + # - ty is the yield type (the type of y in `yield y`) + # - ts is the type received by yield (the type of s in `s = yield`) + # (it's named `ts` after `send()`, since `tr` is `return`). + # - tr is the return type (the type of r in `return r`) + # + # A classic generator must define a return type that's either + # `Generator[ty, ts, tr]`, Iterator[ty], or Iterable[ty] (or + # object or Any). If ts/tr are not given, both are Void. + # + # A coroutine must define a return type corresponding to tr; the + # other two are unconstrained. The "external" return type (seen + # by the caller) is Awaitable[tr]. + # + # There are several useful methods, each taking a type t and a + # flag c indicating whether it's for a generator or coroutine: + # + # - is_generator_return_type(t, c) returns whether t is a Generator, + # Iterator, Iterable (if not c), or Awaitable (if c). + # - get_generator_yield_type(t, c) returns ty. + # - get_generator_receive_type(t, c) returns ts. + # - get_generator_return_type(t, c) returns tr. - def get_generator_yield_type(self, return_type: Type) -> Type: + def is_generator_return_type(self, typ: Type, is_coroutine: bool) -> bool: + """Is `typ` a valid type for a generator/coroutine? + + True if either Generator or Awaitable is a supertype of `typ`. + """ + if is_coroutine: + at = self.named_generic_type('typing.Awaitable', [AnyType()]) + return is_subtype(at, typ) + else: + gt = self.named_generic_type('typing.Generator', [AnyType(), AnyType(), AnyType()]) + return is_subtype(gt, typ) + + def get_generator_yield_type(self, return_type: Type, is_coroutine: bool) -> Type: + """Given the declared return type of a generator (t), return the type it yields (ty).""" if isinstance(return_type, AnyType): return AnyType() - elif not self.is_generator_return_type(return_type): + elif not self.is_generator_return_type(return_type, is_coroutine): # If the function doesn't have a proper Generator (or superclass) return type, anything # is permissible. return AnyType() elif not isinstance(return_type, Instance): # Same as above, but written as a separate branch so the typechecker can understand. return AnyType() + elif return_type.type.fullname() == 'typing.Awaitable': + return AnyType() elif return_type.args: return return_type.args[0] else: @@ -280,10 +321,11 @@ def get_generator_yield_type(self, return_type: Type) -> Type: # be accessed so any type is acceptable. return AnyType() - def get_generator_receive_type(self, return_type: Type) -> Type: + def get_generator_receive_type(self, return_type: Type, is_coroutine: bool) -> Type: + """Given a declared generator return type (t), return the type its yield receives (ts).""" if isinstance(return_type, AnyType): return AnyType() - elif not self.is_generator_return_type(return_type): + elif not self.is_generator_return_type(return_type, is_coroutine): # If the function doesn't have a proper Generator (or superclass) return type, anything # is permissible. return AnyType() @@ -291,17 +333,25 @@ def get_generator_receive_type(self, return_type: Type) -> Type: # Same as above, but written as a separate branch so the typechecker can understand. return AnyType() elif return_type.type.fullname() == 'typing.Generator': - # Generator is the only type which specifies the type of values it can receive. - return return_type.args[1] + # Generator is one of the two types which specify the type of values it can receive. + if len(return_type.args) == 3: + return return_type.args[1] + else: + return AnyType() + elif return_type.type.fullname() == 'typing.Awaitable': + # Awaitable is one of the two types which specify the type of values it can receive. + # According to the stub this is always `Any`. + return AnyType() else: # `return_type` is a supertype of Generator, so callers won't be able to send it # values. return Void() - def get_generator_return_type(self, return_type: Type) -> Type: + def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Type: + """Given the declared return type of a generator (t), return the type it returns (tr).""" if isinstance(return_type, AnyType): return AnyType() - elif not self.is_generator_return_type(return_type): + elif not self.is_generator_return_type(return_type, is_coroutine): # If the function doesn't have a proper Generator (or superclass) return type, anything # is permissible. return AnyType() @@ -309,14 +359,38 @@ def get_generator_return_type(self, return_type: Type) -> Type: # Same as above, but written as a separate branch so the typechecker can understand. return AnyType() elif return_type.type.fullname() == 'typing.Generator': - # Generator is the only type which specifies the type of values it returns into - # `yield from` expressions. - return return_type.args[2] + # Generator is one of the two types which specify the type of values it returns into + # `yield from` expressions (using a `return` statement). + if len(return_type.args) == 3: + return return_type.args[2] + else: + return AnyType() + elif return_type.type.fullname() == 'typing.Awaitable': + # Awaitable is the other type which specifies the type of values it returns into + # `yield from` expressions (using `return`). + if len(return_type.args) == 1: + return return_type.args[0] + else: + return AnyType() else: # `return_type` is supertype of Generator, so callers won't be able to see the return # type when used in a `yield from` expression. return AnyType() + def check_awaitable_expr(self, t: Type, ctx: Context, msg: str) -> Type: + """Check the argument to `await` and extract the type of value. + + Also used by `async for` and `async with`. + """ + if not self.check_subtype(t, self.named_type('typing.Awaitable'), ctx, + msg, 'actual type', 'expected type'): + return AnyType() + else: + echk = self.expr_checker + method = echk.analyze_external_member_access('__await__', t, ctx) + generator = echk.check_call(method, [], [], ctx)[0] + return self.get_generator_return_type(generator, False) + def visit_func_def(self, defn: FuncDef) -> Type: """Type check a function definition.""" self.check_func_item(defn, name=defn.name()) @@ -447,7 +521,7 @@ def is_implicit_any(t: Type) -> bool: # Check that Generator functions have the appropriate return type. if defn.is_generator: - if not self.is_generator_return_type(typ.ret_type): + if not self.is_generator_return_type(typ.ret_type, defn.is_coroutine): self.fail(messages.INVALID_RETURN_TYPE_FOR_GENERATOR, typ) # Python 2 generators aren't allowed to return values. @@ -1336,8 +1410,10 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type: """Type check a return statement.""" self.binder.breaking_out = True if self.is_within_function(): - if self.function_stack[-1].is_generator: - return_type = self.get_generator_return_type(self.return_types[-1]) + defn = self.function_stack[-1] + if defn.is_generator: + return_type = self.get_generator_return_type(self.return_types[-1], + defn.is_coroutine) else: return_type = self.return_types[-1] @@ -1604,10 +1680,32 @@ def visit_except_handler_test(self, n: Node) -> Type: def visit_for_stmt(self, s: ForStmt) -> Type: """Type check a for statement.""" - item_type = self.analyze_iterable_item_type(s.expr) + if s.is_async: + item_type = self.analyze_async_iterable_item_type(s.expr) + else: + item_type = self.analyze_iterable_item_type(s.expr) self.analyze_index_variables(s.index, item_type, s) self.accept_loop(s.body, s.else_body) + def analyze_async_iterable_item_type(self, expr: Node) -> Type: + """Analyse async iterable expression and return iterator item type.""" + iterable = self.accept(expr) + + self.check_not_void(iterable, expr) + + self.check_subtype(iterable, + self.named_generic_type('typing.AsyncIterable', + [AnyType()]), + expr, messages.ASYNC_ITERABLE_EXPECTED) + + echk = self.expr_checker + method = echk.analyze_external_member_access('__aiter__', iterable, expr) + iterator = echk.check_call(method, [], [], expr)[0] + method = echk.analyze_external_member_access('__anext__', iterator, expr) + awaitable = echk.check_call(method, [], [], expr)[0] + return self.check_awaitable_expr(awaitable, expr, + messages.INCOMPATIBLE_TYPES_IN_ASYNC_FOR) + def analyze_iterable_item_type(self, expr: Node) -> Type: """Analyse iterable expression and return iterator item type.""" iterable = self.accept(expr) @@ -1714,18 +1812,39 @@ def check_incompatible_property_override(self, e: Decorator) -> None: self.fail(messages.READ_ONLY_PROPERTY_OVERRIDES_READ_WRITE, e) def visit_with_stmt(self, s: WithStmt) -> Type: - echk = self.expr_checker for expr, target in zip(s.expr, s.target): - ctx = self.accept(expr) - enter = echk.analyze_external_member_access('__enter__', ctx, expr) - obj = echk.check_call(enter, [], [], expr)[0] - if target: - self.check_assignment(target, self.temp_node(obj, expr)) - exit = echk.analyze_external_member_access('__exit__', ctx, expr) - arg = self.temp_node(AnyType(), expr) - echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr) + if s.is_async: + self.check_async_with_item(expr, target) + else: + self.check_with_item(expr, target) self.accept(s.body) + def check_async_with_item(self, expr: Expression, target: Expression) -> None: + echk = self.expr_checker + ctx = self.accept(expr) + enter = echk.analyze_external_member_access('__aenter__', ctx, expr) + obj = echk.check_call(enter, [], [], expr)[0] + obj = self.check_awaitable_expr( + obj, expr, messages.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AENTER) + if target: + self.check_assignment(target, self.temp_node(obj, expr)) + exit = echk.analyze_external_member_access('__aexit__', ctx, expr) + arg = self.temp_node(AnyType(), expr) + res = echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr)[0] + self.check_awaitable_expr( + res, expr, messages.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT) + + def check_with_item(self, expr: Expression, target: Expression) -> None: + echk = self.expr_checker + ctx = self.accept(expr) + enter = echk.analyze_external_member_access('__enter__', ctx, expr) + obj = echk.check_call(enter, [], [], expr)[0] + if target: + self.check_assignment(target, self.temp_node(obj, expr)) + exit = echk.analyze_external_member_access('__exit__', ctx, expr) + arg = self.temp_node(AnyType(), expr) + echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr) + def visit_print_stmt(self, s: PrintStmt) -> Type: for arg in s.args: self.accept(arg) @@ -1771,8 +1890,8 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: # Check that the iterator's item type matches the type yielded by the Generator function # containing this `yield from` expression. - expected_item_type = self.get_generator_yield_type(return_type) - actual_item_type = self.get_generator_yield_type(iter_type) + expected_item_type = self.get_generator_yield_type(return_type, False) + actual_item_type = self.get_generator_yield_type(iter_type, False) self.check_subtype(actual_item_type, expected_item_type, e, messages.INCOMPATIBLE_TYPES_IN_YIELD_FROM, @@ -1781,10 +1900,14 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: # Determine the type of the entire yield from expression. if (isinstance(iter_type, Instance) and iter_type.type.fullname() == 'typing.Generator'): - return self.get_generator_return_type(iter_type) + return self.get_generator_return_type(iter_type, False) else: # Non-Generators don't return anything from `yield from` expressions. - return Void() + # However special-case Any (which might be produced by an error). + if isinstance(actual_item_type, AnyType): + return AnyType() + else: + return Void() def visit_member_expr(self, e: MemberExpr) -> Type: return self.expr_checker.visit_member_expr(e) @@ -1896,7 +2019,7 @@ def visit_backquote_expr(self, e: BackquoteExpr) -> Type: def visit_yield_expr(self, e: YieldExpr) -> Type: return_type = self.return_types[-1] - expected_item_type = self.get_generator_yield_type(return_type) + expected_item_type = self.get_generator_yield_type(return_type, False) if e.expr is None: if (not (isinstance(expected_item_type, Void) or isinstance(expected_item_type, AnyType)) @@ -1907,7 +2030,16 @@ def visit_yield_expr(self, e: YieldExpr) -> Type: self.check_subtype(actual_item_type, expected_item_type, e, messages.INCOMPATIBLE_TYPES_IN_YIELD, 'actual type', 'expected type') - return self.get_generator_receive_type(return_type) + return self.get_generator_receive_type(return_type, False) + + def visit_await_expr(self, e: AwaitExpr) -> Type: + expected_type = self.type_context[-1] + if expected_type is not None: + expected_type = self.named_generic_type('typing.Awaitable', [expected_type]) + actual_type = self.accept(e.expr, expected_type) + if isinstance(actual_type, AnyType): + return AnyType() + return self.check_awaitable_expr(actual_type, e, messages.INCOMPATIBLE_TYPES_IN_AWAIT) # # Helpers @@ -1916,10 +2048,12 @@ def visit_yield_expr(self, e: YieldExpr) -> Type: def check_subtype(self, subtype: Type, supertype: Type, context: Context, msg: str = messages.INCOMPATIBLE_TYPES, subtype_label: str = None, - supertype_label: str = None) -> None: + supertype_label: str = None) -> bool: """Generate an error if the subtype is not compatible with supertype.""" - if not is_subtype(subtype, supertype): + if is_subtype(subtype, supertype): + return True + else: if isinstance(subtype, Void): self.msg.does_not_return_value(subtype, context) else: @@ -1933,6 +2067,7 @@ def check_subtype(self, subtype: Type, supertype: Type, context: Context, if extra_info: msg += ' (' + ', '.join(extra_info) + ')' self.fail(msg, context) + return False def named_type(self, name: str) -> Instance: """Return an instance type with type given by the name and no diff --git a/mypy/fastparse.py b/mypy/fastparse.py index c76e8b98f583..1db7004f6713 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -14,9 +14,12 @@ UnaryExpr, FuncExpr, ComparisonExpr, StarExpr, YieldFromExpr, NonlocalDecl, DictionaryComprehension, SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument, + AwaitExpr, ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_STAR2 ) -from mypy.types import Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType +from mypy.types import ( + Type, CallableType, FunctionLike, AnyType, UnboundType, TupleType, TypeList, EllipsisType, +) from mypy import defaults from mypy import experiments from mypy.errors import Errors @@ -242,6 +245,17 @@ def visit_Module(self, mod: ast35.Module) -> Node: # arg? kwarg, expr* defaults) @with_line def visit_FunctionDef(self, n: ast35.FunctionDef) -> Node: + return self.do_func_def(n) + + # AsyncFunctionDef(identifier name, arguments args, + # stmt* body, expr* decorator_list, expr? returns, string? type_comment) + @with_line + def visit_AsyncFunctionDef(self, n: ast35.AsyncFunctionDef) -> Node: + return self.do_func_def(n, is_coroutine=True) + + def do_func_def(self, n: Union[ast35.FunctionDef, ast35.AsyncFunctionDef], + is_coroutine: bool = False) -> Node: + """Helper shared between visit_FunctionDef and visit_AsyncFunctionDef.""" args = self.transform_args(n.args, n.lineno) arg_kinds = [arg.kind for arg in args] @@ -285,6 +299,9 @@ def visit_FunctionDef(self, n: ast35.FunctionDef) -> Node: args, self.as_block(n.body, n.lineno), func_type) + if is_coroutine: + # A coroutine is also a generator, mostly for internal reasons. + func_def.is_generator = func_def.is_coroutine = True if func_type is not None: func_type.definition = func_def func_type.line = n.lineno @@ -345,9 +362,6 @@ def make_argument(arg: ast35.arg, default: Optional[ast35.expr], kind: int) -> A return new_args - # TODO: AsyncFunctionDef(identifier name, arguments args, - # stmt* body, expr* decorator_list, expr? returns, string? type_comment) - def stringify_name(self, n: ast35.AST) -> str: if isinstance(n, ast35.Name): return n.id @@ -419,7 +433,16 @@ def visit_For(self, n: ast35.For) -> Node: self.as_block(n.body, n.lineno), self.as_block(n.orelse, n.lineno)) - # TODO: AsyncFor(expr target, expr iter, stmt* body, stmt* orelse) + # AsyncFor(expr target, expr iter, stmt* body, stmt* orelse) + @with_line + def visit_AsyncFor(self, n: ast35.AsyncFor) -> Node: + r = ForStmt(self.visit(n.target), + self.visit(n.iter), + self.as_block(n.body, n.lineno), + self.as_block(n.orelse, n.lineno)) + r.is_async = True + return r + # While(expr test, stmt* body, stmt* orelse) @with_line def visit_While(self, n: ast35.While) -> Node: @@ -441,7 +464,14 @@ def visit_With(self, n: ast35.With) -> Node: [self.visit(i.optional_vars) for i in n.items], self.as_block(n.body, n.lineno)) - # TODO: AsyncWith(withitem* items, stmt* body) + # AsyncWith(withitem* items, stmt* body) + @with_line + def visit_AsyncWith(self, n: ast35.AsyncWith) -> Node: + r = WithStmt([self.visit(i.context_expr) for i in n.items], + [self.visit(i.optional_vars) for i in n.items], + self.as_block(n.body, n.lineno)) + r.is_async = True + return r # Raise(expr? exc, expr? cause) @with_line @@ -628,7 +658,11 @@ def visit_GeneratorExp(self, n: ast35.GeneratorExp) -> GeneratorExpr: iters, ifs_list) - # TODO: Await(expr value) + # Await(expr value) + @with_line + def visit_Await(self, n: ast35.Await) -> Node: + v = self.visit(n.value) + return AwaitExpr(v) # Yield(expr? value) @with_line diff --git a/mypy/messages.py b/mypy/messages.py index d37dcdb2be77..b4240f006dc1 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -40,6 +40,11 @@ INCOMPATIBLE_TYPES = 'Incompatible types' INCOMPATIBLE_TYPES_IN_ASSIGNMENT = 'Incompatible types in assignment' INCOMPATIBLE_REDEFINITION = 'Incompatible redefinition' +INCOMPATIBLE_TYPES_IN_AWAIT = 'Incompatible types in await' +INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AENTER = 'Incompatible types in "async with" for __aenter__' +INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT = 'Incompatible types in "async with" for __aexit__' +INCOMPATIBLE_TYPES_IN_ASYNC_FOR = 'Incompatible types in "async for"' + INCOMPATIBLE_TYPES_IN_YIELD = 'Incompatible types in yield' INCOMPATIBLE_TYPES_IN_YIELD_FROM = 'Incompatible types in "yield from"' INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION = 'Incompatible types in string interpolation' @@ -57,6 +62,7 @@ INCOMPATIBLE_VALUE_TYPE = 'Incompatible dictionary value type' NEED_ANNOTATION_FOR_VAR = 'Need type annotation for variable' ITERABLE_EXPECTED = 'Iterable expected' +ASYNC_ITERABLE_EXPECTED = 'AsyncIterable expected' INCOMPATIBLE_TYPES_IN_FOR = 'Incompatible types in for statement' INCOMPATIBLE_ARRAY_VAR_ARGS = 'Incompatible variable arguments in call' INVALID_SLICE_INDEX = 'Slice index must be an integer or None' diff --git a/mypy/nodes.py b/mypy/nodes.py index fe4da0d227d4..cc77c8b82c57 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -416,6 +416,7 @@ class FuncItem(FuncBase): # Is this an overload variant of function with more than one overload variant? is_overload = False is_generator = False # Contains a yield statement? + is_coroutine = False # Defined using 'async def' syntax? is_static = False # Uses @staticmethod? is_class = False # Uses @classmethod? # Variants of function with type variables with values expanded @@ -486,6 +487,7 @@ def serialize(self) -> JsonDict: 'is_property': self.is_property, 'is_overload': self.is_overload, 'is_generator': self.is_generator, + 'is_coroutine': self.is_coroutine, 'is_static': self.is_static, 'is_class': self.is_class, 'is_decorated': self.is_decorated, @@ -507,6 +509,7 @@ def deserialize(cls, data: JsonDict) -> 'FuncDef': ret.is_property = data['is_property'] ret.is_overload = data['is_overload'] ret.is_generator = data['is_generator'] + ret.is_coroutine = data['is_coroutine'] ret.is_static = data['is_static'] ret.is_class = data['is_class'] ret.is_decorated = data['is_decorated'] @@ -798,6 +801,7 @@ class ForStmt(Statement): expr = None # type: Expression body = None # type: Block else_body = None # type: Block + is_async = False # True if `async for ...` (PEP 492, Python 3.5) def __init__(self, index: Expression, expr: Expression, body: Block, else_body: Block) -> None: @@ -908,6 +912,7 @@ class WithStmt(Statement): expr = None # type: List[Expression] target = None # type: List[Expression] body = None # type: Block + is_async = False # True if `async with ...` (PEP 492, Python 3.5) def __init__(self, expr: List[Expression], target: List[Expression], body: Block) -> None: @@ -1705,6 +1710,18 @@ def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit__promote_expr(self) +class AwaitExpr(Node): + """Await expression (await ...).""" + + expr = None # type: Node + + def __init__(self, expr: Node) -> None: + self.expr = expr + + def accept(self, visitor: NodeVisitor[T]) -> T: + return visitor.visit_await_expr(self) + + # Constants diff --git a/mypy/parse.py b/mypy/parse.py index 3007902fde6f..d76ca55c75a3 100644 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -957,6 +957,10 @@ def parse_statement(self) -> Tuple[Node, bool]: stmt = self.parse_exec_stmt() else: stmt = self.parse_expression_or_assignment() + if ts == 'async' and self.current_str() == 'def': + self.parse_error_at(self.current(), + reason='Use --fast-parser to parse code using "async def"') + raise ParseError() if stmt is not None: stmt.set_line(t) return stmt, is_simple diff --git a/mypy/semanal.py b/mypy/semanal.py index 99cba07d0e4f..ce9281de02fe 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -62,16 +62,16 @@ ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, MroError, type_aliases, YieldFromExpr, NamedTupleExpr, NonlocalDecl, SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr, - YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, COVARIANT, CONTRAVARIANT, + YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr, IntExpr, FloatExpr, UnicodeExpr, - INVARIANT, UNBOUND_IMPORTED + COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED, ) from mypy.visitor import NodeVisitor from mypy.traverser import TraverserVisitor from mypy.errors import Errors, report_internal_error from mypy.types import ( NoneTyp, CallableType, Overloaded, Instance, Type, TypeVarType, AnyType, - FunctionLike, UnboundType, TypeList, ErrorType, TypeVarDef, + FunctionLike, UnboundType, TypeList, ErrorType, TypeVarDef, Void, replace_leading_arg_type, TupleType, UnionType, StarType, EllipsisType ) from mypy.nodes import function_type, implicit_module_attrs @@ -314,6 +314,13 @@ def visit_func_def(self, defn: FuncDef) -> None: # Second phase of analysis for function. self.errors.push_function(defn.name()) self.analyze_function(defn) + if defn.is_coroutine and isinstance(defn.type, CallableType): + # A coroutine defined as `async def foo(...) -> T: ...` + # has external return type `Awaitable[T]`. + defn.type = defn.type.copy_modified( + ret_type=Instance( + self.named_type_or_none('typing.Awaitable').type, + [defn.type.ret_type])) self.errors.pop_function() def prepare_method_signature(self, func: FuncDef) -> None: @@ -1815,7 +1822,10 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> None: if not self.is_func_scope(): # not sure self.fail("'yield from' outside function", e, True, blocker=True) else: - self.function_stack[-1].is_generator = True + if self.function_stack[-1].is_coroutine: + self.fail("'yield from' in async function", e, True, blocker=True) + else: + self.function_stack[-1].is_generator = True if e.expr: e.expr.accept(self) @@ -2068,10 +2078,20 @@ def visit_yield_expr(self, expr: YieldExpr) -> None: if not self.is_func_scope(): self.fail("'yield' outside function", expr, True, blocker=True) else: - self.function_stack[-1].is_generator = True + if self.function_stack[-1].is_coroutine: + self.fail("'yield' in async function", expr, True, blocker=True) + else: + self.function_stack[-1].is_generator = True if expr.expr: expr.expr.accept(self) + def visit_await_expr(self, expr: AwaitExpr) -> None: + if not self.is_func_scope(): + self.fail("'await' outside function", expr) + elif not self.function_stack[-1].is_coroutine: + self.fail("'await' outside coroutine ('async def')", expr) + expr.expr.accept(self) + # # Helpers # diff --git a/mypy/strconv.py b/mypy/strconv.py index 8d2c0845d70d..cb48f8ec045c 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -199,8 +199,10 @@ def visit_while_stmt(self, o): return self.dump(a, o) def visit_for_stmt(self, o): - a = [o.index] - a.extend([o.expr, o.body]) + a = [] + if o.is_async: + a.append(('Async', '')) + a.extend([o.index, o.expr, o.body]) if o.else_body: a.append(('Else', o.else_body.body)) return self.dump(a, o) @@ -243,6 +245,9 @@ def visit_yield_from_stmt(self, o): def visit_yield_expr(self, o): return self.dump([o.expr], o) + def visit_await_expr(self, o): + return self.dump([o.expr], o) + def visit_del_stmt(self, o): return self.dump([o.expr], o) @@ -264,6 +269,8 @@ def visit_try_stmt(self, o): def visit_with_stmt(self, o): a = [] + if o.is_async: + a.append(('Async', '')) for i in range(len(o.expr)): a.append(('Expr', [o.expr[i]])) if o.target[i]: diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 0136e5a9f147..ec054264c322 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -63,6 +63,7 @@ 'check-optional.test', 'check-fastparse.test', 'check-warnings.test', + 'check-async-await.test', ] diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 829b86dc4793..f05232586b14 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -19,7 +19,7 @@ ComparisonExpr, TempNode, StarExpr, YieldFromExpr, NamedTupleExpr, NonlocalDecl, SetComprehension, DictionaryComprehension, ComplexExpr, TypeAliasExpr, EllipsisExpr, - YieldExpr, ExecStmt, Argument, BackquoteExpr + YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, ) from mypy.types import Type, FunctionLike, Instance from mypy.visitor import NodeVisitor @@ -339,6 +339,9 @@ def visit_yield_from_expr(self, node: YieldFromExpr) -> Node: def visit_yield_expr(self, node: YieldExpr) -> Node: return YieldExpr(self.node(node.expr)) + def visit_await_expr(self, node: AwaitExpr) -> Node: + return AwaitExpr(self.node(node.expr)) + def visit_call_expr(self, node: CallExpr) -> Node: return CallExpr(self.node(node.callee), self.nodes(node.args), diff --git a/mypy/visitor.py b/mypy/visitor.py index b1e1b883a109..43e7c161ea6d 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -228,5 +228,8 @@ def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T: def visit__promote_expr(self, o: 'mypy.nodes.PromoteExpr') -> T: pass + def visit_await_expr(self, o: 'mypy.nodes.AwaitExpr') -> T: + pass + def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> T: pass diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test new file mode 100644 index 000000000000..581e35d99957 --- /dev/null +++ b/test-data/unit/check-async-await.test @@ -0,0 +1,295 @@ +-- Tests for async def and await (PEP 492) +-- --------------------------------------- + +[case testAsyncDefPass] +# options: fast_parser +async def f() -> int: + pass +[builtins fixtures/async_await.py] + +[case testAsyncDefReturn] +# options: fast_parser +async def f() -> int: + return 0 +reveal_type(f()) # E: Revealed type is 'typing.Awaitable[builtins.int]' +[builtins fixtures/async_await.py] + +[case testAwaitCoroutine] +# options: fast_parser +async def f() -> int: + x = await f() + reveal_type(x) # E: Revealed type is 'builtins.int*' + return x +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": + +[case testAwaitDefaultContext] +# options: fast_parser +from typing import TypeVar +T = TypeVar('T') +async def f(x: T) -> T: + y = await f(x) + reveal_type(y) + return y +[out] +main: note: In function "f": +main:6: error: Revealed type is 'T`-1' + +[case testAwaitAnyContext] +# options: fast_parser +from typing import Any, TypeVar +T = TypeVar('T') +async def f(x: T) -> T: + y = await f(x) # type: Any + reveal_type(y) + return y +[out] +main: note: In function "f": +main:6: error: Revealed type is 'Any' + +[case testAwaitExplicitContext] +# options: fast_parser +from typing import TypeVar +T = TypeVar('T') +async def f(x: T) -> T: + y = await f(x) # type: int + reveal_type(y) +[out] +main: note: In function "f": +main:5: error: Argument 1 to "f" has incompatible type "T"; expected "int" +main:6: error: Revealed type is 'builtins.int' + +[case testAwaitGeneratorError] +# options: fast_parser +from typing import Any, Generator +def g() -> Generator[int, None, str]: + yield 0 + return '' +async def f() -> int: + x = await g() + return x +[out] +main: note: In function "f": +main:7: error: Incompatible types in await (actual type Generator[int, None, str], expected type "Awaitable") + +[case testAwaitIteratorError] +# options: fast_parser +from typing import Any, Iterator +def g() -> Iterator[Any]: + yield +async def f() -> int: + x = await g() + return x +[out] +main: note: In function "f": +main:6: error: Incompatible types in await (actual type Iterator[Any], expected type "Awaitable") + +[case testAwaitArgumentError] +# options: fast_parser +def g() -> int: + return 0 +async def f() -> int: + x = await g() + return x +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main:5: error: Incompatible types in await (actual type "int", expected type "Awaitable") + +[case testAwaitResultError] +# options: fast_parser +async def g() -> int: + return 0 +async def f() -> str: + x = await g() # type: str +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main:5: error: Incompatible types in assignment (expression has type "int", variable has type "str") + +[case testAwaitReturnError] +# options: fast_parser +async def g() -> int: + return 0 +async def f() -> str: + x = await g() + return x +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main:6: error: Incompatible return value type (got "int", expected "str") + +[case testAsyncFor] +# options: fast_parser +from typing import AsyncIterator +class C(AsyncIterator[int]): + async def __anext__(self) -> int: return 0 +async def f() -> None: + async for x in C(): + reveal_type(x) # E: Revealed type is 'builtins.int*' +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": + +[case testAsyncForError] +# options: fast_parser +from typing import AsyncIterator +async def f() -> None: + async for x in [1]: + pass +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main:4: error: AsyncIterable expected +main:4: error: List[int] has no attribute "__aiter__" + +[case testAsyncWith] +# options: fast_parser +class C: + async def __aenter__(self) -> int: pass + async def __aexit__(self, x, y, z) -> None: pass +async def f() -> None: + async with C() as x: + reveal_type(x) # E: Revealed type is 'builtins.int*' +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": + +[case testAsyncWithError] +# options: fast_parser +class C: + def __enter__(self) -> int: pass + def __exit__(self, x, y, z) -> None: pass +async def f() -> None: + async with C() as x: + pass +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main:6: error: "C" has no attribute "__aenter__"; maybe "__enter__"? +main:6: error: "C" has no attribute "__aexit__"; maybe "__exit__"? + +[case testAsyncWithErrorBadAenter] +# options: fast_parser +class C: + def __aenter__(self) -> int: pass + async def __aexit__(self, x, y, z) -> None: pass +async def f() -> None: + async with C() as x: # E: Incompatible types in "async with" for __aenter__ (actual type "int", expected type "Awaitable") + pass +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": + +[case testAsyncWithErrorBadAenter2] +# options: fast_parser +class C: + def __aenter__(self) -> None: pass + async def __aexit__(self, x, y, z) -> None: pass +async def f() -> None: + async with C() as x: # E: "__aenter__" of "C" does not return a value + pass +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": + +[case testAsyncWithErrorBadAexit] +# options: fast_parser +class C: + async def __aenter__(self) -> int: pass + def __aexit__(self, x, y, z) -> int: pass +async def f() -> None: + async with C() as x: # E: Incompatible types in "async with" for __aexit__ (actual type "int", expected type "Awaitable") + pass +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": + +[case testAsyncWithErrorBadAexit2] +# options: fast_parser +class C: + async def __aenter__(self) -> int: pass + def __aexit__(self, x, y, z) -> None: pass +async def f() -> None: + async with C() as x: # E: "__aexit__" of "C" does not return a value + pass +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": + +[case testNoYieldInAsyncDef] +# options: fast_parser +async def f(): + yield None +async def g(): + yield +async def h(): + x = yield +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main:3: error: 'yield' in async function +main: note: In function "g": +main:5: error: 'yield' in async function +main: note: In function "h": +main:7: error: 'yield' in async function + +[case testNoYieldFromInAsyncDef] +# options: fast_parser +async def f(): + yield from [] +async def g(): + x = yield from [] +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main:3: error: 'yield from' in async function +main: note: In function "g": +main:5: error: 'yield from' in async function + +[case testNoAsyncDefInPY2_python2] +# options: fast_parser +async def f(): # E: invalid syntax + pass + +[case testYieldFromNoAwaitable] +# options: fast_parser +from typing import Any, Generator +async def f() -> str: + return '' +def g() -> Generator[Any, None, str]: + x = yield from f() + return x +[builtins fixtures/async_await.py] +[out] +main: note: In function "g": +main:6: error: "yield from" can't be applied to Awaitable[str] + +[case testAwaitableSubclass] +# options: fast_parser +from typing import Any, AsyncIterator, Awaitable, Generator +class A(Awaitable[int]): + def __await__(self) -> Generator[Any, None, int]: + yield + return 0 +class C: + def __aenter__(self) -> A: + return A() + def __aexit__(self, *a) -> A: + return A() +class I(AsyncIterator[int]): + def __aiter__(self) -> 'I': + return self + def __anext__(self) -> A: + return A() +async def main() -> None: + x = await A() + reveal_type(x) # E: Revealed type is 'builtins.int' + async with C() as y: + reveal_type(y) # E: Revealed type is 'builtins.int' + async for z in I(): + reveal_type(z) # E: Revealed type is 'builtins.int' +[builtins fixtures/async_await.py] +[out] +main: note: In function "main": diff --git a/test-data/unit/fixtures/async_await.py b/test-data/unit/fixtures/async_await.py new file mode 100644 index 000000000000..7a166a07294c --- /dev/null +++ b/test-data/unit/fixtures/async_await.py @@ -0,0 +1,9 @@ +import typing +class object: + def __init__(self): pass +class type: pass +class function: pass +class int: pass +class str: pass +class list: pass +class tuple: pass diff --git a/test-data/unit/lib-stub/typing.py b/test-data/unit/lib-stub/typing.py index 09c76a7eb1bf..3e539f1f5e02 100644 --- a/test-data/unit/lib-stub/typing.py +++ b/test-data/unit/lib-stub/typing.py @@ -57,6 +57,19 @@ def close(self) -> None: pass @abstractmethod def __iter__(self) -> 'Generator[T, U, V]': pass +class Awaitable(Generic[T]): + @abstractmethod + def __await__(self) -> Generator[Any, Any, T]: pass + +class AsyncIterable(Generic[T]): + @abstractmethod + def __aiter__(self) -> 'AsyncIterator[T]': pass + +class AsyncIterator(AsyncIterable[T], Generic[T]): + def __aiter__(self) -> 'AsyncIterator[T]': return self + @abstractmethod + def __anext__(self) -> Awaitable[T]: pass + class Sequence(Generic[T]): @abstractmethod def __getitem__(self, n: Any) -> T: pass diff --git a/test-data/unit/python2eval.test b/test-data/unit/python2eval.test index 4f9b633efc4d..944ce614eefa 100644 --- a/test-data/unit/python2eval.test +++ b/test-data/unit/python2eval.test @@ -440,3 +440,11 @@ re.subn(upat, u'', u'')[0] + u'' re.subn(ure, lambda m: u'', u'')[0] + u'' re.subn(upat, lambda m: u'', u'')[0] + u'' [out] + +[case testYieldRegressionTypingAwaitable_python2] +# Make sure we don't reference typing.Awaitable in Python 2 mode. +def g() -> int: + yield +[out] +_program.py: note: In function "g": +_program.py:2: error: The return type of a generator function should be "Generator" or one of its supertypes