From 3ce4289f9a7e045befc0386e61fcbfa0b614f36d Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 13 Jun 2025 14:15:00 +0200 Subject: [PATCH 1/5] Proper check if callable is async --- .../pydantic_ai/_function_schema.py | 5 ++-- pydantic_ai_slim/pydantic_ai/_output.py | 2 +- .../pydantic_ai/_system_prompt.py | 2 +- pydantic_ai_slim/pydantic_ai/_utils.py | 29 +++++++++++++++++-- pydantic_ai_slim/pydantic_ai/tools.py | 3 +- 5 files changed, 32 insertions(+), 9 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_function_schema.py b/pydantic_ai_slim/pydantic_ai/_function_schema.py index 64201a3790..6065681eb7 100644 --- a/pydantic_ai_slim/pydantic_ai/_function_schema.py +++ b/pydantic_ai_slim/pydantic_ai/_function_schema.py @@ -5,7 +5,6 @@ from __future__ import annotations as _annotations -import inspect from collections.abc import Awaitable from dataclasses import dataclass, field from inspect import Parameter, signature @@ -23,7 +22,7 @@ from pydantic_ai.tools import RunContext from ._griffe import doc_descriptions -from ._utils import check_object_json_schema, is_model_like, run_in_executor +from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor if TYPE_CHECKING: from .tools import DocstringFormat, ObjectJsonSchema @@ -214,7 +213,7 @@ def function_schema( # noqa: C901 positional_fields=positional_fields, var_positional_field=var_positional_field, takes_ctx=takes_ctx, - is_async=inspect.iscoroutinefunction(function), + is_async=is_async_callable(function), function=function, ) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 916ddb7e83..22da934a0d 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -60,7 +60,7 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]): def __post_init__(self): self._takes_ctx = len(inspect.signature(self.function).parameters) > 1 - self._is_async = inspect.iscoroutinefunction(self.function) + self._is_async = _utils.is_async_callable(self.function) async def validate( self, diff --git a/pydantic_ai_slim/pydantic_ai/_system_prompt.py b/pydantic_ai_slim/pydantic_ai/_system_prompt.py index aca3080015..df2b93e7e8 100644 --- a/pydantic_ai_slim/pydantic_ai/_system_prompt.py +++ b/pydantic_ai_slim/pydantic_ai/_system_prompt.py @@ -18,7 +18,7 @@ class SystemPromptRunner(Generic[AgentDepsT]): def __post_init__(self): self._takes_ctx = len(inspect.signature(self.function).parameters) > 0 - self._is_async = inspect.iscoroutinefunction(self.function) + self._is_async = _utils.is_async_callable(self.function) async def run(self, run_context: RunContext[AgentDepsT]) -> str: if self._takes_ctx: diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 77c34fbacf..819427a1d6 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -1,15 +1,17 @@ from __future__ import annotations as _annotations import asyncio +import functools +import inspect import time import uuid -from collections.abc import AsyncIterable, AsyncIterator, Iterator +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator from contextlib import asynccontextmanager, suppress from dataclasses import dataclass, fields, is_dataclass from datetime import datetime, timezone from functools import partial from types import GenericAlias -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overload from anyio.to_thread import run_sync from pydantic import BaseModel, TypeAdapter @@ -302,3 +304,26 @@ def dataclasses_no_defaults_repr(self: Any) -> str: def number_to_datetime(x: int | float) -> datetime: return TypeAdapter(datetime).validate_python(x) + + +AwaitableCallable = Callable[..., Awaitable[T]] + + +@overload +def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ... + + +@overload +def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ... + + +def is_async_callable(obj: Any) -> Any: + """Correctly check if a callable is async. + + This function was copied from Starlette: + https://github.com/encode/starlette/blob/78da9b9e218ab289117df7d62aee200ed4c59617/starlette/_utils.py#L36-L40 + """ + while isinstance(obj, functools.partial): + obj = obj.func + + return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 6db5a6455b..bb34015191 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -1,6 +1,5 @@ from __future__ import annotations as _annotations -import asyncio import dataclasses import json from collections.abc import Awaitable, Sequence @@ -337,7 +336,7 @@ def from_schema( validator=SchemaValidator(schema=core_schema.any_schema()), json_schema=json_schema, takes_ctx=False, - is_async=asyncio.iscoroutinefunction(function), + is_async=_utils.is_async_callable(function), ) return cls( From f4e9fb18b02eef09bdc2fdbbd35be0a065759087 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 13 Jun 2025 14:30:11 +0200 Subject: [PATCH 2/5] Add tests --- tests/test_utils.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index e7d3ddcf34..caa75be890 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import asyncio import contextvars +import functools import os from collections.abc import AsyncIterator from importlib.metadata import distributions @@ -10,7 +11,14 @@ from inline_snapshot import snapshot from pydantic_ai import UserError -from pydantic_ai._utils import UNSET, PeekableAsyncStream, check_object_json_schema, group_by_temporal, run_in_executor +from pydantic_ai._utils import ( + UNSET, + PeekableAsyncStream, + check_object_json_schema, + group_by_temporal, + is_async_callable, + run_in_executor, +) from .models.mock_async_stream import MockAsyncStream @@ -153,3 +161,22 @@ async def test_run_in_executor_with_contextvars() -> None: # show that the old version did not work old_result = asyncio.get_running_loop().run_in_executor(None, ctx_var.get) assert old_result != ctx_var.get() + + +def test_is_async_callable(): + def sync_func(): + return 1 + + assert is_async_callable(sync_func) is False + + async def async_func(): + return 1 + + assert is_async_callable(async_func) is True + + class AsyncCallable: + async def __call__(self): + return 42 + + partial_async_callable = functools.partial(AsyncCallable()) + assert is_async_callable(partial_async_callable) is True From 7fcbc8c1f870b7cffab67897975b46453b058254 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 13 Jun 2025 14:30:54 +0200 Subject: [PATCH 3/5] drop function body from tests --- tests/test_utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index caa75be890..c703e2b8f1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -164,19 +164,16 @@ async def test_run_in_executor_with_contextvars() -> None: def test_is_async_callable(): - def sync_func(): - return 1 + def sync_func(): ... assert is_async_callable(sync_func) is False - async def async_func(): - return 1 + async def async_func(): ... assert is_async_callable(async_func) is True class AsyncCallable: - async def __call__(self): - return 42 + async def __call__(self): ... partial_async_callable = functools.partial(AsyncCallable()) assert is_async_callable(partial_async_callable) is True From e79b14322b92ef25e17ba2527b9d1d7d931e9986 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 13 Jun 2025 14:38:09 +0200 Subject: [PATCH 4/5] Use TypeIs instead of TypeGuard --- pydantic_ai_slim/pydantic_ai/_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 819427a1d6..1e958d2427 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -16,7 +16,7 @@ from anyio.to_thread import run_sync from pydantic import BaseModel, TypeAdapter from pydantic.json_schema import JsonSchemaValue -from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict +from typing_extensions import ParamSpec, TypeAlias, TypeGuard, TypeIs, is_typeddict from pydantic_graph._utils import AbstractSpan @@ -310,11 +310,11 @@ def number_to_datetime(x: int | float) -> datetime: @overload -def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ... +def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ... @overload -def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ... +def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ... def is_async_callable(obj: Any) -> Any: From eac089f268e7d3ccd7c4a24fcdc43f4360e11f8d Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 13 Jun 2025 15:19:51 +0200 Subject: [PATCH 5/5] add more ignore --- tests/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index c703e2b8f1..e62e8b6040 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -164,16 +164,16 @@ async def test_run_in_executor_with_contextvars() -> None: def test_is_async_callable(): - def sync_func(): ... + def sync_func(): ... # pragma: no branch assert is_async_callable(sync_func) is False - async def async_func(): ... + async def async_func(): ... # pragma: no branch assert is_async_callable(async_func) is True class AsyncCallable: - async def __call__(self): ... + async def __call__(self): ... # pragma: no branch partial_async_callable = functools.partial(AsyncCallable()) assert is_async_callable(partial_async_callable) is True