Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION
from pydantic_ai._tool_manager import ToolManager
from pydantic_ai._tool_manager import ToolManager, build_validation_context
from pydantic_ai._utils import dataclasses_no_defaults_repr, get_union_args, is_async_callable, run_in_executor
from pydantic_ai.builtin_tools import AbstractBuiltinTool
from pydantic_graph import BaseNode, GraphRunContext
Expand Down Expand Up @@ -590,7 +590,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
text = '' # pragma: no cover
if text:
try:
self._next_node = await self._handle_text_response(ctx, text, text_processor)
self._next_node = await self._handle_text_response(
ctx, ctx.deps.tool_manager.validation_ctx, text, text_processor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already pass in ctx, we don't need a new arg do we?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to store the validation context directly on ctx.deps, instead of going through the tool manager?

)
return
except ToolRetryError:
# If the text from the preview response was invalid, ignore it.
Expand Down Expand Up @@ -654,7 +656,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa

if text_processor := output_schema.text_processor:
if text:
self._next_node = await self._handle_text_response(ctx, text, text_processor)
self._next_node = await self._handle_text_response(
ctx, ctx.deps.tool_manager.validation_ctx, text, text_processor
)
return
alternatives.insert(0, 'return text')

Expand Down Expand Up @@ -716,12 +720,14 @@ async def _handle_tool_calls(
async def _handle_text_response(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
validation_ctx: Any | Callable[[RunContext[DepsT]], Any],
text: str,
text_processor: _output.BaseOutputProcessor[NodeRunEndT],
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
run_context = build_run_context(ctx)
validation_context = build_validation_context(validation_ctx, run_context)

result_data = await text_processor.process(text, run_context)
result_data = await text_processor.process(text, run_context, validation_context)

for validator in ctx.deps.output_validators:
result_data = await validator.validate(result_data, run_context)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, output validators would have access to the validation context as well, as they could call model_validate themselves. They already have access to RunContext, so maybe it'd make sense to store the validation context on there? As the validation context callable itself needs RunContext, building run context could look like ctx = <run context>; validation_ctx = callable(ctx); ctx = replace(ctx, validation_ctx = validation_ctx)

I think that'd be a refactor worth exploring that could allow us to drop a lot of the new arguments.

Expand Down
38 changes: 31 additions & 7 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ async def process(
self,
data: str,
run_context: RunContext[AgentDepsT],
validation_context: Any | None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a default of None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also while we're at it, let's require kwargs for everything after data, by inserting , *,

allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
Expand All @@ -554,13 +555,18 @@ async def process(
self,
data: str,
run_context: RunContext[AgentDepsT],
validation_context: Any | None,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
text = _utils.strip_markdown_fences(data)

return await self.wrapped.process(
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
text,
run_context,
validation_context,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
)


Expand Down Expand Up @@ -639,6 +645,7 @@ async def process(
self,
data: str | dict[str, Any] | None,
run_context: RunContext[AgentDepsT],
validation_context: Any | None,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
Expand All @@ -647,14 +654,15 @@ async def process(
Args:
data: The output data to validate.
run_context: The current run context.
validation_context: Additional Pydantic validation context for the current run.
allow_partial: If true, allow partial validation.
wrap_validation_errors: If true, wrap the validation errors in a retry message.

Returns:
Either the validated output data (left) or a retry message (right).
"""
try:
output = self.validate(data, allow_partial)
output = self.validate(data, allow_partial, validation_context)
except ValidationError as e:
if wrap_validation_errors:
m = _messages.RetryPromptPart(
Expand All @@ -672,12 +680,17 @@ def validate(
self,
data: str | dict[str, Any] | None,
allow_partial: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as up, let's require kwargs for everything but data

validation_context: Any | None = None,
) -> dict[str, Any]:
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
if isinstance(data, str):
return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
return self.validator.validate_json(
data or '{}', allow_partial=pyd_allow_partial, context=validation_context
)
else:
return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
return self.validator.validate_python(
data or {}, allow_partial=pyd_allow_partial, context=validation_context
)

async def call(
self,
Expand Down Expand Up @@ -797,11 +810,16 @@ async def process(
self,
data: str,
run_context: RunContext[AgentDepsT],
validation_context: Any | None = None,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
union_object = await self._union_processor.process(
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
data,
run_context,
validation_context,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
)

result = union_object.result
Expand All @@ -817,7 +835,11 @@ async def process(
raise

return await processor.process(
inner_data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
inner_data,
run_context,
validation_context,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
)


Expand All @@ -826,6 +848,7 @@ async def process(
self,
data: str,
run_context: RunContext[AgentDepsT],
validation_context: Any | None = None,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
Expand Down Expand Up @@ -857,13 +880,14 @@ async def process(
self,
data: str,
run_context: RunContext[AgentDepsT],
validation_context: Any | None = None,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
args = {self._str_argument_name: data}
data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors)

return await super().process(data, run_context, allow_partial, wrap_validation_errors)
return await super().process(data, run_context, validation_context, allow_partial, wrap_validation_errors)


@dataclass(init=False)
Expand Down
28 changes: 24 additions & 4 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import json
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field, replace
from typing import Any, Generic
from typing import Any, Generic, cast

from opentelemetry.trace import Tracer
from pydantic import ValidationError
Expand All @@ -31,6 +31,8 @@ class ToolManager(Generic[AgentDepsT]):
"""The toolset that provides the tools for this run step."""
ctx: RunContext[AgentDepsT] | None = None
"""The agent run context for a specific run step."""
validation_ctx: Any | Callable[[RunContext[AgentDepsT]], Any] = None
"""Additional Pydantic validation context for the run."""
tools: dict[str, ToolsetTool[AgentDepsT]] | None = None
"""The cached tools for this run step."""
failed_tools: set[str] = field(default_factory=set)
Expand Down Expand Up @@ -61,6 +63,7 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe
return self.__class__(
toolset=self.toolset,
ctx=ctx,
validation_ctx=self.validation_ctx,
tools=await self.toolset.get_tools(ctx),
)

Expand Down Expand Up @@ -161,12 +164,18 @@ async def _call_tool(
partial_output=allow_partial,
)

validation_ctx = build_validation_context(self.validation_ctx, self.ctx)

pyd_allow_partial = 'trailing-strings' if allow_partial else 'off'
validator = tool.args_validator
if isinstance(call.args, str):
args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial)
args_dict = validator.validate_json(
call.args or '{}', allow_partial=pyd_allow_partial, context=validation_ctx
)
else:
args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
args_dict = validator.validate_python(
call.args or {}, allow_partial=pyd_allow_partial, context=validation_ctx
)

result = await self.toolset.call_tool(name, args_dict, ctx, tool)

Expand Down Expand Up @@ -270,3 +279,14 @@ async def _call_function_tool(
)

return tool_result


def build_validation_context(
validation_ctx: Any | Callable[[RunContext[AgentDepsT]], Any], run_context: RunContext[AgentDepsT]
) -> Any:
"""Build a Pydantic validation context, potentially from the current agent run context."""
if callable(validation_ctx):
fn = cast(Callable[[RunContext[AgentDepsT]], Any], validation_ctx)
return fn(run_context)
else:
return validation_ctx
9 changes: 8 additions & 1 deletion pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
_max_result_retries: int = dataclasses.field(repr=False)
_max_tool_retries: int = dataclasses.field(repr=False)
_validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False)

_event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False)

Expand All @@ -166,6 +167,7 @@ def __init__(
name: str | None = None,
model_settings: ModelSettings | None = None,
retries: int = 1,
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[AbstractBuiltinTool] = (),
Expand All @@ -192,6 +194,7 @@ def __init__(
name: str | None = None,
model_settings: ModelSettings | None = None,
retries: int = 1,
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[AbstractBuiltinTool] = (),
Expand All @@ -216,6 +219,7 @@ def __init__(
name: str | None = None,
model_settings: ModelSettings | None = None,
retries: int = 1,
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[AbstractBuiltinTool] = (),
Expand Down Expand Up @@ -249,6 +253,7 @@ def __init__(
model_settings: Optional model request settings to use for this agent's runs, by default.
retries: The default number of retries to allow for tool calls and output validation, before raising an error.
For model request retries, see the [HTTP Request Retries](../retries.md) documentation.
validation_context: Additional validation context used to validate all outputs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's link to the Pydantic doc

output_retries: The maximum number of retries to allow for output validation, defaults to `retries`.
tools: Tools to register with the agent, you can also register tools via the decorators
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
Expand Down Expand Up @@ -314,6 +319,8 @@ def __init__(
self._max_result_retries = output_retries if output_retries is not None else retries
self._max_tool_retries = retries

self._validation_context = validation_context

self._builtin_tools = builtin_tools

self._prepare_tools = prepare_tools
Expand Down Expand Up @@ -562,7 +569,7 @@ async def main():
output_toolset.max_retries = self._max_result_retries
output_toolset.output_validators = output_validators
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
tool_manager = ToolManager[AgentDepsT](toolset)
tool_manager = ToolManager[AgentDepsT](toolset, validation_ctx=self._validation_context)

# Build the graph
graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_)
Expand Down
6 changes: 4 additions & 2 deletions pydantic_ai_slim/pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TextOutputSchema,
)
from ._run_context import AgentDepsT, RunContext
from ._tool_manager import ToolManager
from ._tool_manager import ToolManager, build_validation_context
from .messages import ModelResponseStreamEvent
from .output import (
DeferredToolRequests,
Expand Down Expand Up @@ -197,8 +197,10 @@ async def validate_response_output(
# not part of the final result output, so we reset the accumulated text
text = ''

validation_context = build_validation_context(self._tool_manager.validation_ctx, self._run_ctx)

result_data = await text_processor.process(
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
text, self._run_ctx, validation_context, allow_partial=allow_partial, wrap_validation_errors=False
)
for validator in self._output_validators:
result_data = await validator.validate(
Expand Down
Loading