diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 2f1233145..0df5090b8 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -3,10 +3,12 @@ - Docs: https://aws.amazon.com/bedrock/ """ +import asyncio import json import logging import os -from typing import Any, AsyncGenerator, Iterable, List, Literal, Optional, Type, TypeVar, Union, cast +import threading +from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union import boto3 from botocore.config import Config as BotocoreConfig @@ -245,17 +247,6 @@ def format_request( ), } - def format_chunk(self, event: dict[str, Any]) -> StreamEvent: - """Format the Bedrock response events into standardized message chunks. - - Args: - event: A response event from the Bedrock model. - - Returns: - The formatted chunk. - """ - return cast(StreamEvent, event) - def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool: """Check if guardrail data contains any blocked policies. @@ -284,7 +275,7 @@ def _generate_redaction_events(self) -> list[StreamEvent]: Returns: List of redaction events to yield. """ - events: List[StreamEvent] = [] + events: list[StreamEvent] = [] if self.config.get("guardrail_redact_input", True): logger.debug("Redacting user input due to guardrail.") @@ -327,7 +318,55 @@ async def stream( system_prompt: System prompt to provide context to the model. Yields: - Formatted message chunks from the model. + Model events. + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the model service is throttling requests. + """ + + def callback(event: Optional[StreamEvent] = None) -> None: + loop.call_soon_threadsafe(queue.put_nowait, event) + if event is None: + return + + signal.wait() + signal.clear() + + loop = asyncio.get_event_loop() + queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue() + signal = threading.Event() + + thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt) + task = asyncio.create_task(thread) + + while True: + event = await queue.get() + if event is None: + break + + yield event + signal.set() + + await task + + def _stream( + self, + callback: Callable[..., None], + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + ) -> None: + """Stream conversation with the Bedrock model. + + This method operates in a separate thread to avoid blocking the async event loop with the call to + Bedrock's converse_stream. + + Args: + callback: Function to send events to the main thread. + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. Raises: ContextWindowOverflowException: If the input exceeds the model's context window. @@ -343,7 +382,6 @@ async def stream( try: logger.debug("got response from model") if streaming: - # Streaming implementation response = self.client.converse_stream(**request) for chunk in response["stream"]: if ( @@ -354,33 +392,29 @@ async def stream( guardrail_data = chunk["metadata"]["trace"]["guardrail"] if self._has_blocked_guardrail(guardrail_data): for event in self._generate_redaction_events(): - yield event - yield self.format_chunk(chunk) + callback(event) + + callback(chunk) + else: - # Non-streaming implementation response = self.client.converse(**request) - - # Convert and yield from the response for event in self._convert_non_streaming_to_streaming(response): - yield event + callback(event) - # Check for guardrail triggers after yielding any events (same as streaming path) if ( "trace" in response and "guardrail" in response["trace"] and self._has_blocked_guardrail(response["trace"]["guardrail"]) ): for event in self._generate_redaction_events(): - yield event + callback(event) except ClientError as e: error_message = str(e) - # Handle throttling error if e.response["Error"]["Code"] == "ThrottlingException": raise ModelThrottledException(error_message) from e - # Handle context window overflow if any(overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES): logger.warning("bedrock threw context window overflow error") raise ContextWindowOverflowException(e) from e @@ -411,10 +445,11 @@ async def stream( "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported" ) - # Otherwise raise the error raise e - logger.debug("finished streaming response from model") + finally: + callback() + logger.debug("finished streaming response from model") def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]: """Convert a non-streaming response to the streaming format. diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py index 5c17f2be6..1214fa608 100644 --- a/src/strands/tools/executor.py +++ b/src/strands/tools/executor.py @@ -58,6 +58,7 @@ async def work( async for event in handler(tool_use): worker_queue.put_nowait((worker_id, event)) await worker_event.wait() + worker_event.clear() result = cast(ToolResult, event) finally: diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2eb0679fb..f62fce7e4 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -398,13 +398,6 @@ def test_format_request_cache(model, messages, model_id, tool_spec, cache_type): assert tru_request == exp_request -def test_format_chunk(model): - tru_chunk = model.format_chunk("event") - exp_chunk = "event" - - assert tru_chunk == exp_chunk - - @pytest.mark.asyncio async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, messages, alist): error_message = "Rate exceeded" diff --git a/tests_integ/models/test_model_anthropic.py b/tests_integ/models/test_model_anthropic.py index bd0f2bc9c..2ee5e7f23 100644 --- a/tests_integ/models/test_model_anthropic.py +++ b/tests_integ/models/test_model_anthropic.py @@ -1,7 +1,7 @@ import os +import pydantic import pytest -from pydantic import BaseModel import strands from strands import Agent @@ -48,7 +48,7 @@ def agent(model, tools, system_prompt): @pytest.fixture def weather(): - class Weather(BaseModel): + class Weather(pydantic.BaseModel): """Extracts the time and weather from the user's message with the exact strings.""" time: str @@ -59,11 +59,16 @@ class Weather(BaseModel): @pytest.fixture def yellow_color(): - class Color(BaseModel): + class Color(pydantic.BaseModel): """Describes a color.""" name: str + @pydantic.field_validator("name", mode="after") + @classmethod + def lower(_, value): + return value.lower() + return Color(name="yellow") diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 71c0bc05b..eed0e783f 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -1,5 +1,5 @@ +import pydantic import pytest -from pydantic import BaseModel import strands from strands import Agent @@ -39,11 +39,16 @@ def non_streaming_agent(non_streaming_model, system_prompt): @pytest.fixture def yellow_color(): - class Color(BaseModel): + class Color(pydantic.BaseModel): """Describes a color.""" name: str + @pydantic.field_validator("name", mode="after") + @classmethod + def lower(_, value): + return value.lower() + return Color(name="yellow") @@ -136,7 +141,7 @@ def calculator(expression: str) -> float: def test_structured_output_streaming(streaming_model): """Test structured output with streaming model.""" - class Weather(BaseModel): + class Weather(pydantic.BaseModel): time: str weather: str @@ -151,7 +156,7 @@ class Weather(BaseModel): def test_structured_output_non_streaming(non_streaming_model): """Test structured output with non-streaming model.""" - class Weather(BaseModel): + class Weather(pydantic.BaseModel): time: str weather: str diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index 6abd83b55..382f75194 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -1,5 +1,5 @@ +import pydantic import pytest -from pydantic import BaseModel import strands from strands import Agent @@ -31,11 +31,16 @@ def agent(model, tools): @pytest.fixture def yellow_color(): - class Color(BaseModel): + class Color(pydantic.BaseModel): """Describes a color.""" name: str + @pydantic.field_validator("name", mode="after") + @classmethod + def lower(_, value): + return value.lower() + return Color(name="yellow") @@ -47,7 +52,7 @@ def test_agent(agent): def test_structured_output(model): - class Weather(BaseModel): + class Weather(pydantic.BaseModel): time: str weather: str diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index 4d81d880b..7054b222a 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -1,7 +1,7 @@ import os +import pydantic import pytest -from pydantic import BaseModel import strands from strands import Agent, tool @@ -42,7 +42,7 @@ def agent(model, tools): @pytest.fixture def weather(): - class Weather(BaseModel): + class Weather(pydantic.BaseModel): """Extracts the time and weather from the user's message with the exact strings.""" time: str @@ -53,11 +53,16 @@ class Weather(BaseModel): @pytest.fixture def yellow_color(): - class Color(BaseModel): + class Color(pydantic.BaseModel): """Describes a color.""" name: str + @pydantic.field_validator("name", mode="after") + @classmethod + def lower(_, value): + return value.lower() + return Color(name="yellow")