Skip to content

Commit b7d8b42

Browse files
committed
async tools
1 parent 471a6c1 commit b7d8b42

File tree

13 files changed

+402
-346
lines changed

13 files changed

+402
-346
lines changed

src/strands/agent/agent.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,18 @@ def caller(
128128
"input": kwargs.copy(),
129129
}
130130

131-
# Execute the tool
132-
events = run_tool(self._agent, tool_use, kwargs)
131+
async def acall() -> ToolResult:
132+
async for event in run_tool(self._agent, tool_use, kwargs):
133+
_ = event
133134

134-
try:
135-
while True:
136-
next(events)
137-
except StopIteration as stop:
138-
tool_result = cast(ToolResult, stop.value)
135+
return cast(ToolResult, event)
136+
137+
def tcall() -> ToolResult:
138+
return asyncio.run(acall())
139+
140+
with ThreadPoolExecutor() as executor:
141+
future = executor.submit(tcall)
142+
tool_result = future.result()
139143

140144
if record_direct_tool_call is not None:
141145
should_record_direct_tool_call = record_direct_tool_call
@@ -219,8 +223,8 @@ def __init__(
219223
If explicitly set to None, null_callback_handler is used.
220224
conversation_manager: Manager for conversation history and context window.
221225
Defaults to strands.agent.conversation_manager.SlidingWindowConversationManager if None.
222-
max_parallel_tools: Maximum number of tools to run in parallel when the model returns multiple tool calls.
223-
Defaults to os.cpu_count() or 1.
226+
max_parallel_tools: [DEPRECATED] Maximum number of tools to run in parallel when the model returns multiple
227+
tool calls. Defaults to os.cpu_count() or 1.
224228
record_direct_tool_call: Whether to record direct tool calls in message history.
225229
Defaults to True.
226230
load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory.
@@ -263,14 +267,6 @@ def __init__(
263267
):
264268
self.trace_attributes[k] = v
265269

266-
# If max_parallel_tools is 1, we execute tools sequentially
267-
self.thread_pool = None
268-
self.thread_pool_wrapper = None
269-
if max_parallel_tools > 1:
270-
self.thread_pool = ThreadPoolExecutor(max_workers=max_parallel_tools)
271-
elif max_parallel_tools < 1:
272-
raise ValueError("max_parallel_tools must be greater than 0")
273-
274270
self.record_direct_tool_call = record_direct_tool_call
275271
self.load_tools_from_directory = load_tools_from_directory
276272

@@ -344,15 +340,6 @@ def tool_config(self) -> ToolConfig:
344340
"""
345341
return self.tool_registry.initialize_tool_config()
346342

347-
def __del__(self) -> None:
348-
"""Clean up resources when Agent is garbage collected.
349-
350-
Ensures proper shutdown of the thread pool executor if one exists.
351-
"""
352-
if self.thread_pool:
353-
self.thread_pool.shutdown(wait=False)
354-
logger.debug("thread pool executor shutdown complete")
355-
356343
def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
357344
"""Process a natural language prompt through the agent's event loop.
358345

src/strands/event_loop/event_loop.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ async def recurse_event_loop(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGen
256256
recursive_trace.end()
257257

258258

259-
def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerator:
259+
async def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerator:
260260
"""Process a tool invocation.
261261
262262
Looks up the tool in the registry and streams it with the provided parameters.
@@ -267,10 +267,7 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG
267267
kwargs: Additional keyword arguments passed to the tool.
268268
269269
Yields:
270-
Events of the tool stream.
271-
272-
Returns:
273-
The final tool result or an error response if the tool fails or is not found.
270+
Tool events with the last being the tool result.
274271
"""
275272
logger.debug("tool_use=<%s> | streaming", tool_use)
276273
tool_name = tool_use["name"]
@@ -332,9 +329,14 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG
332329
result=result,
333330
)
334331
)
335-
return after_event.result
332+
yield after_event.result
333+
return
334+
335+
async for event in selected_tool.stream(tool_use, kwargs):
336+
yield event
337+
338+
result = event
336339

337-
result = yield from selected_tool.stream(tool_use, **kwargs)
338340
after_event = get_registry(agent).invoke_callbacks(
339341
AfterToolInvocationEvent(
340342
agent=agent,
@@ -344,7 +346,7 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG
344346
result=result,
345347
)
346348
)
347-
return after_event.result
349+
yield after_event.result
348350

349351
except Exception as e:
350352
logger.exception("tool_name=<%s> | failed to process tool", tool_name)
@@ -363,7 +365,7 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG
363365
exception=e,
364366
)
365367
)
366-
return after_event.result
368+
yield after_event.result
367369

368370

369371
async def _handle_tool_execution(
@@ -417,9 +419,8 @@ def tool_handler(tool_use: ToolUse) -> ToolGenerator:
417419
tool_results=tool_results,
418420
cycle_trace=cycle_trace,
419421
parent_span=cycle_span,
420-
thread_pool=agent.thread_pool,
421422
)
422-
for tool_event in tool_events:
423+
async for tool_event in tool_events:
423424
yield tool_event
424425

425426
# Store parent cycle ID for the next cycle

src/strands/tools/decorator.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
4040
```
4141
"""
4242

43+
import asyncio
4344
import functools
4445
import inspect
4546
import logging
@@ -52,7 +53,6 @@ def my_tool(param1: str, param2: int = 42) -> dict:
5253
Type,
5354
TypeVar,
5455
Union,
55-
cast,
5656
get_type_hints,
5757
overload,
5858
)
@@ -61,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
6161
from pydantic import BaseModel, Field, create_model
6262
from typing_extensions import override
6363

64-
from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolResult, ToolSpec, ToolUse
64+
from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolSpec, ToolUse
6565

6666
logger = logging.getLogger(__name__)
6767

@@ -372,7 +372,7 @@ def tool_type(self) -> str:
372372
return "function"
373373

374374
@override
375-
def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolGenerator:
375+
async def stream(self, tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerator:
376376
"""Stream the tool with a tool use specification.
377377
378378
This method handles tool use streams from a Strands Agent. It validates the input,
@@ -388,14 +388,10 @@ def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> Too
388388
389389
Args:
390390
tool_use: The tool use specification from the Agent.
391-
*args: Additional positional arguments (not typically used).
392-
**kwargs: Additional keyword arguments, may include 'agent' reference.
391+
kwargs: Additional keyword arguments, may include 'agent' reference.
393392
394393
Yields:
395-
Events of the tool stream.
396-
397-
Returns:
398-
A standardized tool result dictionary with status and content.
394+
Tool events with the last being the tool result.
399395
"""
400396
# This is a tool use call - process accordingly
401397
tool_use_id = tool_use.get("toolUseId", "unknown")
@@ -409,19 +405,21 @@ def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> Too
409405
if "agent" in kwargs and "agent" in self._metadata.signature.parameters:
410406
validated_input["agent"] = kwargs.get("agent")
411407

412-
result = self._tool_func(**validated_input) # type: ignore # "Too few arguments" expected
413-
if inspect.isgenerator(result):
414-
result = yield from result
408+
# "Too few arguments" expected, hence the type ignore
409+
if inspect.iscoroutinefunction(self._tool_func):
410+
result = await self._tool_func(**validated_input) # type: ignore
411+
else:
412+
result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore
415413

416414
# FORMAT THE RESULT for Strands Agent
417415
if isinstance(result, dict) and "status" in result and "content" in result:
418416
# Result is already in the expected format, just add toolUseId
419417
result["toolUseId"] = tool_use_id
420-
return cast(ToolResult, result)
418+
yield result
421419
else:
422420
# Wrap any other return value in the standard format
423421
# Always include at least one content item for consistency
424-
return {
422+
yield {
425423
"toolUseId": tool_use_id,
426424
"status": "success",
427425
"content": [{"text": str(result)}],
@@ -430,7 +428,7 @@ def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> Too
430428
except ValueError as e:
431429
# Special handling for validation errors
432430
error_msg = str(e)
433-
return {
431+
yield {
434432
"toolUseId": tool_use_id,
435433
"status": "error",
436434
"content": [{"text": f"Error: {error_msg}"}],
@@ -439,7 +437,7 @@ def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> Too
439437
# Return error result with exception details for any other error
440438
error_type = type(e).__name__
441439
error_msg = str(e)
442-
return {
440+
yield {
443441
"toolUseId": tool_use_id,
444442
"status": "error",
445443
"content": [{"text": f"Error: {error_type} - {error_msg}"}],

src/strands/tools/executor.py

Lines changed: 40 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
"""Tool execution functionality for the event loop."""
22

3+
import asyncio
34
import logging
4-
import queue
5-
import threading
65
import time
7-
from concurrent.futures import ThreadPoolExecutor
8-
from typing import Any, Generator, Optional, cast
6+
from typing import Any, Optional, cast
97

108
from opentelemetry import trace
119

@@ -18,17 +16,16 @@
1816
logger = logging.getLogger(__name__)
1917

2018

21-
def run_tools(
19+
async def run_tools(
2220
handler: RunToolHandler,
2321
tool_uses: list[ToolUse],
2422
event_loop_metrics: EventLoopMetrics,
2523
invalid_tool_use_ids: list[str],
2624
tool_results: list[ToolResult],
2725
cycle_trace: Trace,
2826
parent_span: Optional[trace.Span] = None,
29-
thread_pool: Optional[ThreadPoolExecutor] = None,
30-
) -> Generator[dict[str, Any], None, None]:
31-
"""Execute tools either in parallel or sequentially.
27+
) -> ToolGenerator:
28+
"""Execute tools concurrently.
3229
3330
Args:
3431
handler: Tool handler processing function.
@@ -38,21 +35,33 @@ def run_tools(
3835
tool_results: List to populate with tool results.
3936
cycle_trace: Parent trace for the current cycle.
4037
parent_span: Parent span for the current cycle.
41-
thread_pool: Optional thread pool for parallel processing.
4238
4339
Yields:
4440
Events of the tool stream. Tool results are appended to `tool_results`.
4541
"""
4642

47-
def handle(tool_use: ToolUse) -> ToolGenerator:
43+
async def work(
44+
tool_use: ToolUse,
45+
worker_id: int,
46+
worker_queue: asyncio.Queue,
47+
worker_event: asyncio.Event,
48+
stop_event: object,
49+
) -> ToolResult:
4850
tracer = get_tracer()
4951
tool_call_span = tracer.start_tool_call_span(tool_use, parent_span)
5052

5153
tool_name = tool_use["name"]
5254
tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
5355
tool_start_time = time.time()
5456

55-
result = yield from handler(tool_use)
57+
try:
58+
async for event in handler(tool_use):
59+
worker_queue.put_nowait((worker_id, event))
60+
await worker_event.wait()
61+
62+
result = cast(ToolResult, event)
63+
finally:
64+
worker_queue.put_nowait((worker_id, stop_event))
5665

5766
tool_success = result.get("status") == "success"
5867
tool_duration = time.time() - tool_start_time
@@ -65,52 +74,27 @@ def handle(tool_use: ToolUse) -> ToolGenerator:
6574

6675
return result
6776

68-
def work(
69-
tool_use: ToolUse,
70-
worker_id: int,
71-
worker_queue: queue.Queue,
72-
worker_event: threading.Event,
73-
) -> ToolResult:
74-
events = handle(tool_use)
75-
76-
try:
77-
while True:
78-
event = next(events)
79-
worker_queue.put((worker_id, event))
80-
worker_event.wait()
81-
82-
except StopIteration as stop:
83-
return cast(ToolResult, stop.value)
84-
8577
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
86-
87-
if thread_pool:
88-
logger.debug("tool_count=<%s> | executing tools in parallel", len(tool_uses))
89-
90-
worker_queue: queue.Queue[tuple[int, dict[str, Any]]] = queue.Queue()
91-
worker_events = [threading.Event() for _ in range(len(tool_uses))]
92-
93-
workers = [
94-
thread_pool.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id])
95-
for worker_id, tool_use in enumerate(tool_uses)
96-
]
97-
logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses))
98-
99-
while not all(worker.done() for worker in workers):
100-
if not worker_queue.empty():
101-
worker_id, event = worker_queue.get()
102-
yield event
103-
worker_events[worker_id].set()
104-
105-
time.sleep(0.001)
106-
107-
tool_results.extend([worker.result() for worker in workers])
108-
109-
else:
110-
# Sequential execution fallback
111-
for tool_use in tool_uses:
112-
result = yield from handle(tool_use)
113-
tool_results.append(result)
78+
worker_queue: asyncio.Queue[tuple[int, Any]] = asyncio.Queue()
79+
worker_events = [asyncio.Event() for _ in tool_uses]
80+
stop_event = object()
81+
82+
workers = [
83+
asyncio.create_task(work(tool_use, worker_id, worker_queue, worker_events[worker_id], stop_event))
84+
for worker_id, tool_use in enumerate(tool_uses)
85+
]
86+
87+
worker_count = len(workers)
88+
while worker_count:
89+
worker_id, event = await worker_queue.get()
90+
if event is stop_event:
91+
worker_count -= 1
92+
continue
93+
94+
yield event
95+
worker_events[worker_id].set()
96+
97+
tool_results.extend([worker.result() for worker in workers])
11498

11599

116100
def validate_and_prepare_tools(

0 commit comments

Comments
 (0)