Skip to content

Commit 9d3a432

Browse files
Fix providers.Resource missing overloads for AbstractContextManager and AbstractAsyncContextManager (#927)
1 parent 5acddac commit 9d3a432

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

src/dependency_injector/providers.pyi

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from contextlib import AbstractContextManager, AbstractAsyncContextManager
34
from pathlib import Path
45
from typing import (
56
Awaitable,
@@ -465,6 +466,20 @@ class Resource(Provider[T]):
465466
**kwargs: Injection,
466467
) -> None: ...
467468
@overload
469+
def __init__(
470+
self,
471+
provides: Optional[_Callable[..., AbstractContextManager[T]]] = None,
472+
*args: Injection,
473+
**kwargs: Injection,
474+
) -> None: ...
475+
@overload
476+
def __init__(
477+
self,
478+
provides: Optional[_Callable[..., AbstractAsyncContextManager[T]]] = None,
479+
*args: Injection,
480+
**kwargs: Injection,
481+
) -> None: ...
482+
@overload
468483
def __init__(
469484
self,
470485
provides: Optional[_Callable[..., _Iterator[T]]] = None,

tests/typing/resource.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextlib import contextmanager, asynccontextmanager
12
from typing import (
23
Any,
34
AsyncGenerator,
@@ -7,6 +8,7 @@
78
Iterator,
89
List,
910
Optional,
11+
Self,
1012
)
1113

1214
from dependency_injector import providers, resources
@@ -109,3 +111,62 @@ async def _provide8() -> None:
109111
# Test 9: to check string imports
110112
provider9: providers.Resource[Dict[Any, Any]] = providers.Resource("builtins.dict")
111113
provider9.set_provides("builtins.dict")
114+
115+
116+
# Test 10: to check the return type with classes implementing AbstractContextManager protocol
117+
class MyResource10:
118+
def __init__(self) -> None:
119+
pass
120+
121+
def __enter__(self) -> Self:
122+
return self
123+
124+
def __exit__(self, *args: Any, **kwargs: Any) -> None:
125+
return None
126+
127+
128+
provider10 = providers.Resource(MyResource10)
129+
var10: MyResource10 = provider10()
130+
131+
132+
# Test 11: to check the return type with functions decorated with contextlib.contextmanager
133+
@contextmanager
134+
def init11() -> Iterator[int]:
135+
yield 1
136+
137+
138+
provider11 = providers.Resource(init11)
139+
var11: int = provider11()
140+
141+
142+
# Test 12: to check the return type with classes implementing AbstractAsyncContextManager protocol
143+
class MyResource12:
144+
def __init__(self) -> None:
145+
pass
146+
147+
async def __aenter__(self) -> Self:
148+
return self
149+
150+
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
151+
return None
152+
153+
154+
provider12 = providers.Resource(MyResource12)
155+
156+
157+
async def _provide12() -> None:
158+
var1: MyResource12 = await provider12() # type: ignore
159+
var2: MyResource12 = await provider12.async_()
160+
161+
162+
# Test 13: to check the return type with functions decorated with contextlib.asynccontextmanager
163+
@asynccontextmanager
164+
async def init13() -> AsyncIterator[int]:
165+
yield 1
166+
167+
168+
provider13 = providers.Resource(init13)
169+
170+
async def _provide13() -> None:
171+
var1: int = await provider13() # type: ignore
172+
var2: int = await provider13.async_()

0 commit comments

Comments
 (0)