diff --git a/src/strands/__init__.py b/src/strands/__init__.py index f4b1228d2..eaedee351 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -3,6 +3,5 @@ from . import agent, event_loop, models, telemetry, types from .agent.agent import Agent from .tools.decorator import tool -from .tools.thread_pool_executor import ThreadPoolExecutorWrapper -__all__ = ["Agent", "ThreadPoolExecutorWrapper", "agent", "event_loop", "models", "tool", "types", "telemetry"] +__all__ = ["Agent", "agent", "event_loop", "models", "tool", "types", "telemetry"] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f9eba0015..590b4436c 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -27,7 +27,6 @@ from ..telemetry.metrics import EventLoopMetrics from ..telemetry.tracer import get_tracer from ..tools.registry import ToolRegistry -from ..tools.thread_pool_executor import ThreadPoolExecutorWrapper from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException @@ -275,7 +274,6 @@ def __init__( self.thread_pool_wrapper = None if max_parallel_tools > 1: self.thread_pool = ThreadPoolExecutor(max_workers=max_parallel_tools) - self.thread_pool_wrapper = ThreadPoolExecutorWrapper(self.thread_pool) elif max_parallel_tools < 1: raise ValueError("max_parallel_tools must be greater than 0") @@ -358,8 +356,8 @@ def __del__(self) -> None: Ensures proper shutdown of the thread pool executor if one exists. """ - if self.thread_pool_wrapper and hasattr(self.thread_pool_wrapper, "shutdown"): - self.thread_pool_wrapper.shutdown(wait=False) + if self.thread_pool: + self.thread_pool.shutdown(wait=False) logger.debug("thread pool executor shutdown complete") def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: @@ -528,7 +526,7 @@ def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[st messages=self.messages, # will be modified by event_loop_cycle tool_config=self.tool_config, tool_handler=self.tool_handler, - tool_execution_handler=self.thread_pool_wrapper, + thread_pool=self.thread_pool, event_loop_metrics=self.event_loop_metrics, event_loop_parent_span=self.trace_span, kwargs=kwargs, diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 3ca04851b..61eb780c3 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -11,6 +11,7 @@ import logging import time import uuid +from concurrent.futures import ThreadPoolExecutor from functools import partial from typing import Any, Generator, Optional @@ -20,7 +21,6 @@ from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools from ..types.content import Message, Messages -from ..types.event_loop import ParallelToolExecutorInterface from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException from ..types.models import Model from ..types.streaming import Metrics, StopReason @@ -41,7 +41,7 @@ def event_loop_cycle( messages: Messages, tool_config: Optional[ToolConfig], tool_handler: Optional[ToolHandler], - tool_execution_handler: Optional[ParallelToolExecutorInterface], + thread_pool: Optional[ThreadPoolExecutor], event_loop_metrics: EventLoopMetrics, event_loop_parent_span: Optional[trace.Span], kwargs: dict[str, Any], @@ -65,7 +65,7 @@ def event_loop_cycle( messages: Conversation history messages. tool_config: Configuration for available tools. tool_handler: Handler for executing tools. - tool_execution_handler: Optional handler for parallel tool execution. + thread_pool: Optional thread pool for parallel tool execution. event_loop_metrics: Metrics tracking object for the event loop. event_loop_parent_span: Span for the parent of this event loop. kwargs: Additional arguments including: @@ -210,7 +210,7 @@ def event_loop_cycle( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, event_loop_metrics, event_loop_parent_span, cycle_trace, @@ -256,7 +256,7 @@ def recurse_event_loop( messages: Messages, tool_config: Optional[ToolConfig], tool_handler: Optional[ToolHandler], - tool_execution_handler: Optional[ParallelToolExecutorInterface], + thread_pool: Optional[ThreadPoolExecutor], event_loop_metrics: EventLoopMetrics, event_loop_parent_span: Optional[trace.Span], kwargs: dict[str, Any], @@ -271,7 +271,7 @@ def recurse_event_loop( messages: Conversation history messages tool_config: Configuration for available tools tool_handler: Handler for tool execution - tool_execution_handler: Optional handler for parallel tool execution. + thread_pool: Optional thread pool for parallel tool execution. event_loop_metrics: Metrics tracking object for the event loop. event_loop_parent_span: Span for the parent of this event loop. kwargs: Arguments to pass through event_loop_cycle @@ -298,7 +298,7 @@ def recurse_event_loop( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=event_loop_metrics, event_loop_parent_span=event_loop_parent_span, kwargs=kwargs, @@ -315,7 +315,7 @@ def _handle_tool_execution( messages: Messages, tool_config: ToolConfig, tool_handler: ToolHandler, - tool_execution_handler: Optional[ParallelToolExecutorInterface], + thread_pool: Optional[ThreadPoolExecutor], event_loop_metrics: EventLoopMetrics, event_loop_parent_span: Optional[trace.Span], cycle_trace: Trace, @@ -331,20 +331,20 @@ def _handle_tool_execution( Handles the execution of tools requested by the model during an event loop cycle. Args: - stop_reason (StopReason): The reason the model stopped generating. - message (Message): The message from the model that may contain tool use requests. - model (Model): The model provider instance. - system_prompt (Optional[str]): The system prompt instructions for the model. - messages (Messages): The conversation history messages. - tool_config (ToolConfig): Configuration for available tools. - tool_handler (ToolHandler): Handler for tool execution. - tool_execution_handler (Optional[ParallelToolExecutorInterface]): Optional handler for parallel tool execution. - event_loop_metrics (EventLoopMetrics): Metrics tracking object for the event loop. - event_loop_parent_span (Any): Span for the parent of this event loop. - cycle_trace (Trace): Trace object for the current event loop cycle. - cycle_span (Any): Span object for tracing the cycle (type may vary). - cycle_start_time (float): Start time of the current cycle. - kwargs (dict[str, Any]): Additional keyword arguments, including request state. + stop_reason: The reason the model stopped generating. + message: The message from the model that may contain tool use requests. + model: The model provider instance. + system_prompt: The system prompt instructions for the model. + messages: The conversation history messages. + tool_config: Configuration for available tools. + tool_handler: Handler for tool execution. + thread_pool: Optional thread pool for parallel tool execution. + event_loop_metrics: Metrics tracking object for the event loop. + event_loop_parent_span: Span for the parent of this event loop. + cycle_trace: Trace object for the current event loop cycle. + cycle_span: Span object for tracing the cycle (type may vary). + cycle_start_time: Start time of the current cycle. + kwargs: Additional keyword arguments, including request state. Yields: Tool invocation events along with events yielded from a recursive call to the event loop. The last event is a @@ -377,7 +377,7 @@ def _handle_tool_execution( tool_results=tool_results, cycle_trace=cycle_trace, parent_span=cycle_span, - parallel_tool_executor=tool_execution_handler, + thread_pool=thread_pool, ) # Store parent cycle ID for the next cycle @@ -406,7 +406,7 @@ def _handle_tool_execution( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=event_loop_metrics, event_loop_parent_span=event_loop_parent_span, kwargs=kwargs, diff --git a/src/strands/tools/__init__.py b/src/strands/tools/__init__.py index be4a2470e..c61f79748 100644 --- a/src/strands/tools/__init__.py +++ b/src/strands/tools/__init__.py @@ -5,7 +5,6 @@ from .decorator import tool from .structured_output import convert_pydantic_to_tool_spec -from .thread_pool_executor import ThreadPoolExecutorWrapper from .tools import InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec __all__ = [ @@ -14,6 +13,5 @@ "InvalidToolUseNameException", "normalize_schema", "normalize_tool_spec", - "ThreadPoolExecutorWrapper", "convert_pydantic_to_tool_spec", ] diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py index 06feb4e8a..631d0727d 100644 --- a/src/strands/tools/executor.py +++ b/src/strands/tools/executor.py @@ -4,6 +4,7 @@ import queue import threading import time +from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Generator, Optional, cast from opentelemetry import trace @@ -12,7 +13,6 @@ from ..telemetry.tracer import get_tracer from ..tools.tools import InvalidToolUseNameException, validate_tool_use from ..types.content import Message -from ..types.event_loop import ParallelToolExecutorInterface from ..types.tools import ToolGenerator, ToolResult, ToolUse logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ def run_tools( tool_results: list[ToolResult], cycle_trace: Trace, parent_span: Optional[trace.Span] = None, - parallel_tool_executor: Optional[ParallelToolExecutorInterface] = None, + thread_pool: Optional[ThreadPoolExecutor] = None, ) -> Generator[dict[str, Any], None, None]: """Execute tools either in parallel or sequentially. @@ -38,7 +38,7 @@ def run_tools( tool_results: List to populate with tool results. cycle_trace: Parent trace for the current cycle. parent_span: Parent span for the current cycle. - parallel_tool_executor: Optional executor for parallel processing. + thread_pool: Optional thread pool for parallel processing. Yields: Events of the tool invocations. Tool results are appended to `tool_results`. @@ -84,18 +84,14 @@ def work( tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] - if parallel_tool_executor: - logger.debug( - "tool_count=<%s>, tool_executor=<%s> | executing tools in parallel", - len(tool_uses), - type(parallel_tool_executor).__name__, - ) + if thread_pool: + logger.debug("tool_count=<%s> | executing tools in parallel", len(tool_uses)) worker_queue: queue.Queue[tuple[int, dict[str, Any]]] = queue.Queue() worker_events = [threading.Event() for _ in range(len(tool_uses))] workers = [ - parallel_tool_executor.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id]) + thread_pool.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id]) for worker_id, tool_use in enumerate(tool_uses) ] logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses)) diff --git a/src/strands/tools/thread_pool_executor.py b/src/strands/tools/thread_pool_executor.py deleted file mode 100644 index cdb92d29f..000000000 --- a/src/strands/tools/thread_pool_executor.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Thread pool execution management for parallel tool calls.""" - -import concurrent.futures -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Iterable, Iterator, Optional - -from ..types.event_loop import Future, ParallelToolExecutorInterface - - -class ThreadPoolExecutorWrapper(ParallelToolExecutorInterface): - """Wrapper around ThreadPoolExecutor to implement the strands.types.event_loop.ParallelToolExecutorInterface. - - This class adapts Python's standard ThreadPoolExecutor to conform to the SDK's ParallelToolExecutorInterface, - allowing it to be used for parallel tool execution within the agent event loop. It provides methods for submitting - tasks, monitoring their completion, and shutting down the executor. - - Attributes: - thread_pool: The underlying ThreadPoolExecutor instance. - """ - - def __init__(self, thread_pool: ThreadPoolExecutor): - """Initialize with a ThreadPoolExecutor instance. - - Args: - thread_pool: The ThreadPoolExecutor to wrap. - """ - self.thread_pool = thread_pool - - def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> Future: - """Submit a callable to be executed with the given arguments. - - This method schedules the callable to be executed as fn(*args, **kwargs) - and returns a Future instance representing the execution of the callable. - - Args: - fn: The callable to execute. - *args: Positional arguments for the callable. - **kwargs: Keyword arguments for the callable. - - Returns: - A Future instance representing the execution of the callable. - """ - return self.thread_pool.submit(fn, *args, **kwargs) - - def as_completed(self, futures: Iterable[Future], timeout: Optional[int] = None) -> Iterator[Future]: - """Return an iterator over the futures as they complete. - - The returned iterator yields futures as they complete (finished or cancelled). - - Args: - futures: The futures to iterate over. - timeout: The maximum number of seconds to wait. - None means no limit. - - Returns: - An iterator yielding futures as they complete. - - Raises: - concurrent.futures.TimeoutError: If the timeout is reached. - """ - return concurrent.futures.as_completed(futures, timeout=timeout) # type: ignore - - def shutdown(self, wait: bool = True) -> None: - """Shutdown the thread pool executor. - - Args: - wait: If True, waits until all running futures have finished executing. - """ - self.thread_pool.shutdown(wait=wait) diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index 08ad8dc0d..7be33b6fd 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -1,8 +1,8 @@ """Event loop-related type definitions for the SDK.""" -from typing import Any, Callable, Iterable, Iterator, Literal, Optional, Protocol +from typing import Literal -from typing_extensions import TypedDict, runtime_checkable +from typing_extensions import TypedDict class Usage(TypedDict): @@ -46,69 +46,3 @@ class Metrics(TypedDict): - "stop_sequence": Stop sequence encountered - "tool_use": Model requested to use a tool """ - - -@runtime_checkable -class Future(Protocol): - """Interface representing the result of an asynchronous computation.""" - - def result(self, timeout: Optional[int] = None) -> Any: - """Return the result of the call that the future represents. - - This method will block until the asynchronous operation completes or until the specified timeout is reached. - - Args: - timeout: The number of seconds to wait for the result. - If None, then there is no limit on the wait time. - - Returns: - Any: The result of the asynchronous operation. - """ - - def done(self) -> bool: - """Returns true if future is done executing.""" - - -@runtime_checkable -class ParallelToolExecutorInterface(Protocol): - """Interface for parallel tool execution. - - Attributes: - timeout: Default timeout in seconds for futures. - """ - - timeout: int = 900 # default 15 minute timeout for futures - - def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> Future: - """Submit a callable to be executed with the given arguments. - - Schedules the callable to be executed as fn(*args, **kwargs) and returns a Future instance representing the - execution of the callable. - - Args: - fn: The callable to execute. - *args: Positional arguments to pass to the callable. - **kwargs: Keyword arguments to pass to the callable. - - Returns: - Future: A Future representing the given call. - """ - - def as_completed(self, futures: Iterable[Future], timeout: Optional[int] = timeout) -> Iterator[Future]: - """Iterate over the given futures, yielding each as it completes. - - Args: - futures: The sequence of Futures to iterate over. - timeout: The maximum number of seconds to wait. - If None, then there is no limit on the wait time. - - Returns: - An iterator that yields the given Futures as they complete (finished or cancelled). - """ - - def shutdown(self, wait: bool = True) -> None: - """Shutdown the executor and free associated resources. - - Args: - wait: If True, shutdown will not return until all running futures have finished executing. - """ diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 599a71f52..5a8985fb9 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -377,7 +377,6 @@ def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler, override_system_prompt = "Override system prompt" override_model = unittest.mock.Mock() - override_tool_execution_handler = unittest.mock.Mock() override_event_loop_metrics = unittest.mock.Mock() override_callback_handler = unittest.mock.Mock() override_tool_handler = unittest.mock.Mock() @@ -389,7 +388,6 @@ def check_kwargs(**kwargs): assert kwargs_kwargs["some_value"] == "a_value" assert kwargs_kwargs["system_prompt"] == override_system_prompt assert kwargs_kwargs["model"] == override_model - assert kwargs_kwargs["tool_execution_handler"] == override_tool_execution_handler assert kwargs_kwargs["event_loop_metrics"] == override_event_loop_metrics assert kwargs_kwargs["callback_handler"] == override_callback_handler assert kwargs_kwargs["tool_handler"] == override_tool_handler @@ -407,7 +405,6 @@ def check_kwargs(**kwargs): some_value="a_value", system_prompt=override_system_prompt, model=override_model, - tool_execution_handler=override_tool_execution_handler, event_loop_metrics=override_event_loop_metrics, callback_handler=override_callback_handler, tool_handler=override_tool_handler, diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 9a8435efc..f07f0d27a 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -49,9 +49,8 @@ def tool_handler(tool_registry): @pytest.fixture -def tool_execution_handler(): - pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) - return strands.tools.ThreadPoolExecutorWrapper(pool) +def thread_pool(): + return concurrent.futures.ThreadPoolExecutor(max_workers=1) @pytest.fixture @@ -106,7 +105,7 @@ def test_event_loop_cycle_text_response( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, ): model.converse.return_value = [ {"contentBlockDelta": {"delta": {"text": "test text"}}}, @@ -119,7 +118,7 @@ def test_event_loop_cycle_text_response( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -141,7 +140,7 @@ def test_event_loop_cycle_text_response_throttling( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, ): model.converse.side_effect = [ ModelThrottledException("ThrottlingException | ConverseStream"), @@ -157,7 +156,7 @@ def test_event_loop_cycle_text_response_throttling( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -181,7 +180,7 @@ def test_event_loop_cycle_exponential_backoff( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, ): """Test that the exponential backoff works correctly with multiple retries.""" # Set up the model to raise throttling exceptions multiple times before succeeding @@ -201,7 +200,7 @@ def test_event_loop_cycle_exponential_backoff( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -227,7 +226,7 @@ def test_event_loop_cycle_text_response_throttling_exceeded( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, ): model.converse.side_effect = [ ModelThrottledException("ThrottlingException | ConverseStream"), @@ -245,7 +244,7 @@ def test_event_loop_cycle_text_response_throttling_exceeded( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -269,7 +268,7 @@ def test_event_loop_cycle_text_response_error( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, ): model.converse.side_effect = RuntimeError("Unhandled error") @@ -280,7 +279,7 @@ def test_event_loop_cycle_text_response_error( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -294,7 +293,7 @@ def test_event_loop_cycle_tool_result( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, tool_stream, ): model.converse.side_effect = [ @@ -311,7 +310,7 @@ def test_event_loop_cycle_tool_result( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -365,7 +364,7 @@ def test_event_loop_cycle_tool_result_error( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, tool_stream, ): model.converse.side_effect = [tool_stream] @@ -377,7 +376,7 @@ def test_event_loop_cycle_tool_result_error( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -390,7 +389,7 @@ def test_event_loop_cycle_tool_result_no_tool_handler( system_prompt, messages, tool_config, - tool_execution_handler, + thread_pool, tool_stream, ): model.converse.side_effect = [tool_stream] @@ -402,7 +401,7 @@ def test_event_loop_cycle_tool_result_no_tool_handler( messages=messages, tool_config=tool_config, tool_handler=None, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -415,7 +414,7 @@ def test_event_loop_cycle_tool_result_no_tool_config( system_prompt, messages, tool_handler, - tool_execution_handler, + thread_pool, tool_stream, ): model.converse.side_effect = [tool_stream] @@ -427,7 +426,7 @@ def test_event_loop_cycle_tool_result_no_tool_config( messages=messages, tool_config=None, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -441,7 +440,7 @@ def test_event_loop_cycle_stop( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, tool, ): model.converse.side_effect = [ @@ -467,7 +466,7 @@ def test_event_loop_cycle_stop( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={"request_state": {"stop_event_loop": True}}, @@ -499,7 +498,7 @@ def test_cycle_exception( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, tool_stream, ): model.converse.side_effect = [tool_stream, tool_stream, tool_stream, ValueError("Invalid error presented")] @@ -514,7 +513,7 @@ def test_cycle_exception( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -533,7 +532,7 @@ def test_event_loop_cycle_creates_spans( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, mock_tracer, ): # Setup @@ -555,7 +554,7 @@ def test_event_loop_cycle_creates_spans( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -578,7 +577,7 @@ def test_event_loop_tracing_with_model_error( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, mock_tracer, ): # Setup @@ -599,7 +598,7 @@ def test_event_loop_tracing_with_model_error( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -618,7 +617,7 @@ def test_event_loop_tracing_with_tool_execution( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, tool_stream, mock_tracer, ): @@ -645,7 +644,7 @@ def test_event_loop_tracing_with_tool_execution( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -666,7 +665,7 @@ def test_event_loop_tracing_with_throttling_exception( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, mock_tracer, ): # Setup @@ -693,7 +692,7 @@ def test_event_loop_tracing_with_throttling_exception( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -715,7 +714,7 @@ def test_event_loop_cycle_with_parent_span( messages, tool_config, tool_handler, - tool_execution_handler, + thread_pool, mock_tracer, ): # Setup @@ -736,7 +735,7 @@ def test_event_loop_cycle_with_parent_span( messages=messages, tool_config=tool_config, tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, + thread_pool=thread_pool, event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=parent_span, kwargs={}, @@ -757,7 +756,7 @@ def test_request_state_initialization(): messages=MagicMock(), tool_config=MagicMock(), tool_handler=MagicMock(), - tool_execution_handler=MagicMock(), + thread_pool=MagicMock(), event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, @@ -776,7 +775,7 @@ def test_request_state_initialization(): messages=MagicMock(), tool_config=MagicMock(), tool_handler=MagicMock(), - tool_execution_handler=MagicMock(), + thread_pool=MagicMock(), event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={"request_state": initial_request_state}, @@ -816,7 +815,7 @@ def test_prepare_next_cycle_in_tool_execution(model, tool_stream): messages=MagicMock(), tool_config=MagicMock(), tool_handler=MagicMock(), - tool_execution_handler=MagicMock(), + thread_pool=MagicMock(), event_loop_metrics=EventLoopMetrics(), event_loop_parent_span=None, kwargs={}, diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py index 8a4c32ea1..d3e934acc 100644 --- a/tests/strands/tools/test_executor.py +++ b/tests/strands/tools/test_executor.py @@ -1,5 +1,4 @@ import concurrent -import functools import unittest.mock import uuid @@ -67,21 +66,8 @@ def cycle_trace(): @pytest.fixture -def parallel_tool_executor(request): - params = { - "max_workers": 1, - "timeout": None, - } - if hasattr(request, "param"): - params.update(request.param) - - as_completed = functools.partial(concurrent.futures.as_completed, timeout=params["timeout"]) - - pool = concurrent.futures.ThreadPoolExecutor(max_workers=params["max_workers"]) - wrapper = strands.tools.ThreadPoolExecutorWrapper(pool) - - with unittest.mock.patch.object(wrapper, "as_completed", side_effect=as_completed): - yield wrapper +def thread_pool(request): + return concurrent.futures.ThreadPoolExecutor(max_workers=1) def test_run_tools( @@ -90,7 +76,7 @@ def test_run_tools( event_loop_metrics, invalid_tool_use_ids, cycle_trace, - parallel_tool_executor, + thread_pool, ): tool_results = [] @@ -101,7 +87,7 @@ def test_run_tools( invalid_tool_use_ids, tool_results, cycle_trace, - parallel_tool_executor, + thread_pool, ) tru_events = list(stream) @@ -130,7 +116,7 @@ def test_run_tools_invalid_tool( event_loop_metrics, invalid_tool_use_ids, cycle_trace, - parallel_tool_executor, + thread_pool, ): tool_results = [] @@ -141,7 +127,7 @@ def test_run_tools_invalid_tool( invalid_tool_use_ids, tool_results, cycle_trace, - parallel_tool_executor, + thread_pool, ) list(stream) @@ -158,7 +144,7 @@ def test_run_tools_failed_tool( event_loop_metrics, invalid_tool_use_ids, cycle_trace, - parallel_tool_executor, + thread_pool, ): tool_results = [] @@ -169,7 +155,7 @@ def test_run_tools_failed_tool( invalid_tool_use_ids, tool_results, cycle_trace, - parallel_tool_executor, + thread_pool, ) list(stream) @@ -226,7 +212,7 @@ def test_run_tools_sequential( invalid_tool_use_ids, tool_results, cycle_trace, - None, # parallel_tool_executor + None, # tool_pool ) list(stream) @@ -303,7 +289,7 @@ def test_run_tools_creates_and_ends_span_on_success( event_loop_metrics, invalid_tool_use_ids, cycle_trace, - parallel_tool_executor, + thread_pool, ): """Test that run_tools creates and ends a span on successful execution.""" # Setup mock tracer and span @@ -326,7 +312,7 @@ def test_run_tools_creates_and_ends_span_on_success( tool_results, cycle_trace, parent_span, - parallel_tool_executor, + thread_pool, ) list(stream) @@ -350,7 +336,7 @@ def test_run_tools_creates_and_ends_span_on_failure( event_loop_metrics, invalid_tool_use_ids, cycle_trace, - parallel_tool_executor, + thread_pool, ): """Test that run_tools creates and ends a span on tool failure.""" # Setup mock tracer and span @@ -373,7 +359,7 @@ def test_run_tools_creates_and_ends_span_on_failure( tool_results, cycle_trace, parent_span, - parallel_tool_executor, + thread_pool, ) list(stream) @@ -416,7 +402,7 @@ def test_run_tools_parallel_execution_with_spans( event_loop_metrics, invalid_tool_use_ids, cycle_trace, - parallel_tool_executor, + thread_pool, ): """Test that spans are created and ended for each tool in parallel execution.""" # Setup mock tracer and spans @@ -440,7 +426,7 @@ def test_run_tools_parallel_execution_with_spans( tool_results, cycle_trace, parent_span, - parallel_tool_executor, + thread_pool, ) list(stream) diff --git a/tests/strands/tools/test_thread_pool_executor.py b/tests/strands/tools/test_thread_pool_executor.py deleted file mode 100644 index b5eb6b796..000000000 --- a/tests/strands/tools/test_thread_pool_executor.py +++ /dev/null @@ -1,46 +0,0 @@ -import concurrent - -import pytest - -import strands - - -@pytest.fixture -def thread_pool(): - return concurrent.futures.ThreadPoolExecutor(max_workers=1) - - -@pytest.fixture -def thread_pool_wrapper(thread_pool): - return strands.tools.ThreadPoolExecutorWrapper(thread_pool) - - -def test_submit(thread_pool_wrapper): - def fun(a, b): - return (a, b) - - future = thread_pool_wrapper.submit(fun, 1, b=2) - - tru_result = future.result() - exp_result = (1, 2) - - assert tru_result == exp_result - - -def test_as_completed(thread_pool_wrapper): - def fun(i): - return i - - futures = [thread_pool_wrapper.submit(fun, i) for i in range(2)] - - tru_results = sorted(future.result() for future in thread_pool_wrapper.as_completed(futures)) - exp_results = [0, 1] - - assert tru_results == exp_results - - -def test_shutdown(thread_pool_wrapper): - thread_pool_wrapper.shutdown() - - with pytest.raises(RuntimeError): - thread_pool_wrapper.submit(lambda: None)