diff --git a/pydantic_ai_slim/pydantic_ai/_function_schema.py b/pydantic_ai_slim/pydantic_ai/_function_schema.py index 64201a3790..6065681eb7 100644 --- a/pydantic_ai_slim/pydantic_ai/_function_schema.py +++ b/pydantic_ai_slim/pydantic_ai/_function_schema.py @@ -5,7 +5,6 @@ from __future__ import annotations as _annotations -import inspect from collections.abc import Awaitable from dataclasses import dataclass, field from inspect import Parameter, signature @@ -23,7 +22,7 @@ from pydantic_ai.tools import RunContext from ._griffe import doc_descriptions -from ._utils import check_object_json_schema, is_model_like, run_in_executor +from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor if TYPE_CHECKING: from .tools import DocstringFormat, ObjectJsonSchema @@ -214,7 +213,7 @@ def function_schema( # noqa: C901 positional_fields=positional_fields, var_positional_field=var_positional_field, takes_ctx=takes_ctx, - is_async=inspect.iscoroutinefunction(function), + is_async=is_async_callable(function), function=function, ) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 916ddb7e83..22da934a0d 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -60,7 +60,7 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]): def __post_init__(self): self._takes_ctx = len(inspect.signature(self.function).parameters) > 1 - self._is_async = inspect.iscoroutinefunction(self.function) + self._is_async = _utils.is_async_callable(self.function) async def validate( self, diff --git a/pydantic_ai_slim/pydantic_ai/_system_prompt.py b/pydantic_ai_slim/pydantic_ai/_system_prompt.py index aca3080015..df2b93e7e8 100644 --- a/pydantic_ai_slim/pydantic_ai/_system_prompt.py +++ b/pydantic_ai_slim/pydantic_ai/_system_prompt.py @@ -18,7 +18,7 @@ class SystemPromptRunner(Generic[AgentDepsT]): def __post_init__(self): self._takes_ctx = len(inspect.signature(self.function).parameters) > 0 - self._is_async = inspect.iscoroutinefunction(self.function) + self._is_async = _utils.is_async_callable(self.function) async def run(self, run_context: RunContext[AgentDepsT]) -> str: if self._takes_ctx: diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 77c34fbacf..1e958d2427 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -1,20 +1,22 @@ from __future__ import annotations as _annotations import asyncio +import functools +import inspect import time import uuid -from collections.abc import AsyncIterable, AsyncIterator, Iterator +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator from contextlib import asynccontextmanager, suppress from dataclasses import dataclass, fields, is_dataclass from datetime import datetime, timezone from functools import partial from types import GenericAlias -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overload from anyio.to_thread import run_sync from pydantic import BaseModel, TypeAdapter from pydantic.json_schema import JsonSchemaValue -from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict +from typing_extensions import ParamSpec, TypeAlias, TypeGuard, TypeIs, is_typeddict from pydantic_graph._utils import AbstractSpan @@ -302,3 +304,26 @@ def dataclasses_no_defaults_repr(self: Any) -> str: def number_to_datetime(x: int | float) -> datetime: return TypeAdapter(datetime).validate_python(x) + + +AwaitableCallable = Callable[..., Awaitable[T]] + + +@overload +def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ... + + +@overload +def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ... + + +def is_async_callable(obj: Any) -> Any: + """Correctly check if a callable is async. + + This function was copied from Starlette: + https://github.com/encode/starlette/blob/78da9b9e218ab289117df7d62aee200ed4c59617/starlette/_utils.py#L36-L40 + """ + while isinstance(obj, functools.partial): + obj = obj.func + + return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 6db5a6455b..bb34015191 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -1,6 +1,5 @@ from __future__ import annotations as _annotations -import asyncio import dataclasses import json from collections.abc import Awaitable, Sequence @@ -337,7 +336,7 @@ def from_schema( validator=SchemaValidator(schema=core_schema.any_schema()), json_schema=json_schema, takes_ctx=False, - is_async=asyncio.iscoroutinefunction(function), + is_async=_utils.is_async_callable(function), ) return cls( diff --git a/tests/test_utils.py b/tests/test_utils.py index e7d3ddcf34..e62e8b6040 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import asyncio import contextvars +import functools import os from collections.abc import AsyncIterator from importlib.metadata import distributions @@ -10,7 +11,14 @@ from inline_snapshot import snapshot from pydantic_ai import UserError -from pydantic_ai._utils import UNSET, PeekableAsyncStream, check_object_json_schema, group_by_temporal, run_in_executor +from pydantic_ai._utils import ( + UNSET, + PeekableAsyncStream, + check_object_json_schema, + group_by_temporal, + is_async_callable, + run_in_executor, +) from .models.mock_async_stream import MockAsyncStream @@ -153,3 +161,19 @@ async def test_run_in_executor_with_contextvars() -> None: # show that the old version did not work old_result = asyncio.get_running_loop().run_in_executor(None, ctx_var.get) assert old_result != ctx_var.get() + + +def test_is_async_callable(): + def sync_func(): ... # pragma: no branch + + assert is_async_callable(sync_func) is False + + async def async_func(): ... # pragma: no branch + + assert is_async_callable(async_func) is True + + class AsyncCallable: + async def __call__(self): ... # pragma: no branch + + partial_async_callable = functools.partial(AsyncCallable()) + assert is_async_callable(partial_async_callable) is True