diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 91cda373a5..7341dbfc57 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -18,7 +18,7 @@ from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION -from pydantic_ai._tool_manager import ToolManager +from pydantic_ai._tool_manager import ToolManager, build_validation_context from pydantic_ai._utils import dataclasses_no_defaults_repr, get_union_args, is_async_callable, run_in_executor from pydantic_ai.builtin_tools import AbstractBuiltinTool from pydantic_graph import BaseNode, GraphRunContext @@ -590,7 +590,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa text = '' # pragma: no cover if text: try: - self._next_node = await self._handle_text_response(ctx, text, text_processor) + self._next_node = await self._handle_text_response( + ctx, ctx.deps.tool_manager.validation_ctx, text, text_processor + ) return except ToolRetryError: # If the text from the preview response was invalid, ignore it. @@ -654,7 +656,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa if text_processor := output_schema.text_processor: if text: - self._next_node = await self._handle_text_response(ctx, text, text_processor) + self._next_node = await self._handle_text_response( + ctx, ctx.deps.tool_manager.validation_ctx, text, text_processor + ) return alternatives.insert(0, 'return text') @@ -716,12 +720,14 @@ async def _handle_tool_calls( async def _handle_text_response( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], + validation_ctx: Any | Callable[[RunContext[DepsT]], Any], text: str, text_processor: _output.BaseOutputProcessor[NodeRunEndT], ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: run_context = build_run_context(ctx) + validation_context = build_validation_context(validation_ctx, run_context) - result_data = await text_processor.process(text, run_context) + result_data = await text_processor.process(text, run_context, validation_context) for validator in ctx.deps.output_validators: result_data = await validator.validate(result_data, run_context) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index ebb737a1cf..77a073ffc9 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -530,6 +530,7 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], + validation_context: Any | None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -554,13 +555,18 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], + validation_context: Any | None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: text = _utils.strip_markdown_fences(data) return await self.wrapped.process( - text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, + run_context, + validation_context, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, ) @@ -639,6 +645,7 @@ async def process( self, data: str | dict[str, Any] | None, run_context: RunContext[AgentDepsT], + validation_context: Any | None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -647,6 +654,7 @@ async def process( Args: data: The output data to validate. run_context: The current run context. + validation_context: Additional Pydantic validation context for the current run. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -654,7 +662,7 @@ async def process( Either the validated output data (left) or a retry message (right). """ try: - output = self.validate(data, allow_partial) + output = self.validate(data, allow_partial, validation_context) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -672,12 +680,17 @@ def validate( self, data: str | dict[str, Any] | None, allow_partial: bool = False, + validation_context: Any | None = None, ) -> dict[str, Any]: pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' if isinstance(data, str): - return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) + return self.validator.validate_json( + data or '{}', allow_partial=pyd_allow_partial, context=validation_context + ) else: - return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + return self.validator.validate_python( + data or {}, allow_partial=pyd_allow_partial, context=validation_context + ) async def call( self, @@ -797,11 +810,16 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], + validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: union_object = await self._union_processor.process( - data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + data, + run_context, + validation_context, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, ) result = union_object.result @@ -817,7 +835,11 @@ async def process( raise return await processor.process( - inner_data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + inner_data, + run_context, + validation_context, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, ) @@ -826,6 +848,7 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], + validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -857,13 +880,14 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], + validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: args = {self._str_argument_name: data} data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors) - return await super().process(data, run_context, allow_partial, wrap_validation_errors) + return await super().process(data, run_context, validation_context, allow_partial, wrap_validation_errors) @dataclass(init=False) diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index fb7039e2cc..5dede17d6e 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -1,11 +1,11 @@ from __future__ import annotations import json -from collections.abc import Iterator +from collections.abc import Callable, Iterator from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass, field, replace -from typing import Any, Generic +from typing import Any, Generic, cast from opentelemetry.trace import Tracer from pydantic import ValidationError @@ -31,6 +31,8 @@ class ToolManager(Generic[AgentDepsT]): """The toolset that provides the tools for this run step.""" ctx: RunContext[AgentDepsT] | None = None """The agent run context for a specific run step.""" + validation_ctx: Any | Callable[[RunContext[AgentDepsT]], Any] = None + """Additional Pydantic validation context for the run.""" tools: dict[str, ToolsetTool[AgentDepsT]] | None = None """The cached tools for this run step.""" failed_tools: set[str] = field(default_factory=set) @@ -61,6 +63,7 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe return self.__class__( toolset=self.toolset, ctx=ctx, + validation_ctx=self.validation_ctx, tools=await self.toolset.get_tools(ctx), ) @@ -161,12 +164,18 @@ async def _call_tool( partial_output=allow_partial, ) + validation_ctx = build_validation_context(self.validation_ctx, self.ctx) + pyd_allow_partial = 'trailing-strings' if allow_partial else 'off' validator = tool.args_validator if isinstance(call.args, str): - args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial) + args_dict = validator.validate_json( + call.args or '{}', allow_partial=pyd_allow_partial, context=validation_ctx + ) else: - args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial) + args_dict = validator.validate_python( + call.args or {}, allow_partial=pyd_allow_partial, context=validation_ctx + ) result = await self.toolset.call_tool(name, args_dict, ctx, tool) @@ -270,3 +279,14 @@ async def _call_function_tool( ) return tool_result + + +def build_validation_context( + validation_ctx: Any | Callable[[RunContext[AgentDepsT]], Any], run_context: RunContext[AgentDepsT] +) -> Any: + """Build a Pydantic validation context, potentially from the current agent run context.""" + if callable(validation_ctx): + fn = cast(Callable[[RunContext[AgentDepsT]], Any], validation_ctx) + return fn(run_context) + else: + return validation_ctx diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 4cd353b44a..3d16a64488 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) _max_tool_retries: int = dataclasses.field(repr=False) + _validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False) _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) @@ -166,6 +167,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, + validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), builtin_tools: Sequence[AbstractBuiltinTool] = (), @@ -192,6 +194,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, + validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), builtin_tools: Sequence[AbstractBuiltinTool] = (), @@ -216,6 +219,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, + validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), builtin_tools: Sequence[AbstractBuiltinTool] = (), @@ -249,6 +253,7 @@ def __init__( model_settings: Optional model request settings to use for this agent's runs, by default. retries: The default number of retries to allow for tool calls and output validation, before raising an error. For model request retries, see the [HTTP Request Retries](../retries.md) documentation. + validation_context: Additional validation context used to validate all outputs. output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. tools: Tools to register with the agent, you can also register tools via the decorators [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. @@ -314,6 +319,8 @@ def __init__( self._max_result_retries = output_retries if output_retries is not None else retries self._max_tool_retries = retries + self._validation_context = validation_context + self._builtin_tools = builtin_tools self._prepare_tools = prepare_tools @@ -562,7 +569,7 @@ async def main(): output_toolset.max_retries = self._max_result_retries output_toolset.output_validators = output_validators toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) - tool_manager = ToolManager[AgentDepsT](toolset) + tool_manager = ToolManager[AgentDepsT](toolset, validation_ctx=self._validation_context) # Build the graph graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index c6b59ec796..140c8304b0 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -18,7 +18,7 @@ TextOutputSchema, ) from ._run_context import AgentDepsT, RunContext -from ._tool_manager import ToolManager +from ._tool_manager import ToolManager, build_validation_context from .messages import ModelResponseStreamEvent from .output import ( DeferredToolRequests, @@ -197,8 +197,10 @@ async def validate_response_output( # not part of the final result output, so we reset the accumulated text text = '' + validation_context = build_validation_context(self._tool_manager.validation_ctx, self._run_ctx) + result_data = await text_processor.process( - text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + text, self._run_ctx, validation_context, allow_partial=allow_partial, wrap_validation_errors=False ) for validator in self._output_validators: result_data = await validator.validate( diff --git a/tests/test_validation_context.py b/tests/test_validation_context.py new file mode 100644 index 0000000000..abd5de8012 --- /dev/null +++ b/tests/test_validation_context.py @@ -0,0 +1,122 @@ +from dataclasses import dataclass + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel, ValidationInfo, field_validator + +from pydantic_ai import ( + Agent, + ModelMessage, + ModelResponse, + NativeOutput, + PromptedOutput, + RunContext, + TextPart, + ToolCallPart, + ToolOutput, +) +from pydantic_ai._output import OutputSpec +from pydantic_ai.models.function import AgentInfo, FunctionModel + + +class Value(BaseModel): + x: int + + @field_validator('x') + def increment_value(cls, value: int, info: ValidationInfo): + return value + (info.context or 0) + + +@dataclass +class Deps: + increment: int + + +@pytest.mark.parametrize( + 'output_type', + [ + Value, + ToolOutput(Value), + NativeOutput(Value), + PromptedOutput(Value), + ], + ids=[ + 'Value', + 'ToolOutput(Value)', + 'NativeOutput(Value)', + 'PromptedOutput(Value)', + ], +) +def test_agent_output_with_validation_context(output_type: OutputSpec[Value]): + """Test that the output is validated using the validation context""" + + def mock_llm(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + if isinstance(output_type, ToolOutput): + return ModelResponse(parts=[ToolCallPart(tool_name='final_result', args={'x': 0})]) + else: + text = Value(x=0).model_dump_json() + return ModelResponse(parts=[TextPart(content=text)]) + + agent = Agent( + FunctionModel(mock_llm), + output_type=output_type, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output.x == snapshot(10) + + +def test_agent_tool_call_with_validation_context(): + """Test that the argument passed to the tool call is validated using the validation context.""" + + agent = Agent( + 'test', + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + @agent.tool + def get_value(ctx: RunContext[Deps], v: Value) -> int: + # NOTE: The test agent calls this tool with Value(x=0) which should then have been influenced by the validation context through the `increment_value` field validator + assert v.x == ctx.deps.increment + return v.x + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output == snapshot('{"get_value":10}') + + +def test_agent_output_function_with_validation_context(): + """Test that the argument passed to the output function is validated using the validation context.""" + + def get_value(v: Value) -> int: + return v.x + + agent = Agent( + 'test', + output_type=get_value, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output == snapshot(10) + + +def test_agent_output_validator_with_validation_context(): + """Test that the argument passed to the output validator is validated using the validation context.""" + + agent = Agent( + 'test', + output_type=Value, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + @agent.output_validator + def identity(ctx: RunContext[Deps], v: Value) -> Value: + return v + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output.x == snapshot(10)