From 8404e202c3577111b752d9a960a60e8016c3030a Mon Sep 17 00:00:00 2001 From: NicolasPllr1 Date: Sun, 16 Nov 2025 17:28:14 +0100 Subject: [PATCH 1/7] Pass Pydantic validation context to agents (#3381) --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 14 +++++-- pydantic_ai_slim/pydantic_ai/_output.py | 38 +++++++++++++++---- pydantic_ai_slim/pydantic_ai/_tool_manager.py | 28 ++++++++++++-- .../pydantic_ai/agent/__init__.py | 9 ++++- pydantic_ai_slim/pydantic_ai/result.py | 6 ++- 5 files changed, 77 insertions(+), 18 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 91cda373a5..7341dbfc57 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -18,7 +18,7 @@ from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION -from pydantic_ai._tool_manager import ToolManager +from pydantic_ai._tool_manager import ToolManager, build_validation_context from pydantic_ai._utils import dataclasses_no_defaults_repr, get_union_args, is_async_callable, run_in_executor from pydantic_ai.builtin_tools import AbstractBuiltinTool from pydantic_graph import BaseNode, GraphRunContext @@ -590,7 +590,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa text = '' # pragma: no cover if text: try: - self._next_node = await self._handle_text_response(ctx, text, text_processor) + self._next_node = await self._handle_text_response( + ctx, ctx.deps.tool_manager.validation_ctx, text, text_processor + ) return except ToolRetryError: # If the text from the preview response was invalid, ignore it. @@ -654,7 +656,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa if text_processor := output_schema.text_processor: if text: - self._next_node = await self._handle_text_response(ctx, text, text_processor) + self._next_node = await self._handle_text_response( + ctx, ctx.deps.tool_manager.validation_ctx, text, text_processor + ) return alternatives.insert(0, 'return text') @@ -716,12 +720,14 @@ async def _handle_tool_calls( async def _handle_text_response( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], + validation_ctx: Any | Callable[[RunContext[DepsT]], Any], text: str, text_processor: _output.BaseOutputProcessor[NodeRunEndT], ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: run_context = build_run_context(ctx) + validation_context = build_validation_context(validation_ctx, run_context) - result_data = await text_processor.process(text, run_context) + result_data = await text_processor.process(text, run_context, validation_context) for validator in ctx.deps.output_validators: result_data = await validator.validate(result_data, run_context) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index ebb737a1cf..77a073ffc9 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -530,6 +530,7 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], + validation_context: Any | None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -554,13 +555,18 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], + validation_context: Any | None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: text = _utils.strip_markdown_fences(data) return await self.wrapped.process( - text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, + run_context, + validation_context, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, ) @@ -639,6 +645,7 @@ async def process( self, data: str | dict[str, Any] | None, run_context: RunContext[AgentDepsT], + validation_context: Any | None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -647,6 +654,7 @@ async def process( Args: data: The output data to validate. run_context: The current run context. + validation_context: Additional Pydantic validation context for the current run. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -654,7 +662,7 @@ async def process( Either the validated output data (left) or a retry message (right). """ try: - output = self.validate(data, allow_partial) + output = self.validate(data, allow_partial, validation_context) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -672,12 +680,17 @@ def validate( self, data: str | dict[str, Any] | None, allow_partial: bool = False, + validation_context: Any | None = None, ) -> dict[str, Any]: pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' if isinstance(data, str): - return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) + return self.validator.validate_json( + data or '{}', allow_partial=pyd_allow_partial, context=validation_context + ) else: - return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + return self.validator.validate_python( + data or {}, allow_partial=pyd_allow_partial, context=validation_context + ) async def call( self, @@ -797,11 +810,16 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], + validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: union_object = await self._union_processor.process( - data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + data, + run_context, + validation_context, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, ) result = union_object.result @@ -817,7 +835,11 @@ async def process( raise return await processor.process( - inner_data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + inner_data, + run_context, + validation_context, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, ) @@ -826,6 +848,7 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], + validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -857,13 +880,14 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], + validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: args = {self._str_argument_name: data} data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors) - return await super().process(data, run_context, allow_partial, wrap_validation_errors) + return await super().process(data, run_context, validation_context, allow_partial, wrap_validation_errors) @dataclass(init=False) diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index fb7039e2cc..5dede17d6e 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -1,11 +1,11 @@ from __future__ import annotations import json -from collections.abc import Iterator +from collections.abc import Callable, Iterator from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass, field, replace -from typing import Any, Generic +from typing import Any, Generic, cast from opentelemetry.trace import Tracer from pydantic import ValidationError @@ -31,6 +31,8 @@ class ToolManager(Generic[AgentDepsT]): """The toolset that provides the tools for this run step.""" ctx: RunContext[AgentDepsT] | None = None """The agent run context for a specific run step.""" + validation_ctx: Any | Callable[[RunContext[AgentDepsT]], Any] = None + """Additional Pydantic validation context for the run.""" tools: dict[str, ToolsetTool[AgentDepsT]] | None = None """The cached tools for this run step.""" failed_tools: set[str] = field(default_factory=set) @@ -61,6 +63,7 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe return self.__class__( toolset=self.toolset, ctx=ctx, + validation_ctx=self.validation_ctx, tools=await self.toolset.get_tools(ctx), ) @@ -161,12 +164,18 @@ async def _call_tool( partial_output=allow_partial, ) + validation_ctx = build_validation_context(self.validation_ctx, self.ctx) + pyd_allow_partial = 'trailing-strings' if allow_partial else 'off' validator = tool.args_validator if isinstance(call.args, str): - args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial) + args_dict = validator.validate_json( + call.args or '{}', allow_partial=pyd_allow_partial, context=validation_ctx + ) else: - args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial) + args_dict = validator.validate_python( + call.args or {}, allow_partial=pyd_allow_partial, context=validation_ctx + ) result = await self.toolset.call_tool(name, args_dict, ctx, tool) @@ -270,3 +279,14 @@ async def _call_function_tool( ) return tool_result + + +def build_validation_context( + validation_ctx: Any | Callable[[RunContext[AgentDepsT]], Any], run_context: RunContext[AgentDepsT] +) -> Any: + """Build a Pydantic validation context, potentially from the current agent run context.""" + if callable(validation_ctx): + fn = cast(Callable[[RunContext[AgentDepsT]], Any], validation_ctx) + return fn(run_context) + else: + return validation_ctx diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 4cd353b44a..3d16a64488 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) _max_tool_retries: int = dataclasses.field(repr=False) + _validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False) _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) @@ -166,6 +167,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, + validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), builtin_tools: Sequence[AbstractBuiltinTool] = (), @@ -192,6 +194,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, + validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), builtin_tools: Sequence[AbstractBuiltinTool] = (), @@ -216,6 +219,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, + validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), builtin_tools: Sequence[AbstractBuiltinTool] = (), @@ -249,6 +253,7 @@ def __init__( model_settings: Optional model request settings to use for this agent's runs, by default. retries: The default number of retries to allow for tool calls and output validation, before raising an error. For model request retries, see the [HTTP Request Retries](../retries.md) documentation. + validation_context: Additional validation context used to validate all outputs. output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. tools: Tools to register with the agent, you can also register tools via the decorators [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. @@ -314,6 +319,8 @@ def __init__( self._max_result_retries = output_retries if output_retries is not None else retries self._max_tool_retries = retries + self._validation_context = validation_context + self._builtin_tools = builtin_tools self._prepare_tools = prepare_tools @@ -562,7 +569,7 @@ async def main(): output_toolset.max_retries = self._max_result_retries output_toolset.output_validators = output_validators toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) - tool_manager = ToolManager[AgentDepsT](toolset) + tool_manager = ToolManager[AgentDepsT](toolset, validation_ctx=self._validation_context) # Build the graph graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index c6b59ec796..140c8304b0 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -18,7 +18,7 @@ TextOutputSchema, ) from ._run_context import AgentDepsT, RunContext -from ._tool_manager import ToolManager +from ._tool_manager import ToolManager, build_validation_context from .messages import ModelResponseStreamEvent from .output import ( DeferredToolRequests, @@ -197,8 +197,10 @@ async def validate_response_output( # not part of the final result output, so we reset the accumulated text text = '' + validation_context = build_validation_context(self._tool_manager.validation_ctx, self._run_ctx) + result_data = await text_processor.process( - text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + text, self._run_ctx, validation_context, allow_partial=allow_partial, wrap_validation_errors=False ) for validator in self._output_validators: result_data = await validator.validate( From 6e7c2a52d6e533be03f53213cb2aea2ec6e0fd6c Mon Sep 17 00:00:00 2001 From: NicolasPllr1 Date: Sun, 16 Nov 2025 17:28:51 +0100 Subject: [PATCH 2/7] Test agent validation context Add tests involving the new 'validation context' for: - Pydantic model as the output type - Tool, native and prompted output - Tool calling - Output function - Output validator --- tests/test_validation_context.py | 122 +++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 tests/test_validation_context.py diff --git a/tests/test_validation_context.py b/tests/test_validation_context.py new file mode 100644 index 0000000000..abd5de8012 --- /dev/null +++ b/tests/test_validation_context.py @@ -0,0 +1,122 @@ +from dataclasses import dataclass + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel, ValidationInfo, field_validator + +from pydantic_ai import ( + Agent, + ModelMessage, + ModelResponse, + NativeOutput, + PromptedOutput, + RunContext, + TextPart, + ToolCallPart, + ToolOutput, +) +from pydantic_ai._output import OutputSpec +from pydantic_ai.models.function import AgentInfo, FunctionModel + + +class Value(BaseModel): + x: int + + @field_validator('x') + def increment_value(cls, value: int, info: ValidationInfo): + return value + (info.context or 0) + + +@dataclass +class Deps: + increment: int + + +@pytest.mark.parametrize( + 'output_type', + [ + Value, + ToolOutput(Value), + NativeOutput(Value), + PromptedOutput(Value), + ], + ids=[ + 'Value', + 'ToolOutput(Value)', + 'NativeOutput(Value)', + 'PromptedOutput(Value)', + ], +) +def test_agent_output_with_validation_context(output_type: OutputSpec[Value]): + """Test that the output is validated using the validation context""" + + def mock_llm(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + if isinstance(output_type, ToolOutput): + return ModelResponse(parts=[ToolCallPart(tool_name='final_result', args={'x': 0})]) + else: + text = Value(x=0).model_dump_json() + return ModelResponse(parts=[TextPart(content=text)]) + + agent = Agent( + FunctionModel(mock_llm), + output_type=output_type, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output.x == snapshot(10) + + +def test_agent_tool_call_with_validation_context(): + """Test that the argument passed to the tool call is validated using the validation context.""" + + agent = Agent( + 'test', + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + @agent.tool + def get_value(ctx: RunContext[Deps], v: Value) -> int: + # NOTE: The test agent calls this tool with Value(x=0) which should then have been influenced by the validation context through the `increment_value` field validator + assert v.x == ctx.deps.increment + return v.x + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output == snapshot('{"get_value":10}') + + +def test_agent_output_function_with_validation_context(): + """Test that the argument passed to the output function is validated using the validation context.""" + + def get_value(v: Value) -> int: + return v.x + + agent = Agent( + 'test', + output_type=get_value, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output == snapshot(10) + + +def test_agent_output_validator_with_validation_context(): + """Test that the argument passed to the output validator is validated using the validation context.""" + + agent = Agent( + 'test', + output_type=Value, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + @agent.output_validator + def identity(ctx: RunContext[Deps], v: Value) -> Value: + return v + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output.x == snapshot(10) From 8bb9f77b66f9ae8d97d1a3a4ee602e7460ef748e Mon Sep 17 00:00:00 2001 From: NicolasPllr1 Date: Tue, 18 Nov 2025 22:02:41 +0100 Subject: [PATCH 3/7] feedback: Link Pydantic doc --- pydantic_ai_slim/pydantic_ai/agent/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 3d16a64488..e813860b5c 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -253,7 +253,7 @@ def __init__( model_settings: Optional model request settings to use for this agent's runs, by default. retries: The default number of retries to allow for tool calls and output validation, before raising an error. For model request retries, see the [HTTP Request Retries](../retries.md) documentation. - validation_context: Additional validation context used to validate all outputs. + validation_context: Additional [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) used to validate all outputs. output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. tools: Tools to register with the agent, you can also register tools via the decorators [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. From bff548e73513de8495ede2bdc18193e8192b7718 Mon Sep 17 00:00:00 2001 From: NicolasPllr1 Date: Tue, 18 Nov 2025 22:03:49 +0100 Subject: [PATCH 4/7] feedback: Remove unnecessary arg in _handle_text_response --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 7341dbfc57..f34b98cfb0 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -590,9 +590,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa text = '' # pragma: no cover if text: try: - self._next_node = await self._handle_text_response( - ctx, ctx.deps.tool_manager.validation_ctx, text, text_processor - ) + self._next_node = await self._handle_text_response(ctx, text, text_processor) return except ToolRetryError: # If the text from the preview response was invalid, ignore it. @@ -656,9 +654,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa if text_processor := output_schema.text_processor: if text: - self._next_node = await self._handle_text_response( - ctx, ctx.deps.tool_manager.validation_ctx, text, text_processor - ) + self._next_node = await self._handle_text_response(ctx, text, text_processor) return alternatives.insert(0, 'return text') @@ -720,14 +716,13 @@ async def _handle_tool_calls( async def _handle_text_response( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], - validation_ctx: Any | Callable[[RunContext[DepsT]], Any], text: str, text_processor: _output.BaseOutputProcessor[NodeRunEndT], ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: run_context = build_run_context(ctx) - validation_context = build_validation_context(validation_ctx, run_context) + validation_context = build_validation_context(ctx.deps.tool_manager.validation_ctx, run_context) - result_data = await text_processor.process(text, run_context, validation_context) + result_data = await text_processor.process(text, run_context=run_context, validation_context=validation_context) for validator in ctx.deps.output_validators: result_data = await validator.validate(result_data, run_context) From 6b0658109f25c21fa5d532e7fdfbbd2331f7a5a1 Mon Sep 17 00:00:00 2001 From: NicolasPllr1 Date: Tue, 18 Nov 2025 22:05:16 +0100 Subject: [PATCH 5/7] feedback: `None` default for the val ctx and require kwargs after `data` in process functions --- pydantic_ai_slim/pydantic_ai/_output.py | 32 +++++++++++++++++-------- pydantic_ai_slim/pydantic_ai/result.py | 6 ++++- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 77a073ffc9..1ad9c57d57 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -529,8 +529,9 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]): async def process( self, data: str, + *, run_context: RunContext[AgentDepsT], - validation_context: Any | None, + validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -554,8 +555,9 @@ def __init__(self, wrapped: BaseObjectOutputProcessor[OutputDataT]): async def process( self, data: str, + *, run_context: RunContext[AgentDepsT], - validation_context: Any | None, + validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -563,8 +565,8 @@ async def process( return await self.wrapped.process( text, - run_context, - validation_context, + run_context=run_context, + validation_context=validation_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors, ) @@ -644,8 +646,9 @@ def __init__( async def process( self, data: str | dict[str, Any] | None, + *, run_context: RunContext[AgentDepsT], - validation_context: Any | None, + validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -809,6 +812,7 @@ def __init__( async def process( self, data: str, + *, run_context: RunContext[AgentDepsT], validation_context: Any | None = None, allow_partial: bool = False, @@ -816,8 +820,8 @@ async def process( ) -> OutputDataT: union_object = await self._union_processor.process( data, - run_context, - validation_context, + run_context=run_context, + validation_context=validation_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors, ) @@ -836,8 +840,8 @@ async def process( return await processor.process( inner_data, - run_context, - validation_context, + run_context=run_context, + validation_context=validation_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors, ) @@ -847,6 +851,7 @@ class TextOutputProcessor(BaseOutputProcessor[OutputDataT]): async def process( self, data: str, + *, run_context: RunContext[AgentDepsT], validation_context: Any | None = None, allow_partial: bool = False, @@ -879,6 +884,7 @@ def __init__( async def process( self, data: str, + *, run_context: RunContext[AgentDepsT], validation_context: Any | None = None, allow_partial: bool = False, @@ -887,7 +893,13 @@ async def process( args = {self._str_argument_name: data} data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors) - return await super().process(data, run_context, validation_context, allow_partial, wrap_validation_errors) + return await super().process( + data, + run_context=run_context, + validation_context=validation_context, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, + ) @dataclass(init=False) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 140c8304b0..6d91d1bc39 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -200,7 +200,11 @@ async def validate_response_output( validation_context = build_validation_context(self._tool_manager.validation_ctx, self._run_ctx) result_data = await text_processor.process( - text, self._run_ctx, validation_context, allow_partial=allow_partial, wrap_validation_errors=False + text, + run_context=self._run_ctx, + validation_context=validation_context, + allow_partial=allow_partial, + wrap_validation_errors=False, ) for validator in self._output_validators: result_data = await validator.validate( From 61ab2f5ace5ce48f170f85b3a8f1a8d3bb111473 Mon Sep 17 00:00:00 2001 From: NicolasPllr1 Date: Wed, 19 Nov 2025 20:19:03 +0100 Subject: [PATCH 6/7] feedback: Shove the validation context inside the RunContext --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 24 ++++++++++++++++--- pydantic_ai_slim/pydantic_ai/_output.py | 8 +------ pydantic_ai_slim/pydantic_ai/_run_context.py | 4 +++- pydantic_ai_slim/pydantic_ai/_tool_manager.py | 24 ++++--------------- .../pydantic_ai/agent/__init__.py | 3 ++- pydantic_ai_slim/pydantic_ai/result.py | 5 +--- 6 files changed, 32 insertions(+), 36 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index f34b98cfb0..0001f8af26 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -18,7 +18,7 @@ from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION -from pydantic_ai._tool_manager import ToolManager, build_validation_context +from pydantic_ai._tool_manager import ToolManager from pydantic_ai._utils import dataclasses_no_defaults_repr, get_union_args, is_async_callable, run_in_executor from pydantic_ai.builtin_tools import AbstractBuiltinTool from pydantic_graph import BaseNode, GraphRunContext @@ -144,6 +144,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): output_schema: _output.OutputSchema[OutputDataT] output_validators: list[_output.OutputValidator[DepsT, OutputDataT]] + validation_context: Any | Callable[[RunContext[DepsT]], Any] history_processors: Sequence[HistoryProcessor[DepsT]] @@ -477,6 +478,8 @@ async def _prepare_request( ctx.state.run_step += 1 run_context = build_run_context(ctx) + validation_context = build_validation_context(ctx.deps.validation_context, run_context) + run_context = replace(run_context, validation_context=validation_context) # This will raise errors for any tool name conflicts ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context) @@ -720,9 +723,11 @@ async def _handle_text_response( text_processor: _output.BaseOutputProcessor[NodeRunEndT], ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: run_context = build_run_context(ctx) - validation_context = build_validation_context(ctx.deps.tool_manager.validation_ctx, run_context) + validation_context = build_validation_context(ctx.deps.validation_context, run_context) - result_data = await text_processor.process(text, run_context=run_context, validation_context=validation_context) + run_context = replace(run_context, validation_context=validation_context) + + result_data = await text_processor.process(text, run_context=run_context) for validator in ctx.deps.output_validators: result_data = await validator.validate(result_data, run_context) @@ -773,6 +778,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT usage=ctx.state.usage, prompt=ctx.deps.prompt, messages=ctx.state.message_history, + validation_context=None, tracer=ctx.deps.tracer, trace_include_content=ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content, @@ -784,6 +790,18 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT ) +def build_validation_context( + validation_ctx: Any | Callable[[RunContext[DepsT]], Any], + run_context: RunContext[DepsT], +) -> Any: + """Build a Pydantic validation context, potentially from the current agent run context.""" + if callable(validation_ctx): + fn = cast(Callable[[RunContext[DepsT]], Any], validation_ctx) + return fn(run_context) + else: + return validation_ctx + + async def process_tool_calls( # noqa: C901 tool_manager: ToolManager[DepsT], tool_calls: list[_messages.ToolCallPart], diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 1ad9c57d57..5207a33329 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -531,7 +531,6 @@ async def process( data: str, *, run_context: RunContext[AgentDepsT], - validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -566,7 +565,6 @@ async def process( return await self.wrapped.process( text, run_context=run_context, - validation_context=validation_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors, ) @@ -648,7 +646,6 @@ async def process( data: str | dict[str, Any] | None, *, run_context: RunContext[AgentDepsT], - validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -665,7 +662,7 @@ async def process( Either the validated output data (left) or a retry message (right). """ try: - output = self.validate(data, allow_partial, validation_context) + output = self.validate(data, allow_partial, run_context.validation_context) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -814,14 +811,12 @@ async def process( data: str, *, run_context: RunContext[AgentDepsT], - validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: union_object = await self._union_processor.process( data, run_context=run_context, - validation_context=validation_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors, ) @@ -841,7 +836,6 @@ async def process( return await processor.process( inner_data, run_context=run_context, - validation_context=validation_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors, ) diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index 4f9b253767..edee68f975 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -3,7 +3,7 @@ import dataclasses from collections.abc import Sequence from dataclasses import field -from typing import TYPE_CHECKING, Generic +from typing import TYPE_CHECKING, Any, Generic from opentelemetry.trace import NoOpTracer, Tracer from typing_extensions import TypeVar @@ -38,6 +38,8 @@ class RunContext(Generic[RunContextAgentDepsT]): """The original user prompt passed to the run.""" messages: list[_messages.ModelMessage] = field(default_factory=list) """Messages exchanged in the conversation so far.""" + validation_context: Any = None + """Additional Pydantic validation context for the run outputs.""" tracer: Tracer = field(default_factory=NoOpTracer) """The tracer to use for tracing the run.""" trace_include_content: bool = False diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index 5dede17d6e..9a9f93e1ff 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -1,11 +1,11 @@ from __future__ import annotations import json -from collections.abc import Callable, Iterator +from collections.abc import Iterator from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass, field, replace -from typing import Any, Generic, cast +from typing import Any, Generic from opentelemetry.trace import Tracer from pydantic import ValidationError @@ -31,8 +31,6 @@ class ToolManager(Generic[AgentDepsT]): """The toolset that provides the tools for this run step.""" ctx: RunContext[AgentDepsT] | None = None """The agent run context for a specific run step.""" - validation_ctx: Any | Callable[[RunContext[AgentDepsT]], Any] = None - """Additional Pydantic validation context for the run.""" tools: dict[str, ToolsetTool[AgentDepsT]] | None = None """The cached tools for this run step.""" failed_tools: set[str] = field(default_factory=set) @@ -63,7 +61,6 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe return self.__class__( toolset=self.toolset, ctx=ctx, - validation_ctx=self.validation_ctx, tools=await self.toolset.get_tools(ctx), ) @@ -164,17 +161,15 @@ async def _call_tool( partial_output=allow_partial, ) - validation_ctx = build_validation_context(self.validation_ctx, self.ctx) - pyd_allow_partial = 'trailing-strings' if allow_partial else 'off' validator = tool.args_validator if isinstance(call.args, str): args_dict = validator.validate_json( - call.args or '{}', allow_partial=pyd_allow_partial, context=validation_ctx + call.args or '{}', allow_partial=pyd_allow_partial, context=ctx.validation_context ) else: args_dict = validator.validate_python( - call.args or {}, allow_partial=pyd_allow_partial, context=validation_ctx + call.args or {}, allow_partial=pyd_allow_partial, context=ctx.validation_context ) result = await self.toolset.call_tool(name, args_dict, ctx, tool) @@ -279,14 +274,3 @@ async def _call_function_tool( ) return tool_result - - -def build_validation_context( - validation_ctx: Any | Callable[[RunContext[AgentDepsT]], Any], run_context: RunContext[AgentDepsT] -) -> Any: - """Build a Pydantic validation context, potentially from the current agent run context.""" - if callable(validation_ctx): - fn = cast(Callable[[RunContext[AgentDepsT]], Any], validation_ctx) - return fn(run_context) - else: - return validation_ctx diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index e813860b5c..80bd223b1b 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -569,7 +569,7 @@ async def main(): output_toolset.max_retries = self._max_result_retries output_toolset.output_validators = output_validators toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) - tool_manager = ToolManager[AgentDepsT](toolset, validation_ctx=self._validation_context) + tool_manager = ToolManager[AgentDepsT](toolset) # Build the graph graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) @@ -619,6 +619,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: end_strategy=self.end_strategy, output_schema=output_schema, output_validators=output_validators, + validation_context=self._validation_context, history_processors=self.history_processors, builtin_tools=[*self._builtin_tools, *(builtin_tools or [])], tool_manager=tool_manager, diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 6d91d1bc39..88bfe407fa 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -18,7 +18,7 @@ TextOutputSchema, ) from ._run_context import AgentDepsT, RunContext -from ._tool_manager import ToolManager, build_validation_context +from ._tool_manager import ToolManager from .messages import ModelResponseStreamEvent from .output import ( DeferredToolRequests, @@ -197,12 +197,9 @@ async def validate_response_output( # not part of the final result output, so we reset the accumulated text text = '' - validation_context = build_validation_context(self._tool_manager.validation_ctx, self._run_ctx) - result_data = await text_processor.process( text, run_context=self._run_ctx, - validation_context=validation_context, allow_partial=allow_partial, wrap_validation_errors=False, ) From 375a09101200767dcf708069a50304ce325974d8 Mon Sep 17 00:00:00 2001 From: NicolasPllr1 Date: Wed, 19 Nov 2025 21:51:33 +0100 Subject: [PATCH 7/7] Test that the validation context is updated as the deps are mutated --- tests/test_validation_context.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_validation_context.py b/tests/test_validation_context.py index abd5de8012..ae475859ee 100644 --- a/tests/test_validation_context.py +++ b/tests/test_validation_context.py @@ -120,3 +120,28 @@ def identity(ctx: RunContext[Deps], v: Value) -> Value: result = agent.run_sync('', deps=Deps(increment=10)) assert result.output.x == snapshot(10) + + +def test_agent_output_validator_with_intermediary_deps_change_and_validation_context(): + """Test that the validation context is updated as run dependencies are mutated.""" + + agent = Agent( + 'test', + output_type=Value, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + @agent.tool + def bump_increment(ctx: RunContext[Deps]): + assert ctx.validation_context == snapshot(10) # validation ctx was first computed using the original deps + ctx.deps.increment += 5 # update the deps + + @agent.output_validator + def identity(ctx: RunContext[Deps], v: Value) -> Value: + assert ctx.validation_context == snapshot(15) # validation ctx was re-computed after deps update from tool call + + return v + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output.x == snapshot(15)