diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 82bbb1eaf..95eb23078 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -177,7 +177,8 @@ async def stream( async for event in response: _ = event - yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + if event.usage: + yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) logger.debug("finished streaming response from model") diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 6374590b9..9a2a87f6a 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -394,7 +394,8 @@ async def stream( async for event in response: _ = event - yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + if event.usage: + yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) logger.debug("finished streaming response from model") diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index bddd44abb..44b6df63b 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -197,6 +197,42 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, litellm_acompletion.assert_called_once_with(**expected_request) +@pytest.mark.asyncio +async def test_stream_empty(litellm_acompletion, api_key, model_id, model, agenerator, alist): + mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock() + mock_event_4 = unittest.mock.Mock(usage=None) + + litellm_acompletion.side_effect = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + ) + + messages = [{"role": "user", "content": []}] + response = model.stream(messages) + + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + + assert len(tru_events) == len(exp_events) + expected_request = { + "api_key": api_key, + "model": model_id, + "messages": [], + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [], + } + litellm_acompletion.assert_called_once_with(**expected_request) + + @pytest.mark.asyncio async def test_structured_output(litellm_acompletion, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 0a095ab9d..a7c97701c 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -382,7 +382,7 @@ def test_format_chunk_unknown_type(model): @pytest.mark.asyncio -async def test_stream(openai_client, model, agenerator, alist): +async def test_stream(openai_client, model_id, model, agenerator, alist): mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) mock_delta_1 = unittest.mock.Mock( @@ -465,7 +465,7 @@ async def test_stream(openai_client, model, agenerator, alist): # Verify that format_request was called with the correct arguments expected_request = { "max_tokens": 1, - "model": "m1", + "model": model_id, "messages": [{"role": "user", "content": [{"text": "calculate 2+2", "type": "text"}]}], "stream": True, "stream_options": {"include_usage": True}, @@ -475,14 +475,13 @@ async def test_stream(openai_client, model, agenerator, alist): @pytest.mark.asyncio -async def test_stream_empty(openai_client, model, agenerator, alist): +async def test_stream_empty(openai_client, model_id, model, agenerator, alist): mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None) - mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0) mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) mock_event_3 = unittest.mock.Mock() - mock_event_4 = unittest.mock.Mock(usage=mock_usage) + mock_event_4 = unittest.mock.Mock(usage=None) openai_client.chat.completions.create = unittest.mock.AsyncMock( return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4]), @@ -497,13 +496,12 @@ async def test_stream_empty(openai_client, model, agenerator, alist): {"contentBlockStart": {"start": {}}}, {"contentBlockStop": {}}, {"messageStop": {"stopReason": "end_turn"}}, - {"metadata": {"usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, "metrics": {"latencyMs": 0}}}, ] assert len(tru_events) == len(exp_events) expected_request = { "max_tokens": 1, - "model": "m1", + "model": model_id, "messages": [], "stream": True, "stream_options": {"include_usage": True},