From bf423b38704ad93d39a1b4b45625f43d74c62d89 Mon Sep 17 00:00:00 2001 From: sobolevn Date: Sat, 23 Sep 2023 14:10:16 +0300 Subject: [PATCH 1/5] Fix `contextlib.asynccontextmanager` to work with coroutine functions --- stdlib/contextlib.pyi | 11 +++++++---- test_cases/stdlib/check_contextlib.py | 17 ++++++++++++++++- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/stdlib/contextlib.pyi b/stdlib/contextlib.pyi index dc2101dc01f7..bce8f5e91a13 100644 --- a/stdlib/contextlib.pyi +++ b/stdlib/contextlib.pyi @@ -2,7 +2,7 @@ import abc import sys from _typeshed import FileDescriptorOrPath, Unused from abc import abstractmethod -from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Generator, Iterator +from collections.abc import AsyncGenerator, AsyncGenerator, Awaitable, Callable, Coroutine, Generator, Iterator from types import TracebackType from typing import IO, Any, Generic, Protocol, TypeVar, overload, runtime_checkable from typing_extensions import ParamSpec, Self, TypeAlias @@ -84,7 +84,7 @@ if sys.version_info >= (3, 10): class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co], AsyncContextDecorator, Generic[_T_co]): # __init__ and these attributes are actually defined in the base class _GeneratorContextManagerBase, # which is more trouble than it's worth to include in the stub (see #6676) - def __init__(self, func: Callable[..., AsyncIterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ... + def __init__(self, func: Callable[..., AsyncGenerator[_T_co, Any]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ... gen: AsyncGenerator[_T_co, Any] func: Callable[..., AsyncGenerator[_T_co, Any]] args: tuple[Any, ...] @@ -95,7 +95,7 @@ if sys.version_info >= (3, 10): else: class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co], Generic[_T_co]): - def __init__(self, func: Callable[..., AsyncIterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ... + def __init__(self, func: Callable[..., AsyncGenerator[_T_co, Any]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ... gen: AsyncGenerator[_T_co, Any] func: Callable[..., AsyncGenerator[_T_co, Any]] args: tuple[Any, ...] @@ -104,7 +104,10 @@ else: self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None ) -> bool | None: ... -def asynccontextmanager(func: Callable[_P, AsyncIterator[_T_co]]) -> Callable[_P, _AsyncGeneratorContextManager[_T_co]]: ... +@overload +def asynccontextmanager(func: Callable[_P, Coroutine[Any, Any, AsyncGenerator[_T_co, Any]]]) -> Callable[_P, _AsyncGeneratorContextManager[_T_co]]: ... +@overload +def asynccontextmanager(func: Callable[_P, AsyncGenerator[_T_co, Any]]) -> Callable[_P, _AsyncGeneratorContextManager[_T_co]]: ... class _SupportsClose(Protocol): def close(self) -> object: ... diff --git a/test_cases/stdlib/check_contextlib.py b/test_cases/stdlib/check_contextlib.py index 648661bca856..845fab59adfa 100644 --- a/test_cases/stdlib/check_contextlib.py +++ b/test_cases/stdlib/check_contextlib.py @@ -1,6 +1,7 @@ from __future__ import annotations -from contextlib import ExitStack +from contextlib import ExitStack, asynccontextmanager +from typing import AsyncIterator, Awaitable from typing_extensions import assert_type @@ -18,3 +19,17 @@ class Thing(ExitStack): assert_type(cm, ExitStack) with thing as cm2: assert_type(cm2, Thing) + + +@asynccontextmanager +async def async_context() -> AsyncGenerator[str, None]: + yield 'example' + + +async def async_gen() -> AsyncGenerator[str, None]: + yield 'async gen' + + +@asynccontextmanager +def async_cm_func() -> AsyncGenerator[str, None]: + return async_gen() From 1de2736856176ab3c3ab2b71e366063c3bc65391 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 23 Sep 2023 11:40:49 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks --- stdlib/contextlib.pyi | 14 ++++++++++---- test_cases/stdlib/check_contextlib.py | 4 ++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/stdlib/contextlib.pyi b/stdlib/contextlib.pyi index bce8f5e91a13..99186814ad53 100644 --- a/stdlib/contextlib.pyi +++ b/stdlib/contextlib.pyi @@ -2,7 +2,7 @@ import abc import sys from _typeshed import FileDescriptorOrPath, Unused from abc import abstractmethod -from collections.abc import AsyncGenerator, AsyncGenerator, Awaitable, Callable, Coroutine, Generator, Iterator +from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Generator, Iterator from types import TracebackType from typing import IO, Any, Generic, Protocol, TypeVar, overload, runtime_checkable from typing_extensions import ParamSpec, Self, TypeAlias @@ -84,7 +84,9 @@ if sys.version_info >= (3, 10): class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co], AsyncContextDecorator, Generic[_T_co]): # __init__ and these attributes are actually defined in the base class _GeneratorContextManagerBase, # which is more trouble than it's worth to include in the stub (see #6676) - def __init__(self, func: Callable[..., AsyncGenerator[_T_co, Any]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ... + def __init__( + self, func: Callable[..., AsyncGenerator[_T_co, Any]], args: tuple[Any, ...], kwds: dict[str, Any] + ) -> None: ... gen: AsyncGenerator[_T_co, Any] func: Callable[..., AsyncGenerator[_T_co, Any]] args: tuple[Any, ...] @@ -95,7 +97,9 @@ if sys.version_info >= (3, 10): else: class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co], Generic[_T_co]): - def __init__(self, func: Callable[..., AsyncGenerator[_T_co, Any]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ... + def __init__( + self, func: Callable[..., AsyncGenerator[_T_co, Any]], args: tuple[Any, ...], kwds: dict[str, Any] + ) -> None: ... gen: AsyncGenerator[_T_co, Any] func: Callable[..., AsyncGenerator[_T_co, Any]] args: tuple[Any, ...] @@ -105,7 +109,9 @@ else: ) -> bool | None: ... @overload -def asynccontextmanager(func: Callable[_P, Coroutine[Any, Any, AsyncGenerator[_T_co, Any]]]) -> Callable[_P, _AsyncGeneratorContextManager[_T_co]]: ... +def asynccontextmanager( + func: Callable[_P, Coroutine[Any, Any, AsyncGenerator[_T_co, Any]]] +) -> Callable[_P, _AsyncGeneratorContextManager[_T_co]]: ... @overload def asynccontextmanager(func: Callable[_P, AsyncGenerator[_T_co, Any]]) -> Callable[_P, _AsyncGeneratorContextManager[_T_co]]: ... diff --git a/test_cases/stdlib/check_contextlib.py b/test_cases/stdlib/check_contextlib.py index 845fab59adfa..d9dbf059f355 100644 --- a/test_cases/stdlib/check_contextlib.py +++ b/test_cases/stdlib/check_contextlib.py @@ -23,11 +23,11 @@ class Thing(ExitStack): @asynccontextmanager async def async_context() -> AsyncGenerator[str, None]: - yield 'example' + yield "example" async def async_gen() -> AsyncGenerator[str, None]: - yield 'async gen' + yield "async gen" @asynccontextmanager From 14e457677447041bfff4a1569af76fe6f0436792 Mon Sep 17 00:00:00 2001 From: sobolevn Date: Sat, 23 Sep 2023 14:48:24 +0300 Subject: [PATCH 3/5] Fix CI --- test_cases/stdlib/check_contextlib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_cases/stdlib/check_contextlib.py b/test_cases/stdlib/check_contextlib.py index 845fab59adfa..280bb6bac6ec 100644 --- a/test_cases/stdlib/check_contextlib.py +++ b/test_cases/stdlib/check_contextlib.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import ExitStack, asynccontextmanager -from typing import AsyncIterator, Awaitable +from typing import AsyncGenerator from typing_extensions import assert_type From 205713d2c08bd9773d4672d306e978ea30b114da Mon Sep 17 00:00:00 2001 From: sobolevn Date: Sat, 23 Sep 2023 15:11:18 +0300 Subject: [PATCH 4/5] Use `AsyncIterator` for now --- stdlib/contextlib.pyi | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/stdlib/contextlib.pyi b/stdlib/contextlib.pyi index 99186814ad53..6c377f7353a2 100644 --- a/stdlib/contextlib.pyi +++ b/stdlib/contextlib.pyi @@ -2,7 +2,7 @@ import abc import sys from _typeshed import FileDescriptorOrPath, Unused from abc import abstractmethod -from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Generator, Iterator +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Coroutine, Generator, Iterator from types import TracebackType from typing import IO, Any, Generic, Protocol, TypeVar, overload, runtime_checkable from typing_extensions import ParamSpec, Self, TypeAlias @@ -85,7 +85,7 @@ if sys.version_info >= (3, 10): # __init__ and these attributes are actually defined in the base class _GeneratorContextManagerBase, # which is more trouble than it's worth to include in the stub (see #6676) def __init__( - self, func: Callable[..., AsyncGenerator[_T_co, Any]], args: tuple[Any, ...], kwds: dict[str, Any] + self, func: Callable[..., AsyncIterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any] ) -> None: ... gen: AsyncGenerator[_T_co, Any] func: Callable[..., AsyncGenerator[_T_co, Any]] @@ -98,7 +98,7 @@ if sys.version_info >= (3, 10): else: class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co], Generic[_T_co]): def __init__( - self, func: Callable[..., AsyncGenerator[_T_co, Any]], args: tuple[Any, ...], kwds: dict[str, Any] + self, func: Callable[..., AsyncIterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any] ) -> None: ... gen: AsyncGenerator[_T_co, Any] func: Callable[..., AsyncGenerator[_T_co, Any]] @@ -110,10 +110,10 @@ else: @overload def asynccontextmanager( - func: Callable[_P, Coroutine[Any, Any, AsyncGenerator[_T_co, Any]]] + func: Callable[_P, Coroutine[Any, Any, AsyncIterator[_T_co]]] ) -> Callable[_P, _AsyncGeneratorContextManager[_T_co]]: ... @overload -def asynccontextmanager(func: Callable[_P, AsyncGenerator[_T_co, Any]]) -> Callable[_P, _AsyncGeneratorContextManager[_T_co]]: ... +def asynccontextmanager(func: Callable[_P, AsyncIterator[_T_co]]) -> Callable[_P, _AsyncGeneratorContextManager[_T_co]]: ... class _SupportsClose(Protocol): def close(self) -> object: ... From d7077aad5bf3527f1b104d630845da60cd53d9fe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 23 Sep 2023 12:11:59 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks --- stdlib/contextlib.pyi | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/stdlib/contextlib.pyi b/stdlib/contextlib.pyi index 6c377f7353a2..6683955d9afe 100644 --- a/stdlib/contextlib.pyi +++ b/stdlib/contextlib.pyi @@ -84,9 +84,7 @@ if sys.version_info >= (3, 10): class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co], AsyncContextDecorator, Generic[_T_co]): # __init__ and these attributes are actually defined in the base class _GeneratorContextManagerBase, # which is more trouble than it's worth to include in the stub (see #6676) - def __init__( - self, func: Callable[..., AsyncIterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any] - ) -> None: ... + def __init__(self, func: Callable[..., AsyncIterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ... gen: AsyncGenerator[_T_co, Any] func: Callable[..., AsyncGenerator[_T_co, Any]] args: tuple[Any, ...] @@ -97,9 +95,7 @@ if sys.version_info >= (3, 10): else: class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co], Generic[_T_co]): - def __init__( - self, func: Callable[..., AsyncIterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any] - ) -> None: ... + def __init__(self, func: Callable[..., AsyncIterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ... gen: AsyncGenerator[_T_co, Any] func: Callable[..., AsyncGenerator[_T_co, Any]] args: tuple[Any, ...]