Skip to content

Commit ccdbeab

Browse files
committed
Merge pull request #1128 from sakana/master
Generalize yield statement function type
2 parents fa98a4e + fe81772 commit ccdbeab

File tree

5 files changed

+46
-6
lines changed

5 files changed

+46
-6
lines changed

mypy/checker.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,10 +1458,18 @@ def count_nested_types(self, typ: Instance, check_type: str) -> int:
14581458
def visit_yield_stmt(self, s: YieldStmt) -> Type:
14591459
return_type = self.return_types[-1]
14601460
if isinstance(return_type, Instance):
1461-
if return_type.type.fullname() != 'typing.Iterator':
1461+
if not is_subtype(self.named_generic_type('typing.Generator',
1462+
[AnyType(), AnyType(), AnyType()]),
1463+
return_type):
14621464
self.fail(messages.INVALID_RETURN_TYPE_FOR_YIELD, s)
14631465
return None
1464-
expected_item_type = return_type.args[0]
1466+
if return_type.args:
1467+
expected_item_type = return_type.args[0]
1468+
else:
1469+
# if the declared supertype of `Generator` has no type
1470+
# parameters (i.e. is `object`), then the yielded values can't
1471+
# be accessed so any type is acceptable.
1472+
expected_item_type = AnyType()
14651473
elif isinstance(return_type, AnyType):
14661474
expected_item_type = AnyType()
14671475
else:

mypy/messages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
INVALID_EXCEPTION = 'Exception must be derived from BaseException'
3131
INVALID_EXCEPTION_TYPE = 'Exception type must be derived from BaseException'
3232
INVALID_RETURN_TYPE_FOR_YIELD = \
33-
'Iterator function return type expected for "yield"'
33+
'The return type of a generator function should be "Generator" or one of its supertypes'
3434
INVALID_RETURN_TYPE_FOR_YIELD_FROM = \
3535
'Iterable function return type expected for "yield from"'
3636
INCOMPATIBLE_TYPES = 'Incompatible types'

mypy/test/data/check-statements.test

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,26 @@ def f() -> Iterator[int]:
647647
[out]
648648
main: note: In function "f":
649649

650+
[case testYieldInFunctionReturningGenerator]
651+
from typing import Generator
652+
def f() -> Generator[int, None, None]:
653+
yield 1
654+
[builtins fixtures/for.py]
655+
[out]
656+
657+
[case testYieldInFunctionReturningIterable]
658+
from typing import Iterable
659+
def f() -> Iterable[int]:
660+
yield 1
661+
[builtins fixtures/for.py]
662+
[out]
663+
664+
[case testYieldInFunctionReturningObject]
665+
def f() -> object:
666+
yield 1
667+
[builtins fixtures/for.py]
668+
[out]
669+
650670
[case testYieldInFunctionReturningAny]
651671
from typing import Any
652672
def f() -> Any:
@@ -656,7 +676,7 @@ def f() -> Any:
656676
[case testYieldInFunctionReturningFunction]
657677
from typing import Callable
658678
def f() -> Callable[[], None]:
659-
yield object() # E: Iterator function return type expected for "yield"
679+
yield object() # E: The return type of a generator function should be "Generator" or one of its supertypes
660680
[out]
661681
main: note: In function "f":
662682

@@ -668,7 +688,7 @@ def f():
668688
[case testWithInvalidInstanceReturnType]
669689
import typing
670690
def f() -> int:
671-
yield 1 # E: Iterator function return type expected for "yield"
691+
yield 1 # E: The return type of a generator function should be "Generator" or one of its supertypes
672692
[builtins fixtures/for.py]
673693
[out]
674694
main: note: In function "f":

mypy/test/data/fixtures/for.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# builtins stub used in for statement test cases
22

3-
from typing import TypeVar, Generic, Iterable, Iterator
3+
from typing import TypeVar, Generic, Iterable, Iterator, Generator
44
from abc import abstractmethod, ABCMeta
55

66
t = TypeVar('t')

mypy/test/data/lib-stub/typing.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
Set = 0
2424

2525
T = TypeVar('T')
26+
U = TypeVar('U')
27+
V = TypeVar('V')
2628

2729
class Container(Generic[T]):
2830
@abstractmethod
@@ -41,6 +43,16 @@ class Iterator(Iterable[T], Generic[T]):
4143
@abstractmethod
4244
def __next__(self) -> T: pass
4345

46+
class Generator(Iterator[T], Generic[T, U, V]):
47+
@abstractmethod
48+
def send(self, value: U) -> T: pass
49+
50+
@abstractmethod
51+
def throw(self, typ: Any, val: Any=None, tb=None) -> None: pass
52+
53+
@abstractmethod
54+
def close(self) -> None: pass
55+
4456
class Sequence(Generic[T]):
4557
@abstractmethod
4658
def __getitem__(self, n: Any) -> T: pass

0 commit comments

Comments
 (0)