Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 62 additions & 27 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
- Docs: https://aws.amazon.com/bedrock/
"""

import asyncio
import json
import logging
import os
from typing import Any, AsyncGenerator, Iterable, List, Literal, Optional, Type, TypeVar, Union, cast
import threading
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union

import boto3
from botocore.config import Config as BotocoreConfig
Expand Down Expand Up @@ -245,17 +247,6 @@ def format_request(
),
}

def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""Format the Bedrock response events into standardized message chunks.

Args:
event: A response event from the Bedrock model.

Returns:
The formatted chunk.
"""
return cast(StreamEvent, event)

def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
"""Check if guardrail data contains any blocked policies.

Expand Down Expand Up @@ -284,7 +275,7 @@ def _generate_redaction_events(self) -> list[StreamEvent]:
Returns:
List of redaction events to yield.
"""
events: List[StreamEvent] = []
events: list[StreamEvent] = []

if self.config.get("guardrail_redact_input", True):
logger.debug("Redacting user input due to guardrail.")
Expand Down Expand Up @@ -327,7 +318,55 @@ async def stream(
system_prompt: System prompt to provide context to the model.

Yields:
Formatted message chunks from the model.
Model events.

Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the model service is throttling requests.
"""

def callback(event: Optional[StreamEvent] = None) -> None:
loop.call_soon_threadsafe(queue.put_nowait, event)
if event is None:
return

signal.wait()
signal.clear()

loop = asyncio.get_event_loop()
queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue()
signal = threading.Event()

thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt)
task = asyncio.create_task(thread)

while True:
event = await queue.get()
if event is None:
break

yield event
signal.set()

await task

def _stream(
self,
callback: Callable[..., None],
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
) -> None:
"""Stream conversation with the Bedrock model.

This method operates in a separate thread to avoid blocking the async event loop with the call to
Bedrock's converse_stream.

Args:
callback: Function to send events to the main thread.
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.

Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
Expand All @@ -343,7 +382,6 @@ async def stream(
try:
logger.debug("got response from model")
if streaming:
# Streaming implementation
response = self.client.converse_stream(**request)
for chunk in response["stream"]:
if (
Expand All @@ -354,33 +392,29 @@ async def stream(
guardrail_data = chunk["metadata"]["trace"]["guardrail"]
if self._has_blocked_guardrail(guardrail_data):
for event in self._generate_redaction_events():
yield event
yield self.format_chunk(chunk)
callback(event)

callback(chunk)

else:
# Non-streaming implementation
response = self.client.converse(**request)

# Convert and yield from the response
for event in self._convert_non_streaming_to_streaming(response):
yield event
callback(event)

# Check for guardrail triggers after yielding any events (same as streaming path)
if (
"trace" in response
and "guardrail" in response["trace"]
and self._has_blocked_guardrail(response["trace"]["guardrail"])
):
for event in self._generate_redaction_events():
yield event
callback(event)

except ClientError as e:
error_message = str(e)

# Handle throttling error
if e.response["Error"]["Code"] == "ThrottlingException":
raise ModelThrottledException(error_message) from e

# Handle context window overflow
if any(overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES):
logger.warning("bedrock threw context window overflow error")
raise ContextWindowOverflowException(e) from e
Expand Down Expand Up @@ -411,10 +445,11 @@ async def stream(
"https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported"
)

# Otherwise raise the error
raise e

logger.debug("finished streaming response from model")
finally:
callback()
logger.debug("finished streaming response from model")

def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
"""Convert a non-streaming response to the streaming format.
Expand Down
1 change: 1 addition & 0 deletions src/strands/tools/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ async def work(
async for event in handler(tool_use):
worker_queue.put_nowait((worker_id, event))
await worker_event.wait()
worker_event.clear()

result = cast(ToolResult, event)
finally:
Expand Down
7 changes: 0 additions & 7 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,6 @@ def test_format_request_cache(model, messages, model_id, tool_spec, cache_type):
assert tru_request == exp_request


def test_format_chunk(model):
tru_chunk = model.format_chunk("event")
exp_chunk = "event"

assert tru_chunk == exp_chunk


@pytest.mark.asyncio
async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, messages, alist):
error_message = "Rate exceeded"
Expand Down
11 changes: 8 additions & 3 deletions tests_integ/models/test_model_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

import pydantic
import pytest
from pydantic import BaseModel

import strands
from strands import Agent
Expand Down Expand Up @@ -48,7 +48,7 @@ def agent(model, tools, system_prompt):

@pytest.fixture
def weather():
class Weather(BaseModel):
class Weather(pydantic.BaseModel):
"""Extracts the time and weather from the user's message with the exact strings."""

time: str
Expand All @@ -59,11 +59,16 @@ class Weather(BaseModel):

@pytest.fixture
def yellow_color():
class Color(BaseModel):
class Color(pydantic.BaseModel):
"""Describes a color."""

name: str

@pydantic.field_validator("name", mode="after")
@classmethod
def lower(_, value):
return value.lower()

return Color(name="yellow")


Expand Down
13 changes: 9 additions & 4 deletions tests_integ/models/test_model_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pydantic
import pytest
from pydantic import BaseModel

import strands
from strands import Agent
Expand Down Expand Up @@ -39,11 +39,16 @@ def non_streaming_agent(non_streaming_model, system_prompt):

@pytest.fixture
def yellow_color():
class Color(BaseModel):
class Color(pydantic.BaseModel):
"""Describes a color."""

name: str

@pydantic.field_validator("name", mode="after")
@classmethod
def lower(_, value):
return value.lower()

return Color(name="yellow")


Expand Down Expand Up @@ -136,7 +141,7 @@ def calculator(expression: str) -> float:
def test_structured_output_streaming(streaming_model):
"""Test structured output with streaming model."""

class Weather(BaseModel):
class Weather(pydantic.BaseModel):
time: str
weather: str

Expand All @@ -151,7 +156,7 @@ class Weather(BaseModel):
def test_structured_output_non_streaming(non_streaming_model):
"""Test structured output with non-streaming model."""

class Weather(BaseModel):
class Weather(pydantic.BaseModel):
time: str
weather: str

Expand Down
11 changes: 8 additions & 3 deletions tests_integ/models/test_model_litellm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pydantic
import pytest
from pydantic import BaseModel

import strands
from strands import Agent
Expand Down Expand Up @@ -31,11 +31,16 @@ def agent(model, tools):

@pytest.fixture
def yellow_color():
class Color(BaseModel):
class Color(pydantic.BaseModel):
"""Describes a color."""

name: str

@pydantic.field_validator("name", mode="after")
@classmethod
def lower(_, value):
return value.lower()

return Color(name="yellow")


Expand All @@ -47,7 +52,7 @@ def test_agent(agent):


def test_structured_output(model):
class Weather(BaseModel):
class Weather(pydantic.BaseModel):
time: str
weather: str

Expand Down
11 changes: 8 additions & 3 deletions tests_integ/models/test_model_openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

import pydantic
import pytest
from pydantic import BaseModel

import strands
from strands import Agent, tool
Expand Down Expand Up @@ -42,7 +42,7 @@ def agent(model, tools):

@pytest.fixture
def weather():
class Weather(BaseModel):
class Weather(pydantic.BaseModel):
"""Extracts the time and weather from the user's message with the exact strings."""

time: str
Expand All @@ -53,11 +53,16 @@ class Weather(BaseModel):

@pytest.fixture
def yellow_color():
class Color(BaseModel):
class Color(pydantic.BaseModel):
"""Describes a color."""

name: str

@pydantic.field_validator("name", mode="after")
@classmethod
def lower(_, value):
return value.lower()

return Color(name="yellow")


Expand Down
Loading