diff --git a/mypy/checker.py b/mypy/checker.py index 6d0d6e5bc86b..06667bcb1e30 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -430,6 +430,7 @@ def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: num_abstract = 0 + num_awaitable_coroutine = 0 if not defn.items: # In this case we have already complained about none of these being # valid overloads. @@ -445,8 +446,22 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: self.check_func_item(fdef.func, name=fdef.func.name()) if fdef.func.is_abstract: num_abstract += 1 + if fdef.func.is_awaitable_coroutine: + num_awaitable_coroutine += 1 if num_abstract not in (0, len(defn.items)): self.fail(message_registry.INCONSISTENT_ABSTRACT_OVERLOAD, defn) + if num_awaitable_coroutine not in (0, len(defn.items)): + self.fail(message_registry.INCONSISTENT_COROUTINE_OVERLOAD, defn) + # If items contains coroutines and check_func_item fixed their type, + # also fix the overload type. + if num_awaitable_coroutine: + assert isinstance(defn.type, Overloaded) + types = [] + for fdef, typ in zip(defn.items, defn.type.items()): + assert isinstance(fdef, Decorator) + types.append( + self.get_awaitable_coroutine_return_type(fdef.func, typ)) + defn.type = Overloaded(types) if defn.impl: defn.impl.accept(self) if defn.info: @@ -853,17 +868,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) if defn.is_awaitable_coroutine: # Update the return type to AwaitableGenerator. # (This doesn't exist in typing.py, only in typing.pyi.) - t = typ.ret_type - c = defn.is_coroutine - ty = self.get_generator_yield_type(t, c) - tc = self.get_generator_receive_type(t, c) - if c: - tr = self.get_coroutine_return_type(t) - else: - tr = self.get_generator_return_type(t, c) - ret_type = self.named_generic_type('typing.AwaitableGenerator', - [ty, tc, tr, t]) - typ = typ.copy_modified(ret_type=ret_type) + typ = self.get_awaitable_coroutine_return_type(defn, typ) defn.type = typ # Push return type. @@ -963,6 +968,21 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) self.binder = old_binder + def get_awaitable_coroutine_return_type(self, + defn: FuncItem, + typ: CallableType) -> CallableType: + t = typ.ret_type + c = defn.is_coroutine + ty = self.get_generator_yield_type(t, c) + tc = self.get_generator_receive_type(t, c) + if c: + tr = self.get_coroutine_return_type(t) + else: + tr = self.get_generator_return_type(t, c) + ret_type = self.named_generic_type('typing.AwaitableGenerator', + [ty, tc, tr, t]) + return typ.copy_modified(ret_type=ret_type) + def is_forward_op_method(self, method_name: str) -> bool: if self.options.python_version[0] == 2 and method_name == '__div__': return True diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 311e06e2a3ae..16e4da3a56a7 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -62,6 +62,8 @@ CANNOT_ASSIGN_TO_TYPE = 'Cannot assign to a type' # type: Final INCONSISTENT_ABSTRACT_OVERLOAD = \ 'Overloaded method has both abstract and non-abstract variants' # type: Final +INCONSISTENT_COROUTINE_OVERLOAD = \ + 'Overloaded method has both coroutine and non-coroutine variants' # type: Final MULTIPLE_OVERLOADS_REQUIRED = 'Single overload definition, multiple required' # type: Final READ_ONLY_PROPERTY_OVERRIDES_READ_WRITE = \ 'Read-only property cannot override read-write property' # type: Final diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 22c45f75a0a2..800017b47f60 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5001,3 +5001,18 @@ def f(x): reveal_type(f(g([]))) # E: Revealed type is 'builtins.list[builtins.int]' [builtins fixtures/list.pyi] + +[case testTypeCheckOverloadCoroutine] +from types import coroutine +from typing import overload +@overload +@coroutine +def f(x: int) -> None: ... +@overload +@coroutine +def f(x: str) -> None: ... +def f(x): pass + +reveal_type(f) # E: Revealed type is 'Overload(def (x: builtins.int) -> typing.AwaitableGenerator[Any, Any, Any, None], def (x: builtins.str) -> typing.AwaitableGenerator[Any, Any, Any, None])' +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi]