Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
57e568b
Add `builtin_tools` to `Agent`
Kludex May 14, 2025
97ab44b
make AbstractBuiltinTool serializable
Kludex May 14, 2025
e3dda9d
Add more work on it
Kludex May 14, 2025
3ad6d38
Merge remote-tracking branch 'origin/main' into add-builtin-tools
Kludex May 23, 2025
0b43f65
Add builtin tools
Kludex May 23, 2025
fa7fd11
merge
Kludex May 26, 2025
32324fa
add more built-in-tools
Kludex May 27, 2025
f33e568
Fix test
Kludex May 27, 2025
13d7433
Add support on Groq
Kludex May 27, 2025
ac85205
Add support for Google
Kludex May 28, 2025
c93633f
Add support for MCP's Streamable HTTP transport (#1716)
BrandonShar May 26, 2025
3a8b640
Timeout for initializing MCP client (#1833)
alexmojaki May 27, 2025
360de87
Require mcp 1.9.0+ (#1840)
DouweM May 27, 2025
cb4e539
Don't send empty messages to Anthropic (#1027)
oscar-broman May 27, 2025
4e3769a
Add `vendor_id` and `finish_reason` to Gemini/Google model responses …
davide-andreoli May 27, 2025
ebb536f
Fix units of `sse_read_timeout` `timedelta` (#1843)
alexmojaki May 27, 2025
c8bb611
Support functions as output_type, as well as lists of functions and o…
DouweM May 27, 2025
6bcc1a8
Enhance Gemini usage tracking to collect comprehensive token data (#1…
kiqaps May 28, 2025
97ff651
more
Kludex May 30, 2025
1d47e1e
merge
Kludex May 30, 2025
5f89444
merge
Kludex Jun 1, 2025
fe08bbf
updates
mattbrandman Jun 5, 2025
06a8f5b
updates
mattbrandman Jun 5, 2025
fb896fe
updates
mattbrandman Jun 6, 2025
1dbff81
Merge pull request #1 from mattbrandman/new-tests
mattbrandman Jun 8, 2025
a434fe5
Merge branch 'main' into add-builtin-tools
mattbrandman Jun 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ repos:
rev: v2.3.0
hooks:
- id: codespell
args: ['--skip', 'tests/models/cassettes/*,docs/a2a/fasta2a.md']
args: ['--skip', 'tests/models/cassettes/*,docs/a2a/fasta2a.md,tests/models/test_groq.py']
additional_dependencies:
- tomli
9 changes: 8 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from opentelemetry.trace import Tracer
from typing_extensions import TypeGuard, TypeVar, assert_never

from pydantic_ai.builtin_tools import AbstractBuiltinTool
from pydantic_graph import BaseNode, Graph, GraphRunContext
from pydantic_graph.nodes import End, NodeRunEndT

Expand Down Expand Up @@ -94,6 +95,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]

function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
builtin_tools: list[AbstractBuiltinTool] = dataclasses.field(repr=False)
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
default_retries: int

Expand Down Expand Up @@ -262,6 +264,7 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
output_schema = ctx.deps.output_schema
return models.ModelRequestParameters(
function_tools=function_tool_defs,
builtin_tools=ctx.deps.builtin_tools,
allow_text_output=_output.allow_text_output(output_schema),
output_tools=output_schema.tool_defs() if output_schema is not None else [],
)
Expand Down Expand Up @@ -414,7 +417,7 @@ async def stream(
async for _event in stream:
pass

async def _run_stream(
async def _run_stream( # noqa C901
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> AsyncIterator[_messages.HandleResponseEvent]:
if self._events_iterator is None:
Expand All @@ -430,6 +433,10 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
texts.append(part.content)
elif isinstance(part, _messages.ToolCallPart):
tool_calls.append(part)
elif isinstance(part, _messages.ServerToolCallPart):
yield _messages.ServerToolCallEvent(part)
elif isinstance(part, _messages.ServerToolReturnPart):
yield _messages.ServerToolResultEvent(part)
else:
assert_never(part)

Expand Down
98 changes: 93 additions & 5 deletions pydantic_ai_slim/pydantic_ai/_parts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
ModelResponseStreamEvent,
PartDeltaEvent,
PartStartEvent,
ServerToolCallPart,
ServerToolCallPartDelta,
TextPart,
TextPartDelta,
ToolCallPart,
Expand All @@ -36,11 +38,11 @@
Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.)
"""

ManagedPart = Union[ModelResponsePart, ToolCallPartDelta]
ManagedPart = Union[ModelResponsePart, ToolCallPartDelta, ServerToolCallPartDelta]
"""
A union of types that are managed by the ModelResponsePartsManager.
Because many vendors have streaming APIs that may produce not-fully-formed tool calls,
this includes ToolCallPartDelta's in addition to the more fully-formed ModelResponsePart's.
this includes ToolCallPartDelta's and ServerToolCallPartDelta's in addition to the more fully-formed ModelResponsePart's.
"""


Expand All @@ -57,12 +59,12 @@ class ModelResponsePartsManager:
"""Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides."""

def get_parts(self) -> list[ModelResponsePart]:
"""Return only model response parts that are complete (i.e., not ToolCallPartDelta's).
"""Return only model response parts that are complete (i.e., not ToolCallPartDelta's or ServerToolCallPartDelta's).

Returns:
A list of ModelResponsePart objects. ToolCallPartDelta objects are excluded.
A list of ModelResponsePart objects. ToolCallPartDelta and ServerToolCallPartDelta objects are excluded.
"""
return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]
return [p for p in self._parts if not isinstance(p, (ToolCallPartDelta, ServerToolCallPartDelta))]

def handle_text_delta(
self,
Expand Down Expand Up @@ -245,3 +247,89 @@ def handle_tool_call_part(
self._parts.append(new_part)
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
return PartStartEvent(index=new_part_index, part=new_part)

def handle_server_tool_call_delta(
self,
*,
vendor_part_id: Hashable | None,
tool_name: str | None,
args: str | dict[str, Any] | None,
tool_call_id: str | None,
model_name: str | None,
) -> ModelResponseStreamEvent | None:
"""Handle or update a server tool call, creating or updating a `ServerToolCallPart` or `ServerToolCallPartDelta`.

Managed items remain as `ServerToolCallPartDelta`s until they have at least a tool_name, at which
point they are upgraded to `ServerToolCallPart`s.

If `vendor_part_id` is None, updates the latest matching ServerToolCallPart (or ServerToolCallPartDelta)
if any. Otherwise, a new part (or delta) may be created.

Args:
vendor_part_id: The ID the vendor uses for this server tool call.
If None, the latest matching server tool call may be updated.
tool_name: The name of the server tool. If None, the manager does not enforce
a name match when `vendor_part_id` is None.
args: Arguments for the server tool call, either as a string, a dictionary of key-value pairs, or None.
tool_call_id: An optional string representing an identifier for this server tool call.
model_name: An optional string representing the model name that generated this server tool call.

Returns:
- A `PartStartEvent` if a new ServerToolCallPart is created.
- A `PartDeltaEvent` if an existing part is updated.
- `None` if no new event is emitted (e.g., the part is still incomplete).

Raises:
UnexpectedModelBehavior: If attempting to apply a server tool call delta to a part that is not
a ServerToolCallPart or ServerToolCallPartDelta.
"""
existing_matching_part_and_index: tuple[ServerToolCallPartDelta | ServerToolCallPart, int] | None = None

if vendor_part_id is None:
# vendor_part_id is None, so check if the latest part is a matching server tool call or delta to update
# When the vendor_part_id is None, if the tool_name is _not_ None, assume this should be a new part rather
# than a delta on an existing one. We can change this behavior in the future if necessary for some model.
if tool_name is None and self._parts:
part_index = len(self._parts) - 1
latest_part = self._parts[part_index]
if isinstance(latest_part, (ServerToolCallPart, ServerToolCallPartDelta)): # pragma: no branch
existing_matching_part_and_index = latest_part, part_index
else:
# vendor_part_id is provided, so look up the corresponding part or delta
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
if part_index is not None:
existing_part = self._parts[part_index]
if not isinstance(existing_part, (ServerToolCallPartDelta, ServerToolCallPart)):
raise UnexpectedModelBehavior(f'Cannot apply a server tool call delta to {existing_part=}')
existing_matching_part_and_index = existing_part, part_index

if existing_matching_part_and_index is None:
# No matching part/delta was found, so create a new ServerToolCallPartDelta (or ServerToolCallPart if fully formed)
delta = ServerToolCallPartDelta(
tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id, model_name=model_name
)
part = delta.as_part() or delta
if vendor_part_id is not None:
self._vendor_id_to_part_index[vendor_part_id] = len(self._parts)
new_part_index = len(self._parts)
self._parts.append(part)
# Only emit a PartStartEvent if we have enough information to produce a full ServerToolCallPart
if isinstance(part, ServerToolCallPart):
return PartStartEvent(index=new_part_index, part=part)
else:
# Update the existing part or delta with the new information
existing_part, part_index = existing_matching_part_and_index
delta = ServerToolCallPartDelta(
tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id, model_name=model_name
)
updated_part = delta.apply(existing_part)
self._parts[part_index] = updated_part
if isinstance(updated_part, ServerToolCallPart):
if isinstance(existing_part, ServerToolCallPartDelta):
# We just upgraded a delta to a full part, so emit a PartStartEvent
return PartStartEvent(index=part_index, part=updated_part)
else:
# We updated an existing part, so emit a PartDeltaEvent
if updated_part.tool_call_id and not delta.tool_call_id:
delta = replace(delta, tool_call_id=updated_part.tool_call_id)
return PartDeltaEvent(index=part_index, delta=delta)
8 changes: 7 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,13 @@ def now_utc() -> datetime:
return datetime.now(tz=timezone.utc)


def guard_tool_call_id(t: _messages.ToolCallPart | _messages.ToolReturnPart | _messages.RetryPromptPart) -> str:
def guard_tool_call_id(
t: _messages.ToolCallPart
| _messages.ToolReturnPart
| _messages.RetryPromptPart
| _messages.ServerToolCallPart
| _messages.ServerToolReturnPart,
) -> str:
"""Type guard that either returns the tool call id or generates a new one if it's None."""
return t.tool_call_id or generate_tool_call_id()

Expand Down
17 changes: 16 additions & 1 deletion pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pydantic.json_schema import GenerateJsonSchema
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated

from pydantic_ai.builtin_tools import AbstractBuiltinTool, WebSearchTool
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
from pydantic_graph._utils import get_event_loop

Expand Down Expand Up @@ -174,6 +175,7 @@ def __init__(
retries: int = 1,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[Literal['web-search'] | AbstractBuiltinTool] = (),
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
mcp_servers: Sequence[MCPServer] = (),
defer_model_check: bool = False,
Expand Down Expand Up @@ -203,6 +205,7 @@ def __init__(
result_tool_description: str | None = None,
result_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[Literal['web-search'] | AbstractBuiltinTool] = (),
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
mcp_servers: Sequence[MCPServer] = (),
defer_model_check: bool = False,
Expand All @@ -227,6 +230,7 @@ def __init__(
retries: int = 1,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[Literal['web-search'] | AbstractBuiltinTool] = (),
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
mcp_servers: Sequence[MCPServer] = (),
defer_model_check: bool = False,
Expand Down Expand Up @@ -256,6 +260,8 @@ def __init__(
output_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
tools: Tools to register with the agent, you can also register tools via the decorators
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
builtin_tools: The builtin tools that the agent will use. This depends on the model, as some models may not
support certain tools. On models that don't support certain tools, the tool will be ignored.
prepare_tools: custom method to prepare the tool definition of all tools for each step.
This is useful if you want to customize the definition of multiple tools or you want to register
a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
Expand Down Expand Up @@ -342,6 +348,14 @@ def __init__(
self._default_retries = retries
self._max_result_retries = output_retries if output_retries is not None else retries
self._mcp_servers = mcp_servers
self._builtin_tools: list[AbstractBuiltinTool] = []

for tool in builtin_tools:
if tool == 'web-search':
self._builtin_tools.append(WebSearchTool())
else:
self._builtin_tools.append(tool)

self._prepare_tools = prepare_tools
for tool in tools:
if isinstance(tool, Tool):
Expand Down Expand Up @@ -689,7 +703,8 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
end_strategy=self.end_strategy,
output_schema=output_schema,
output_validators=output_validators,
function_tools=run_function_tools,
function_tools=self._function_tools,
builtin_tools=self._builtin_tools,
mcp_servers=self._mcp_servers,
default_retries=self._default_retries,
tracer=tracer,
Expand Down
95 changes: 95 additions & 0 deletions pydantic_ai_slim/pydantic_ai/builtin_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations as _annotations

from abc import ABC
from dataclasses import dataclass
from typing import Any, Literal

from typing_extensions import TypedDict

__all__ = ('AbstractBuiltinTool', 'WebSearchTool', 'UserLocation')


@dataclass
class AbstractBuiltinTool(ABC):
"""A builtin tool that can be used by an agent.

This class is abstract and cannot be instantiated directly.

The builtin tools are passed to the model as part of the `ModelRequestParameters`.
"""

def handle_custom_tool_definition(self, model: str) -> Any: ...


@dataclass
class WebSearchTool(AbstractBuiltinTool):
"""A builtin tool that allows your agent to search the web for information.

The parameters that PydanticAI passes depend on the model, as some parameters may not be supported by certain models.
"""

search_context_size: Literal['low', 'medium', 'high'] = 'medium'
"""The `search_context_size` parameter controls how much context is retrieved from the web to help the tool formulate a response.

Supported by:
* OpenAI
"""

user_location: UserLocation | None = None
"""The `user_location` parameter allows you to localize search results based on a user's location.

Supported by:
* Anthropic
* OpenAI
"""

blocked_domains: list[str] | None = None
"""If provided, these domains will never appear in results.

With Anthropic, you can only use one of `blocked_domains` or `allowed_domains`, not both.

Supported by:
* Anthropic (https://docs.anthropic.com/en/docs/build-with-claude/tool-use/web-search-tool#domain-filtering)
* Groq (https://console.groq.com/docs/agentic-tooling#search-settings)
* MistralAI
"""

allowed_domains: list[str] | None = None
"""If provided, only these domains will be included in results.

With Anthropic, you can only use one of `blocked_domains` or `allowed_domains`, not both.

Supported by:
* Anthropic (https://docs.anthropic.com/en/docs/build-with-claude/tool-use/web-search-tool#domain-filtering)
* Groq (https://console.groq.com/docs/agentic-tooling#search-settings)
"""

max_uses: int | None = None
"""If provided, the tool will stop searching the web after the given number of uses.

Supported by:
* Anthropic
"""


class UserLocation(TypedDict, total=False):
"""Allows you to localize search results based on a user's location.

Supported by:
* Anthropic
* OpenAI
"""

city: str
country: str
region: str
timezone: str


class CodeExecutionTool(AbstractBuiltinTool):
"""A builtin tool that allows your agent to execute code.

Supported by:
* Anthropic
* OpenAI
"""
Loading