Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations as _annotations

import copy
import inspect
import json
import re
Expand All @@ -15,6 +16,7 @@
from pydantic_ai._instrumentation import InstrumentationNames

from . import _function_schema, _utils, messages as _messages
from ._json_schema import JsonSchema
from ._run_context import AgentDepsT, RunContext
from .exceptions import ModelRetry, ToolRetryError, UserError
from .output import (
Expand Down Expand Up @@ -226,6 +228,10 @@ def mode(self) -> OutputMode:
def allows_text(self) -> bool:
return self.text_processor is not None

@abstractmethod
def dump(self) -> JsonSchema:
raise NotImplementedError()

@classmethod
def build( # noqa: C901
cls,
Expand Down Expand Up @@ -405,6 +411,18 @@ def __init__(
def mode(self) -> OutputMode:
return 'auto'

def dump(self) -> JsonSchema:
if self.toolset:
processors: list[ObjectOutputProcessor[OutputDataT]] = []
for tool_def in self.toolset._tool_defs: # pyright: ignore [reportPrivateUsage]
processor = copy.copy(self.toolset.processors[tool_def.name])
processor.object_def.name = tool_def.name
processor.object_def.description = tool_def.description
processors.append(processor)
return UnionOutputProcessor(processors).object_def.json_schema

return self.processor.object_def.json_schema


@dataclass(init=False)
class TextOutputSchema(OutputSchema[OutputDataT]):
Expand All @@ -425,6 +443,9 @@ def __init__(
def mode(self) -> OutputMode:
return 'text'

def dump(self) -> JsonSchema:
return {'type': 'string'}


class ImageOutputSchema(OutputSchema[OutputDataT]):
def __init__(self, *, allows_deferred_tools: bool):
Expand All @@ -434,6 +455,9 @@ def __init__(self, *, allows_deferred_tools: bool):
def mode(self) -> OutputMode:
return 'image'

def dump(self) -> JsonSchema:
raise NotImplementedError()


@dataclass(init=False)
class StructuredTextOutputSchema(OutputSchema[OutputDataT], ABC):
Expand All @@ -450,6 +474,9 @@ def __init__(
)
self.processor = processor

def dump(self) -> JsonSchema:
return self.processor.object_def.json_schema


class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
@property
Expand Down Expand Up @@ -523,6 +550,18 @@ def __init__(
def mode(self) -> OutputMode:
return 'tool'

def dump(self) -> JsonSchema:
if self.toolset is None:
# need to check expected behavior
raise NotImplementedError()
processors: list[ObjectOutputProcessor[OutputDataT]] = []
for tool_def in self.toolset._tool_defs: # pyright: ignore [reportPrivateUsage]
processor = copy.copy(self.toolset.processors[tool_def.name])
processor.object_def.name = tool_def.name
processor.object_def.description = tool_def.description
processors.append(processor)
return UnionOutputProcessor(processors).object_def.json_schema


class BaseOutputProcessor(ABC, Generic[OutputDataT]):
@abstractmethod
Expand Down Expand Up @@ -714,7 +753,7 @@ class UnionOutputProcessor(BaseObjectOutputProcessor[OutputDataT]):

def __init__(
self,
outputs: Sequence[OutputTypeOrFunction[OutputDataT]],
outputs: Sequence[OutputTypeOrFunction[OutputDataT] | ObjectOutputProcessor[OutputDataT]],
*,
name: str | None = None,
description: str | None = None,
Expand All @@ -725,7 +764,10 @@ def __init__(
json_schemas: list[ObjectJsonSchema] = []
self._processors = {}
for output in outputs:
processor = ObjectOutputProcessor(output=output, strict=strict)
if isinstance(output, ObjectOutputProcessor):
processor = output
else:
processor = ObjectOutputProcessor(output=output, strict=strict)
object_def = processor.object_def

object_key = object_def.name or output.__name__
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
UserPromptNode,
capture_run_messages,
)
from .._json_schema import JsonSchema
from .._output import OutputToolset
from .._tool_manager import ToolManager
from ..builtin_tools import AbstractBuiltinTool
Expand Down Expand Up @@ -947,6 +948,11 @@ def decorator(
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic))
return func

def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
"""The output JSON schema."""
output_schema = self._prepare_output_schema(output_type)
return output_schema.dump()

@overload
def output_validator(
self, func: Callable[[RunContext[AgentDepsT], OutputDataT], OutputDataT], /
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
result,
usage as _usage,
)
from .._json_schema import JsonSchema
from .._tool_manager import ToolManager
from ..builtin_tools import AbstractBuiltinTool
from ..output import OutputDataT, OutputSpec
Expand Down Expand Up @@ -122,6 +123,11 @@ def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
"""
raise NotImplementedError

@abstractmethod
def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
"""The output JSON schema."""
raise NotImplementedError

@overload
async def run(
self,
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
models,
usage as _usage,
)
from .._json_schema import JsonSchema
from ..builtin_tools import AbstractBuiltinTool
from ..output import OutputDataT, OutputSpec
from ..run import AgentRun
Expand Down Expand Up @@ -67,6 +68,9 @@ async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]:
async def __aexit__(self, *args: Any) -> bool | None:
return await self.wrapped.__aexit__(*args)

def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
return self.wrapped.output_json_schema(output_type=output_type)

@overload
def iter(
self,
Expand Down
Loading
Loading