From a8fbc10576cd783f6e29ec20ce986025a7c0c6f9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 11 May 2019 20:39:22 +0200 Subject: [PATCH 1/4] Add support for overloading coroutines. When applying @overload to @coroutine, update the return type of the overload to AwaitableGenerator like the underlying coroutines. Fix #6802. --- mypy/checker.py | 38 ++++++++++++++++++--------- mypy/message_registry.py | 2 ++ test-data/unit/check-overloading.test | 15 +++++++++++ 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 6d0d6e5bc86b..5c29113a1021 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -430,6 +430,7 @@ def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: num_abstract = 0 + num_awaitable_coroutine = 0 if not defn.items: # In this case we have already complained about none of these being # valid overloads. @@ -445,8 +446,19 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: self.check_func_item(fdef.func, name=fdef.func.name()) if fdef.func.is_abstract: num_abstract += 1 + if fdef.func.is_awaitable_coroutine: + num_awaitable_coroutine += 1 if num_abstract not in (0, len(defn.items)): self.fail(message_registry.INCONSISTENT_ABSTRACT_OVERLOAD, defn) + if num_awaitable_coroutine not in (0, len(defn.items)): + self.fail(message_registry.INCONSISTENT_COROUTINE_OVERLOAD, defn) + # If items contains coroutines and check_func_item fixed their type, + # also fix the overload type. + if num_awaitable_coroutine: + defn.type = Overloaded([ + self.get_awaitable_coroutine_return_type(fdef.func, typ) + for fdef, typ in zip(defn.items, defn.type.items()) + ]) if defn.impl: defn.impl.accept(self) if defn.info: @@ -853,18 +865,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) if defn.is_awaitable_coroutine: # Update the return type to AwaitableGenerator. # (This doesn't exist in typing.py, only in typing.pyi.) - t = typ.ret_type - c = defn.is_coroutine - ty = self.get_generator_yield_type(t, c) - tc = self.get_generator_receive_type(t, c) - if c: - tr = self.get_coroutine_return_type(t) - else: - tr = self.get_generator_return_type(t, c) - ret_type = self.named_generic_type('typing.AwaitableGenerator', - [ty, tc, tr, t]) - typ = typ.copy_modified(ret_type=ret_type) - defn.type = typ + defn.type = self.get_awaitable_coroutine_return_type(defn, typ) # Push return type. self.return_types.append(typ.ret_type) @@ -963,6 +964,19 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) self.binder = old_binder + def get_awaitable_coroutine_return_type(self, defn: FuncItem, typ: CallableType): + t = typ.ret_type + c = defn.is_coroutine + ty = self.get_generator_yield_type(t, c) + tc = self.get_generator_receive_type(t, c) + if c: + tr = self.get_coroutine_return_type(t) + else: + tr = self.get_generator_return_type(t, c) + ret_type = self.named_generic_type('typing.AwaitableGenerator', + [ty, tc, tr, t]) + return typ.copy_modified(ret_type=ret_type) + def is_forward_op_method(self, method_name: str) -> bool: if self.options.python_version[0] == 2 and method_name == '__div__': return True diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 311e06e2a3ae..16e4da3a56a7 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -62,6 +62,8 @@ CANNOT_ASSIGN_TO_TYPE = 'Cannot assign to a type' # type: Final INCONSISTENT_ABSTRACT_OVERLOAD = \ 'Overloaded method has both abstract and non-abstract variants' # type: Final +INCONSISTENT_COROUTINE_OVERLOAD = \ + 'Overloaded method has both coroutine and non-coroutine variants' # type: Final MULTIPLE_OVERLOADS_REQUIRED = 'Single overload definition, multiple required' # type: Final READ_ONLY_PROPERTY_OVERRIDES_READ_WRITE = \ 'Read-only property cannot override read-write property' # type: Final diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 22c45f75a0a2..364e47d535b3 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5001,3 +5001,18 @@ def f(x): reveal_type(f(g([]))) # E: Revealed type is 'builtins.list[builtins.int]' [builtins fixtures/list.pyi] + +[case testTypeCheckOverloadCoroutine] +from asyncio import coroutine +from typing import overload +@overload +@coroutine +def f(x: int) -> None: ... +@overload +@coroutine +def f(x: str) -> None: ... +def f(x): pass + +reveal_type(f) # E: Revealed type is 'Overload(def (x: builtins.int) -> typing.AwaitableGenerator[Any, Any, Any, None], def (x: builtins.str) -> typing.AwaitableGenerator[Any, Any, Any, None])' +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] From 75bf1a1a627155b74ccda2ab86e275c81a84314b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 11 May 2019 21:10:40 +0200 Subject: [PATCH 2/4] Fix typing. --- mypy/checker.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 5c29113a1021..4f936d978d0d 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -455,10 +455,13 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: # If items contains coroutines and check_func_item fixed their type, # also fix the overload type. if num_awaitable_coroutine: - defn.type = Overloaded([ - self.get_awaitable_coroutine_return_type(fdef.func, typ) - for fdef, typ in zip(defn.items, defn.type.items()) - ]) + assert isinstance(defn.type, Overloaded) + types = [] + for fdef, typ in zip(defn.items, defn.type.items()): + assert isinstance(fdef, Decorator) + types.append( + self.get_awaitable_coroutine_return_type(fdef.func, typ)) + defn.type = Overloaded(types) if defn.impl: defn.impl.accept(self) if defn.info: @@ -964,7 +967,9 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) self.binder = old_binder - def get_awaitable_coroutine_return_type(self, defn: FuncItem, typ: CallableType): + def get_awaitable_coroutine_return_type(self, + defn: FuncItem, + typ: CallableType) -> CallableType: t = typ.ret_type c = defn.is_coroutine ty = self.get_generator_yield_type(t, c) From f780fe66d968e7a7e7f70c9a4cf4263934dd3645 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 11 May 2019 21:23:06 +0200 Subject: [PATCH 3/4] Use the same imports as other tests. --- test-data/unit/check-overloading.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 364e47d535b3..800017b47f60 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5003,7 +5003,7 @@ reveal_type(f(g([]))) # E: Revealed type is 'builtins.list[builtins.int]' [builtins fixtures/list.pyi] [case testTypeCheckOverloadCoroutine] -from asyncio import coroutine +from types import coroutine from typing import overload @overload @coroutine From 9a3790e9de8e5b181ddc5f231ecd07bad3b1d980 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 11 May 2019 21:27:37 +0200 Subject: [PATCH 4/4] Fix refactoring error. --- mypy/checker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 4f936d978d0d..06667bcb1e30 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -868,7 +868,8 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) if defn.is_awaitable_coroutine: # Update the return type to AwaitableGenerator. # (This doesn't exist in typing.py, only in typing.pyi.) - defn.type = self.get_awaitable_coroutine_return_type(defn, typ) + typ = self.get_awaitable_coroutine_return_type(defn, typ) + defn.type = typ # Push return type. self.return_types.append(typ.ret_type)