Skip to content

Commit e75cca1

Browse files
authored
Merge pull request #89 from WorkflowAI/guillaume/fix-completions
Fix completion object parsing
2 parents 70ae66a + 1bf2107 commit e75cca1

File tree

2 files changed

+178
-9
lines changed

2 files changed

+178
-9
lines changed

workflowai/core/domain/completion.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Annotated, Any, Literal, Optional, Union
22

33
from pydantic import BaseModel, Field
44

@@ -18,22 +18,67 @@ class CompletionUsage(BaseModel):
1818
model_context_window_size: Optional[int] = None
1919

2020

21+
class TextContent(BaseModel):
22+
type: Literal["text"] = "text"
23+
text: str
24+
25+
26+
class DocumentURL(BaseModel):
27+
url: str
28+
29+
30+
class DocumentContent(BaseModel):
31+
type: Literal["document_url"] = "document_url"
32+
source: DocumentURL
33+
34+
35+
class ImageURL(BaseModel):
36+
url: str
37+
38+
39+
class AudioURL(BaseModel):
40+
url: str
41+
42+
43+
class ImageContent(BaseModel):
44+
type: Literal["image_url"] = "image_url"
45+
image_url: ImageURL
46+
47+
48+
class AudioContent(BaseModel):
49+
type: Literal["audio_url"] = "audio_url"
50+
audio_url: AudioURL
51+
52+
53+
class ToolCallRequest(BaseModel):
54+
type: Literal["tool_call_request"] = "tool_call_request"
55+
id: Union[str, None] = None
56+
tool_name: str
57+
tool_input_dict: Union[dict[str, Any], None] = None
58+
59+
60+
class ToolCallResult(BaseModel):
61+
type: Literal["tool_call_result"] = "tool_call_result"
62+
id: Union[str, None] = None
63+
tool_name: Union[str, None] = None
64+
tool_input_dict: Union[dict[str, Any], None] = None
65+
result: Union[Any, None] = None
66+
error: Union[str, None] = None
67+
68+
69+
MessageContent = Annotated[Union[TextContent, DocumentContent, ImageContent, AudioContent], Field(discriminator="type")]
70+
71+
2172
class Message(BaseModel):
2273
"""A message in a completion."""
2374

2475
role: str = ""
25-
content: str = ""
76+
content: Union[str, MessageContent] = Field(default="")
2677

2778

2879
class Completion(BaseModel):
2980
"""A completion from the model."""
3081

3182
messages: list[Message] = Field(default_factory=list)
32-
response: Optional[str] = None
83+
response: Optional[str] = Field(default=None)
3384
usage: CompletionUsage = Field(default_factory=CompletionUsage)
34-
35-
36-
class CompletionsResponse(BaseModel):
37-
"""Response from the completions API endpoint."""
38-
39-
completions: list[Completion]
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import pytest
2+
from pydantic import ValidationError
3+
4+
from workflowai.core.domain.completion import AudioContent, DocumentContent, ImageContent, Message, TextContent
5+
6+
7+
class TestMessage:
8+
def test_basic_text(self):
9+
# Test basic text message validation
10+
json_str = '{"role": "user", "content": "Hello, world!"}'
11+
message = Message.model_validate_json(json_str)
12+
assert message.role == "user"
13+
assert message.content == "Hello, world!"
14+
15+
def test_with_text_content(self):
16+
# Test message with TextContent
17+
json_str = """
18+
{
19+
"role": "assistant",
20+
"content": {
21+
"type": "text",
22+
"text": "This is a test message"
23+
}
24+
}
25+
"""
26+
message = Message.model_validate_json(json_str)
27+
assert message.role == "assistant"
28+
assert isinstance(message.content, TextContent)
29+
assert message.content.text == "This is a test message"
30+
31+
def test_with_document_content(self):
32+
# Test message with DocumentContent
33+
json_str = """
34+
{
35+
"role": "user",
36+
"content": {
37+
"type": "document_url",
38+
"source": {
39+
"url": "https://example.com/doc.pdf"
40+
}
41+
}
42+
}
43+
"""
44+
message = Message.model_validate_json(json_str)
45+
assert message.role == "user"
46+
assert isinstance(message.content, DocumentContent)
47+
assert message.content.source.url == "https://example.com/doc.pdf"
48+
49+
def test_with_image_content(self):
50+
# Test message with ImageContent
51+
json_str = """
52+
{
53+
"role": "user",
54+
"content": {
55+
"type": "image_url",
56+
"image_url": {
57+
"url": "https://example.com/image.jpg"
58+
}
59+
}
60+
}
61+
"""
62+
message = Message.model_validate_json(json_str)
63+
assert message.role == "user"
64+
assert isinstance(message.content, ImageContent)
65+
assert message.content.image_url.url == "https://example.com/image.jpg"
66+
67+
def test_with_audio_content(self):
68+
# Test message with AudioContent
69+
json_str = """
70+
{
71+
"role": "user",
72+
"content": {
73+
"type": "audio_url",
74+
"audio_url": {
75+
"url": "https://example.com/audio.mp3"
76+
}
77+
}
78+
}
79+
"""
80+
message = Message.model_validate_json(json_str)
81+
assert message.role == "user"
82+
assert isinstance(message.content, AudioContent)
83+
assert message.content.audio_url.url == "https://example.com/audio.mp3"
84+
85+
def test_empty_role(self):
86+
# Test message with empty role
87+
json_str = '{"role": "", "content": "Test message"}'
88+
message = Message.model_validate_json(json_str)
89+
assert message.role == ""
90+
assert message.content == "Test message"
91+
92+
def test_missing_role(self):
93+
# Test message with missing role
94+
json_str = '{"content": "Test message"}'
95+
message = Message.model_validate_json(json_str)
96+
assert message.role == "" # Default value
97+
assert message.content == "Test message"
98+
99+
def test_invalid_content_type(self):
100+
# Test message with invalid content type
101+
json_str = """
102+
{
103+
"role": "user",
104+
"content": {
105+
"type": "invalid_type",
106+
"text": "This should fail"
107+
}
108+
}
109+
"""
110+
with pytest.raises(ValidationError):
111+
Message.model_validate_json(json_str)
112+
113+
def test_missing_content(self):
114+
# Test message with missing content
115+
json_str = '{"role": "user"}'
116+
message = Message.model_validate_json(json_str)
117+
assert message.role == "user"
118+
assert message.content == "" # Default value
119+
120+
def test_invalid_json(self):
121+
# Test with invalid JSON string
122+
json_str = '{"role": "user", "content": "Test message"' # Missing closing brace
123+
with pytest.raises(ValidationError):
124+
Message.model_validate_json(json_str)

0 commit comments

Comments
 (0)