Skip to content

Commit ca678a2

Browse files
cjohnhansonDouweM
andauthored
Let metadata be passed to CallDeferred and ApprovalRequired exceptions and end up on DeferredToolRequests (#3345)
Co-authored-by: Douwe Maan <[email protected]>
1 parent 335555e commit ca678a2

File tree

9 files changed

+253
-35
lines changed

9 files changed

+253
-35
lines changed

docs/deferred-tools.md

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ PROTECTED_FILES = {'.env'}
4747
@agent.tool
4848
def update_file(ctx: RunContext, path: str, content: str) -> str:
4949
if path in PROTECTED_FILES and not ctx.tool_call_approved:
50-
raise ApprovalRequired
50+
raise ApprovalRequired(metadata={'reason': 'protected'}) # (1)!
5151
return f'File {path!r} updated: {content!r}'
5252

5353

@@ -77,6 +77,7 @@ DeferredToolRequests(
7777
tool_call_id='delete_file',
7878
),
7979
],
80+
metadata={'update_file_dotenv': {'reason': 'protected'}},
8081
)
8182
"""
8283

@@ -175,6 +176,8 @@ print(result.all_messages())
175176
"""
176177
```
177178

179+
1. The optional `metadata` parameter can attach arbitrary context to deferred tool calls, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
180+
178181
_(This example is complete, it can be run "as is")_
179182

180183
## External Tool Execution
@@ -209,13 +212,13 @@ from pydantic_ai import (
209212

210213
@dataclass
211214
class TaskResult:
212-
tool_call_id: str
215+
task_id: str
213216
result: Any
214217

215218

216-
async def calculate_answer_task(tool_call_id: str, question: str) -> TaskResult:
219+
async def calculate_answer_task(task_id: str, question: str) -> TaskResult:
217220
await asyncio.sleep(1)
218-
return TaskResult(tool_call_id=tool_call_id, result=42)
221+
return TaskResult(task_id=task_id, result=42)
219222

220223

221224
agent = Agent('openai:gpt-5', output_type=[str, DeferredToolRequests])
@@ -225,12 +228,11 @@ tasks: list[asyncio.Task[TaskResult]] = []
225228

226229
@agent.tool
227230
async def calculate_answer(ctx: RunContext, question: str) -> str:
228-
assert ctx.tool_call_id is not None
229-
230-
task = asyncio.create_task(calculate_answer_task(ctx.tool_call_id, question)) # (1)!
231+
task_id = f'task_{len(tasks)}' # (1)!
232+
task = asyncio.create_task(calculate_answer_task(task_id, question))
231233
tasks.append(task)
232234

233-
raise CallDeferred
235+
raise CallDeferred(metadata={'task_id': task_id}) # (2)!
234236

235237

236238
async def main():
@@ -252,17 +254,19 @@ async def main():
252254
)
253255
],
254256
approvals=[],
257+
metadata={'pyd_ai_tool_call_id': {'task_id': 'task_0'}},
255258
)
256259
"""
257260

258-
done, _ = await asyncio.wait(tasks) # (2)!
261+
done, _ = await asyncio.wait(tasks) # (3)!
259262
task_results = [task.result() for task in done]
260-
task_results_by_tool_call_id = {result.tool_call_id: result.result for result in task_results}
263+
task_results_by_task_id = {result.task_id: result.result for result in task_results}
261264

262265
results = DeferredToolResults()
263266
for call in requests.calls:
264267
try:
265-
result = task_results_by_tool_call_id[call.tool_call_id]
268+
task_id = requests.metadata[call.tool_call_id]['task_id']
269+
result = task_results_by_task_id[task_id]
266270
except KeyError:
267271
result = ModelRetry('No result for this tool call was found.')
268272

@@ -324,8 +328,9 @@ async def main():
324328
"""
325329
```
326330

327-
1. In reality, you'd likely use Celery or a similar task queue to run the task in the background.
328-
2. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete.
331+
1. Generate a task ID that can be tracked independently of the tool call ID.
332+
2. The optional `metadata` parameter passes the `task_id` so it can be matched with results later, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
333+
3. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete.
329334

330335
_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_
331336

docs/toolsets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ DeferredToolRequests(
362362
tool_call_id='pyd_ai_tool_call_id__temperature_fahrenheit',
363363
),
364364
],
365+
metadata={},
365366
)
366367
"""
367368

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,7 @@ async def process_tool_calls( # noqa: C901
888888
calls_to_run = [call for call in calls_to_run if call.tool_call_id in calls_to_run_results]
889889

890890
deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list)
891+
deferred_metadata: dict[str, dict[str, Any]] = {}
891892

892893
if calls_to_run:
893894
async for event in _call_tools(
@@ -899,6 +900,7 @@ async def process_tool_calls( # noqa: C901
899900
usage_limits=ctx.deps.usage_limits,
900901
output_parts=output_parts,
901902
output_deferred_calls=deferred_calls,
903+
output_deferred_metadata=deferred_metadata,
902904
):
903905
yield event
904906

@@ -932,6 +934,7 @@ async def process_tool_calls( # noqa: C901
932934
deferred_tool_requests = _output.DeferredToolRequests(
933935
calls=deferred_calls['external'],
934936
approvals=deferred_calls['unapproved'],
937+
metadata=deferred_metadata,
935938
)
936939

937940
final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None)
@@ -949,10 +952,12 @@ async def _call_tools(
949952
usage_limits: _usage.UsageLimits,
950953
output_parts: list[_messages.ModelRequestPart],
951954
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
955+
output_deferred_metadata: dict[str, dict[str, Any]],
952956
) -> AsyncIterator[_messages.HandleResponseEvent]:
953957
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
954958
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
955959
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
960+
deferred_metadata_by_index: dict[int, dict[str, Any] | None] = {}
956961

957962
if usage_limits.tool_calls_limit is not None:
958963
projected_usage = deepcopy(usage)
@@ -987,10 +992,12 @@ async def handle_call_or_result(
987992
tool_part, tool_user_content = (
988993
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
989994
)
990-
except exceptions.CallDeferred:
995+
except exceptions.CallDeferred as e:
991996
deferred_calls_by_index[index] = 'external'
992-
except exceptions.ApprovalRequired:
997+
deferred_metadata_by_index[index] = e.metadata
998+
except exceptions.ApprovalRequired as e:
993999
deferred_calls_by_index[index] = 'unapproved'
1000+
deferred_metadata_by_index[index] = e.metadata
9941001
else:
9951002
tool_parts_by_index[index] = tool_part
9961003
if tool_user_content:
@@ -1028,8 +1035,25 @@ async def handle_call_or_result(
10281035
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])
10291036
output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)])
10301037

1038+
_populate_deferred_calls(
1039+
tool_calls, deferred_calls_by_index, deferred_metadata_by_index, output_deferred_calls, output_deferred_metadata
1040+
)
1041+
1042+
1043+
def _populate_deferred_calls(
1044+
tool_calls: list[_messages.ToolCallPart],
1045+
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']],
1046+
deferred_metadata_by_index: dict[int, dict[str, Any] | None],
1047+
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
1048+
output_deferred_metadata: dict[str, dict[str, Any]],
1049+
) -> None:
1050+
"""Populate deferred calls and metadata from indexed mappings."""
10311051
for k in sorted(deferred_calls_by_index):
1032-
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
1052+
call = tool_calls[k]
1053+
output_deferred_calls[deferred_calls_by_index[k]].append(call)
1054+
metadata = deferred_metadata_by_index[k]
1055+
if metadata is not None:
1056+
output_deferred_metadata[call.tool_call_id] = metadata
10331057

10341058

10351059
async def _call_tool(

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@ class CallToolParams:
2727

2828
@dataclass
2929
class _ApprovalRequired:
30+
metadata: dict[str, Any] | None = None
3031
kind: Literal['approval_required'] = 'approval_required'
3132

3233

3334
@dataclass
3435
class _CallDeferred:
36+
metadata: dict[str, Any] | None = None
3537
kind: Literal['call_deferred'] = 'call_deferred'
3638

3739

@@ -75,20 +77,20 @@ async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
7577
try:
7678
result = await coro
7779
return _ToolReturn(result=result)
78-
except ApprovalRequired:
79-
return _ApprovalRequired()
80-
except CallDeferred:
81-
return _CallDeferred()
80+
except ApprovalRequired as e:
81+
return _ApprovalRequired(metadata=e.metadata)
82+
except CallDeferred as e:
83+
return _CallDeferred(metadata=e.metadata)
8284
except ModelRetry as e:
8385
return _ModelRetry(message=e.message)
8486

8587
def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
8688
if isinstance(result, _ToolReturn):
8789
return result.result
8890
elif isinstance(result, _ApprovalRequired):
89-
raise ApprovalRequired()
91+
raise ApprovalRequired(metadata=result.metadata)
9092
elif isinstance(result, _CallDeferred):
91-
raise CallDeferred()
93+
raise CallDeferred(metadata=result.metadata)
9294
elif isinstance(result, _ModelRetry):
9395
raise ModelRetry(result.message)
9496
else:

pydantic_ai_slim/pydantic_ai/exceptions.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,30 @@ class CallDeferred(Exception):
7070
"""Exception to raise when a tool call should be deferred.
7171
7272
See [tools docs](../deferred-tools.md#deferred-tools) for more information.
73+
74+
Args:
75+
metadata: Optional dictionary of metadata to attach to the deferred tool call.
76+
This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
7377
"""
7478

75-
pass
79+
def __init__(self, metadata: dict[str, Any] | None = None):
80+
self.metadata = metadata
81+
super().__init__()
7682

7783

7884
class ApprovalRequired(Exception):
7985
"""Exception to raise when a tool call requires human-in-the-loop approval.
8086
8187
See [tools docs](../deferred-tools.md#human-in-the-loop-tool-approval) for more information.
88+
89+
Args:
90+
metadata: Optional dictionary of metadata to attach to the deferred tool call.
91+
This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
8292
"""
8393

84-
pass
94+
def __init__(self, metadata: dict[str, Any] | None = None):
95+
self.metadata = metadata
96+
super().__init__()
8597

8698

8799
class UserError(RuntimeError):

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ class DeferredToolRequests:
147147
"""Tool calls that require external execution."""
148148
approvals: list[ToolCallPart] = field(default_factory=list)
149149
"""Tool calls that require human-in-the-loop approval."""
150+
metadata: dict[str, dict[str, Any]] = field(default_factory=dict)
151+
"""Metadata for deferred tool calls, keyed by `tool_call_id`."""
150152

151153

152154
@dataclass(kw_only=True)

tests/test_streaming.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,9 +1712,7 @@ def my_tool(x: int) -> int:
17121712
[DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])]
17131713
)
17141714
assert await result.get_output() == snapshot(
1715-
DeferredToolRequests(
1716-
calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())],
1717-
)
1715+
DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])
17181716
)
17191717
responses = [c async for c, _is_last in result.stream_responses(debounce_by=None)]
17201718
assert responses == snapshot(
@@ -1757,9 +1755,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int:
17571755
messages = result.all_messages()
17581756
output = await result.get_output()
17591757
assert output == snapshot(
1760-
DeferredToolRequests(
1761-
approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())],
1762-
)
1758+
DeferredToolRequests(approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())])
17631759
)
17641760
assert result.is_complete
17651761

0 commit comments

Comments
 (0)