diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index ebb737a1cf..50b6fe38eb 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import copy import inspect import json import re @@ -15,6 +16,7 @@ from pydantic_ai._instrumentation import InstrumentationNames from . import _function_schema, _utils, messages as _messages +from ._json_schema import JsonSchema from ._run_context import AgentDepsT, RunContext from .exceptions import ModelRetry, ToolRetryError, UserError from .output import ( @@ -226,6 +228,10 @@ def mode(self) -> OutputMode: def allows_text(self) -> bool: return self.text_processor is not None + @abstractmethod + def dump(self) -> JsonSchema: + raise NotImplementedError() + @classmethod def build( # noqa: C901 cls, @@ -405,6 +411,18 @@ def __init__( def mode(self) -> OutputMode: return 'auto' + def dump(self) -> JsonSchema: + if self.toolset: + processors: list[ObjectOutputProcessor[OutputDataT]] = [] + for tool_def in self.toolset._tool_defs: # pyright: ignore [reportPrivateUsage] + processor = copy.copy(self.toolset.processors[tool_def.name]) + processor.object_def.name = tool_def.name + processor.object_def.description = tool_def.description + processors.append(processor) + return UnionOutputProcessor(processors).object_def.json_schema + + return self.processor.object_def.json_schema + @dataclass(init=False) class TextOutputSchema(OutputSchema[OutputDataT]): @@ -425,6 +443,9 @@ def __init__( def mode(self) -> OutputMode: return 'text' + def dump(self) -> JsonSchema: + return {'type': 'string'} + class ImageOutputSchema(OutputSchema[OutputDataT]): def __init__(self, *, allows_deferred_tools: bool): @@ -434,6 +455,9 @@ def __init__(self, *, allows_deferred_tools: bool): def mode(self) -> OutputMode: return 'image' + def dump(self) -> JsonSchema: + raise NotImplementedError() + @dataclass(init=False) class StructuredTextOutputSchema(OutputSchema[OutputDataT], ABC): @@ -450,6 +474,9 @@ def __init__( ) self.processor = processor + def dump(self) -> JsonSchema: + return self.processor.object_def.json_schema + class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]): @property @@ -523,6 +550,18 @@ def __init__( def mode(self) -> OutputMode: return 'tool' + def dump(self) -> JsonSchema: + if self.toolset is None: + # need to check expected behavior + raise NotImplementedError() + processors: list[ObjectOutputProcessor[OutputDataT]] = [] + for tool_def in self.toolset._tool_defs: # pyright: ignore [reportPrivateUsage] + processor = copy.copy(self.toolset.processors[tool_def.name]) + processor.object_def.name = tool_def.name + processor.object_def.description = tool_def.description + processors.append(processor) + return UnionOutputProcessor(processors).object_def.json_schema + class BaseOutputProcessor(ABC, Generic[OutputDataT]): @abstractmethod @@ -714,7 +753,7 @@ class UnionOutputProcessor(BaseObjectOutputProcessor[OutputDataT]): def __init__( self, - outputs: Sequence[OutputTypeOrFunction[OutputDataT]], + outputs: Sequence[OutputTypeOrFunction[OutputDataT] | ObjectOutputProcessor[OutputDataT]], *, name: str | None = None, description: str | None = None, @@ -725,7 +764,10 @@ def __init__( json_schemas: list[ObjectJsonSchema] = [] self._processors = {} for output in outputs: - processor = ObjectOutputProcessor(output=output, strict=strict) + if isinstance(output, ObjectOutputProcessor): + processor = output + else: + processor = ObjectOutputProcessor(output=output, strict=strict) object_def = processor.object_def object_key = object_def.name or output.__name__ diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 4cd353b44a..54c98e87ab 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -34,6 +34,7 @@ UserPromptNode, capture_run_messages, ) +from .._json_schema import JsonSchema from .._output import OutputToolset from .._tool_manager import ToolManager from ..builtin_tools import AbstractBuiltinTool @@ -947,6 +948,11 @@ def decorator( self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic)) return func + def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema: + """The output JSON schema.""" + output_schema = self._prepare_output_schema(output_type) + return output_schema.dump() + @overload def output_validator( self, func: Callable[[RunContext[AgentDepsT], OutputDataT], OutputDataT], / diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index c7c1cb2b5c..8c7c81ee41 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -23,6 +23,7 @@ result, usage as _usage, ) +from .._json_schema import JsonSchema from .._tool_manager import ToolManager from ..builtin_tools import AbstractBuiltinTool from ..output import OutputDataT, OutputSpec @@ -122,6 +123,11 @@ def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]: """ raise NotImplementedError + @abstractmethod + def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema: + """The output JSON schema.""" + raise NotImplementedError + @overload async def run( self, diff --git a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py index 38e832fa2b..6431be7297 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py @@ -10,6 +10,7 @@ models, usage as _usage, ) +from .._json_schema import JsonSchema from ..builtin_tools import AbstractBuiltinTool from ..output import OutputDataT, OutputSpec from ..run import AgentRun @@ -67,6 +68,9 @@ async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]: async def __aexit__(self, *args: Any) -> bool | None: return await self.wrapped.__aexit__(*args) + def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema: + return self.wrapped.output_json_schema(output_type=output_type) + @overload def iter( self, diff --git a/tests/test_agent.py b/tests/test_agent.py index c2a513af47..fda72265cb 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -5155,6 +5155,35 @@ def foo() -> str: assert wrapper_agent.name == 'wrapped' assert wrapper_agent.output_type == agent.output_type assert wrapper_agent.event_stream_handler == agent.event_stream_handler + assert wrapper_agent.output_json_schema() == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'final_result'}, + 'data': { + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'string'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'final_result', + 'description': 'The final response which ends this conversation', + } + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + ) + assert wrapper_agent.output_json_schema(output_type=str) == snapshot({'type': 'string'}) bar_toolset = FunctionToolset() @@ -6148,3 +6177,218 @@ def test_message_history_cannot_start_with_model_response(): match='Message history cannot start with a `ModelResponse`.', ): agent.run_sync('hello', message_history=invalid_history) + + +async def test_text_output_json_schema(): + agent = Agent('test') + assert agent.output_json_schema() == snapshot({'type': 'string'}) + + +async def test_auto_output_json_schema(): + agent = Agent('test', output_type=bool) + assert agent.output_json_schema() == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'final_result'}, + 'data': { + 'properties': {'response': {'type': 'boolean'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'final_result', + 'description': 'The final response which ends this conversation', + } + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + ) + + +async def test_tool_output_json_schema(): + agent = Agent( + 'test', + output_type=[ToolOutput(bool, name='alice', description='Dreaming...')], + ) + assert agent.output_json_schema() == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'alice'}, + 'data': { + 'properties': {'response': {'type': 'boolean'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'alice', + 'description': 'Dreaming...', + } + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + ) + + agent = Agent( + 'test', + output_type=[ToolOutput(bool, name='alice'), ToolOutput(bool, name='bob')], + ) + assert agent.output_json_schema() == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'alice'}, + 'data': { + 'properties': {'response': {'type': 'boolean'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'alice', + 'description': 'bool: The final response which ends this conversation', + }, + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'bob'}, + 'data': { + 'properties': {'response': {'type': 'boolean'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'bob', + 'description': 'bool: The final response which ends this conversation', + }, + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + ) + + +async def test_native_output_json_schema(): + agent = Agent( + 'test', + output_type=NativeOutput([bool], name='native_output_name', description='native_output_description'), + ) + assert agent.output_json_schema() == snapshot( + {'properties': {'response': {'type': 'boolean'}}, 'required': ['response'], 'type': 'object'} + ) + + +async def test_prompted_output_json_schema(): + agent = Agent( + 'test', + output_type=PromptedOutput([bool], name='prompted_output_name', description='prompted_output_description'), + ) + assert agent.output_json_schema() == snapshot( + {'properties': {'response': {'type': 'boolean'}}, 'required': ['response'], 'type': 'object'} + ) + + +async def test_custom_output_json_schema(): + HumanDict = StructuredDict( + { + 'type': 'object', + 'properties': {'name': {'type': 'string'}, 'age': {'type': 'integer'}}, + 'required': ['name', 'age'], + }, + name='Human', + description='A human with a name and age', + ) + agent = Agent('test', output_type=HumanDict) + assert agent.output_json_schema() == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'final_result'}, + 'data': { + 'properties': {'name': {'type': 'string'}, 'age': {'type': 'integer'}}, + 'required': ['name', 'age'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'final_result', + 'description': 'A human with a name and age', + } + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + ) + + +async def test_override_output_json_schema(): + agent = Agent('test') + assert agent.output_json_schema() == snapshot({'type': 'string'}) + output_type = [ToolOutput(bool, name='alice', description='Dreaming...')] + assert agent.output_json_schema(output_type=output_type) == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'alice'}, + 'data': { + 'properties': {'response': {'type': 'boolean'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'alice', + 'description': 'Dreaming...', + } + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + )