diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 49769aabe..9580ea358 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -130,14 +130,19 @@ def event_loop_cycle( ) try: - stop_reason, message, usage, metrics, kwargs["request_state"] = stream_messages( - model, - system_prompt, - messages, - tool_config, - callback_handler, - **kwargs, - ) + # TODO: As part of the migration to async-iterator, we will continue moving callback_handler calls up the + # call stack. At this point, we converted all events that were previously passed to the handler in + # `stream_messages` into yielded events that now have the "callback" key. To maintain backwards + # compatability, we need to combine the event with kwargs before passing to the handler. This we will + # revisit when migrating to strongly typed events. + for event in stream_messages(model, system_prompt, messages, tool_config): + if "callback" in event: + inputs = {**event["callback"], **(kwargs if "delta" in event["callback"] else {})} + callback_handler(**inputs) + else: + stop_reason, message, usage, metrics = event["stop"] + kwargs.setdefault("request_state", {}) + if model_invoke_span: tracer.end_model_invoke_span(model_invoke_span, message, usage) break # Success! Break out of retry loop @@ -334,7 +339,7 @@ def _handle_tool_execution( kwargs (Dict[str, Any]): Additional keyword arguments, including request state. Returns: - Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: + Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: - The stop reason, - The updated message, - The updated event loop metrics, diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 6e8a806fd..0e9d472bd 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Generator, Iterable, Optional from ..types.content import ContentBlock, Message, Messages from ..types.models import Model @@ -80,7 +80,7 @@ def handle_message_start(event: MessageStartEvent, message: Message) -> Message: return message -def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]: +def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: """Handles the start of a content block by extracting tool usage information if any. Args: @@ -102,31 +102,31 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]: def handle_content_block_delta( - event: ContentBlockDeltaEvent, state: Dict[str, Any], callback_handler: Any, **kwargs: Any -) -> Dict[str, Any]: + event: ContentBlockDeltaEvent, state: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: """Handles content block delta updates by appending text, tool input, or reasoning content to the state. Args: event: Delta event. state: The current state of message processing. - callback_handler: Callback for processing events as they happen. - **kwargs: Additional keyword arguments to pass to the callback handler. Returns: Updated state with appended text or tool input. """ delta_content = event["delta"] + callback_event = {} + if "toolUse" in delta_content: if "input" not in state["current_tool_use"]: state["current_tool_use"]["input"] = "" state["current_tool_use"]["input"] += delta_content["toolUse"]["input"] - callback_handler(delta=delta_content, current_tool_use=state["current_tool_use"], **kwargs) + callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]} elif "text" in delta_content: state["text"] += delta_content["text"] - callback_handler(data=delta_content["text"], delta=delta_content, **kwargs) + callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content} elif "reasoningContent" in delta_content: if "text" in delta_content["reasoningContent"]: @@ -134,29 +134,27 @@ def handle_content_block_delta( state["reasoningText"] = "" state["reasoningText"] += delta_content["reasoningContent"]["text"] - callback_handler( - reasoningText=delta_content["reasoningContent"]["text"], - delta=delta_content, - reasoning=True, - **kwargs, - ) + callback_event["callback"] = { + "reasoningText": delta_content["reasoningContent"]["text"], + "delta": delta_content, + "reasoning": True, + } elif "signature" in delta_content["reasoningContent"]: if "signature" not in state: state["signature"] = "" state["signature"] += delta_content["reasoningContent"]["signature"] - callback_handler( - reasoning_signature=delta_content["reasoningContent"]["signature"], - delta=delta_content, - reasoning=True, - **kwargs, - ) + callback_event["callback"] = { + "reasoning_signature": delta_content["reasoningContent"]["signature"], + "delta": delta_content, + "reasoning": True, + } - return state + return state, callback_event -def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: +def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: """Handles the end of a content block by finalizing tool usage, text content, or reasoning content. Args: @@ -165,7 +163,7 @@ def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: Returns: Updated state with finalized content block. """ - content: List[ContentBlock] = state["content"] + content: list[ContentBlock] = state["content"] current_tool_use = state["current_tool_use"] text = state["text"] @@ -223,7 +221,7 @@ def handle_message_stop(event: MessageStopEvent) -> StopReason: return event["stopReason"] -def handle_redact_content(event: RedactContentEvent, messages: Messages, state: Dict[str, Any]) -> None: +def handle_redact_content(event: RedactContentEvent, messages: Messages, state: dict[str, Any]) -> None: """Handles redacting content from the input or output. Args: @@ -238,7 +236,7 @@ def handle_redact_content(event: RedactContentEvent, messages: Messages, state: state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}] -def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]: +def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]: """Extracts usage metrics from the metadata chunk. Args: @@ -255,25 +253,20 @@ def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]: def process_stream( chunks: Iterable[StreamEvent], - callback_handler: Any, messages: Messages, - **kwargs: Any, -) -> Tuple[StopReason, Message, Usage, Metrics, Any]: +) -> Generator[dict[str, Any], None, None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: chunks: The chunks of the response stream from the model. - callback_handler: Callback for processing events as they happen. messages: The agents messages. - **kwargs: Additional keyword arguments that will be passed to the callback handler. - And also returned in the request_state. Returns: - The reason for stopping, the constructed message, the usage metrics, and the updated request state. + The reason for stopping, the constructed message, and the usage metrics. """ stop_reason: StopReason = "end_turn" - state: Dict[str, Any] = { + state: dict[str, Any] = { "message": {"role": "assistant", "content": []}, "text": "", "current_tool_use": {}, @@ -285,18 +278,16 @@ def process_stream( usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) metrics: Metrics = Metrics(latencyMs=0) - kwargs.setdefault("request_state", {}) - for chunk in chunks: - # Callback handler call here allows each event to be visible to the caller - callback_handler(event=chunk) + yield {"callback": {"event": chunk}} if "messageStart" in chunk: state["message"] = handle_message_start(chunk["messageStart"], state["message"]) elif "contentBlockStart" in chunk: state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"]) elif "contentBlockDelta" in chunk: - state = handle_content_block_delta(chunk["contentBlockDelta"], state, callback_handler, **kwargs) + state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state) + yield callback_event elif "contentBlockStop" in chunk: state = handle_content_block_stop(state) elif "messageStop" in chunk: @@ -306,7 +297,7 @@ def process_stream( elif "redactContent" in chunk: handle_redact_content(chunk["redactContent"], messages, state) - return stop_reason, state["message"], usage, metrics, kwargs["request_state"] + yield {"stop": (stop_reason, state["message"], usage, metrics)} def stream_messages( @@ -314,9 +305,7 @@ def stream_messages( system_prompt: Optional[str], messages: Messages, tool_config: Optional[ToolConfig], - callback_handler: Any, - **kwargs: Any, -) -> Tuple[StopReason, Message, Usage, Metrics, Any]: +) -> Generator[dict[str, Any], None, None]: """Streams messages to the model and processes the response. Args: @@ -324,12 +313,9 @@ def stream_messages( system_prompt: The system prompt to send. messages: List of messages to send. tool_config: Configuration for the tools to use. - callback_handler: Callback for processing events as they happen. - **kwargs: Additional keyword arguments that will be passed to the callback handler. - And also returned in the request_state. Returns: - The reason for stopping, the final message, the usage metrics, and updated request state. + The reason for stopping, the final message, and the usage metrics """ logger.debug("model=<%s> | streaming messages", model) @@ -337,4 +323,4 @@ def stream_messages( tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None chunks = model.converse(messages, tool_specs, system_prompt) - return process_stream(chunks, callback_handler, messages, **kwargs) + yield from process_stream(chunks, messages) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index ab427e53d..51089d47e 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -387,15 +387,15 @@ def structured_output( prompt(Messages): The prompt messages to use for the agent. callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. """ + callback_handler = callback_handler or PrintingCallbackHandler() tool_spec = convert_pydantic_to_tool_spec(output_model) response = self.converse(messages=prompt, tool_specs=[tool_spec]) - # process the stream and get the tool use input - results = process_stream( - response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt - ) - - stop_reason, messages, _, _, _ = results + for event in process_stream(response, prompt): + if "callback" in event: + callback_handler(**event["callback"]) + else: + stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 3de41198a..a5ffb539d 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -504,15 +504,15 @@ def structured_output( prompt(Messages): The prompt messages to use for the agent. callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. """ + callback_handler = callback_handler or PrintingCallbackHandler() tool_spec = convert_pydantic_to_tool_spec(output_model) response = self.converse(messages=prompt, tool_specs=[tool_spec]) - # process the stream and get the tool use input - results = process_stream( - response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt - ) - - stop_reason, messages, _, _, _ = results + for event in process_stream(response, prompt): + if "callback" in event: + callback_handler(**event["callback"]) + else: + stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") diff --git a/tests-integ/test_model_anthropic.py b/tests-integ/test_model_anthropic.py index 95bfceb56..50033f8f7 100644 --- a/tests-integ/test_model_anthropic.py +++ b/tests-integ/test_model_anthropic.py @@ -34,7 +34,7 @@ def tool_weather() -> str: @pytest.fixture def system_prompt(): - return "You are an AI assistant that uses & instead of ." + return "You are an AI assistant." @pytest.fixture @@ -47,7 +47,7 @@ def test_agent(agent): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() - assert all(string in text for string in ["12:00", "sunny", "&"]) + assert all(string in text for string in ["12:00", "sunny"]) @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index efdf7af8c..11f145033 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -701,6 +701,106 @@ def test_event_loop_cycle_with_parent_span( ) +def test_event_loop_cycle_callback( + model, + model_id, + system_prompt, + messages, + tool_config, + callback_handler, + tool_handler, + tool_execution_handler, +): + model.converse.return_value = [ + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "value"}}}, + {"contentBlockStop": {}}, + ] + + strands.event_loop.event_loop.event_loop_cycle( + model=model, + model_id=model_id, + system_prompt=system_prompt, + messages=messages, + tool_config=tool_config, + callback_handler=callback_handler, + tool_handler=tool_handler, + tool_execution_handler=tool_execution_handler, + ) + + callback_handler.assert_has_calls( + [ + call(start=True), + call(start_event_loop=True), + call(event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}), + call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), + call( + delta={"toolUse": {"input": '{"value"}'}}, + current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call(event={"contentBlockStart": {"start": {}}}), + call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}), + call( + reasoningText="value", + delta={"reasoningContent": {"text": "value"}}, + reasoning=True, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}), + call( + reasoning_signature="value", + delta={"reasoningContent": {"signature": "value"}}, + reasoning=True, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call(event={"contentBlockStart": {"start": {}}}), + call(event={"contentBlockDelta": {"delta": {"text": "value"}}}), + call( + data="value", + delta={"text": "value"}, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call( + message={ + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}, + {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, + {"text": "value"}, + ], + }, + ), + ], + ) + + def test_request_state_initialization(): # Call without providing request_state tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index c24e7e48a..e91f49867 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -3,6 +3,7 @@ import pytest import strands +import strands.event_loop from strands.types.streaming import ( ContentBlockDeltaEvent, ContentBlockStartEvent, @@ -17,13 +18,6 @@ def moto_autouse(moto_env, moto_mock_aws): _ = moto_mock_aws -@pytest.fixture -def agent(): - mock = unittest.mock.Mock() - - return mock - - @pytest.mark.parametrize( ("messages", "exp_result"), [ @@ -81,7 +75,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) @pytest.mark.parametrize( - ("event", "state", "exp_updated_state", "exp_handler_args"), + ("event", "state", "exp_updated_state", "callback_args"), [ # Tool Use - Existing input ( @@ -148,21 +142,13 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) ), ], ) -def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, exp_handler_args): - if exp_handler_args: - exp_handler_args.update({"delta": event["delta"], "extra_arg": 1}) +def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, callback_args): + exp_callback_event = {"callback": {**callback_args, "delta": event["delta"]}} if callback_args else {} - tru_handler_args = {} - - def callback_handler(**kwargs): - tru_handler_args.update(kwargs) - - tru_updated_state = strands.event_loop.streaming.handle_content_block_delta( - event, state, callback_handler, extra_arg=1 - ) + tru_updated_state, tru_callback_event = strands.event_loop.streaming.handle_content_block_delta(event, state) assert tru_updated_state == exp_updated_state - assert tru_handler_args == exp_handler_args + assert tru_callback_event == exp_callback_event @pytest.mark.parametrize( @@ -275,8 +261,9 @@ def test_extract_usage_metrics(): @pytest.mark.parametrize( - ("response", "exp_stop_reason", "exp_message", "exp_usage", "exp_metrics", "exp_request_state", "exp_messages"), + ("response", "exp_events"), [ + # Standard Message ( [ {"messageStart": {"role": "assistant"}}, @@ -297,28 +284,127 @@ def test_extract_usage_metrics(): } }, ], - "tool_use", - { - "role": "assistant", - "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], - }, - {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, - {"latencyMs": 1}, - {"calls": 1}, - [{"role": "user", "content": [{"text": "Some input!"}]}], + [ + { + "callback": { + "event": { + "messageStart": { + "role": "assistant", + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStart": { + "start": { + "toolUse": { + "name": "test", + "toolUseId": "123", + }, + }, + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "toolUse": { + "input": '{"key": "value"}', + }, + }, + }, + }, + }, + }, + { + "callback": { + "current_tool_use": { + "input": { + "key": "value", + }, + "name": "test", + "toolUseId": "123", + }, + "delta": { + "toolUse": { + "input": '{"key": "value"}', + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "callback": { + "event": { + "messageStop": { + "stopReason": "tool_use", + }, + }, + }, + }, + { + "callback": { + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + }, + }, + }, + }, + { + "stop": ( + "tool_use", + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], + }, + {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + {"latencyMs": 1}, + ) + }, + ], ), + # Empty Message ( [{}], - "end_turn", - { - "role": "assistant", - "content": [], - }, - {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - {"latencyMs": 0}, - {}, - [{"role": "user", "content": [{"text": "Some input!"}]}], + [ + { + "callback": { + "event": {}, + }, + }, + { + "stop": ( + "end_turn", + { + "role": "assistant", + "content": [], + }, + {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + {"latencyMs": 0}, + ), + }, + ], ), + # Redacted Message ( [ {"messageStart": {"role": "assistant"}}, @@ -345,77 +431,161 @@ def test_extract_usage_metrics(): } }, ], - "guardrail_intervened", - { - "role": "assistant", - "content": [{"text": "REDACTED."}], - }, - {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, - {"latencyMs": 1}, - {"calls": 1}, - [{"role": "user", "content": [{"text": "REDACTED"}]}], + [ + { + "callback": { + "event": { + "messageStart": { + "role": "assistant", + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStart": { + "start": {}, + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "text": "Hello!", + }, + }, + }, + }, + }, + { + "callback": { + "data": "Hello!", + "delta": { + "text": "Hello!", + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "callback": { + "event": { + "messageStop": { + "stopReason": "guardrail_intervened", + }, + }, + }, + }, + { + "callback": { + "event": { + "redactContent": { + "redactAssistantContentMessage": "REDACTED.", + "redactUserContentMessage": "REDACTED", + }, + }, + }, + }, + { + "callback": { + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + }, + }, + }, + }, + { + "stop": ( + "guardrail_intervened", + { + "role": "assistant", + "content": [{"text": "REDACTED."}], + }, + {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + {"latencyMs": 1}, + ), + }, + ], ), ], ) -def test_process_stream( - response, exp_stop_reason, exp_message, exp_usage, exp_metrics, exp_request_state, exp_messages -): - def callback_handler(**kwargs): - if "request_state" in kwargs: - kwargs["request_state"].setdefault("calls", 0) - kwargs["request_state"]["calls"] += 1 - - tru_messages = [{"role": "user", "content": [{"text": "Some input!"}]}] - - tru_stop_reason, tru_message, tru_usage, tru_metrics, tru_request_state = ( - strands.event_loop.streaming.process_stream(response, callback_handler, tru_messages) - ) - - assert tru_stop_reason == exp_stop_reason - assert tru_message == exp_message - assert tru_usage == exp_usage - assert tru_metrics == exp_metrics - assert tru_request_state == exp_request_state - assert tru_messages == exp_messages +def test_process_stream(response, exp_events): + messages = [{"role": "user", "content": [{"text": "Some input!"}]}] + stream = strands.event_loop.streaming.process_stream(response, messages) + tru_events = list(stream) + assert tru_events == exp_events -def test_stream_messages(agent): - def callback_handler(**kwargs): - if "request_state" in kwargs: - kwargs["request_state"].setdefault("calls", 0) - kwargs["request_state"]["calls"] += 1 +def test_stream_messages(): mock_model = unittest.mock.MagicMock() mock_model.converse.return_value = [ {"contentBlockDelta": {"delta": {"text": "test"}}}, {"contentBlockStop": {}}, ] - tru_stop_reason, tru_message, tru_usage, tru_metrics, tru_request_state = ( - strands.event_loop.streaming.stream_messages( - mock_model, - model_id="test_model", - system_prompt="test prompt", - messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}], - tool_config=None, - callback_handler=callback_handler, - agent=agent, - ) + stream = strands.event_loop.streaming.stream_messages( + mock_model, + system_prompt="test prompt", + messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}], + tool_config=None, ) - exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test"}]} - exp_usage = {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0} - exp_metrics = {"latencyMs": 0} - exp_request_state = {"calls": 1} - - assert ( - tru_stop_reason == exp_stop_reason - and tru_message == exp_message - and tru_usage == exp_usage - and tru_metrics == exp_metrics - and tru_request_state == exp_request_state - ) + tru_events = list(stream) + exp_events = [ + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "text": "test", + }, + }, + }, + }, + }, + { + "callback": { + "data": "test", + "delta": { + "text": "test", + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "stop": ( + "end_turn", + {"role": "assistant", "content": [{"text": "test"}]}, + {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + {"latencyMs": 0}, + ) + }, + ] + assert tru_events == exp_events mock_model.converse.assert_called_with( [{"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]}],