Skip to content

Commit d8f01f0

Browse files
NicolasPllr1DouweM
andauthored
Pass pydantic validation context (#3448)
Co-authored-by: Douwe Maan <[email protected]>
1 parent 5eb7778 commit d8f01f0

File tree

10 files changed

+277
-14
lines changed

10 files changed

+277
-14
lines changed

docs/output.md

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ Instead of plain text or structured data, you may want the output of your agent
121121

122122
Output functions are similar to [function tools](tools.md), but the model is forced to call one of them, the call ends the agent run, and the result is not passed back to the model.
123123

124-
As with tool functions, output function arguments provided by the model are validated using Pydantic, they can optionally take [`RunContext`][pydantic_ai.tools.RunContext] as the first argument, and they can raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to try again with modified arguments (or with a different output type).
124+
As with tool functions, output function arguments provided by the model are validated using Pydantic (with optional [validation context](#validation-context)), can optionally take [`RunContext`][pydantic_ai.tools.RunContext] as the first argument, and can raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to try again with modified arguments (or with a different output type).
125125

126126
To specify output functions, you set the agent's `output_type` to either a single function (or bound instance method), or a list of functions. The list can also contain other output types like simple scalars or entire Pydantic models.
127127
You typically do not want to also register your output function as a tool (using the `@agent.tool` decorator or `tools` argument), as this could confuse the model about which it should be calling.
@@ -416,6 +416,62 @@ result = agent.run_sync('Create a person')
416416
#> {'name': 'John Doe', 'age': 30}
417417
```
418418

419+
### Validation context {#validation-context}
420+
421+
Some validation relies on an extra Pydantic [context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) object. You can pass such an object to an `Agent` at definition-time via its [`validation_context`][pydantic_ai.Agent.__init__] parameter. It will be used in the validation of both structured outputs and [tool arguments](tools-advanced.md#tool-retries).
422+
423+
This validation context can be either:
424+
425+
- the context object itself (`Any`), used as-is to validate outputs, or
426+
- a function that takes the [`RunContext`][pydantic_ai.tools.RunContext] and returns a context object (`Any`). This function will be called automatically before each validation, allowing you to build a dynamic validation context.
427+
428+
!!! warning "Don't confuse this _validation_ context with the _LLM_ context"
429+
This Pydantic validation context object is only used internally by Pydantic AI for tool arg and output validation. In particular, it is **not** included in the prompts or messages sent to the language model.
430+
431+
```python {title="validation_context.py"}
432+
from dataclasses import dataclass
433+
434+
from pydantic import BaseModel, ValidationInfo, field_validator
435+
436+
from pydantic_ai import Agent
437+
438+
439+
class Value(BaseModel):
440+
x: int
441+
442+
@field_validator('x')
443+
def increment_value(cls, value: int, info: ValidationInfo):
444+
return value + (info.context or 0)
445+
446+
447+
agent = Agent(
448+
'google-gla:gemini-2.5-flash',
449+
output_type=Value,
450+
validation_context=10,
451+
)
452+
result = agent.run_sync('Give me a value of 5.')
453+
print(repr(result.output)) # 5 from the model + 10 from the validation context
454+
#> Value(x=15)
455+
456+
457+
@dataclass
458+
class Deps:
459+
increment: int
460+
461+
462+
agent = Agent(
463+
'google-gla:gemini-2.5-flash',
464+
output_type=Value,
465+
deps_type=Deps,
466+
validation_context=lambda ctx: ctx.deps.increment,
467+
)
468+
result = agent.run_sync('Give me a value of 5.', deps=Deps(increment=10))
469+
print(repr(result.output)) # 5 from the model + 10 from the validation context
470+
#> Value(x=15)
471+
```
472+
473+
_(This example is complete, it can be run "as is")_
474+
419475
### Output validators {#output-validator-functions}
420476

421477
Some validation is inconvenient or impossible to do in Pydantic validators, in particular when the validation requires IO and is asynchronous. Pydantic AI provides a way to add validation functions via the [`agent.output_validator`][pydantic_ai.Agent.output_validator] decorator.

docs/tools-advanced.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ If both per-tool `prepare` and agent-wide `prepare_tools` are used, the per-tool
353353

354354
## Tool Execution and Retries {#tool-retries}
355355

356-
When a tool is executed, its arguments (provided by the LLM) are first validated against the function's signature using Pydantic. If validation fails (e.g., due to incorrect types or missing required arguments), a `ValidationError` is raised, and the framework automatically generates a [`RetryPromptPart`][pydantic_ai.messages.RetryPromptPart] containing the validation details. This prompt is sent back to the LLM, informing it of the error and allowing it to correct the parameters and retry the tool call.
356+
When a tool is executed, its arguments (provided by the LLM) are first validated against the function's signature using Pydantic (with optional [validation context](output.md#validation-context)). If validation fails (e.g., due to incorrect types or missing required arguments), a `ValidationError` is raised, and the framework automatically generates a [`RetryPromptPart`][pydantic_ai.messages.RetryPromptPart] containing the validation details. This prompt is sent back to the LLM, informing it of the error and allowing it to correct the parameters and retry the tool call.
357357

358358
Beyond automatic validation errors, the tool's own internal logic can also explicitly request a retry by raising the [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception. This is useful for situations where the parameters were technically valid, but an issue occurred during execution (like a transient network error, or the tool determining the initial attempt needs modification).
359359

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
144144

145145
output_schema: _output.OutputSchema[OutputDataT]
146146
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
147+
validation_context: Any | Callable[[RunContext[DepsT]], Any]
147148

148149
history_processors: Sequence[HistoryProcessor[DepsT]]
149150

@@ -745,7 +746,7 @@ async def _handle_text_response(
745746
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
746747
run_context = build_run_context(ctx)
747748

748-
result_data = await text_processor.process(text, run_context)
749+
result_data = await text_processor.process(text, run_context=run_context)
749750

750751
for validator in ctx.deps.output_validators:
751752
result_data = await validator.validate(result_data, run_context)
@@ -790,12 +791,13 @@ async def run(
790791

791792
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
792793
"""Build a `RunContext` object from the current agent graph run context."""
793-
return RunContext[DepsT](
794+
run_context = RunContext[DepsT](
794795
deps=ctx.deps.user_deps,
795796
model=ctx.deps.model,
796797
usage=ctx.state.usage,
797798
prompt=ctx.deps.prompt,
798799
messages=ctx.state.message_history,
800+
validation_context=None,
799801
tracer=ctx.deps.tracer,
800802
trace_include_content=ctx.deps.instrumentation_settings is not None
801803
and ctx.deps.instrumentation_settings.include_content,
@@ -805,6 +807,21 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
805807
run_step=ctx.state.run_step,
806808
run_id=ctx.state.run_id,
807809
)
810+
validation_context = build_validation_context(ctx.deps.validation_context, run_context)
811+
run_context = replace(run_context, validation_context=validation_context)
812+
return run_context
813+
814+
815+
def build_validation_context(
816+
validation_ctx: Any | Callable[[RunContext[DepsT]], Any],
817+
run_context: RunContext[DepsT],
818+
) -> Any:
819+
"""Build a Pydantic validation context, potentially from the current agent run context."""
820+
if callable(validation_ctx):
821+
fn = cast(Callable[[RunContext[DepsT]], Any], validation_ctx)
822+
return fn(run_context)
823+
else:
824+
return validation_ctx
808825

809826

810827
async def process_tool_calls( # noqa: C901

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]):
522522
async def process(
523523
self,
524524
data: str,
525+
*,
525526
run_context: RunContext[AgentDepsT],
526527
allow_partial: bool = False,
527528
wrap_validation_errors: bool = True,
@@ -609,6 +610,7 @@ def __init__(
609610
async def process(
610611
self,
611612
data: str | dict[str, Any] | None,
613+
*,
612614
run_context: RunContext[AgentDepsT],
613615
allow_partial: bool = False,
614616
wrap_validation_errors: bool = True,
@@ -628,7 +630,7 @@ async def process(
628630
data = _utils.strip_markdown_fences(data)
629631

630632
try:
631-
output = self.validate(data, allow_partial)
633+
output = self.validate(data, allow_partial=allow_partial, validation_context=run_context.validation_context)
632634
except ValidationError as e:
633635
if wrap_validation_errors:
634636
m = _messages.RetryPromptPart(
@@ -645,13 +647,19 @@ async def process(
645647
def validate(
646648
self,
647649
data: str | dict[str, Any] | None,
650+
*,
648651
allow_partial: bool = False,
652+
validation_context: Any | None = None,
649653
) -> dict[str, Any]:
650654
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
651655
if isinstance(data, str):
652-
return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
656+
return self.validator.validate_json(
657+
data or '{}', allow_partial=pyd_allow_partial, context=validation_context
658+
)
653659
else:
654-
return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
660+
return self.validator.validate_python(
661+
data or {}, allow_partial=pyd_allow_partial, context=validation_context
662+
)
655663

656664
async def call(
657665
self,
@@ -770,12 +778,16 @@ def __init__(
770778
async def process(
771779
self,
772780
data: str,
781+
*,
773782
run_context: RunContext[AgentDepsT],
774783
allow_partial: bool = False,
775784
wrap_validation_errors: bool = True,
776785
) -> OutputDataT:
777786
union_object = await self._union_processor.process(
778-
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
787+
data,
788+
run_context=run_context,
789+
allow_partial=allow_partial,
790+
wrap_validation_errors=wrap_validation_errors,
779791
)
780792

781793
result = union_object.result
@@ -791,15 +803,20 @@ async def process(
791803
raise
792804

793805
return await processor.process(
794-
inner_data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
806+
inner_data,
807+
run_context=run_context,
808+
allow_partial=allow_partial,
809+
wrap_validation_errors=wrap_validation_errors,
795810
)
796811

797812

798813
class TextOutputProcessor(BaseOutputProcessor[OutputDataT]):
799814
async def process(
800815
self,
801816
data: str,
817+
*,
802818
run_context: RunContext[AgentDepsT],
819+
validation_context: Any | None = None,
803820
allow_partial: bool = False,
804821
wrap_validation_errors: bool = True,
805822
) -> OutputDataT:
@@ -830,14 +847,22 @@ def __init__(
830847
async def process(
831848
self,
832849
data: str,
850+
*,
833851
run_context: RunContext[AgentDepsT],
852+
validation_context: Any | None = None,
834853
allow_partial: bool = False,
835854
wrap_validation_errors: bool = True,
836855
) -> OutputDataT:
837856
args = {self._str_argument_name: data}
838857
data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors)
839858

840-
return await super().process(data, run_context, allow_partial, wrap_validation_errors)
859+
return await super().process(
860+
data,
861+
run_context=run_context,
862+
validation_context=validation_context,
863+
allow_partial=allow_partial,
864+
wrap_validation_errors=wrap_validation_errors,
865+
)
841866

842867

843868
@dataclass(init=False)

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import dataclasses
44
from collections.abc import Sequence
55
from dataclasses import field
6-
from typing import TYPE_CHECKING, Generic
6+
from typing import TYPE_CHECKING, Any, Generic
77

88
from opentelemetry.trace import NoOpTracer, Tracer
99
from typing_extensions import TypeVar
@@ -38,6 +38,8 @@ class RunContext(Generic[RunContextAgentDepsT]):
3838
"""The original user prompt passed to the run."""
3939
messages: list[_messages.ModelMessage] = field(default_factory=list)
4040
"""Messages exchanged in the conversation so far."""
41+
validation_context: Any = None
42+
"""Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) for tool args and run outputs."""
4143
tracer: Tracer = field(default_factory=NoOpTracer)
4244
"""The tracer to use for tracing the run."""
4345
trace_include_content: bool = False

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,13 @@ async def _call_tool(
164164
pyd_allow_partial = 'trailing-strings' if allow_partial else 'off'
165165
validator = tool.args_validator
166166
if isinstance(call.args, str):
167-
args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial)
167+
args_dict = validator.validate_json(
168+
call.args or '{}', allow_partial=pyd_allow_partial, context=ctx.validation_context
169+
)
168170
else:
169-
args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
171+
args_dict = validator.validate_python(
172+
call.args or {}, allow_partial=pyd_allow_partial, context=ctx.validation_context
173+
)
170174

171175
result = await self.toolset.call_tool(name, args_dict, ctx, tool)
172176

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
147147
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
148148
_max_result_retries: int = dataclasses.field(repr=False)
149149
_max_tool_retries: int = dataclasses.field(repr=False)
150+
_validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False)
150151

151152
_event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False)
152153

@@ -166,6 +167,7 @@ def __init__(
166167
name: str | None = None,
167168
model_settings: ModelSettings | None = None,
168169
retries: int = 1,
170+
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
169171
output_retries: int | None = None,
170172
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
171173
builtin_tools: Sequence[AbstractBuiltinTool] = (),
@@ -192,6 +194,7 @@ def __init__(
192194
name: str | None = None,
193195
model_settings: ModelSettings | None = None,
194196
retries: int = 1,
197+
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
195198
output_retries: int | None = None,
196199
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
197200
builtin_tools: Sequence[AbstractBuiltinTool] = (),
@@ -216,6 +219,7 @@ def __init__(
216219
name: str | None = None,
217220
model_settings: ModelSettings | None = None,
218221
retries: int = 1,
222+
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
219223
output_retries: int | None = None,
220224
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
221225
builtin_tools: Sequence[AbstractBuiltinTool] = (),
@@ -249,6 +253,7 @@ def __init__(
249253
model_settings: Optional model request settings to use for this agent's runs, by default.
250254
retries: The default number of retries to allow for tool calls and output validation, before raising an error.
251255
For model request retries, see the [HTTP Request Retries](../retries.md) documentation.
256+
validation_context: Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) used to validate tool arguments and outputs.
252257
output_retries: The maximum number of retries to allow for output validation, defaults to `retries`.
253258
tools: Tools to register with the agent, you can also register tools via the decorators
254259
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
@@ -314,6 +319,8 @@ def __init__(
314319
self._max_result_retries = output_retries if output_retries is not None else retries
315320
self._max_tool_retries = retries
316321

322+
self._validation_context = validation_context
323+
317324
self._builtin_tools = builtin_tools
318325

319326
self._prepare_tools = prepare_tools
@@ -612,6 +619,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
612619
end_strategy=self.end_strategy,
613620
output_schema=output_schema,
614621
output_validators=output_validators,
622+
validation_context=self._validation_context,
615623
history_processors=self.history_processors,
616624
builtin_tools=[*self._builtin_tools, *(builtin_tools or [])],
617625
tool_manager=tool_manager,

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,10 @@ async def validate_response_output(
198198
text = ''
199199

200200
result_data = await text_processor.process(
201-
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
201+
text,
202+
run_context=self._run_ctx,
203+
allow_partial=allow_partial,
204+
wrap_validation_errors=False,
202205
)
203206
for validator in self._output_validators:
204207
result_data = await validator.validate(

tests/test_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,7 @@ async def call_tool(
512512
'What is a banana?': ToolCallPart(tool_name='return_fruit', args={'name': 'banana', 'color': 'yellow'}),
513513
'What is a Ford Explorer?': '{"result": {"kind": "Vehicle", "data": {"name": "Ford Explorer", "wheels": 4}}}',
514514
'What is a MacBook?': '{"result": {"kind": "Device", "data": {"name": "MacBook", "kind": "laptop"}}}',
515+
'Give me a value of 5.': ToolCallPart(tool_name='final_result', args={'x': 5}),
515516
'Write a creative story about space exploration': 'In the year 2157, Captain Maya Chen piloted her spacecraft through the vast expanse of the Andromeda Galaxy. As she discovered a planet with crystalline mountains that sang in harmony with the cosmic winds, she realized that space exploration was not just about finding new worlds, but about finding new ways to understand the universe and our place within it.',
516517
'Create a person': ToolCallPart(
517518
tool_name='final_result',

0 commit comments

Comments
 (0)