Skip to content

Commit 4a5356c

Browse files
JelleZijlstragvanrossum
authored andcommitted
Fix support for async for comprehensions (#2926)
Fixes #2924
1 parent 311d690 commit 4a5356c

File tree

8 files changed

+121
-18
lines changed

8 files changed

+121
-18
lines changed

mypy/checkexpr.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,8 +1888,13 @@ def visit_set_comprehension(self, e: SetComprehension) -> Type:
18881888
e.generator, 'builtins.set', '<set-comprehension>')
18891889

18901890
def visit_generator_expr(self, e: GeneratorExpr) -> Type:
1891-
return self.check_generator_or_comprehension(e, 'typing.Iterator',
1892-
'<generator>')
1891+
# If any of the comprehensions use async for, the expression will return an async generator
1892+
# object
1893+
if any(e.is_async):
1894+
typ = 'typing.AsyncIterator'
1895+
else:
1896+
typ = 'typing.Iterator'
1897+
return self.check_generator_or_comprehension(e, typ, '<generator>')
18931898

18941899
def check_generator_or_comprehension(self, gen: GeneratorExpr,
18951900
type_name: str,
@@ -1941,9 +1946,12 @@ def check_for_comp(self, e: Union[GeneratorExpr, DictionaryComprehension]) -> No
19411946
19421947
Note: This adds the type information derived from the condlists to the current binder.
19431948
"""
1944-
for index, sequence, conditions in zip(e.indices, e.sequences,
1945-
e.condlists):
1946-
sequence_type = self.chk.analyze_iterable_item_type(sequence)
1949+
for index, sequence, conditions, is_async in zip(e.indices, e.sequences,
1950+
e.condlists, e.is_async):
1951+
if is_async:
1952+
sequence_type = self.chk.analyze_async_iterable_item_type(sequence)
1953+
else:
1954+
sequence_type = self.chk.analyze_iterable_item_type(sequence)
19471955
self.chk.analyze_index_variables(index, sequence_type, True, e)
19481956
for condition in conditions:
19491957
self.accept(condition)

mypy/fastparse.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -774,22 +774,26 @@ def visit_DictComp(self, n: ast3.DictComp) -> DictionaryComprehension:
774774
targets = [self.visit(c.target) for c in n.generators]
775775
iters = [self.visit(c.iter) for c in n.generators]
776776
ifs_list = [self.translate_expr_list(c.ifs) for c in n.generators]
777+
is_async = [bool(c.is_async) for c in n.generators]
777778
return DictionaryComprehension(self.visit(n.key),
778779
self.visit(n.value),
779780
targets,
780781
iters,
781-
ifs_list)
782+
ifs_list,
783+
is_async)
782784

783785
# GeneratorExp(expr elt, comprehension* generators)
784786
@with_line
785787
def visit_GeneratorExp(self, n: ast3.GeneratorExp) -> GeneratorExpr:
786788
targets = [self.visit(c.target) for c in n.generators]
787789
iters = [self.visit(c.iter) for c in n.generators]
788790
ifs_list = [self.translate_expr_list(c.ifs) for c in n.generators]
791+
is_async = [bool(c.is_async) for c in n.generators]
789792
return GeneratorExpr(self.visit(n.elt),
790793
targets,
791794
iters,
792-
ifs_list)
795+
ifs_list,
796+
is_async)
793797

794798
# Await(expr value)
795799
@with_line

mypy/fastparse2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,8 @@ def visit_DictComp(self, n: ast27.DictComp) -> DictionaryComprehension:
776776
self.visit(n.value),
777777
targets,
778778
iters,
779-
ifs_list)
779+
ifs_list,
780+
[False for _ in n.generators])
780781

781782
# GeneratorExp(expr elt, comprehension* generators)
782783
@with_line
@@ -787,7 +788,8 @@ def visit_GeneratorExp(self, n: ast27.GeneratorExp) -> GeneratorExpr:
787788
return GeneratorExpr(self.visit(n.elt),
788789
targets,
789790
iters,
790-
ifs_list)
791+
ifs_list,
792+
[False for _ in n.generators])
791793

792794
# Yield(expr? value)
793795
@with_line

mypy/nodes.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,14 +1585,17 @@ class GeneratorExpr(Expression):
15851585
left_expr = None # type: Expression
15861586
sequences = None # type: List[Expression]
15871587
condlists = None # type: List[List[Expression]]
1588+
is_async = None # type: List[bool]
15881589
indices = None # type: List[Lvalue]
15891590

15901591
def __init__(self, left_expr: Expression, indices: List[Lvalue],
1591-
sequences: List[Expression], condlists: List[List[Expression]]) -> None:
1592+
sequences: List[Expression], condlists: List[List[Expression]],
1593+
is_async: List[bool]) -> None:
15921594
self.left_expr = left_expr
15931595
self.sequences = sequences
15941596
self.condlists = condlists
15951597
self.indices = indices
1598+
self.is_async = is_async
15961599

15971600
def accept(self, visitor: ExpressionVisitor[T]) -> T:
15981601
return visitor.visit_generator_expr(self)
@@ -1629,15 +1632,18 @@ class DictionaryComprehension(Expression):
16291632
value = None # type: Expression
16301633
sequences = None # type: List[Expression]
16311634
condlists = None # type: List[List[Expression]]
1635+
is_async = None # type: List[bool]
16321636
indices = None # type: List[Lvalue]
16331637

16341638
def __init__(self, key: Expression, value: Expression, indices: List[Lvalue],
1635-
sequences: List[Expression], condlists: List[List[Expression]]) -> None:
1639+
sequences: List[Expression], condlists: List[List[Expression]],
1640+
is_async: List[bool]) -> None:
16361641
self.key = key
16371642
self.value = value
16381643
self.sequences = sequences
16391644
self.condlists = condlists
16401645
self.indices = indices
1646+
self.is_async = is_async
16411647

16421648
def accept(self, visitor: ExpressionVisitor[T]) -> T:
16431649
return visitor.visit_dictionary_comprehension(self)

mypy/parse.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,7 +1472,7 @@ def parse_generator_expr(self, left_expr: Expression) -> GeneratorExpr:
14721472
tok = self.current()
14731473
indices, sequences, condlists = self.parse_comp_for()
14741474

1475-
gen = GeneratorExpr(left_expr, indices, sequences, condlists)
1475+
gen = GeneratorExpr(left_expr, indices, sequences, condlists, [False for _ in indices])
14761476
gen.set_line(tok)
14771477
return gen
14781478

@@ -1560,7 +1560,8 @@ def parse_set_comprehension(self, expr: Expression) -> SetComprehension:
15601560
def parse_dict_comprehension(self, key: Expression, value: Expression,
15611561
colon: Token) -> DictionaryComprehension:
15621562
indices, sequences, condlists = self.parse_comp_for()
1563-
dic = DictionaryComprehension(key, value, indices, sequences, condlists)
1563+
dic = DictionaryComprehension(key, value, indices, sequences, condlists,
1564+
[False for _ in indices])
15641565
dic.set_line(colon)
15651566
self.expect('}')
15661567
return dic

mypy/treetransform.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -439,10 +439,11 @@ def visit_set_comprehension(self, node: SetComprehension) -> SetComprehension:
439439
def visit_dictionary_comprehension(self, node: DictionaryComprehension
440440
) -> DictionaryComprehension:
441441
return DictionaryComprehension(self.expr(node.key), self.expr(node.value),
442-
[self.expr(index) for index in node.indices],
443-
[self.expr(s) for s in node.sequences],
444-
[[self.expr(cond) for cond in conditions]
445-
for conditions in node.condlists])
442+
[self.expr(index) for index in node.indices],
443+
[self.expr(s) for s in node.sequences],
444+
[[self.expr(cond) for cond in conditions]
445+
for conditions in node.condlists],
446+
node.is_async)
446447

447448
def visit_generator_expr(self, node: GeneratorExpr) -> GeneratorExpr:
448449
return self.duplicate_generator(node)
@@ -452,7 +453,8 @@ def duplicate_generator(self, node: GeneratorExpr) -> GeneratorExpr:
452453
[self.expr(index) for index in node.indices],
453454
[self.expr(s) for s in node.sequences],
454455
[[self.expr(cond) for cond in conditions]
455-
for conditions in node.condlists])
456+
for conditions in node.condlists],
457+
node.is_async)
456458

457459
def visit_slice_expr(self, node: SliceExpr) -> SliceExpr:
458460
return SliceExpr(self.optional_expr(node.begin_index),

test-data/unit/check-async-await.test

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,79 @@ async def f() -> None:
168168
reveal_type(z) # E: Revealed type is 'Union[builtins.int, builtins.str]'
169169
[builtins fixtures/async_await.pyi]
170170

171+
[case testAsyncForComprehension]
172+
# flags: --fast-parser --python-version 3.6
173+
from typing import Generic, Iterable, TypeVar, AsyncIterator, Tuple
174+
175+
T = TypeVar('T')
176+
177+
class asyncify(Generic[T], AsyncIterator[T]):
178+
def __init__(self, iterable: Iterable[T]) -> None:
179+
self.iterable = iter(iterable)
180+
def __aiter__(self) -> AsyncIterator[T]:
181+
return self
182+
async def __anext__(self) -> T:
183+
try:
184+
return next(self.iterable)
185+
except StopIteration:
186+
raise StopAsyncIteration
187+
188+
async def listcomp(obj: Iterable[int]):
189+
lst = [i async for i in asyncify(obj)]
190+
reveal_type(lst) # E: Revealed type is 'builtins.list[builtins.int*]'
191+
lst2 = [i async for i in asyncify(obj) for j in obj]
192+
reveal_type(lst2) # E: Revealed type is 'builtins.list[builtins.int*]'
193+
194+
async def setcomp(obj: Iterable[int]):
195+
lst = {i async for i in asyncify(obj)}
196+
reveal_type(lst) # E: Revealed type is 'builtins.set[builtins.int*]'
197+
198+
async def dictcomp(obj: Iterable[Tuple[int, str]]):
199+
lst = {i: j async for i, j in asyncify(obj)}
200+
reveal_type(lst) # E: Revealed type is 'builtins.dict[builtins.int*, builtins.str*]'
201+
202+
async def generatorexp(obj: Iterable[int]):
203+
lst = (i async for i in asyncify(obj))
204+
reveal_type(lst) # E: Revealed type is 'typing.AsyncIterator[builtins.int*]'
205+
lst2 = (i async for i in asyncify(obj) for i in obj)
206+
reveal_type(lst2) # E: Revealed type is 'typing.AsyncIterator[builtins.int*]'
207+
208+
[builtins fixtures/async_await.pyi]
209+
210+
[case testAsyncForComprehensionErrors]
211+
# flags: --fast-parser --python-version 3.6
212+
from typing import Generic, Iterable, TypeVar, AsyncIterator, Tuple
213+
214+
T = TypeVar('T')
215+
216+
class asyncify(Generic[T], AsyncIterator[T]):
217+
def __init__(self, iterable: Iterable[T]) -> None:
218+
self.iterable = iter(iterable)
219+
def __aiter__(self) -> AsyncIterator[T]:
220+
return self
221+
async def __anext__(self) -> T:
222+
try:
223+
return next(self.iterable)
224+
except StopIteration:
225+
raise StopAsyncIteration
226+
227+
async def wrong_iterable(obj: Iterable[int]):
228+
[i async for i in obj]
229+
[i for i in asyncify(obj)]
230+
{i: i async for i in obj}
231+
{i: i for i in asyncify(obj)}
232+
233+
[out]
234+
main:18: error: AsyncIterable expected
235+
main:18: error: Iterable[int] has no attribute "__aiter__"; maybe "__iter__"?
236+
main:19: error: Iterable expected
237+
main:19: error: asyncify[int] has no attribute "__iter__"; maybe "__aiter__"?
238+
main:20: error: AsyncIterable expected
239+
main:20: error: Iterable[int] has no attribute "__aiter__"; maybe "__iter__"?
240+
main:21: error: Iterable expected
241+
main:21: error: asyncify[int] has no attribute "__iter__"; maybe "__aiter__"?
242+
[builtins fixtures/async_await.pyi]
243+
171244
[case testAsyncWith]
172245
# flags: --fast-parser
173246
class C:

test-data/unit/fixtures/async_await.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,12 @@ class type: pass
55
class function: pass
66
class int: pass
77
class str: pass
8+
class dict: pass
89
class list: pass
10+
class set: pass
911
class tuple: pass
12+
class BaseException: pass
13+
class StopIteration(BaseException): pass
14+
class StopAsyncIteration(BaseException): pass
15+
def iter(obj: typing.Any) -> typing.Any: pass
16+
def next(obj: typing.Any) -> typing.Any: pass

0 commit comments

Comments
 (0)