Skip to content

Commit 4b759b3

Browse files
CopilotTomeHirata
andcommitted
Use ParamSpec and TypeVar to preserve function signatures in asyncify
Co-authored-by: TomeHirata <[email protected]>
1 parent 9495fd8 commit 4b759b3

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

dspy/utils/asyncify.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Union, overload
1+
from typing import TYPE_CHECKING, Awaitable, Callable, ParamSpec, TypeVar, Union, overload
22

33
import asyncer
44
from anyio import CapacityLimiter
55

66
if TYPE_CHECKING:
77
from dspy.primitives.module import Module
88

9+
P = ParamSpec("P")
10+
T = TypeVar("T")
11+
912
_limiter = None
1013

1114

@@ -28,14 +31,14 @@ def get_limiter():
2831

2932

3033
@overload
31-
def asyncify(program: "Module") -> Callable[..., Awaitable[Any]]: ...
34+
def asyncify(program: Callable[P, T]) -> Callable[P, Awaitable[T]]: ...
3235

3336

3437
@overload
35-
def asyncify(program: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: ...
38+
def asyncify(program: "Module") -> Callable[..., Awaitable[T]]: ...
3639

3740

38-
def asyncify(program: Union["Module", Callable[..., Any]]) -> Callable[..., Awaitable[Any]]:
41+
def asyncify(program: Union[Callable[P, T], "Module"]) -> Callable[P, Awaitable[T]] | Callable[..., Awaitable[T]]:
3942
"""
4043
Wraps a DSPy program or callable so that it can be called asynchronously. This is useful for running a
4144
program in parallel with another task (e.g., another DSPy program).
@@ -50,7 +53,7 @@ def asyncify(program: Union["Module", Callable[..., Any]]) -> Callable[..., Awai
5053
The current thread's configuration context is inherited for each call.
5154
"""
5255

53-
async def async_program(*args: Any, **kwargs: Any) -> Any:
56+
async def async_program(*args: P.args, **kwargs: P.kwargs) -> T:
5457
# Capture the current overrides at call-time.
5558
from dspy.dsp.utils.settings import thread_local_overrides
5659

@@ -70,4 +73,4 @@ def wrapped_program(*a, **kw):
7073
call_async = asyncer.asyncify(wrapped_program, abandon_on_cancel=True, limiter=get_limiter())
7174
return await call_async(*args, **kwargs)
7275

73-
return async_program
76+
return async_program # type: ignore[return-value]

0 commit comments

Comments
 (0)