Skip to content

Commit fbb490c

Browse files
authored
Merge pull request #21 from workflowai/guillaume/fix-optional-fields
Non optional model construct in streams
2 parents a56afde + 67b0663 commit fbb490c

File tree

5 files changed

+127
-4
lines changed

5 files changed

+127
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "workflowai"
3-
version = "0.4.0"
3+
version = "0.4.1"
44
description = ""
55
authors = ["Guillaume Aquilina <[email protected]>"]
66
readme = "README.md"

tests/models/hello_task.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,13 @@ class HelloTaskOutput(BaseModel):
1414
class HelloTask(Task[HelloTaskInput, HelloTaskOutput]):
1515
input_class: type[HelloTaskInput] = HelloTaskInput
1616
output_class: type[HelloTaskOutput] = HelloTaskOutput
17+
18+
19+
class HelloTaskOutputNotOptional(HelloTaskOutput):
20+
message: str
21+
another_field: str
22+
23+
24+
class HelloTaskNotOptional(Task[HelloTaskInput, HelloTaskOutputNotOptional]):
25+
input_class: type[HelloTaskInput] = HelloTaskInput
26+
output_class: type[HelloTaskOutputNotOptional] = HelloTaskOutputNotOptional

workflowai/core/client/client_test.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
from pytest_httpx import HTTPXMock, IteratorStream
66

7-
from tests.models.hello_task import HelloTask, HelloTaskInput, HelloTaskOutput
7+
from tests.models.hello_task import HelloTask, HelloTaskInput, HelloTaskNotOptional, HelloTaskOutput
88
from tests.utils import fixtures_json
99
from workflowai.core.client import Client
1010
from workflowai.core.client.client import WorkflowAIClient
@@ -70,6 +70,43 @@ async def test_stream(self, httpx_mock: HTTPXMock, client: Client):
7070
assert last_message.cost_usd == 0.01
7171
assert last_message.duration_seconds == 10.1
7272

73+
async def test_stream_not_optional(self, httpx_mock: HTTPXMock, client: Client):
74+
# Checking that streaming works even with non optional fields
75+
# The first two chunks are missing a required key but the last one has it
76+
httpx_mock.add_response(
77+
stream=IteratorStream(
78+
[
79+
b'data: {"id":"1","task_output":{"message":""}}\n\n',
80+
b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501
81+
b'data: {"id":"1","task_output":{"message":"hello", "another_field": "test"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501
82+
],
83+
),
84+
)
85+
task = HelloTaskNotOptional(id="123", schema_id=1)
86+
87+
streamed = await client.run(
88+
task,
89+
task_input=HelloTaskInput(name="Alice"),
90+
stream=True,
91+
)
92+
chunks = [chunk async for chunk in streamed]
93+
94+
messages = [chunk.task_output.message for chunk in chunks]
95+
assert messages == ["", "hel", "hello", "hello"]
96+
97+
for chunk in chunks[:-1]:
98+
with pytest.raises(AttributeError):
99+
# Since the field is not optional, it will raise an attribute error
100+
assert chunk.task_output.another_field
101+
assert chunks[-1].task_output.another_field == "test"
102+
103+
last_message = chunks[-1]
104+
assert isinstance(last_message, Run)
105+
assert last_message.version.properties.model == "gpt-4o"
106+
assert last_message.version.properties.temperature == 0.5
107+
assert last_message.cost_usd == 0.01
108+
assert last_message.duration_seconds == 10.1
109+
73110
async def test_run_with_env(self, httpx_mock: HTTPXMock, client: Client):
74111
httpx_mock.add_response(json=fixtures_json("task_run.json"))
75112
task = HelloTask(id="123", schema_id=1)

workflowai/core/client/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def to_domain(self, task: Task[TaskInput, TaskOutput]) -> Union[Run[TaskOutput],
7171
if self.version is None:
7272
return RunChunk(
7373
id=self.id,
74-
task_output=task.output_class.model_validate(self.task_output),
74+
task_output=task.output_class.model_construct(None, **self.task_output),
7575
)
7676

7777
return Run(

workflowai/core/client/models_test.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
from typing import Optional
2+
13
import pytest
4+
from pydantic import BaseModel, ValidationError
25

36
from tests.utils import fixture_text
4-
from workflowai.core.client.models import RunResponse
7+
from workflowai.core.client.models import RunResponse, RunStreamChunk
8+
from workflowai.core.domain.task import Task
9+
from workflowai.core.domain.task_run import Run, RunChunk
510

611

712
@pytest.mark.parametrize(
@@ -14,3 +19,74 @@ def test_task_run_response(fixture: str):
1419
txt = fixture_text(fixture)
1520
task_run = RunResponse.model_validate_json(txt)
1621
assert task_run
22+
23+
24+
class _TaskOutput(BaseModel):
25+
a: int
26+
b: str
27+
28+
29+
class _TaskOutputOpt(BaseModel):
30+
a: Optional[int] = None
31+
b: Optional[str] = None
32+
33+
34+
class _Task(Task[_TaskOutput, _TaskOutput]):
35+
id: str = "test-task"
36+
schema_id: int = 1
37+
input_class: type[_TaskOutput] = _TaskOutput
38+
output_class: type[_TaskOutput] = _TaskOutput
39+
40+
41+
class _TaskOpt(Task[_TaskOutputOpt, _TaskOutputOpt]):
42+
id: str = "test-task"
43+
schema_id: int = 1
44+
input_class: type[_TaskOutputOpt] = _TaskOutputOpt
45+
output_class: type[_TaskOutputOpt] = _TaskOutputOpt
46+
47+
48+
class TestRunStreamChunkToDomain:
49+
def test_no_version_not_optional(self):
50+
# Check that partial model is ok
51+
chunk = RunStreamChunk.model_validate_json('{"id": "1", "task_output": {"a": 1}}')
52+
assert chunk.task_output == {"a": 1}
53+
54+
with pytest.raises(ValidationError): # sanity
55+
_TaskOutput.model_validate({"a": 1})
56+
57+
parsed = chunk.to_domain(_Task())
58+
assert isinstance(parsed, RunChunk)
59+
assert parsed.task_output.a == 1
60+
# b is not defined
61+
with pytest.raises(AttributeError):
62+
assert parsed.task_output.b
63+
64+
def test_no_version_optional(self):
65+
chunk = RunStreamChunk.model_validate_json('{"id": "1", "task_output": {"a": 1}}')
66+
assert chunk
67+
68+
parsed = chunk.to_domain(_TaskOpt())
69+
assert isinstance(parsed, RunChunk)
70+
assert parsed.task_output.a == 1
71+
assert parsed.task_output.b is None
72+
73+
def test_with_version(self):
74+
chunk = RunStreamChunk.model_validate_json(
75+
'{"id": "1", "task_output": {"a": 1, "b": "test"}, "cost_usd": 0.1, "duration_seconds": 1, "version": {"properties": {"a": 1, "b": "test"}}}', # noqa: E501
76+
)
77+
assert chunk
78+
79+
parsed = chunk.to_domain(_Task())
80+
assert isinstance(parsed, Run)
81+
assert parsed.task_output.a == 1
82+
assert parsed.task_output.b == "test"
83+
84+
assert parsed.cost_usd == 0.1
85+
assert parsed.duration_seconds == 1
86+
87+
def test_with_version_validation_fails(self):
88+
chunk = RunStreamChunk.model_validate_json(
89+
'{"id": "1", "task_output": {"a": 1}, "version": {"properties": {"a": 1, "b": "test"}}}',
90+
)
91+
with pytest.raises(ValidationError):
92+
chunk.to_domain(_Task())

0 commit comments

Comments
 (0)