Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,19 @@ def event_loop_cycle(
)

try:
stop_reason, message, usage, metrics, kwargs["request_state"] = stream_messages(
model,
system_prompt,
messages,
tool_config,
callback_handler,
**kwargs,
)
# TODO: As part of the migration to async-iterator, we will continue moving callback_handler calls up the
# call stack. At this point, we converted all events that were previously passed to the handler in
# `stream_messages` into yielded events that now have the "callback" key. To maintain backwards
# compatability, we need to combine the event with kwargs before passing to the handler. This we will
# revisit when migrating to strongly typed events.
for event in stream_messages(model, system_prompt, messages, tool_config):
if "callback" in event:
inputs = {**event["callback"], **(kwargs if "delta" in event["callback"] else {})}
callback_handler(**inputs)
else:
stop_reason, message, usage, metrics = event["stop"]
kwargs.setdefault("request_state", {})

if model_invoke_span:
tracer.end_model_invoke_span(model_invoke_span, message, usage)
break # Success! Break out of retry loop
Expand Down Expand Up @@ -334,7 +339,7 @@ def _handle_tool_execution(
kwargs (Dict[str, Any]): Additional keyword arguments, including request state.

Returns:
Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]:
Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]:
- The stop reason,
- The updated message,
- The updated event loop metrics,
Expand Down
80 changes: 33 additions & 47 deletions src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Generator, Iterable, Optional

from ..types.content import ContentBlock, Message, Messages
from ..types.models import Model
Expand Down Expand Up @@ -80,7 +80,7 @@ def handle_message_start(event: MessageStartEvent, message: Message) -> Message:
return message


def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]:
def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]:
"""Handles the start of a content block by extracting tool usage information if any.

Args:
Expand All @@ -102,61 +102,59 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]:


def handle_content_block_delta(
event: ContentBlockDeltaEvent, state: Dict[str, Any], callback_handler: Any, **kwargs: Any
) -> Dict[str, Any]:
event: ContentBlockDeltaEvent, state: dict[str, Any]
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Handles content block delta updates by appending text, tool input, or reasoning content to the state.

Args:
event: Delta event.
state: The current state of message processing.
callback_handler: Callback for processing events as they happen.
**kwargs: Additional keyword arguments to pass to the callback handler.

Returns:
Updated state with appended text or tool input.
"""
delta_content = event["delta"]

callback_event = {}

if "toolUse" in delta_content:
if "input" not in state["current_tool_use"]:
state["current_tool_use"]["input"] = ""

state["current_tool_use"]["input"] += delta_content["toolUse"]["input"]
callback_handler(delta=delta_content, current_tool_use=state["current_tool_use"], **kwargs)
callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]}

elif "text" in delta_content:
state["text"] += delta_content["text"]
callback_handler(data=delta_content["text"], delta=delta_content, **kwargs)
callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content}

elif "reasoningContent" in delta_content:
if "text" in delta_content["reasoningContent"]:
if "reasoningText" not in state:
state["reasoningText"] = ""

state["reasoningText"] += delta_content["reasoningContent"]["text"]
callback_handler(
reasoningText=delta_content["reasoningContent"]["text"],
delta=delta_content,
reasoning=True,
**kwargs,
)
callback_event["callback"] = {
"reasoningText": delta_content["reasoningContent"]["text"],
"delta": delta_content,
"reasoning": True,
}

elif "signature" in delta_content["reasoningContent"]:
if "signature" not in state:
state["signature"] = ""

state["signature"] += delta_content["reasoningContent"]["signature"]
callback_handler(
reasoning_signature=delta_content["reasoningContent"]["signature"],
delta=delta_content,
reasoning=True,
**kwargs,
)
callback_event["callback"] = {
"reasoning_signature": delta_content["reasoningContent"]["signature"],
"delta": delta_content,
"reasoning": True,
}

return state
return state, callback_event


def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]:
def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
"""Handles the end of a content block by finalizing tool usage, text content, or reasoning content.

Args:
Expand All @@ -165,7 +163,7 @@ def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]:
Returns:
Updated state with finalized content block.
"""
content: List[ContentBlock] = state["content"]
content: list[ContentBlock] = state["content"]

current_tool_use = state["current_tool_use"]
text = state["text"]
Expand Down Expand Up @@ -223,7 +221,7 @@ def handle_message_stop(event: MessageStopEvent) -> StopReason:
return event["stopReason"]


def handle_redact_content(event: RedactContentEvent, messages: Messages, state: Dict[str, Any]) -> None:
def handle_redact_content(event: RedactContentEvent, messages: Messages, state: dict[str, Any]) -> None:
"""Handles redacting content from the input or output.

Args:
Expand All @@ -238,7 +236,7 @@ def handle_redact_content(event: RedactContentEvent, messages: Messages, state:
state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}]


def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]:
def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]:
"""Extracts usage metrics from the metadata chunk.

Args:
Expand All @@ -255,25 +253,20 @@ def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]:

def process_stream(
chunks: Iterable[StreamEvent],
callback_handler: Any,
messages: Messages,
**kwargs: Any,
) -> Tuple[StopReason, Message, Usage, Metrics, Any]:
) -> Generator[dict[str, Any], None, None]:
"""Processes the response stream from the API, constructing the final message and extracting usage metrics.

Args:
chunks: The chunks of the response stream from the model.
callback_handler: Callback for processing events as they happen.
messages: The agents messages.
**kwargs: Additional keyword arguments that will be passed to the callback handler.
And also returned in the request_state.

Returns:
The reason for stopping, the constructed message, the usage metrics, and the updated request state.
The reason for stopping, the constructed message, and the usage metrics.
"""
stop_reason: StopReason = "end_turn"

state: Dict[str, Any] = {
state: dict[str, Any] = {
"message": {"role": "assistant", "content": []},
"text": "",
"current_tool_use": {},
Expand All @@ -285,18 +278,16 @@ def process_stream(
usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
metrics: Metrics = Metrics(latencyMs=0)

kwargs.setdefault("request_state", {})

for chunk in chunks:
# Callback handler call here allows each event to be visible to the caller
callback_handler(event=chunk)
yield {"callback": {"event": chunk}}

if "messageStart" in chunk:
state["message"] = handle_message_start(chunk["messageStart"], state["message"])
elif "contentBlockStart" in chunk:
state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"])
elif "contentBlockDelta" in chunk:
state = handle_content_block_delta(chunk["contentBlockDelta"], state, callback_handler, **kwargs)
state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state)
yield callback_event
elif "contentBlockStop" in chunk:
state = handle_content_block_stop(state)
elif "messageStop" in chunk:
Expand All @@ -306,35 +297,30 @@ def process_stream(
elif "redactContent" in chunk:
handle_redact_content(chunk["redactContent"], messages, state)

return stop_reason, state["message"], usage, metrics, kwargs["request_state"]
yield {"stop": (stop_reason, state["message"], usage, metrics)}


def stream_messages(
model: Model,
system_prompt: Optional[str],
messages: Messages,
tool_config: Optional[ToolConfig],
callback_handler: Any,
**kwargs: Any,
) -> Tuple[StopReason, Message, Usage, Metrics, Any]:
) -> Generator[dict[str, Any], None, None]:
"""Streams messages to the model and processes the response.

Args:
model: Model provider.
system_prompt: The system prompt to send.
messages: List of messages to send.
tool_config: Configuration for the tools to use.
callback_handler: Callback for processing events as they happen.
**kwargs: Additional keyword arguments that will be passed to the callback handler.
And also returned in the request_state.

Returns:
The reason for stopping, the final message, the usage metrics, and updated request state.
The reason for stopping, the final message, and the usage metrics
"""
logger.debug("model=<%s> | streaming messages", model)

messages = remove_blank_messages_content_text(messages)
tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None

chunks = model.converse(messages, tool_specs, system_prompt)
return process_stream(chunks, callback_handler, messages, **kwargs)
yield from process_stream(chunks, messages)
12 changes: 6 additions & 6 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,15 +387,15 @@ def structured_output(
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
"""
callback_handler = callback_handler or PrintingCallbackHandler()
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.converse(messages=prompt, tool_specs=[tool_spec])
# process the stream and get the tool use input
results = process_stream(
response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt
)

stop_reason, messages, _, _, _ = results
for event in process_stream(response, prompt):
if "callback" in event:
callback_handler(**event["callback"])
else:
stop_reason, messages, _, _ = event["stop"]

if stop_reason != "tool_use":
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")
Expand Down
12 changes: 6 additions & 6 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,15 +504,15 @@ def structured_output(
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
"""
callback_handler = callback_handler or PrintingCallbackHandler()
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.converse(messages=prompt, tool_specs=[tool_spec])
# process the stream and get the tool use input
results = process_stream(
response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt
)

stop_reason, messages, _, _, _ = results
for event in process_stream(response, prompt):
if "callback" in event:
callback_handler(**event["callback"])
else:
stop_reason, messages, _, _ = event["stop"]

if stop_reason != "tool_use":
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
Expand Down
4 changes: 2 additions & 2 deletions tests-integ/test_model_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def tool_weather() -> str:

@pytest.fixture
def system_prompt():
return "You are an AI assistant that uses & instead of ."
return "You are an AI assistant."


@pytest.fixture
Expand All @@ -47,7 +47,7 @@ def test_agent(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny", "&"])
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
Expand Down
Loading