diff --git a/docs/agents.md b/docs/agents.md index 83d2658cc8..89df03a14e 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -280,7 +280,6 @@ async def main(): [ UserPromptNode( user_prompt='What is the capital of France?', - instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], @@ -343,7 +342,6 @@ async def main(): [ UserPromptNode( user_prompt='What is the capital of France?', - instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], diff --git a/docs/deferred-tools.md b/docs/deferred-tools.md index d3d7917ba2..58e275add3 100644 --- a/docs/deferred-tools.md +++ b/docs/deferred-tools.md @@ -132,17 +132,21 @@ print(result.all_messages()) ), ModelRequest( parts=[ - ToolReturnPart( - tool_name='delete_file', - content='Deleting files is not allowed', - tool_call_id='delete_file', - timestamp=datetime.datetime(...), - ), ToolReturnPart( tool_name='update_file', content="File 'README.md' updated: 'Hello, world!'", tool_call_id='update_file_readme', timestamp=datetime.datetime(...), + ) + ] + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='delete_file', + content='Deleting files is not allowed', + tool_call_id='delete_file', + timestamp=datetime.datetime(...), ), ToolReturnPart( tool_name='update_file', diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index e81711e8e6..19bd5cd224 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -8,7 +8,7 @@ from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar -from dataclasses import field +from dataclasses import field, replace from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast from opentelemetry.trace import Tracer @@ -16,7 +16,7 @@ from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore from pydantic_ai._tool_manager import ToolManager -from pydantic_ai._utils import is_async_callable, run_in_executor +from pydantic_ai._utils import dataclasses_no_defaults_repr, get_union_args, is_async_callable, run_in_executor from pydantic_ai.builtin_tools import AbstractBuiltinTool from pydantic_graph import BaseNode, Graph, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT @@ -26,7 +26,9 @@ from .output import OutputDataT, OutputSpec from .settings import ModelSettings from .tools import ( + DeferredToolCallResult, DeferredToolResult, + DeferredToolResults, RunContext, ToolApproved, ToolDefinition, @@ -123,7 +125,6 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): builtin_tools: list[AbstractBuiltinTool] = dataclasses.field(repr=False) tool_manager: ToolManager[DepsT] - tool_call_results: dict[str, DeferredToolResult] | None tracer: Tracer instrumentation_settings: InstrumentationSettings | None @@ -160,12 +161,16 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]): _: dataclasses.KW_ONLY - instructions: str | None - instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]] + deferred_tool_results: DeferredToolResults | None = None - system_prompts: tuple[str, ...] - system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] - system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] + instructions: str | None = None + instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(default_factory=list) + + system_prompts: tuple[str, ...] = dataclasses.field(default_factory=tuple) + system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(default_factory=list) + system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field( + default_factory=dict + ) async def run( # noqa: C901 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] @@ -181,24 +186,15 @@ async def run( # noqa: C901 messages = ctx_messages.messages ctx_messages.used = True + message_history = _clean_message_history(ctx.state.message_history) # Add message history to the `capture_run_messages` list, which will be empty at this point - messages.extend(ctx.state.message_history) + messages.extend(message_history) # Use the `capture_run_messages` list as the message history so that new messages are added to it ctx.state.message_history = messages + ctx.deps.new_message_index = len(messages) - if (tool_call_results := ctx.deps.tool_call_results) is not None: - if messages and (last_message := messages[-1]) and isinstance(last_message, _messages.ModelRequest): - # If tool call results were provided, that means the previous run ended on deferred tool calls. - # That run would typically have ended on a `ModelResponse`, but if it had a mix of deferred tool calls and ones that could already be executed, - # a `ModelRequest` would already have been added to the history with the preliminary results, even if it wouldn't have been sent to the model yet. - # So now that we have all of the deferred results, we roll back to the last `ModelResponse` and store the contents of the `ModelRequest` on `deferred_tool_results` to be handled by `CallToolsNode`. - ctx.deps.tool_call_results = self._update_tool_call_results_from_model_request( - tool_call_results, last_message - ) - messages.pop() - - if not messages: - raise exceptions.UserError('Tool call results were provided, but the message history is empty.') + if self.deferred_tool_results is not None: + return await self._handle_deferred_tool_results(self.deferred_tool_results, messages, ctx) next_message: _messages.ModelRequest | None = None @@ -222,9 +218,13 @@ async def run( # noqa: C901 combined_content.extend(part.content) ctx.deps.prompt = combined_content elif isinstance(last_message, _messages.ModelResponse): - call_tools_node = await self._handle_message_history_model_response(ctx, last_message) - if call_tools_node is not None: - return call_tools_node + if self.user_prompt is None: + # Skip ModelRequestNode and go directly to CallToolsNode + return CallToolsNode[DepsT, NodeRunEndT](last_message) + elif any(isinstance(part, _messages.ToolCallPart) for part in last_message.parts): + raise exceptions.UserError( + 'Cannot provide a new user prompt when the message history contains unprocessed tool calls.' + ) # Build the run context after `ctx.deps.prompt` has been updated run_context = build_run_context(ctx) @@ -249,73 +249,64 @@ async def run( # noqa: C901 return ModelRequestNode[DepsT, NodeRunEndT](request=next_message) - async def _handle_message_history_model_response( + async def _handle_deferred_tool_results( # noqa: C901 self, + deferred_tool_results: DeferredToolResults, + messages: list[_messages.ModelMessage], ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], - message: _messages.ModelResponse, - ) -> CallToolsNode[DepsT, NodeRunEndT] | None: - unprocessed_tool_calls = any(isinstance(part, _messages.ToolCallPart) for part in message.parts) - if unprocessed_tool_calls: - if self.user_prompt is not None: - raise exceptions.UserError( - 'Cannot provide a new user prompt when the message history contains unprocessed tool calls.' - ) - else: - if ctx.deps.tool_call_results is not None: - raise exceptions.UserError( - 'Tool call results were provided, but the message history does not contain any unprocessed tool calls.' - ) - - if unprocessed_tool_calls or self.user_prompt is None: - # `CallToolsNode` requires the tool manager to be prepared for the run step - # This will raise errors for any tool name conflicts - run_context = build_run_context(ctx) - ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context) - - # Skip ModelRequestNode and go directly to CallToolsNode - return CallToolsNode[DepsT, NodeRunEndT](model_response=message) - - def _update_tool_call_results_from_model_request( - self, tool_call_results: dict[str, DeferredToolResult], message: _messages.ModelRequest - ) -> dict[str, DeferredToolResult]: - last_tool_return: _messages.ToolReturn | None = None - user_content: list[str | _messages.UserContent] = [] - for part in message.parts: - if isinstance(part, _messages.ToolReturnPart): - if part.tool_call_id in tool_call_results: - raise exceptions.UserError( - f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.' - ) - - last_tool_return = _messages.ToolReturn(return_value=part.content, metadata=part.metadata) - tool_call_results[part.tool_call_id] = last_tool_return - elif isinstance(part, _messages.RetryPromptPart): - if part.tool_call_id in tool_call_results: - raise exceptions.UserError( - f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.' - ) - - tool_call_results[part.tool_call_id] = part - elif isinstance(part, _messages.UserPromptPart): - # Tools can return user parts via `ToolReturn.content` or by returning multi-modal content. - # These go together with a specific `ToolReturnPart`, but we don't have a way to know which, - # so (below) we just add them to the last one, matching the tool-results-before-user-parts order of the request. - if isinstance(part.content, str): - user_content.append(part.content) - else: - user_content.extend(part.content) - else: - raise exceptions.UserError(f'Unexpected message part type: {type(part)}') # pragma: no cover + ) -> CallToolsNode[DepsT, NodeRunEndT]: + if not messages: + raise exceptions.UserError('Tool call results were provided, but the message history is empty.') + + last_model_request: _messages.ModelRequest | None = None + last_model_response: _messages.ModelResponse | None = None + for message in reversed(messages): + if isinstance(message, _messages.ModelRequest): + last_model_request = message + elif isinstance(message, _messages.ModelResponse): # pragma: no branch + last_model_response = message + break + + if not last_model_response: + raise exceptions.UserError( + 'Tool call results were provided, but the message history does not contain a `ModelResponse`.' + ) + if not any(isinstance(part, _messages.ToolCallPart) for part in last_model_response.parts): + raise exceptions.UserError( + 'Tool call results were provided, but the message history does not contain any unprocessed tool calls.' + ) + if self.user_prompt is not None: + raise exceptions.UserError( + 'Cannot provide a new user prompt when the message history contains unprocessed tool calls.' + ) - if user_content: - if last_tool_return is None: - raise exceptions.UserError( - 'Tool call results were provided, but the last message in the history was a `ModelRequest` with user parts not tied to preliminary tool results.' - ) - assert last_tool_return is not None - last_tool_return.content = user_content + tool_call_results: dict[str, DeferredToolResult | Literal['skip']] | None = None + tool_call_results = {} + for tool_call_id, approval in deferred_tool_results.approvals.items(): + if approval is True: + approval = ToolApproved() + elif approval is False: + approval = ToolDenied() + tool_call_results[tool_call_id] = approval + + if calls := deferred_tool_results.calls: + call_result_types = get_union_args(DeferredToolCallResult) + for tool_call_id, result in calls.items(): + if not isinstance(result, call_result_types): + result = _messages.ToolReturn(result) + tool_call_results[tool_call_id] = result + + if last_model_request: + for part in last_model_request.parts: + if isinstance(part, _messages.ToolReturnPart | _messages.RetryPromptPart): + if part.tool_call_id in tool_call_results: + raise exceptions.UserError( + f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.' + ) + tool_call_results[part.tool_call_id] = 'skip' - return tool_call_results + # Skip ModelRequestNode and go directly to CallToolsNode + return CallToolsNode[DepsT, NodeRunEndT](last_model_response, tool_call_results=tool_call_results) async def _reevaluate_dynamic_prompts( self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT] @@ -352,6 +343,8 @@ async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.Mod messages.append(_messages.SystemPromptPart(prompt)) return messages + __repr__ = dataclasses_no_defaults_repr + async def _prepare_request_parameters( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], @@ -463,6 +456,7 @@ async def _prepare_request( ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context) message_history = await _process_message_history(ctx.state, ctx.deps.history_processors, run_context) + message_history = _clean_message_history(message_history) model_request_parameters = await _prepare_request_parameters(ctx) model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) @@ -498,12 +492,15 @@ def _finish_handling( return self._result + __repr__ = dataclasses_no_defaults_repr + @dataclasses.dataclass class CallToolsNode(AgentNode[DepsT, NodeRunEndT]): """The node that processes a model response, and decides whether to end the run or make a new request.""" model_response: _messages.ModelResponse + tool_call_results: dict[str, DeferredToolResult | Literal['skip']] | None = None _events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, init=False, repr=False) _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field( @@ -604,11 +601,20 @@ async def _handle_tool_calls( ) -> AsyncIterator[_messages.HandleResponseEvent]: run_context = build_run_context(ctx) + # This will raise errors for any tool name conflicts + ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context) + output_parts: list[_messages.ModelRequestPart] = [] output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1) - async for event in process_function_tools( - ctx.deps.tool_manager, tool_calls, None, ctx, output_parts, output_final_result + async for event in process_tool_calls( + tool_manager=ctx.deps.tool_manager, + tool_calls=tool_calls, + tool_call_results=self.tool_call_results, + final_result=None, + ctx=ctx, + output_parts=output_parts, + output_final_result=output_final_result, ): yield event @@ -661,6 +667,8 @@ async def _handle_text_response( else: return self._handle_final_result(ctx, result.FinalResult(result_data), []) + __repr__ = dataclasses_no_defaults_repr + def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: """Build a `RunContext` object from the current agent graph run context.""" @@ -674,13 +682,14 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT trace_include_content=ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content, run_step=ctx.state.run_step, - tool_call_approved=ctx.state.run_step == 0 and ctx.deps.tool_call_results is not None, + tool_call_approved=ctx.state.run_step == 0, ) -async def process_function_tools( # noqa: C901 +async def process_tool_calls( # noqa: C901 tool_manager: ToolManager[DepsT], tool_calls: list[_messages.ToolCallPart], + tool_call_results: dict[str, DeferredToolResult | Literal['skip']] | None, final_result: result.FinalResult[NodeRunEndT] | None, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], output_parts: list[_messages.ModelRequestPart], @@ -761,14 +770,13 @@ async def process_function_tools( # noqa: C901 ctx.state.increment_retries(ctx.deps.max_result_retries) calls_to_run.extend(tool_calls_by_kind['unknown']) - deferred_tool_results: dict[str, DeferredToolResult] = {} - if build_run_context(ctx).tool_call_approved and ctx.deps.tool_call_results is not None: - deferred_tool_results = ctx.deps.tool_call_results + calls_to_run_results: dict[str, DeferredToolResult] = {} + if tool_call_results is not None: # Deferred tool calls are "run" as well, by reading their value from the tool call results calls_to_run.extend(tool_calls_by_kind['external']) calls_to_run.extend(tool_calls_by_kind['unapproved']) - result_tool_call_ids = set(deferred_tool_results.keys()) + result_tool_call_ids = set(tool_call_results.keys()) tool_call_ids_to_run = {call.tool_call_id for call in calls_to_run} if tool_call_ids_to_run != result_tool_call_ids: raise exceptions.UserError( @@ -776,24 +784,29 @@ async def process_function_tools( # noqa: C901 f'Expected: {tool_call_ids_to_run}, got: {result_tool_call_ids}' ) + # Filter out calls that were already executed before and should now be skipped + calls_to_run_results = {call_id: result for call_id, result in tool_call_results.items() if result != 'skip'} + calls_to_run = [call for call in calls_to_run if call.tool_call_id in calls_to_run_results] + deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list) if calls_to_run: async for event in _call_tools( - tool_manager, - calls_to_run, - deferred_tool_results, - ctx.deps.tracer, - ctx.deps.usage_limits, - output_parts, - deferred_calls, + tool_manager=tool_manager, + tool_calls=calls_to_run, + tool_call_results=calls_to_run_results, + tracer=ctx.deps.tracer, + usage_limits=ctx.deps.usage_limits, + output_parts=output_parts, + output_deferred_calls=deferred_calls, ): yield event # Finally, we handle deferred tool calls (unless they were already included in the run because results were provided) - if not deferred_tool_results: + if tool_call_results is None: + calls = [*tool_calls_by_kind['external'], *tool_calls_by_kind['unapproved']] if final_result: - for call in [*tool_calls_by_kind['external'], *tool_calls_by_kind['unapproved']]: + for call in calls: output_parts.append( _messages.ToolReturnPart( tool_name=call.tool_name, @@ -801,13 +814,11 @@ async def process_function_tools( # noqa: C901 tool_call_id=call.tool_call_id, ) ) - else: - for call in tool_calls_by_kind['external']: - deferred_calls['external'].append(call) - yield _messages.FunctionToolCallEvent(call) + elif calls: + deferred_calls['external'].extend(tool_calls_by_kind['external']) + deferred_calls['unapproved'].extend(tool_calls_by_kind['unapproved']) - for call in tool_calls_by_kind['unapproved']: - deferred_calls['unapproved'].append(call) + for call in calls: yield _messages.FunctionToolCallEvent(call) if not final_result and deferred_calls: @@ -829,7 +840,7 @@ async def process_function_tools( # noqa: C901 async def _call_tools( tool_manager: ToolManager[DepsT], tool_calls: list[_messages.ToolCallPart], - deferred_tool_results: dict[str, DeferredToolResult], + tool_call_results: dict[str, DeferredToolResult], tracer: Tracer, usage_limits: _usage.UsageLimits | None, output_parts: list[_messages.ModelRequestPart], @@ -875,7 +886,7 @@ async def handle_call_or_result( if tool_manager.should_call_sequentially(tool_calls): for index, call in enumerate(tool_calls): if event := await handle_call_or_result( - _call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id), usage_limits), + _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits), index, ): yield event @@ -883,7 +894,7 @@ async def handle_call_or_result( else: tasks = [ asyncio.create_task( - _call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id), usage_limits), + _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits), name=call.tool_name, ) for call in tool_calls @@ -1101,3 +1112,51 @@ async def _process_message_history( # Replaces the message history in the state with the processed messages state.message_history = messages return messages + + +def _clean_message_history(messages: list[_messages.ModelMessage]) -> list[_messages.ModelMessage]: + """Clean the message history by merging consecutive messages of the same type.""" + clean_messages: list[_messages.ModelMessage] = [] + for message in messages: + last_message = clean_messages[-1] if len(clean_messages) > 0 else None + + if isinstance(message, _messages.ModelRequest): + if ( + last_message + and isinstance(last_message, _messages.ModelRequest) + # Requests can only be merged if they have the same instructions + and ( + not last_message.instructions + or not message.instructions + or last_message.instructions == message.instructions + ) + ): + parts = [*last_message.parts, *message.parts] + parts.sort( + # Tool return parts always need to be at the start + key=lambda x: 0 if isinstance(x, _messages.ToolReturnPart | _messages.RetryPromptPart) else 1 + ) + merged_message = _messages.ModelRequest( + parts=parts, + instructions=last_message.instructions or message.instructions, + ) + clean_messages[-1] = merged_message + else: + clean_messages.append(message) + elif isinstance(message, _messages.ModelResponse): # pragma: no branch + if ( + last_message + and isinstance(last_message, _messages.ModelResponse) + # Responses can only be merged if they didn't really come from an API + and last_message.provider_response_id is None + and last_message.provider_name is None + and last_message.model_name is None + and message.provider_response_id is None + and message.provider_name is None + and message.model_name is None + ): + merged_message = replace(last_message, parts=[*last_message.parts, *message.parts]) + clean_messages[-1] = merged_message + else: + clean_messages.append(message) + return clean_messages diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index dd66726603..45678a2d34 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -45,15 +45,11 @@ from ..settings import ModelSettings, merge_model_settings from ..tools import ( AgentDepsT, - DeferredToolCallResult, - DeferredToolResult, DeferredToolResults, DocstringFormat, GenerateToolJsonSchema, RunContext, Tool, - ToolApproved, - ToolDenied, ToolFuncContext, ToolFuncEither, ToolFuncPlain, @@ -462,7 +458,7 @@ def iter( ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager - async def iter( # noqa: C901 + async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, @@ -505,7 +501,6 @@ async def main(): [ UserPromptNode( user_prompt='What is the capital of France?', - instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], @@ -559,7 +554,6 @@ async def main(): del model deps = self._get_deps(deps) - new_message_index = len(message_history) if message_history else 0 output_schema = self._prepare_output_schema(output_type, model_used.profile) output_type_ = output_type or self.output_type @@ -620,27 +614,12 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: instrumentation_settings = None tracer = NoOpTracer() - tool_call_results: dict[str, DeferredToolResult] | None = None - if deferred_tool_results is not None: - tool_call_results = {} - for tool_call_id, approval in deferred_tool_results.approvals.items(): - if approval is True: - approval = ToolApproved() - elif approval is False: - approval = ToolDenied() - tool_call_results[tool_call_id] = approval - - if calls := deferred_tool_results.calls: - call_result_types = _utils.get_union_args(DeferredToolCallResult) - for tool_call_id, result in calls.items(): - if not isinstance(result, call_result_types): - result = _messages.ToolReturn(result) - tool_call_results[tool_call_id] = result - - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( + graph_deps = _agent_graph.GraphAgentDeps[ + AgentDepsT, RunOutputDataT + ]( user_deps=deps, prompt=user_prompt, - new_message_index=new_message_index, + new_message_index=0, # This will be set in `UserPromptNode` based on the length of the cleaned message history model=model_used, model_settings=model_settings, usage_limits=usage_limits, @@ -651,13 +630,13 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: history_processors=self.history_processors, builtin_tools=list(self._builtin_tools), tool_manager=tool_manager, - tool_call_results=tool_call_results, tracer=tracer, get_instructions=get_instructions, instrumentation_settings=instrumentation_settings, ) start_node = _agent_graph.UserPromptNode[AgentDepsT]( user_prompt=user_prompt, + deferred_tool_results=deferred_tool_results, instructions=self._instructions, instructions_functions=self._instructions_functions, system_prompts=self._system_prompts, diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index 2e18954360..8d6c9ff293 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -499,12 +499,13 @@ async def on_complete() -> None: ] parts: list[_messages.ModelRequestPart] = [] - async for _event in _agent_graph.process_function_tools( - graph_ctx.deps.tool_manager, - tool_calls, - final_result, - graph_ctx, - parts, + async for _event in _agent_graph.process_tool_calls( + tool_manager=graph_ctx.deps.tool_manager, + tool_calls=tool_calls, + tool_call_results=None, + final_result=final_result, + ctx=graph_ctx, + output_parts=parts, ): pass if parts: @@ -621,7 +622,6 @@ async def main(): [ UserPromptNode( user_prompt='What is the capital of France?', - instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], diff --git a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py index e53ead8cef..36f7969323 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py @@ -144,7 +144,6 @@ async def main(): [ UserPromptNode( user_prompt='What is the capital of France?', - instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py index 41829c4d72..a12c9e70c0 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py @@ -627,7 +627,6 @@ async def main(): [ UserPromptNode( user_prompt='What is the capital of France?', - instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py index 0b487e39b5..cb284b6097 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py @@ -660,7 +660,6 @@ async def main(): [ UserPromptNode( user_prompt='What is the capital of France?', - instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index 7ed6b848c0..0cc9481043 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -48,7 +48,6 @@ async def main(): [ UserPromptNode( user_prompt='What is the capital of France?', - instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], @@ -183,7 +182,6 @@ async def main(): [ UserPromptNode( user_prompt='What is the capital of France?', - instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], diff --git a/tests/test_a2a.py b/tests/test_a2a.py index f72227f8bc..433ba1f111 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -623,10 +623,10 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), - ) - ] + ), + UserPromptPart(content='Second message', timestamp=IsDatetime()), + ], ), - ModelRequest(parts=[UserPromptPart(content='Second message', timestamp=IsDatetime())]), ] ) diff --git a/tests/test_agent.py b/tests/test_agent.py index 6ee8a0225f..8c90db56b5 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1797,7 +1797,7 @@ async def ret_a(x: str) -> str: ), ] ) - assert result2._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] + assert result2.new_messages() == result2.all_messages()[-2:] assert result2.output == snapshot('{"ret_a":"a-apple"}') assert result2._output_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] assert result2.usage() == snapshot(RunUsage(requests=1, input_tokens=55, output_tokens=13)) @@ -1854,7 +1854,7 @@ async def ret_a(x: str) -> str: ), ] ) - assert result3._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] + assert result3.new_messages() == result3.all_messages()[-2:] assert result3.output == snapshot('{"ret_a":"a-apple"}') assert result3._output_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] assert result3.usage() == snapshot(RunUsage(requests=1, input_tokens=55, output_tokens=13)) @@ -1983,7 +1983,7 @@ async def ret_a(x: str) -> str: ] ) assert result2.output == snapshot(Response(a=0)) - assert result2._new_message_index == snapshot(5) # pyright: ignore[reportPrivateUsage] + assert result2.new_messages() == result2.all_messages()[-3:] assert result2._output_tool_name == snapshot('final_result') # pyright: ignore[reportPrivateUsage] assert result2.usage() == snapshot(RunUsage(requests=1, input_tokens=59, output_tokens=13)) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] @@ -2056,6 +2056,8 @@ async def instructions(ctx: RunContext) -> str: ] ) + assert result.new_messages() == result.all_messages()[-1:] + def test_run_with_history_ending_on_model_response_with_tool_calls_and_no_user_prompt(): """Test that an agent run with message_history ending on ModelResponse starts with CallToolsNode.""" @@ -2109,6 +2111,8 @@ def test_tool() -> str: ] ) + assert result.new_messages() == result.all_messages()[-2:] + def test_run_with_history_ending_on_model_response_with_tool_calls_and_user_prompt(): """Test that an agent run raises error when message_history ends on ModelResponse with tool calls and there's a new prompt.""" @@ -2161,6 +2165,8 @@ def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelRes ] ) + assert result.new_messages() == [] + def test_empty_tool_calls(): def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: @@ -2966,6 +2972,8 @@ async def func() -> str: ] ) + assert res_two.new_messages() == res_two.all_messages()[-2:] + def test_dynamic_true_reevaluate_system_prompt(): """When dynamic is true, the system prompt is reevaluated @@ -3048,6 +3056,8 @@ async def func(): ] ) + assert res_two.new_messages() == res_two.all_messages()[-2:] + def test_dynamic_system_prompt_no_changes(): """Test coverage for _reevaluate_dynamic_prompts branch where no parts are changed @@ -3578,6 +3588,8 @@ def test_instructions_with_message_history(): ] ) + assert result.new_messages() == result.all_messages()[-2:] + def test_instructions_parameter_with_sequence(): def instructions() -> str: @@ -4700,7 +4712,11 @@ def create_file(path: str, content: str) -> str: content='File \'new_file.py\' created with content: print("Hello, world!")', tool_call_id='create_file', timestamp=IsDatetime(), - ), + ) + ] + ), + ModelRequest( + parts=[ ToolReturnPart( tool_name='delete_file', content="File 'ok_to_delete.py' deleted", @@ -4725,6 +4741,33 @@ def create_file(path: str, content: str) -> str: ) assert result.output == snapshot('Done!') + assert result.new_messages() == snapshot( + [ + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='delete_file', + content="File 'ok_to_delete.py' deleted", + tool_call_id='ok_to_delete', + timestamp=IsDatetime(), + ), + ToolReturnPart( + tool_name='delete_file', + content='File cannot be deleted', + tool_call_id='never_delete', + timestamp=IsDatetime(), + ), + ] + ), + ModelResponse( + parts=[TextPart(content='Done!')], + usage=RequestUsage(input_tokens=78, output_tokens=24), + model_name='function:model_function:', + timestamp=IsDatetime(), + ), + ] + ) + async def test_run_with_deferred_tool_results_errors(): agent = Agent('test') @@ -4733,7 +4776,7 @@ async def test_run_with_deferred_tool_results_errors(): with pytest.raises( UserError, - match='Tool call results were provided, but the last message in the history was a `ModelRequest` with user parts not tied to preliminary tool results.', + match='Tool call results were provided, but the message history does not contain a `ModelResponse`.', ): await agent.run( 'Hello again', @@ -4778,6 +4821,15 @@ async def test_run_with_deferred_tool_results_errors(): deferred_tool_results=DeferredToolResults(approvals={'create_file': True}), ) + with pytest.raises( + UserError, match='Cannot provide a new user prompt when the message history contains unprocessed tool calls.' + ): + await agent.run( + 'Hello again', + message_history=message_history, + deferred_tool_results=DeferredToolResults(approvals={'create_file': True}), + ) + message_history: list[ModelMessage] = [ ModelRequest(parts=[UserPromptPart(content='Hello')]), ModelResponse( @@ -4797,7 +4849,6 @@ async def test_run_with_deferred_tool_results_errors(): with pytest.raises(UserError, match="Tool call 'run_me' was already executed and its result cannot be overridden."): await agent.run( - 'Hello again', message_history=message_history, deferred_tool_results=DeferredToolResults( calls={'run_me': 'Failure', 'defer_me': 'Failure'}, @@ -4808,7 +4859,6 @@ async def test_run_with_deferred_tool_results_errors(): UserError, match="Tool call 'run_me_too' was already executed and its result cannot be overridden." ): await agent.run( - 'Hello again', message_history=message_history, deferred_tool_results=DeferredToolResults( calls={'run_me_too': 'Success', 'defer_me': 'Failure'}, @@ -4827,3 +4877,101 @@ def test_tool_requires_approval_error(): @agent.tool_plain(requires_approval=True) def delete_file(path: str) -> None: pass + + +async def test_consecutive_model_responses_in_history(): + received_messages: list[ModelMessage] | None = None + + def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal received_messages + received_messages = messages + return ModelResponse( + parts=[ + TextPart('All right then, goodbye!'), + ] + ) + + history: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Hello...')]), + ModelResponse(parts=[TextPart(content='...world!')]), + ModelResponse(parts=[TextPart(content='Anything else I can help with?')]), + ] + + m = FunctionModel(llm) + agent = Agent(m) + result = await agent.run('No thanks', message_history=history) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Hello...', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='...world!'), TextPart(content='Anything else I can help with?')], + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + UserPromptPart( + content='No thanks', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='All right then, goodbye!')], + usage=RequestUsage(input_tokens=54, output_tokens=12), + model_name='function:llm:', + timestamp=IsDatetime(), + ), + ] + ) + + assert result.new_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='No thanks', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='All right then, goodbye!')], + usage=RequestUsage(input_tokens=54, output_tokens=12), + model_name='function:llm:', + timestamp=IsDatetime(), + ), + ] + ) + + assert received_messages == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Hello...', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='...world!'), TextPart(content='Anything else I can help with?')], + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + UserPromptPart( + content='No thanks', + timestamp=IsDatetime(), + ) + ] + ), + ] + ) diff --git a/tests/test_history_processor.py b/tests/test_history_processor.py index 1aa138935e..14040a84bf 100644 --- a/tests/test_history_processor.py +++ b/tests/test_history_processor.py @@ -1,5 +1,5 @@ from collections.abc import AsyncIterator -from typing import Any, cast +from typing import Any import pytest from inline_snapshot import snapshot @@ -174,8 +174,18 @@ def capture_messages_processor(messages: list[ModelMessage]) -> list[ModelMessag ) assert received_messages == snapshot( [ - ModelRequest(parts=[UserPromptPart(content='Previous question', timestamp=IsDatetime())]), - ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]), + ModelRequest( + parts=[ + UserPromptPart( + content='Previous question', + timestamp=IsDatetime(), + ), + UserPromptPart( + content='New question', + timestamp=IsDatetime(), + ), + ] + ) ] ) @@ -244,8 +254,18 @@ async def async_processor(messages: list[ModelMessage]) -> list[ModelMessage]: await agent.run('Question 2', message_history=message_history) assert received_messages == snapshot( [ - ModelRequest(parts=[UserPromptPart(content='Question 1', timestamp=IsDatetime())]), - ModelRequest(parts=[UserPromptPart(content='Question 2', timestamp=IsDatetime())]), + ModelRequest( + parts=[ + UserPromptPart( + content='Question 1', + timestamp=IsDatetime(), + ), + UserPromptPart( + content='Question 2', + timestamp=IsDatetime(), + ), + ] + ) ] ) @@ -271,8 +291,18 @@ async def async_processor(messages: list[ModelMessage]) -> list[ModelMessage]: assert received_messages == snapshot( [ - ModelRequest(parts=[UserPromptPart(content='Question 1', timestamp=IsDatetime())]), - ModelRequest(parts=[UserPromptPart(content='Question 2', timestamp=IsDatetime())]), + ModelRequest( + parts=[ + UserPromptPart( + content='Question 1', + timestamp=IsDatetime(), + ), + UserPromptPart( + content='Question 2', + timestamp=IsDatetime(), + ), + ] + ) ] ) @@ -367,9 +397,19 @@ class Deps: await agent.run('Question 2', message_history=message_history, deps=Deps()) # Should have filtered responses and added prefix - assert len(received_messages) == 2 - for msg in received_messages: - assert isinstance(msg, ModelRequest) - user_part = msg.parts[0] - assert isinstance(user_part, UserPromptPart) - assert cast(str, user_part.content).startswith('TEST: ') + assert received_messages == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='TEST: Question 1', + timestamp=IsDatetime(), + ), + UserPromptPart( + content='TEST: Question 2', + timestamp=IsDatetime(), + ), + ] + ) + ] + ) diff --git a/tests/test_tools.py b/tests/test_tools.py index f626745c90..0fab7b47b7 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1452,6 +1452,8 @@ def test_output_type_empty(): def test_parallel_tool_return_with_deferred(): + final_received_messages: list[ModelMessage] | None = None + def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse( @@ -1462,9 +1464,12 @@ def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ToolCallPart('get_price', {'fruit': 'grape'}, tool_call_id='get_price_grape'), ToolCallPart('buy', {'fruit': 'apple'}, tool_call_id='buy_apple'), ToolCallPart('buy', {'fruit': 'banana'}, tool_call_id='buy_banana'), + ToolCallPart('buy', {'fruit': 'pear'}, tool_call_id='buy_pear'), ] ) else: + nonlocal final_received_messages + final_received_messages = messages return ModelResponse( parts=[ TextPart('Done!'), @@ -1509,8 +1514,9 @@ def buy(fruit: str): ToolCallPart(tool_name='get_price', args={'fruit': 'grape'}, tool_call_id='get_price_grape'), ToolCallPart(tool_name='buy', args={'fruit': 'apple'}, tool_call_id='buy_apple'), ToolCallPart(tool_name='buy', args={'fruit': 'banana'}, tool_call_id='buy_banana'), + ToolCallPart(tool_name='buy', args={'fruit': 'pear'}, tool_call_id='buy_pear'), ], - usage=RequestUsage(input_tokens=68, output_tokens=30), + usage=RequestUsage(input_tokens=68, output_tokens=35), model_name='function:llm:', timestamp=IsDatetime(), ), @@ -1559,6 +1565,7 @@ def buy(fruit: str): calls=[ ToolCallPart(tool_name='buy', args={'fruit': 'apple'}, tool_call_id='buy_apple'), ToolCallPart(tool_name='buy', args={'fruit': 'banana'}, tool_call_id='buy_banana'), + ToolCallPart(tool_name='buy', args={'fruit': 'pear'}, tool_call_id='buy_pear'), ], ) ) @@ -1573,6 +1580,9 @@ def buy(fruit: str): content='I bought a banana', metadata={'fruit': 'banana', 'price': 100.0}, ), + 'buy_pear': RetryPromptPart( + content='The purchase of pears was denied.', + ), }, ), ) @@ -1594,8 +1604,9 @@ def buy(fruit: str): ToolCallPart(tool_name='get_price', args={'fruit': 'grape'}, tool_call_id='get_price_grape'), ToolCallPart(tool_name='buy', args={'fruit': 'apple'}, tool_call_id='buy_apple'), ToolCallPart(tool_name='buy', args={'fruit': 'banana'}, tool_call_id='buy_banana'), + ToolCallPart(tool_name='buy', args={'fruit': 'pear'}, tool_call_id='buy_pear'), ], - usage=RequestUsage(input_tokens=68, output_tokens=30), + usage=RequestUsage(input_tokens=68, output_tokens=35), model_name='function:llm:', timestamp=IsDatetime(), ), @@ -1627,6 +1638,18 @@ def buy(fruit: str): tool_call_id='get_price_grape', timestamp=IsDatetime(), ), + UserPromptPart( + content='The price of apple is 10.0.', + timestamp=IsDatetime(), + ), + UserPromptPart( + content='The price of pear is 10.0.', + timestamp=IsDatetime(), + ), + ] + ), + ModelRequest( + parts=[ RetryPromptPart( content='Apples are not available', tool_name='buy', @@ -1640,11 +1663,48 @@ def buy(fruit: str): metadata={'fruit': 'banana', 'price': 100.0}, timestamp=IsDatetime(), ), + RetryPromptPart( + content='The purchase of pears was denied.', + tool_name='buy', + tool_call_id='buy_pear', + timestamp=IsDatetime(), + ), UserPromptPart( - content=[ - 'The price of apple is 10.0.', - 'The price of pear is 10.0.', - ], + content='I bought a banana', + timestamp=IsDatetime(), + ), + ] + ), + ModelResponse( + parts=[TextPart(content='Done!')], + usage=RequestUsage(input_tokens=137, output_tokens=36), + model_name='function:llm:', + timestamp=IsDatetime(), + ), + ] + ) + + assert result.new_messages() == snapshot( + [ + ModelRequest( + parts=[ + RetryPromptPart( + content='Apples are not available', + tool_name='buy', + tool_call_id='buy_apple', + timestamp=IsDatetime(), + ), + ToolReturnPart( + tool_name='buy', + content=True, + tool_call_id='buy_banana', + metadata={'fruit': 'banana', 'price': 100.0}, + timestamp=IsDatetime(), + ), + RetryPromptPart( + content='The purchase of pears was denied.', + tool_name='buy', + tool_call_id='buy_pear', timestamp=IsDatetime(), ), UserPromptPart( @@ -1655,13 +1715,100 @@ def buy(fruit: str): ), ModelResponse( parts=[TextPart(content='Done!')], - usage=RequestUsage(input_tokens=124, output_tokens=31), + usage=RequestUsage(input_tokens=137, output_tokens=36), model_name='function:llm:', timestamp=IsDatetime(), ), ] ) - assert result.output == snapshot('Done!') + + assert final_received_messages == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What do an apple, a banana, a pear and a grape cost? Also buy me a pear.', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_price', args={'fruit': 'apple'}, tool_call_id='get_price_apple'), + ToolCallPart(tool_name='get_price', args={'fruit': 'banana'}, tool_call_id='get_price_banana'), + ToolCallPart(tool_name='get_price', args={'fruit': 'pear'}, tool_call_id='get_price_pear'), + ToolCallPart(tool_name='get_price', args={'fruit': 'grape'}, tool_call_id='get_price_grape'), + ToolCallPart(tool_name='buy', args={'fruit': 'apple'}, tool_call_id='buy_apple'), + ToolCallPart(tool_name='buy', args={'fruit': 'banana'}, tool_call_id='buy_banana'), + ToolCallPart(tool_name='buy', args={'fruit': 'pear'}, tool_call_id='buy_pear'), + ], + usage=RequestUsage(input_tokens=68, output_tokens=35), + model_name='function:llm:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_price', + content=10.0, + tool_call_id='get_price_apple', + metadata={'fruit': 'apple', 'price': 10.0}, + timestamp=IsDatetime(), + ), + RetryPromptPart( + content='Unknown fruit: banana', + tool_name='get_price', + tool_call_id='get_price_banana', + timestamp=IsDatetime(), + ), + ToolReturnPart( + tool_name='get_price', + content=10.0, + tool_call_id='get_price_pear', + metadata={'fruit': 'pear', 'price': 10.0}, + timestamp=IsDatetime(), + ), + RetryPromptPart( + content='Unknown fruit: grape', + tool_name='get_price', + tool_call_id='get_price_grape', + timestamp=IsDatetime(), + ), + RetryPromptPart( + content='Apples are not available', + tool_name='buy', + tool_call_id='buy_apple', + timestamp=IsDatetime(), + ), + ToolReturnPart( + tool_name='buy', + content=True, + tool_call_id='buy_banana', + metadata={'fruit': 'banana', 'price': 100.0}, + timestamp=IsDatetime(), + ), + RetryPromptPart( + content='The purchase of pears was denied.', + tool_name='buy', + tool_call_id='buy_pear', + timestamp=IsDatetime(), + ), + UserPromptPart( + content='The price of apple is 10.0.', + timestamp=IsDatetime(), + ), + UserPromptPart( + content='The price of pear is 10.0.', + timestamp=IsDatetime(), + ), + UserPromptPart( + content='I bought a banana', + timestamp=IsDatetime(), + ), + ] + ), + ] + ) def test_deferred_tool_call_approved_fails(): @@ -1800,6 +1947,16 @@ def bar(x: int) -> int: model_name='function:llm:', timestamp=IsDatetime(), ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='bar', + content=9, + tool_call_id='bar', + timestamp=IsDatetime(), + ) + ] + ), ModelRequest( parts=[ ToolReturnPart( @@ -1814,12 +1971,6 @@ def bar(x: int) -> int: tool_call_id='foo2', timestamp=IsDatetime(), ), - ToolReturnPart( - tool_name='bar', - content=9, - tool_call_id='bar', - timestamp=IsDatetime(), - ), ] ), ModelResponse(