diff --git a/pyproject.toml b/pyproject.toml index 1495254e3..a391180ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,10 @@ openai = [ otel = [ "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", ] +writer = [ + "writer-sdk>=2.2.0,<3.0.0" +] + a2a = [ "a2a-sdk>=0.2.6", "uvicorn>=0.34.2", @@ -95,7 +99,7 @@ a2a = [ source = "vcs" [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel","mistral"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -119,7 +123,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel","mistral"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -135,7 +139,7 @@ extra-args = [ [tool.hatch.envs.dev] dev-mode = true -features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel","mistral"] +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer"] [tool.hatch.envs.a2a] dev-mode = true diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py new file mode 100644 index 000000000..0a5ca4a95 --- /dev/null +++ b/src/strands/models/writer.py @@ -0,0 +1,431 @@ +"""Writer model provider. + +- Docs: https://dev.writer.com/home/introduction +""" + +import base64 +import json +import logging +import mimetypes +from typing import Any, AsyncGenerator, Dict, List, Optional, Type, TypedDict, TypeVar, Union, cast + +import writerai +from pydantic import BaseModel +from typing_extensions import Unpack, override + +from ..types.content import ContentBlock, Messages +from ..types.exceptions import ModelThrottledException +from ..types.models import Model +from ..types.streaming import StreamEvent +from ..types.tools import ToolResult, ToolSpec, ToolUse + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class WriterModel(Model): + """Writer API model provider implementation.""" + + class WriterConfig(TypedDict, total=False): + """Configuration options for Writer API. + + Attributes: + model_id: Model name to use (e.g. palmyra-x5, palmyra-x4, etc.). + max_tokens: Maximum number of tokens to generate. + stop: Default stop sequences. + stream_options: Additional options for streaming. + temperature: What sampling temperature to use. + top_p: Threshold for 'nucleus sampling' + """ + + model_id: str + max_tokens: Optional[int] + stop: Optional[Union[str, List[str]]] + stream_options: Dict[str, Any] + temperature: Optional[float] + top_p: Optional[float] + + def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[WriterConfig]): + """Initialize provider instance. + + Args: + client_args: Arguments for the Writer client (e.g., api_key, base_url, timeout, etc.). + **model_config: Configuration options for the Writer model. + """ + self.config = WriterModel.WriterConfig(**model_config) + + logger.debug("config=<%s> | initializing", self.config) + + client_args = client_args or {} + self.client = writerai.AsyncClient(**client_args) + + @override + def update_config(self, **model_config: Unpack[WriterConfig]) -> None: # type: ignore[override] + """Update the Writer Model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> WriterConfig: + """Get the Writer model configuration. + + Returns: + The Writer model configuration. + """ + return self.config + + def _format_request_message_contents_vision(self, contents: list[ContentBlock]) -> list[dict[str, Any]]: + def _format_content_vision(content: ContentBlock) -> dict[str, Any]: + """Format a Writer content block for Palmyra V5 request. + + - NOTE: "reasoningContent", "document" and "video" are not supported currently. + + Args: + content: Message content. + + Returns: + Writer formatted content block for models, which support vision content format. + + Raises: + TypeError: If the content block type cannot be converted to a Writer-compatible format. + """ + if "text" in content: + return {"text": content["text"], "type": "text"} + + if "image" in content: + mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") + image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + + return { + "image_url": { + "url": f"data:{mime_type};base64,{image_data}", + }, + "type": "image_url", + } + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + return [ + _format_content_vision(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + + def _format_request_message_contents(self, contents: list[ContentBlock]) -> str: + def _format_content(content: ContentBlock) -> str: + """Format a Writer content block for Palmyra models (except V5) request. + + - NOTE: "reasoningContent", "document", "video" and "image" are not supported currently. + + Args: + content: Message content. + + Returns: + Writer formatted content block. + + Raises: + TypeError: If the content block type cannot be converted to a Writer-compatible format. + """ + if "text" in content: + return content["text"] + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + content_blocks = list( + filter( + lambda content: content.get("text") + and not any(block_type in content for block_type in ["toolResult", "toolUse"]), + contents, + ) + ) + + if len(content_blocks) > 1: + raise ValueError( + f"Model with name {self.get_config().get('model_id', 'N/A')} doesn't support multiple contents" + ) + elif len(content_blocks) == 1: + return _format_content(content_blocks[0]) + else: + return "" + + def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]: + """Format a Writer tool call. + + Args: + tool_use: Tool use requested by the model. + + Returns: + Writer formatted tool call. + """ + return { + "function": { + "arguments": json.dumps(tool_use["input"]), + "name": tool_use["name"], + }, + "id": tool_use["toolUseId"], + "type": "function", + } + + def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]: + """Format a Writer tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + Writer formatted tool message. + """ + contents = cast( + list[ContentBlock], + [ + {"text": json.dumps(content["json"])} if "json" in content else content + for content in tool_result["content"] + ], + ) + + if self.get_config().get("model_id", "") == "palmyra-x5": + formatted_contents = self._format_request_message_contents_vision(contents) + else: + formatted_contents = self._format_request_message_contents(contents) # type: ignore [assignment] + + return { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": formatted_contents, + } + + def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format a Writer compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + Writer compatible messages array. + """ + formatted_messages: list[dict[str, Any]] + formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + + for message in messages: + contents = message["content"] + + # Only palmyra V5 support multiple content. Other models support only '{"content": "text_content"}' + if self.get_config().get("model_id", "") == "palmyra-x5": + formatted_contents: str | list[dict[str, Any]] = self._format_request_message_contents_vision(contents) + else: + formatted_contents = self._format_request_message_contents(contents) + + formatted_tool_calls = [ + self._format_request_message_tool_call(content["toolUse"]) + for content in contents + if "toolUse" in content + ] + formatted_tool_messages = [ + self._format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + formatted_message = { + "role": message["role"], + "content": formatted_contents if len(formatted_contents) > 0 else "", + **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), + } + formatted_messages.append(formatted_message) + formatted_messages.extend(formatted_tool_messages) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> Any: + """Format a streaming request to the underlying model. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + The formatted request. + """ + request = { + **{k: v for k, v in self.config.items()}, + "messages": self._format_request_messages(messages, system_prompt), + "stream": True, + } + try: + request["model"] = request.pop( + "model_id" + ) # To be consisted with other models WriterConfig use 'model_id' arg, but Writer API wait for 'model' arg + except KeyError as e: + raise KeyError("Please specify a model ID. Use 'model_id' keyword argument.") from e + + # Writer don't support empty tools attribute + if tool_specs: + request["tools"] = [ + { + "type": "function", + "function": { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + }, + } + for tool_spec in tool_specs + ] + + return request + + @override + def format_chunk(self, event: Any) -> StreamEvent: + """Format the model response events into standardized message chunks. + + Args: + event: A response event from the model. + + Returns: + The formatted chunk. + """ + match event.get("chunk_type", ""): + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_block_start": + if event["data_type"] == "text": + return {"contentBlockStart": {"start": {}}} + + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + + case "content_block_delta": + if event["data_type"] == "text": + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}} + + case "content_block_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].prompt_tokens if event["data"] else 0, + "outputTokens": event["data"].completion_tokens if event["data"] else 0, + "totalTokens": event["data"].total_tokens if event["data"] else 0, + }, # If 'stream_options' param is unset, empty metadata will be provided. + # To avoid errors replacing expected fields with default zero value + "metrics": { + "latencyMs": 0, # All palmyra models don't provide 'latency' metadata + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + + @override + async def stream(self, request: Any) -> AsyncGenerator[Any, None]: + """Send the request to the model and get a streaming response. + + Args: + request: The formatted request to send to the model. + + Returns: + The model's response. + + Raises: + ModelThrottledException: When the model service is throttling requests from the client. + """ + try: + response = await self.client.chat.chat(**request) + except writerai.RateLimitError as e: + raise ModelThrottledException(str(e)) from e + + yield {"chunk_type": "message_start"} + yield {"chunk_type": "content_block_start", "data_type": "text"} + + tool_calls: dict[int, list[Any]] = {} + + async for chunk in response: + if not getattr(chunk, "choices", None): + continue + choice = chunk.choices[0] + + if choice.delta.content: + yield {"chunk_type": "content_block_delta", "data_type": "text", "data": choice.delta.content} + + for tool_call in choice.delta.tool_calls or []: + tool_calls.setdefault(tool_call.index, []).append(tool_call) + + if choice.finish_reason: + break + + yield {"chunk_type": "content_block_stop", "data_type": "text"} + + for tool_deltas in tool_calls.values(): + tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:] + yield {"chunk_type": "content_block_start", "data_type": "tool", "data": tool_start} + + for tool_delta in tool_deltas: + yield {"chunk_type": "content_block_delta", "data_type": "tool", "data": tool_delta} + + yield {"chunk_type": "content_block_stop", "data_type": "tool"} + + yield {"chunk_type": "message_stop", "data": choice.finish_reason} + + # Iterating until the end to fetch metadata chunk + async for chunk in response: + _ = chunk + + yield {"chunk_type": "metadata", "data": chunk.usage} + + @override + async def structured_output( + self, output_model: Type[T], prompt: Messages + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + """ + formatted_request = self.format_request(messages=prompt) + formatted_request["response_format"] = { + "type": "json_schema", + "json_schema": {"schema": output_model.model_json_schema()}, + } + formatted_request["stream"] = False + formatted_request.pop("stream_options", None) + + response = await self.client.chat.chat(**formatted_request) + + try: + content = response.choices[0].message.content.strip() + yield {"output": output_model.model_validate_json(content)} + except Exception as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e diff --git a/tests-integ/test_model_writer.py b/tests-integ/test_model_writer.py new file mode 100644 index 000000000..3469d64ef --- /dev/null +++ b/tests-integ/test_model_writer.py @@ -0,0 +1,97 @@ +import os + +import pytest +from pydantic import BaseModel + +import strands +from strands import Agent +from strands.models.writer import WriterModel + + +@pytest.fixture +def model(): + return WriterModel( + model_id="palmyra-x4", + client_args={"api_key": os.getenv("WRITER_API_KEY", "")}, + stream_options={"include_usage": True}, + ) + + +@pytest.fixture +def system_prompt(): + return "You are a smart assistant, that uses @ instead of all punctuation marks" + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def agent(model, tools, system_prompt): + return Agent(model=model, tools=tools, system_prompt=system_prompt, load_tools_from_directory=False) + + +@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing") +def test_agent(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing") +async def test_agent_async(agent): + result = await agent.invoke_async("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing") +async def test_agent_stream_async(agent): + stream = agent.stream_async("What is the time and weather in New York?") + async for event in stream: + _ = event + + result = event["result"] + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing") +def test_structured_output(agent): + class Weather(BaseModel): + time: str + weather: str + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" + + +@pytest.mark.asyncio +@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing") +async def test_structured_output_async(agent): + class Weather(BaseModel): + time: str + weather: str + + result = await agent.structured_output_async(Weather, "The time is 12:00 and the weather is sunny") + + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests/strands/models/test_writer.py b/tests/strands/models/test_writer.py new file mode 100644 index 000000000..09aa033c5 --- /dev/null +++ b/tests/strands/models/test_writer.py @@ -0,0 +1,396 @@ +import unittest.mock +from typing import Any, List + +import pytest + +import strands +from strands.models.writer import WriterModel + + +@pytest.fixture +def writer_client_cls(): + with unittest.mock.patch.object(strands.models.writer.writerai, "AsyncClient") as mock_client_cls: + yield mock_client_cls + + +@pytest.fixture +def writer_client(writer_client_cls): + return writer_client_cls.return_value + + +@pytest.fixture +def client_args(): + return {"api_key": "writer_api_key"} + + +@pytest.fixture +def model_id(): + return "palmyra-x5" + + +@pytest.fixture +def stream_options(): + return {"include_usage": True} + + +@pytest.fixture +def model(writer_client, model_id, stream_options, client_args): + _ = writer_client + + return WriterModel(client_args, model_id=model_id, stream_options=stream_options) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def system_prompt(): + return "System prompt" + + +def test__init__(writer_client_cls, model_id, stream_options, client_args): + model = WriterModel(client_args=client_args, model_id=model_id, stream_options=stream_options) + + config = model.get_config() + exp_config = {"stream_options": stream_options, "model_id": model_id} + + assert config == exp_config + + writer_client_cls.assert_called_once_with(api_key=client_args.get("api_key", "")) + + +def test_update_config(model): + model.update_config(model_id="palmyra-x4") + + model_id = model.get_config().get("model_id") + + assert model_id == "palmyra-x4" + + +def test_format_request_basic(model, messages, model_id, stream_options): + request = model.format_request(messages) + + exp_request = { + "stream": True, + "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], + "model": model_id, + "stream_options": stream_options, + } + + assert request == exp_request + + +def test_format_request_with_params(model, messages, model_id, stream_options): + model.update_config(temperature=0.19) + + request = model.format_request(messages) + exp_request = { + "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], + "model": model_id, + "stream_options": stream_options, + "temperature": 0.19, + "stream": True, + } + + assert request == exp_request + + +def test_format_request_with_system_prompt(model, messages, model_id, stream_options, system_prompt): + request = model.format_request(messages, system_prompt=system_prompt) + + exp_request = { + "messages": [ + {"content": "System prompt", "role": "system"}, + {"content": [{"text": "test", "type": "text"}], "role": "user"}, + ], + "model": model_id, + "stream_options": stream_options, + "stream": True, + } + + assert request == exp_request + + +def test_format_request_with_tool_use(model, model_id, stream_options): + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "c1", + "name": "calculator", + "input": {"expression": "2+2"}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + exp_request = { + "messages": [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": {"arguments": '{"expression": "2+2"}', "name": "calculator"}, + "id": "c1", + "type": "function", + } + ], + }, + ], + "model": model_id, + "stream_options": stream_options, + "stream": True, + } + + assert request == exp_request + + +def test_format_request_with_tool_results(model, model_id, stream_options): + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "c1", + "status": "success", + "content": [ + {"text": "answer is 4"}, + ], + } + } + ], + } + ] + + request = model.format_request(messages) + exp_request = { + "messages": [ + { + "role": "tool", + "content": [{"text": "answer is 4", "type": "text"}], + "tool_call_id": "c1", + }, + ], + "model": model_id, + "stream_options": stream_options, + "stream": True, + } + + assert request == exp_request + + +def test_format_request_with_image(model, model_id, stream_options): + messages = [ + { + "role": "user", + "content": [ + { + "image": { + "format": "png", + "source": {"bytes": b"lovely sunny day"}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + exp_request = { + "messages": [ + { + "role": "user", + "content": [ + { + "image_url": { + "url": "", + }, + "type": "image_url", + }, + ], + }, + ], + "model": model_id, + "stream": True, + "stream_options": stream_options, + } + + assert request == exp_request + + +def test_format_request_with_empty_content(model, model_id, stream_options): + messages = [ + { + "role": "user", + "content": [], + }, + ] + + tru_request = model.format_request(messages) + exp_request = { + "messages": [], + "model": model_id, + "stream_options": stream_options, + "stream": True, + } + + assert tru_request == exp_request + + +@pytest.mark.parametrize( + ("content", "content_type"), + [ + ({"video": {}}, "video"), + ({"document": {}}, "document"), + ({"reasoningContent": {}}, "reasoningContent"), + ({"other": {}}, "other"), + ], +) +def test_format_request_with_unsupported_type(model, content, content_type): + messages = [ + { + "role": "user", + "content": [content], + }, + ] + + with pytest.raises(TypeError, match=f"content_type=<{content_type}> | unsupported type"): + model.format_request(messages) + + +class AsyncStreamWrapper: + def __init__(self, items: List[Any]): + self.items = items + + def __aiter__(self): + return self._generator() + + async def _generator(self): + for item in self.items: + yield item + + +async def mock_streaming_response(items: List[Any]): + return AsyncStreamWrapper(items) + + +@pytest.mark.asyncio +async def test_stream(writer_client, model, model_id): + mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) + mock_delta_1 = unittest.mock.Mock( + content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1] + ) + + mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) + mock_delta_2 = unittest.mock.Mock( + content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2] + ) + + mock_delta_3 = unittest.mock.Mock(content="", tool_calls=None) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_3)]) + mock_event_4 = unittest.mock.Mock() + + writer_client.chat.chat.return_value = mock_streaming_response( + [mock_event_1, mock_event_2, mock_event_3, mock_event_4] + ) + + request = { + "model": model_id, + "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}], + } + response = model.stream(request) + + events = [event async for event in response] + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_block_start", "data_type": "text"}, + {"chunk_type": "content_block_delta", "data_type": "text", "data": "I'll calculate"}, + {"chunk_type": "content_block_delta", "data_type": "text", "data": "that for you"}, + {"chunk_type": "content_block_stop", "data_type": "text"}, + {"chunk_type": "content_block_start", "data_type": "tool", "data": mock_tool_call_1_part_1}, + {"chunk_type": "content_block_delta", "data_type": "tool", "data": mock_tool_call_1_part_2}, + {"chunk_type": "content_block_stop", "data_type": "tool"}, + {"chunk_type": "content_block_start", "data_type": "tool", "data": mock_tool_call_2_part_1}, + {"chunk_type": "content_block_delta", "data_type": "tool", "data": mock_tool_call_2_part_2}, + {"chunk_type": "content_block_stop", "data_type": "tool"}, + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"chunk_type": "metadata", "data": mock_event_4.usage}, + ] + + assert events == exp_events + writer_client.chat.chat(**request) + + +@pytest.mark.asyncio +async def test_stream_empty(writer_client, model, model_id): + mock_delta = unittest.mock.Mock(content=None, tool_calls=None) + mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock() + mock_event_4 = unittest.mock.Mock(usage=mock_usage) + + writer_client.chat.chat.return_value = mock_streaming_response( + [mock_event_1, mock_event_2, mock_event_3, mock_event_4] + ) + + request = {"model": model_id, "messages": [{"role": "user", "content": []}]} + response = model.stream(request) + + events = [event async for event in response] + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_block_start", "data_type": "text"}, + {"chunk_type": "content_block_stop", "data_type": "text"}, + {"chunk_type": "message_stop", "data": "stop"}, + {"chunk_type": "metadata", "data": mock_usage}, + ] + + assert events == exp_events + writer_client.chat.chat.assert_called_once_with(**request) + + +@pytest.mark.asyncio +async def test_stream_with_empty_choices(writer_client, model, model_id): + mock_delta = unittest.mock.Mock(content="content", tool_calls=None) + mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + + mock_event_1 = unittest.mock.Mock(spec=[]) + mock_event_2 = unittest.mock.Mock(choices=[]) + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_5 = unittest.mock.Mock(usage=mock_usage) + + writer_client.chat.chat.return_value = mock_streaming_response( + [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5] + ) + + request = {"model": model_id, "messages": [{"role": "user", "content": ["test"]}]} + response = model.stream(request) + + events = [event async for event in response] + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_block_start", "data_type": "text"}, + {"chunk_type": "content_block_delta", "data_type": "text", "data": "content"}, + {"chunk_type": "content_block_delta", "data_type": "text", "data": "content"}, + {"chunk_type": "content_block_stop", "data_type": "text"}, + {"chunk_type": "message_stop", "data": "stop"}, + {"chunk_type": "metadata", "data": mock_usage}, + ] + + assert events == exp_events + writer_client.chat.chat.assert_called_once_with(**request)