-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Pass pydantic validation context #3448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,7 @@ | |
|
|
||
| from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore | ||
| from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION | ||
| from pydantic_ai._tool_manager import ToolManager | ||
| from pydantic_ai._tool_manager import ToolManager, build_validation_context | ||
| from pydantic_ai._utils import dataclasses_no_defaults_repr, get_union_args, is_async_callable, run_in_executor | ||
| from pydantic_ai.builtin_tools import AbstractBuiltinTool | ||
| from pydantic_graph import BaseNode, GraphRunContext | ||
|
|
@@ -590,7 +590,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa | |
| text = '' # pragma: no cover | ||
| if text: | ||
| try: | ||
| self._next_node = await self._handle_text_response(ctx, text, text_processor) | ||
| self._next_node = await self._handle_text_response( | ||
| ctx, ctx.deps.tool_manager.validation_ctx, text, text_processor | ||
| ) | ||
| return | ||
| except ToolRetryError: | ||
| # If the text from the preview response was invalid, ignore it. | ||
|
|
@@ -654,7 +656,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa | |
|
|
||
| if text_processor := output_schema.text_processor: | ||
| if text: | ||
| self._next_node = await self._handle_text_response(ctx, text, text_processor) | ||
| self._next_node = await self._handle_text_response( | ||
| ctx, ctx.deps.tool_manager.validation_ctx, text, text_processor | ||
| ) | ||
| return | ||
| alternatives.insert(0, 'return text') | ||
|
|
||
|
|
@@ -716,12 +720,14 @@ async def _handle_tool_calls( | |
| async def _handle_text_response( | ||
| self, | ||
| ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], | ||
| validation_ctx: Any | Callable[[RunContext[DepsT]], Any], | ||
| text: str, | ||
| text_processor: _output.BaseOutputProcessor[NodeRunEndT], | ||
| ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: | ||
| run_context = build_run_context(ctx) | ||
| validation_context = build_validation_context(validation_ctx, run_context) | ||
|
|
||
| result_data = await text_processor.process(text, run_context) | ||
| result_data = await text_processor.process(text, run_context, validation_context) | ||
|
|
||
| for validator in ctx.deps.output_validators: | ||
| result_data = await validator.validate(result_data, run_context) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally, output validators would have access to the validation context as well, as they could call I think that'd be a refactor worth exploring that could allow us to drop a lot of the new arguments. |
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -530,6 +530,7 @@ async def process( | |
| self, | ||
| data: str, | ||
| run_context: RunContext[AgentDepsT], | ||
| validation_context: Any | None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's have a default of
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also while we're at it, let's require kwargs for everything after |
||
| allow_partial: bool = False, | ||
| wrap_validation_errors: bool = True, | ||
| ) -> OutputDataT: | ||
|
|
@@ -554,13 +555,18 @@ async def process( | |
| self, | ||
| data: str, | ||
| run_context: RunContext[AgentDepsT], | ||
| validation_context: Any | None, | ||
| allow_partial: bool = False, | ||
| wrap_validation_errors: bool = True, | ||
| ) -> OutputDataT: | ||
| text = _utils.strip_markdown_fences(data) | ||
|
|
||
| return await self.wrapped.process( | ||
| text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
| text, | ||
| run_context, | ||
| validation_context, | ||
| allow_partial=allow_partial, | ||
| wrap_validation_errors=wrap_validation_errors, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -639,6 +645,7 @@ async def process( | |
| self, | ||
| data: str | dict[str, Any] | None, | ||
| run_context: RunContext[AgentDepsT], | ||
| validation_context: Any | None, | ||
| allow_partial: bool = False, | ||
| wrap_validation_errors: bool = True, | ||
| ) -> OutputDataT: | ||
|
|
@@ -647,14 +654,15 @@ async def process( | |
| Args: | ||
| data: The output data to validate. | ||
| run_context: The current run context. | ||
| validation_context: Additional Pydantic validation context for the current run. | ||
| allow_partial: If true, allow partial validation. | ||
| wrap_validation_errors: If true, wrap the validation errors in a retry message. | ||
|
|
||
| Returns: | ||
| Either the validated output data (left) or a retry message (right). | ||
| """ | ||
| try: | ||
| output = self.validate(data, allow_partial) | ||
| output = self.validate(data, allow_partial, validation_context) | ||
| except ValidationError as e: | ||
| if wrap_validation_errors: | ||
| m = _messages.RetryPromptPart( | ||
|
|
@@ -672,12 +680,17 @@ def validate( | |
| self, | ||
| data: str | dict[str, Any] | None, | ||
| allow_partial: bool = False, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as up, let's require kwargs for everything but data |
||
| validation_context: Any | None = None, | ||
| ) -> dict[str, Any]: | ||
| pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' | ||
| if isinstance(data, str): | ||
| return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) | ||
| return self.validator.validate_json( | ||
| data or '{}', allow_partial=pyd_allow_partial, context=validation_context | ||
| ) | ||
| else: | ||
| return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial) | ||
| return self.validator.validate_python( | ||
| data or {}, allow_partial=pyd_allow_partial, context=validation_context | ||
| ) | ||
|
|
||
| async def call( | ||
| self, | ||
|
|
@@ -797,11 +810,16 @@ async def process( | |
| self, | ||
| data: str, | ||
| run_context: RunContext[AgentDepsT], | ||
| validation_context: Any | None = None, | ||
| allow_partial: bool = False, | ||
| wrap_validation_errors: bool = True, | ||
| ) -> OutputDataT: | ||
| union_object = await self._union_processor.process( | ||
| data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
| data, | ||
| run_context, | ||
| validation_context, | ||
| allow_partial=allow_partial, | ||
| wrap_validation_errors=wrap_validation_errors, | ||
| ) | ||
|
|
||
| result = union_object.result | ||
|
|
@@ -817,7 +835,11 @@ async def process( | |
| raise | ||
|
|
||
| return await processor.process( | ||
| inner_data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
| inner_data, | ||
| run_context, | ||
| validation_context, | ||
| allow_partial=allow_partial, | ||
| wrap_validation_errors=wrap_validation_errors, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -826,6 +848,7 @@ async def process( | |
| self, | ||
| data: str, | ||
| run_context: RunContext[AgentDepsT], | ||
| validation_context: Any | None = None, | ||
| allow_partial: bool = False, | ||
| wrap_validation_errors: bool = True, | ||
| ) -> OutputDataT: | ||
|
|
@@ -857,13 +880,14 @@ async def process( | |
| self, | ||
| data: str, | ||
| run_context: RunContext[AgentDepsT], | ||
| validation_context: Any | None = None, | ||
| allow_partial: bool = False, | ||
| wrap_validation_errors: bool = True, | ||
| ) -> OutputDataT: | ||
| args = {self._str_argument_name: data} | ||
| data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors) | ||
|
|
||
| return await super().process(data, run_context, allow_partial, wrap_validation_errors) | ||
| return await super().process(data, run_context, validation_context, allow_partial, wrap_validation_errors) | ||
|
|
||
|
|
||
| @dataclass(init=False) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): | |
| _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) | ||
| _max_result_retries: int = dataclasses.field(repr=False) | ||
| _max_tool_retries: int = dataclasses.field(repr=False) | ||
| _validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False) | ||
|
|
||
| _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) | ||
|
|
||
|
|
@@ -166,6 +167,7 @@ def __init__( | |
| name: str | None = None, | ||
| model_settings: ModelSettings | None = None, | ||
| retries: int = 1, | ||
| validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, | ||
| output_retries: int | None = None, | ||
| tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), | ||
| builtin_tools: Sequence[AbstractBuiltinTool] = (), | ||
|
|
@@ -192,6 +194,7 @@ def __init__( | |
| name: str | None = None, | ||
| model_settings: ModelSettings | None = None, | ||
| retries: int = 1, | ||
| validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, | ||
| output_retries: int | None = None, | ||
| tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), | ||
| builtin_tools: Sequence[AbstractBuiltinTool] = (), | ||
|
|
@@ -216,6 +219,7 @@ def __init__( | |
| name: str | None = None, | ||
| model_settings: ModelSettings | None = None, | ||
| retries: int = 1, | ||
| validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, | ||
| output_retries: int | None = None, | ||
| tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), | ||
| builtin_tools: Sequence[AbstractBuiltinTool] = (), | ||
|
|
@@ -249,6 +253,7 @@ def __init__( | |
| model_settings: Optional model request settings to use for this agent's runs, by default. | ||
| retries: The default number of retries to allow for tool calls and output validation, before raising an error. | ||
| For model request retries, see the [HTTP Request Retries](../retries.md) documentation. | ||
| validation_context: Additional validation context used to validate all outputs. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's link to the Pydantic doc |
||
| output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. | ||
| tools: Tools to register with the agent, you can also register tools via the decorators | ||
| [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. | ||
|
|
@@ -314,6 +319,8 @@ def __init__( | |
| self._max_result_retries = output_retries if output_retries is not None else retries | ||
| self._max_tool_retries = retries | ||
|
|
||
| self._validation_context = validation_context | ||
|
|
||
| self._builtin_tools = builtin_tools | ||
|
|
||
| self._prepare_tools = prepare_tools | ||
|
|
@@ -562,7 +569,7 @@ async def main(): | |
| output_toolset.max_retries = self._max_result_retries | ||
| output_toolset.output_validators = output_validators | ||
| toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) | ||
| tool_manager = ToolManager[AgentDepsT](toolset) | ||
| tool_manager = ToolManager[AgentDepsT](toolset, validation_ctx=self._validation_context) | ||
|
|
||
| # Build the graph | ||
| graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we already pass in
ctx, we don't need a new arg do we?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to store the validation context directly on
ctx.deps, instead of going through the tool manager?