Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/strands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@
from . import agent, event_loop, models, telemetry, types
from .agent.agent import Agent
from .tools.decorator import tool
from .tools.thread_pool_executor import ThreadPoolExecutorWrapper

__all__ = ["Agent", "ThreadPoolExecutorWrapper", "agent", "event_loop", "models", "tool", "types", "telemetry"]
__all__ = ["Agent", "agent", "event_loop", "models", "tool", "types", "telemetry"]
8 changes: 3 additions & 5 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from ..telemetry.metrics import EventLoopMetrics
from ..telemetry.tracer import get_tracer
from ..tools.registry import ToolRegistry
from ..tools.thread_pool_executor import ThreadPoolExecutorWrapper
from ..tools.watcher import ToolWatcher
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException
Expand Down Expand Up @@ -275,7 +274,6 @@ def __init__(
self.thread_pool_wrapper = None
if max_parallel_tools > 1:
self.thread_pool = ThreadPoolExecutor(max_workers=max_parallel_tools)
self.thread_pool_wrapper = ThreadPoolExecutorWrapper(self.thread_pool)
elif max_parallel_tools < 1:
raise ValueError("max_parallel_tools must be greater than 0")

Expand Down Expand Up @@ -358,8 +356,8 @@ def __del__(self) -> None:

Ensures proper shutdown of the thread pool executor if one exists.
"""
if self.thread_pool_wrapper and hasattr(self.thread_pool_wrapper, "shutdown"):
self.thread_pool_wrapper.shutdown(wait=False)
if self.thread_pool:
self.thread_pool.shutdown(wait=False)
logger.debug("thread pool executor shutdown complete")

def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
Expand Down Expand Up @@ -528,7 +526,7 @@ def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[st
messages=self.messages, # will be modified by event_loop_cycle
tool_config=self.tool_config,
tool_handler=self.tool_handler,
tool_execution_handler=self.thread_pool_wrapper,
thread_pool=self.thread_pool,
event_loop_metrics=self.event_loop_metrics,
event_loop_parent_span=self.trace_span,
kwargs=kwargs,
Expand Down
48 changes: 24 additions & 24 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any, Generator, Optional

Expand All @@ -20,7 +21,6 @@
from ..telemetry.tracer import get_tracer
from ..tools.executor import run_tools, validate_and_prepare_tools
from ..types.content import Message, Messages
from ..types.event_loop import ParallelToolExecutorInterface
from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
from ..types.models import Model
from ..types.streaming import Metrics, StopReason
Expand All @@ -41,7 +41,7 @@ def event_loop_cycle(
messages: Messages,
tool_config: Optional[ToolConfig],
tool_handler: Optional[ToolHandler],
tool_execution_handler: Optional[ParallelToolExecutorInterface],
thread_pool: Optional[ThreadPoolExecutor],
event_loop_metrics: EventLoopMetrics,
event_loop_parent_span: Optional[trace.Span],
kwargs: dict[str, Any],
Expand All @@ -65,7 +65,7 @@ def event_loop_cycle(
messages: Conversation history messages.
tool_config: Configuration for available tools.
tool_handler: Handler for executing tools.
tool_execution_handler: Optional handler for parallel tool execution.
thread_pool: Optional thread pool for parallel tool execution.
event_loop_metrics: Metrics tracking object for the event loop.
event_loop_parent_span: Span for the parent of this event loop.
kwargs: Additional arguments including:
Expand Down Expand Up @@ -210,7 +210,7 @@ def event_loop_cycle(
messages,
tool_config,
tool_handler,
tool_execution_handler,
thread_pool,
event_loop_metrics,
event_loop_parent_span,
cycle_trace,
Expand Down Expand Up @@ -256,7 +256,7 @@ def recurse_event_loop(
messages: Messages,
tool_config: Optional[ToolConfig],
tool_handler: Optional[ToolHandler],
tool_execution_handler: Optional[ParallelToolExecutorInterface],
thread_pool: Optional[ThreadPoolExecutor],
event_loop_metrics: EventLoopMetrics,
event_loop_parent_span: Optional[trace.Span],
kwargs: dict[str, Any],
Expand All @@ -271,7 +271,7 @@ def recurse_event_loop(
messages: Conversation history messages
tool_config: Configuration for available tools
tool_handler: Handler for tool execution
tool_execution_handler: Optional handler for parallel tool execution.
thread_pool: Optional thread pool for parallel tool execution.
event_loop_metrics: Metrics tracking object for the event loop.
event_loop_parent_span: Span for the parent of this event loop.
kwargs: Arguments to pass through event_loop_cycle
Expand All @@ -298,7 +298,7 @@ def recurse_event_loop(
messages=messages,
tool_config=tool_config,
tool_handler=tool_handler,
tool_execution_handler=tool_execution_handler,
thread_pool=thread_pool,
event_loop_metrics=event_loop_metrics,
event_loop_parent_span=event_loop_parent_span,
kwargs=kwargs,
Expand All @@ -315,7 +315,7 @@ def _handle_tool_execution(
messages: Messages,
tool_config: ToolConfig,
tool_handler: ToolHandler,
tool_execution_handler: Optional[ParallelToolExecutorInterface],
thread_pool: Optional[ThreadPoolExecutor],
event_loop_metrics: EventLoopMetrics,
event_loop_parent_span: Optional[trace.Span],
cycle_trace: Trace,
Expand All @@ -331,20 +331,20 @@ def _handle_tool_execution(
Handles the execution of tools requested by the model during an event loop cycle.

Args:
stop_reason (StopReason): The reason the model stopped generating.
message (Message): The message from the model that may contain tool use requests.
model (Model): The model provider instance.
system_prompt (Optional[str]): The system prompt instructions for the model.
messages (Messages): The conversation history messages.
tool_config (ToolConfig): Configuration for available tools.
tool_handler (ToolHandler): Handler for tool execution.
tool_execution_handler (Optional[ParallelToolExecutorInterface]): Optional handler for parallel tool execution.
event_loop_metrics (EventLoopMetrics): Metrics tracking object for the event loop.
event_loop_parent_span (Any): Span for the parent of this event loop.
cycle_trace (Trace): Trace object for the current event loop cycle.
cycle_span (Any): Span object for tracing the cycle (type may vary).
cycle_start_time (float): Start time of the current cycle.
kwargs (dict[str, Any]): Additional keyword arguments, including request state.
stop_reason: The reason the model stopped generating.
message: The message from the model that may contain tool use requests.
model: The model provider instance.
system_prompt: The system prompt instructions for the model.
messages: The conversation history messages.
tool_config: Configuration for available tools.
tool_handler: Handler for tool execution.
thread_pool: Optional thread pool for parallel tool execution.
event_loop_metrics: Metrics tracking object for the event loop.
event_loop_parent_span: Span for the parent of this event loop.
cycle_trace: Trace object for the current event loop cycle.
cycle_span: Span object for tracing the cycle (type may vary).
cycle_start_time: Start time of the current cycle.
kwargs: Additional keyword arguments, including request state.

Yields:
Tool invocation events along with events yielded from a recursive call to the event loop. The last event is a
Expand Down Expand Up @@ -377,7 +377,7 @@ def _handle_tool_execution(
tool_results=tool_results,
cycle_trace=cycle_trace,
parent_span=cycle_span,
parallel_tool_executor=tool_execution_handler,
thread_pool=thread_pool,
)

# Store parent cycle ID for the next cycle
Expand Down Expand Up @@ -406,7 +406,7 @@ def _handle_tool_execution(
messages=messages,
tool_config=tool_config,
tool_handler=tool_handler,
tool_execution_handler=tool_execution_handler,
thread_pool=thread_pool,
event_loop_metrics=event_loop_metrics,
event_loop_parent_span=event_loop_parent_span,
kwargs=kwargs,
Expand Down
2 changes: 0 additions & 2 deletions src/strands/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from .decorator import tool
from .structured_output import convert_pydantic_to_tool_spec
from .thread_pool_executor import ThreadPoolExecutorWrapper
from .tools import InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec

__all__ = [
Expand All @@ -14,6 +13,5 @@
"InvalidToolUseNameException",
"normalize_schema",
"normalize_tool_spec",
"ThreadPoolExecutorWrapper",
"convert_pydantic_to_tool_spec",
]
16 changes: 6 additions & 10 deletions src/strands/tools/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import queue
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Generator, Optional, cast

from opentelemetry import trace
Expand All @@ -12,7 +13,6 @@
from ..telemetry.tracer import get_tracer
from ..tools.tools import InvalidToolUseNameException, validate_tool_use
from ..types.content import Message
from ..types.event_loop import ParallelToolExecutorInterface
from ..types.tools import ToolGenerator, ToolResult, ToolUse

logger = logging.getLogger(__name__)
Expand All @@ -26,7 +26,7 @@ def run_tools(
tool_results: list[ToolResult],
cycle_trace: Trace,
parent_span: Optional[trace.Span] = None,
parallel_tool_executor: Optional[ParallelToolExecutorInterface] = None,
thread_pool: Optional[ThreadPoolExecutor] = None,
) -> Generator[dict[str, Any], None, None]:
"""Execute tools either in parallel or sequentially.

Expand All @@ -38,7 +38,7 @@ def run_tools(
tool_results: List to populate with tool results.
cycle_trace: Parent trace for the current cycle.
parent_span: Parent span for the current cycle.
parallel_tool_executor: Optional executor for parallel processing.
thread_pool: Optional thread pool for parallel processing.

Yields:
Events of the tool invocations. Tool results are appended to `tool_results`.
Expand Down Expand Up @@ -84,18 +84,14 @@ def work(

tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]

if parallel_tool_executor:
logger.debug(
"tool_count=<%s>, tool_executor=<%s> | executing tools in parallel",
len(tool_uses),
type(parallel_tool_executor).__name__,
)
if thread_pool:
logger.debug("tool_count=<%s> | executing tools in parallel", len(tool_uses))

worker_queue: queue.Queue[tuple[int, dict[str, Any]]] = queue.Queue()
worker_events = [threading.Event() for _ in range(len(tool_uses))]

workers = [
parallel_tool_executor.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id])
thread_pool.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id])
for worker_id, tool_use in enumerate(tool_uses)
]
logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses))
Expand Down
69 changes: 0 additions & 69 deletions src/strands/tools/thread_pool_executor.py

This file was deleted.

70 changes: 2 additions & 68 deletions src/strands/types/event_loop.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Event loop-related type definitions for the SDK."""

from typing import Any, Callable, Iterable, Iterator, Literal, Optional, Protocol
from typing import Literal

from typing_extensions import TypedDict, runtime_checkable
from typing_extensions import TypedDict


class Usage(TypedDict):
Expand Down Expand Up @@ -46,69 +46,3 @@ class Metrics(TypedDict):
- "stop_sequence": Stop sequence encountered
- "tool_use": Model requested to use a tool
"""


@runtime_checkable
class Future(Protocol):
"""Interface representing the result of an asynchronous computation."""

def result(self, timeout: Optional[int] = None) -> Any:
"""Return the result of the call that the future represents.
This method will block until the asynchronous operation completes or until the specified timeout is reached.
Args:
timeout: The number of seconds to wait for the result.
If None, then there is no limit on the wait time.
Returns:
Any: The result of the asynchronous operation.
"""

def done(self) -> bool:
"""Returns true if future is done executing."""


@runtime_checkable
class ParallelToolExecutorInterface(Protocol):
"""Interface for parallel tool execution.
Attributes:
timeout: Default timeout in seconds for futures.
"""

timeout: int = 900 # default 15 minute timeout for futures

def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> Future:
"""Submit a callable to be executed with the given arguments.
Schedules the callable to be executed as fn(*args, **kwargs) and returns a Future instance representing the
execution of the callable.
Args:
fn: The callable to execute.
*args: Positional arguments to pass to the callable.
**kwargs: Keyword arguments to pass to the callable.
Returns:
Future: A Future representing the given call.
"""

def as_completed(self, futures: Iterable[Future], timeout: Optional[int] = timeout) -> Iterator[Future]:
"""Iterate over the given futures, yielding each as it completes.
Args:
futures: The sequence of Futures to iterate over.
timeout: The maximum number of seconds to wait.
If None, then there is no limit on the wait time.
Returns:
An iterator that yields the given Futures as they complete (finished or cancelled).
"""

def shutdown(self, wait: bool = True) -> None:
"""Shutdown the executor and free associated resources.
Args:
wait: If True, shutdown will not return until all running futures have finished executing.
"""
Loading