Skip to content

Commit 7ac4574

Browse files
committed
fix: fix methods signature to pass updated tests
1 parent b0c192c commit 7ac4574

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

src/strands/models/writer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import json
88
import logging
99
import mimetypes
10-
from typing import Any, Dict, Generator, Iterable, List, Optional, Type, TypedDict, TypeVar, Union, cast
10+
from typing import Any, AsyncGenerator, Dict, List, Optional, Type, TypedDict, TypeVar, Union, cast
1111

1212
import writerai
1313
from pydantic import BaseModel
@@ -349,7 +349,7 @@ def format_chunk(self, event: Any) -> StreamEvent:
349349
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
350350

351351
@override
352-
def stream(self, request: Any) -> Iterable[Any]:
352+
async def stream(self, request: Any) -> AsyncGenerator[Any, None]:
353353
"""Send the request to the model and get a streaming response.
354354
355355
Args:
@@ -405,9 +405,9 @@ def stream(self, request: Any) -> Iterable[Any]:
405405
yield {"chunk_type": "metadata", "data": chunk.usage}
406406

407407
@override
408-
def structured_output(
408+
async def structured_output(
409409
self, output_model: Type[T], prompt: Messages
410-
) -> Generator[dict[str, Union[T, Any]], None, None]:
410+
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
411411
"""Get structured output from the model.
412412
413413
Args:

tests/strands/models/test_writer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ def test_format_request_with_unsupported_type(model, content, content_type):
264264
model.format_request(messages)
265265

266266

267-
def test_stream(writer_client, model, model_id):
267+
@pytest.mark.asyncio
268+
async def test_stream(writer_client, model, model_id):
268269
mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
269270
mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
270271
mock_delta_1 = unittest.mock.Mock(
@@ -292,7 +293,7 @@ def test_stream(writer_client, model, model_id):
292293
}
293294
response = model.stream(request)
294295

295-
events = list(response)
296+
events = [event async for event in response]
296297
exp_events = [
297298
{"chunk_type": "message_start"},
298299
{"chunk_type": "content_block_start", "data_type": "text"},
@@ -313,7 +314,8 @@ def test_stream(writer_client, model, model_id):
313314
writer_client.chat.chat(**request)
314315

315316

316-
def test_stream_empty(writer_client, model, model_id):
317+
@pytest.mark.asyncio
318+
async def test_stream_empty(writer_client, model, model_id):
317319
mock_delta = unittest.mock.Mock(content=None, tool_calls=None)
318320
mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0)
319321

@@ -327,7 +329,7 @@ def test_stream_empty(writer_client, model, model_id):
327329
request = {"model": model_id, "messages": [{"role": "user", "content": []}]}
328330
response = model.stream(request)
329331

330-
events = list(response)
332+
events = [event async for event in response]
331333
exp_events = [
332334
{"chunk_type": "message_start"},
333335
{"chunk_type": "content_block_start", "data_type": "text"},
@@ -340,7 +342,8 @@ def test_stream_empty(writer_client, model, model_id):
340342
writer_client.chat.chat.assert_called_once_with(**request)
341343

342344

343-
def test_stream_with_empty_choices(writer_client, model, model_id):
345+
@pytest.mark.asyncio
346+
async def test_stream_with_empty_choices(writer_client, model, model_id):
344347
mock_delta = unittest.mock.Mock(content="content", tool_calls=None)
345348
mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
346349

@@ -355,7 +358,7 @@ def test_stream_with_empty_choices(writer_client, model, model_id):
355358
request = {"model": model_id, "messages": [{"role": "user", "content": ["test"]}]}
356359
response = model.stream(request)
357360

358-
events = list(response)
361+
events = [event async for event in response]
359362
exp_events = [
360363
{"chunk_type": "message_start"},
361364
{"chunk_type": "content_block_start", "data_type": "text"},

0 commit comments

Comments
 (0)