Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 11 additions & 31 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import asyncio
import json
import logging
import os
import random
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
Expand Down Expand Up @@ -128,14 +127,18 @@ def caller(
"input": kwargs.copy(),
}

# Execute the tool
events = run_tool(self._agent, tool_use, kwargs)
async def acall() -> ToolResult:
async for event in run_tool(self._agent, tool_use, kwargs):
_ = event

try:
while True:
next(events)
except StopIteration as stop:
tool_result = cast(ToolResult, stop.value)
return cast(ToolResult, event)

def tcall() -> ToolResult:
return asyncio.run(acall())

with ThreadPoolExecutor() as executor:
future = executor.submit(tcall)
tool_result = future.result()

if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
Expand Down Expand Up @@ -186,7 +189,6 @@ def __init__(
Union[Callable[..., Any], _DefaultCallbackHandlerSentinel]
] = _DEFAULT_CALLBACK_HANDLER,
conversation_manager: Optional[ConversationManager] = None,
max_parallel_tools: int = os.cpu_count() or 1,
record_direct_tool_call: bool = True,
load_tools_from_directory: bool = True,
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
Expand Down Expand Up @@ -219,8 +221,6 @@ def __init__(
If explicitly set to None, null_callback_handler is used.
conversation_manager: Manager for conversation history and context window.
Defaults to strands.agent.conversation_manager.SlidingWindowConversationManager if None.
max_parallel_tools: Maximum number of tools to run in parallel when the model returns multiple tool calls.
Defaults to os.cpu_count() or 1.
record_direct_tool_call: Whether to record direct tool calls in message history.
Defaults to True.
load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory.
Expand All @@ -232,9 +232,6 @@ def __init__(
Defaults to None.
state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict.
Defaults to an empty AgentState object.

Raises:
ValueError: If max_parallel_tools is less than 1.
"""
self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model
self.messages = messages if messages is not None else []
Expand Down Expand Up @@ -263,14 +260,6 @@ def __init__(
):
self.trace_attributes[k] = v

# If max_parallel_tools is 1, we execute tools sequentially
self.thread_pool = None
self.thread_pool_wrapper = None
if max_parallel_tools > 1:
self.thread_pool = ThreadPoolExecutor(max_workers=max_parallel_tools)
elif max_parallel_tools < 1:
raise ValueError("max_parallel_tools must be greater than 0")

self.record_direct_tool_call = record_direct_tool_call
self.load_tools_from_directory = load_tools_from_directory

Expand Down Expand Up @@ -335,15 +324,6 @@ def tool_names(self) -> list[str]:
all_tools = self.tool_registry.get_all_tools_config()
return list(all_tools.keys())

def __del__(self) -> None:
"""Clean up resources when Agent is garbage collected.

Ensures proper shutdown of the thread pool executor if one exists.
"""
if self.thread_pool:
self.thread_pool.shutdown(wait=False)
logger.debug("thread pool executor shutdown complete")

def __call__(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.

Expand Down
23 changes: 12 additions & 11 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ async def recurse_event_loop(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGen
recursive_trace.end()


def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerator:
async def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerator:
"""Process a tool invocation.

Looks up the tool in the registry and streams it with the provided parameters.
Expand All @@ -263,10 +263,7 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG
kwargs: Additional keyword arguments passed to the tool.

Yields:
Events of the tool stream.

Returns:
The final tool result or an error response if the tool fails or is not found.
Tool events with the last being the tool result.
"""
logger.debug("tool_use=<%s> | streaming", tool_use)
tool_name = tool_use["name"]
Expand Down Expand Up @@ -331,9 +328,14 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG
result=result,
)
)
return after_event.result
yield after_event.result
return

async for event in selected_tool.stream(tool_use, kwargs):
yield event

result = event

result = yield from selected_tool.stream(tool_use, **kwargs)
after_event = get_registry(agent).invoke_callbacks(
AfterToolInvocationEvent(
agent=agent,
Expand All @@ -343,7 +345,7 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG
result=result,
)
)
return after_event.result
yield after_event.result

except Exception as e:
logger.exception("tool_name=<%s> | failed to process tool", tool_name)
Expand All @@ -362,7 +364,7 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG
exception=e,
)
)
return after_event.result
yield after_event.result


async def _handle_tool_execution(
Expand Down Expand Up @@ -416,9 +418,8 @@ def tool_handler(tool_use: ToolUse) -> ToolGenerator:
tool_results=tool_results,
cycle_trace=cycle_trace,
parent_span=cycle_span,
thread_pool=agent.thread_pool,
)
for tool_event in tool_events:
async for tool_event in tool_events:
yield tool_event

# Store parent cycle ID for the next cycle
Expand Down
30 changes: 14 additions & 16 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
```
"""

import asyncio
import functools
import inspect
import logging
Expand All @@ -52,7 +53,6 @@ def my_tool(param1: str, param2: int = 42) -> dict:
Type,
TypeVar,
Union,
cast,
get_type_hints,
overload,
)
Expand All @@ -61,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
from pydantic import BaseModel, Field, create_model
from typing_extensions import override

from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolResult, ToolSpec, ToolUse
from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolSpec, ToolUse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -372,7 +372,7 @@ def tool_type(self) -> str:
return "function"

@override
def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolGenerator:
async def stream(self, tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerator:
"""Stream the tool with a tool use specification.

This method handles tool use streams from a Strands Agent. It validates the input,
Expand All @@ -388,14 +388,10 @@ def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> Too

Args:
tool_use: The tool use specification from the Agent.
*args: Additional positional arguments (not typically used).
**kwargs: Additional keyword arguments, may include 'agent' reference.
kwargs: Additional keyword arguments, may include 'agent' reference.

Yields:
Events of the tool stream.

Returns:
A standardized tool result dictionary with status and content.
Tool events with the last being the tool result.
"""
# This is a tool use call - process accordingly
tool_use_id = tool_use.get("toolUseId", "unknown")
Expand All @@ -409,19 +405,21 @@ def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> Too
if "agent" in kwargs and "agent" in self._metadata.signature.parameters:
validated_input["agent"] = kwargs.get("agent")

result = self._tool_func(**validated_input) # type: ignore # "Too few arguments" expected
if inspect.isgenerator(result):
result = yield from result
# "Too few arguments" expected, hence the type ignore
if inspect.iscoroutinefunction(self._tool_func):
result = await self._tool_func(**validated_input) # type: ignore
else:
result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore

# FORMAT THE RESULT for Strands Agent
if isinstance(result, dict) and "status" in result and "content" in result:
# Result is already in the expected format, just add toolUseId
result["toolUseId"] = tool_use_id
return cast(ToolResult, result)
yield result
else:
# Wrap any other return value in the standard format
# Always include at least one content item for consistency
return {
yield {
"toolUseId": tool_use_id,
"status": "success",
"content": [{"text": str(result)}],
Expand All @@ -430,7 +428,7 @@ def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> Too
except ValueError as e:
# Special handling for validation errors
error_msg = str(e)
return {
yield {
"toolUseId": tool_use_id,
"status": "error",
"content": [{"text": f"Error: {error_msg}"}],
Expand All @@ -439,7 +437,7 @@ def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> Too
# Return error result with exception details for any other error
error_type = type(e).__name__
error_msg = str(e)
return {
yield {
"toolUseId": tool_use_id,
"status": "error",
"content": [{"text": f"Error: {error_type} - {error_msg}"}],
Expand Down
96 changes: 40 additions & 56 deletions src/strands/tools/executor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Tool execution functionality for the event loop."""

import asyncio
import logging
import queue
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Generator, Optional, cast
from typing import Any, Optional, cast

from opentelemetry import trace

Expand All @@ -18,17 +16,16 @@
logger = logging.getLogger(__name__)


def run_tools(
async def run_tools(
handler: RunToolHandler,
tool_uses: list[ToolUse],
event_loop_metrics: EventLoopMetrics,
invalid_tool_use_ids: list[str],
tool_results: list[ToolResult],
cycle_trace: Trace,
parent_span: Optional[trace.Span] = None,
thread_pool: Optional[ThreadPoolExecutor] = None,
) -> Generator[dict[str, Any], None, None]:
"""Execute tools either in parallel or sequentially.
) -> ToolGenerator:
"""Execute tools concurrently.

Args:
handler: Tool handler processing function.
Expand All @@ -38,21 +35,33 @@ def run_tools(
tool_results: List to populate with tool results.
cycle_trace: Parent trace for the current cycle.
parent_span: Parent span for the current cycle.
thread_pool: Optional thread pool for parallel processing.

Yields:
Events of the tool stream. Tool results are appended to `tool_results`.
"""

def handle(tool_use: ToolUse) -> ToolGenerator:
async def work(
tool_use: ToolUse,
worker_id: int,
worker_queue: asyncio.Queue,
worker_event: asyncio.Event,
stop_event: object,
) -> ToolResult:
tracer = get_tracer()
tool_call_span = tracer.start_tool_call_span(tool_use, parent_span)

tool_name = tool_use["name"]
tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
tool_start_time = time.time()

result = yield from handler(tool_use)
try:
async for event in handler(tool_use):
worker_queue.put_nowait((worker_id, event))
await worker_event.wait()

result = cast(ToolResult, event)
finally:
worker_queue.put_nowait((worker_id, stop_event))

tool_success = result.get("status") == "success"
tool_duration = time.time() - tool_start_time
Expand All @@ -65,52 +74,27 @@ def handle(tool_use: ToolUse) -> ToolGenerator:

return result

def work(
tool_use: ToolUse,
worker_id: int,
worker_queue: queue.Queue,
worker_event: threading.Event,
) -> ToolResult:
events = handle(tool_use)

try:
while True:
event = next(events)
worker_queue.put((worker_id, event))
worker_event.wait()

except StopIteration as stop:
return cast(ToolResult, stop.value)

tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]

if thread_pool:
logger.debug("tool_count=<%s> | executing tools in parallel", len(tool_uses))

worker_queue: queue.Queue[tuple[int, dict[str, Any]]] = queue.Queue()
worker_events = [threading.Event() for _ in range(len(tool_uses))]

workers = [
thread_pool.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id])
for worker_id, tool_use in enumerate(tool_uses)
]
logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses))

while not all(worker.done() for worker in workers):
if not worker_queue.empty():
worker_id, event = worker_queue.get()
yield event
worker_events[worker_id].set()

time.sleep(0.001)

tool_results.extend([worker.result() for worker in workers])

else:
# Sequential execution fallback
for tool_use in tool_uses:
result = yield from handle(tool_use)
tool_results.append(result)
worker_queue: asyncio.Queue[tuple[int, Any]] = asyncio.Queue()
worker_events = [asyncio.Event() for _ in tool_uses]
stop_event = object()

workers = [
asyncio.create_task(work(tool_use, worker_id, worker_queue, worker_events[worker_id], stop_event))
for worker_id, tool_use in enumerate(tool_uses)
]

worker_count = len(workers)
while worker_count:
worker_id, event = await worker_queue.get()
if event is stop_event:
worker_count -= 1
continue

yield event
worker_events[worker_id].set()

tool_results.extend([worker.result() for worker in workers])


def validate_and_prepare_tools(
Expand Down
Loading