Skip to content

Fix support for async for comprehensions #2926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1885,8 +1885,13 @@ def visit_set_comprehension(self, e: SetComprehension) -> Type:
e.generator, 'builtins.set', '<set-comprehension>')

def visit_generator_expr(self, e: GeneratorExpr) -> Type:
return self.check_generator_or_comprehension(e, 'typing.Iterator',
'<generator>')
# If any of the comprehensions use async for, the expression will return an async generator
# object
if any(e.is_async):
typ = 'typing.AsyncIterator'
else:
typ = 'typing.Iterator'
return self.check_generator_or_comprehension(e, typ, '<generator>')

def check_generator_or_comprehension(self, gen: GeneratorExpr,
type_name: str,
Expand Down Expand Up @@ -1938,9 +1943,12 @@ def check_for_comp(self, e: Union[GeneratorExpr, DictionaryComprehension]) -> No

Note: This adds the type information derived from the condlists to the current binder.
"""
for index, sequence, conditions in zip(e.indices, e.sequences,
e.condlists):
sequence_type = self.chk.analyze_iterable_item_type(sequence)
for index, sequence, conditions, is_async in zip(e.indices, e.sequences,
e.condlists, e.is_async):
if is_async:
sequence_type = self.chk.analyze_async_iterable_item_type(sequence)
else:
sequence_type = self.chk.analyze_iterable_item_type(sequence)
self.chk.analyze_index_variables(index, sequence_type, True, e)
for condition in conditions:
self.accept(condition)
Expand Down
8 changes: 6 additions & 2 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,22 +774,26 @@ def visit_DictComp(self, n: ast3.DictComp) -> DictionaryComprehension:
targets = [self.visit(c.target) for c in n.generators]
iters = [self.visit(c.iter) for c in n.generators]
ifs_list = [self.translate_expr_list(c.ifs) for c in n.generators]
is_async = [bool(c.is_async) for c in n.generators]
return DictionaryComprehension(self.visit(n.key),
self.visit(n.value),
targets,
iters,
ifs_list)
ifs_list,
is_async)

# GeneratorExp(expr elt, comprehension* generators)
@with_line
def visit_GeneratorExp(self, n: ast3.GeneratorExp) -> GeneratorExpr:
targets = [self.visit(c.target) for c in n.generators]
iters = [self.visit(c.iter) for c in n.generators]
ifs_list = [self.translate_expr_list(c.ifs) for c in n.generators]
is_async = [bool(c.is_async) for c in n.generators]
return GeneratorExpr(self.visit(n.elt),
targets,
iters,
ifs_list)
ifs_list,
is_async)

# Await(expr value)
@with_line
Expand Down
6 changes: 4 additions & 2 deletions mypy/fastparse2.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,8 @@ def visit_DictComp(self, n: ast27.DictComp) -> DictionaryComprehension:
self.visit(n.value),
targets,
iters,
ifs_list)
ifs_list,
[False for _ in n.generators])

# GeneratorExp(expr elt, comprehension* generators)
@with_line
Expand All @@ -787,7 +788,8 @@ def visit_GeneratorExp(self, n: ast27.GeneratorExp) -> GeneratorExpr:
return GeneratorExpr(self.visit(n.elt),
targets,
iters,
ifs_list)
ifs_list,
[False for _ in n.generators])

# Yield(expr? value)
@with_line
Expand Down
10 changes: 8 additions & 2 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,14 +1577,17 @@ class GeneratorExpr(Expression):
left_expr = None # type: Expression
sequences = None # type: List[Expression]
condlists = None # type: List[List[Expression]]
is_async = None # type: List[bool]
indices = None # type: List[Lvalue]

def __init__(self, left_expr: Expression, indices: List[Lvalue],
sequences: List[Expression], condlists: List[List[Expression]]) -> None:
sequences: List[Expression], condlists: List[List[Expression]],
is_async: List[bool]) -> None:
self.left_expr = left_expr
self.sequences = sequences
self.condlists = condlists
self.indices = indices
self.is_async = is_async

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_generator_expr(self)
Expand Down Expand Up @@ -1621,15 +1624,18 @@ class DictionaryComprehension(Expression):
value = None # type: Expression
sequences = None # type: List[Expression]
condlists = None # type: List[List[Expression]]
is_async = None # type: List[bool]
indices = None # type: List[Lvalue]

def __init__(self, key: Expression, value: Expression, indices: List[Lvalue],
sequences: List[Expression], condlists: List[List[Expression]]) -> None:
sequences: List[Expression], condlists: List[List[Expression]],
is_async: List[bool]) -> None:
self.key = key
self.value = value
self.sequences = sequences
self.condlists = condlists
self.indices = indices
self.is_async = is_async

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_dictionary_comprehension(self)
Expand Down
5 changes: 3 additions & 2 deletions mypy/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,7 +1472,7 @@ def parse_generator_expr(self, left_expr: Expression) -> GeneratorExpr:
tok = self.current()
indices, sequences, condlists = self.parse_comp_for()

gen = GeneratorExpr(left_expr, indices, sequences, condlists)
gen = GeneratorExpr(left_expr, indices, sequences, condlists, [False for _ in indices])
gen.set_line(tok)
return gen

Expand Down Expand Up @@ -1560,7 +1560,8 @@ def parse_set_comprehension(self, expr: Expression) -> SetComprehension:
def parse_dict_comprehension(self, key: Expression, value: Expression,
colon: Token) -> DictionaryComprehension:
indices, sequences, condlists = self.parse_comp_for()
dic = DictionaryComprehension(key, value, indices, sequences, condlists)
dic = DictionaryComprehension(key, value, indices, sequences, condlists,
[False for _ in indices])
dic.set_line(colon)
self.expect('}')
return dic
Expand Down
12 changes: 7 additions & 5 deletions mypy/treetransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,10 +439,11 @@ def visit_set_comprehension(self, node: SetComprehension) -> SetComprehension:
def visit_dictionary_comprehension(self, node: DictionaryComprehension
) -> DictionaryComprehension:
return DictionaryComprehension(self.expr(node.key), self.expr(node.value),
[self.expr(index) for index in node.indices],
[self.expr(s) for s in node.sequences],
[[self.expr(cond) for cond in conditions]
for conditions in node.condlists])
[self.expr(index) for index in node.indices],
[self.expr(s) for s in node.sequences],
[[self.expr(cond) for cond in conditions]
for conditions in node.condlists],
node.is_async)

def visit_generator_expr(self, node: GeneratorExpr) -> GeneratorExpr:
return self.duplicate_generator(node)
Expand All @@ -452,7 +453,8 @@ def duplicate_generator(self, node: GeneratorExpr) -> GeneratorExpr:
[self.expr(index) for index in node.indices],
[self.expr(s) for s in node.sequences],
[[self.expr(cond) for cond in conditions]
for conditions in node.condlists])
for conditions in node.condlists],
node.is_async)

def visit_slice_expr(self, node: SliceExpr) -> SliceExpr:
return SliceExpr(self.optional_expr(node.begin_index),
Expand Down
73 changes: 73 additions & 0 deletions test-data/unit/check-async-await.test
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,79 @@ async def f() -> None:
reveal_type(z) # E: Revealed type is 'Union[builtins.int, builtins.str]'
[builtins fixtures/async_await.pyi]

[case testAsyncForComprehension]
# flags: --fast-parser --python-version 3.6
from typing import Generic, Iterable, TypeVar, AsyncIterator, Tuple

T = TypeVar('T')

class asyncify(Generic[T], AsyncIterator[T]):
def __init__(self, iterable: Iterable[T]) -> None:
self.iterable = iter(iterable)
def __aiter__(self) -> AsyncIterator[T]:
return self
async def __anext__(self) -> T:
try:
return next(self.iterable)
except StopIteration:
raise StopAsyncIteration

async def listcomp(obj: Iterable[int]):
lst = [i async for i in asyncify(obj)]
reveal_type(lst) # E: Revealed type is 'builtins.list[builtins.int*]'
lst2 = [i async for i in asyncify(obj) for j in obj]
reveal_type(lst2) # E: Revealed type is 'builtins.list[builtins.int*]'

async def setcomp(obj: Iterable[int]):
lst = {i async for i in asyncify(obj)}
reveal_type(lst) # E: Revealed type is 'builtins.set[builtins.int*]'

async def dictcomp(obj: Iterable[Tuple[int, str]]):
lst = {i: j async for i, j in asyncify(obj)}
reveal_type(lst) # E: Revealed type is 'builtins.dict[builtins.int*, builtins.str*]'

async def generatorexp(obj: Iterable[int]):
lst = (i async for i in asyncify(obj))
reveal_type(lst) # E: Revealed type is 'typing.AsyncIterator[builtins.int*]'
lst2 = (i async for i in asyncify(obj) for i in obj)
reveal_type(lst2) # E: Revealed type is 'typing.AsyncIterator[builtins.int*]'

[builtins fixtures/async_await.pyi]

[case testAsyncForComprehensionErrors]
# flags: --fast-parser --python-version 3.6
from typing import Generic, Iterable, TypeVar, AsyncIterator, Tuple

T = TypeVar('T')

class asyncify(Generic[T], AsyncIterator[T]):
def __init__(self, iterable: Iterable[T]) -> None:
self.iterable = iter(iterable)
def __aiter__(self) -> AsyncIterator[T]:
return self
async def __anext__(self) -> T:
try:
return next(self.iterable)
except StopIteration:
raise StopAsyncIteration

async def wrong_iterable(obj: Iterable[int]):
[i async for i in obj]
[i for i in asyncify(obj)]
{i: i async for i in obj}
{i: i for i in asyncify(obj)}

[out]
main:18: error: AsyncIterable expected
main:18: error: Iterable[int] has no attribute "__aiter__"; maybe "__iter__"?
main:19: error: Iterable expected
main:19: error: asyncify[int] has no attribute "__iter__"; maybe "__aiter__"?
main:20: error: AsyncIterable expected
main:20: error: Iterable[int] has no attribute "__aiter__"; maybe "__iter__"?
main:21: error: Iterable expected
main:21: error: asyncify[int] has no attribute "__iter__"; maybe "__aiter__"?
[builtins fixtures/async_await.pyi]

[case testAsyncWith]
# flags: --fast-parser
class C:
Expand Down
7 changes: 7 additions & 0 deletions test-data/unit/fixtures/async_await.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,12 @@ class type: pass
class function: pass
class int: pass
class str: pass
class dict: pass
class list: pass
class set: pass
class tuple: pass
class BaseException: pass
class StopIteration(BaseException): pass
class StopAsyncIteration(BaseException): pass
def iter(obj: typing.Any) -> typing.Any: pass
def next(obj: typing.Any) -> typing.Any: pass