diff --git a/stdlib/contextlib.pyi b/stdlib/contextlib.pyi index 1b6ee4298174..54a47e28ac48 100644 --- a/stdlib/contextlib.pyi +++ b/stdlib/contextlib.pyi @@ -163,7 +163,7 @@ class _RedirectStream(AbstractContextManager[_T_io]): class redirect_stdout(_RedirectStream[_T_io]): ... class redirect_stderr(_RedirectStream[_T_io]): ... -class ExitStack(AbstractContextManager[ExitStack]): +class ExitStack: def __init__(self) -> None: ... def enter_context(self, cm: AbstractContextManager[_T]) -> _T: ... def push(self, exit: _CM_EF) -> _CM_EF: ... @@ -179,7 +179,7 @@ if sys.version_info >= (3, 7): _ExitCoroFunc: TypeAlias = Callable[[type[BaseException] | None, BaseException | None, TracebackType | None], Awaitable[bool]] _ACM_EF = TypeVar("_ACM_EF", bound=AbstractAsyncContextManager[Any] | _ExitCoroFunc) - class AsyncExitStack(AbstractAsyncContextManager[AsyncExitStack]): + class AsyncExitStack: def __init__(self) -> None: ... def enter_context(self, cm: AbstractContextManager[_T]) -> _T: ... async def enter_async_context(self, cm: AbstractAsyncContextManager[_T]) -> _T: ... diff --git a/test_cases/stdlib/test_contextlib.py b/test_cases/stdlib/test_contextlib.py new file mode 100644 index 000000000000..26253326bfce --- /dev/null +++ b/test_cases/stdlib/test_contextlib.py @@ -0,0 +1,18 @@ +from contextlib import ExitStack +from typing_extensions import assert_type + + +# See issue #7961 +class Thing(ExitStack): + pass + + +stack = ExitStack() +thing = Thing() +assert_type(stack.enter_context(Thing()), Thing) +assert_type(thing.enter_context(ExitStack()), ExitStack) + +with stack as cm: + assert_type(cm, ExitStack) +with thing as cm2: + assert_type(cm2, Thing)