Skip to content

Non optional model construct in streams #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "workflowai"
version = "0.4.0"
version = "0.4.1"
description = ""
authors = ["Guillaume Aquilina <[email protected]>"]
readme = "README.md"
Expand Down
10 changes: 10 additions & 0 deletions tests/models/hello_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,13 @@ class HelloTaskOutput(BaseModel):
class HelloTask(Task[HelloTaskInput, HelloTaskOutput]):
input_class: type[HelloTaskInput] = HelloTaskInput
output_class: type[HelloTaskOutput] = HelloTaskOutput


class HelloTaskOutputNotOptional(HelloTaskOutput):
message: str
another_field: str


class HelloTaskNotOptional(Task[HelloTaskInput, HelloTaskOutputNotOptional]):
input_class: type[HelloTaskInput] = HelloTaskInput
output_class: type[HelloTaskOutputNotOptional] = HelloTaskOutputNotOptional
39 changes: 38 additions & 1 deletion workflowai/core/client/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from pytest_httpx import HTTPXMock, IteratorStream

from tests.models.hello_task import HelloTask, HelloTaskInput, HelloTaskOutput
from tests.models.hello_task import HelloTask, HelloTaskInput, HelloTaskNotOptional, HelloTaskOutput
from tests.utils import fixtures_json
from workflowai.core.client import Client
from workflowai.core.client.client import WorkflowAIClient
Expand Down Expand Up @@ -70,6 +70,43 @@ async def test_stream(self, httpx_mock: HTTPXMock, client: Client):
assert last_message.cost_usd == 0.01
assert last_message.duration_seconds == 10.1

async def test_stream_not_optional(self, httpx_mock: HTTPXMock, client: Client):
# Checking that streaming works even with non optional fields
# The first two chunks are missing a required key but the last one has it
httpx_mock.add_response(
stream=IteratorStream(
[
b'data: {"id":"1","task_output":{"message":""}}\n\n',
b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501
b'data: {"id":"1","task_output":{"message":"hello", "another_field": "test"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501
],
),
)
task = HelloTaskNotOptional(id="123", schema_id=1)

streamed = await client.run(
task,
task_input=HelloTaskInput(name="Alice"),
stream=True,
)
chunks = [chunk async for chunk in streamed]

messages = [chunk.task_output.message for chunk in chunks]
assert messages == ["", "hel", "hello", "hello"]

for chunk in chunks[:-1]:
with pytest.raises(AttributeError):
# Since the field is not optional, it will raise an attribute error
assert chunk.task_output.another_field
assert chunks[-1].task_output.another_field == "test"

last_message = chunks[-1]
assert isinstance(last_message, Run)
assert last_message.version.properties.model == "gpt-4o"
assert last_message.version.properties.temperature == 0.5
assert last_message.cost_usd == 0.01
assert last_message.duration_seconds == 10.1

async def test_run_with_env(self, httpx_mock: HTTPXMock, client: Client):
httpx_mock.add_response(json=fixtures_json("task_run.json"))
task = HelloTask(id="123", schema_id=1)
Expand Down
2 changes: 1 addition & 1 deletion workflowai/core/client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def to_domain(self, task: Task[TaskInput, TaskOutput]) -> Union[Run[TaskOutput],
if self.version is None:
return RunChunk(
id=self.id,
task_output=task.output_class.model_validate(self.task_output),
task_output=task.output_class.model_construct(None, **self.task_output),
)

return Run(
Expand Down
78 changes: 77 additions & 1 deletion workflowai/core/client/models_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Optional

import pytest
from pydantic import BaseModel, ValidationError

from tests.utils import fixture_text
from workflowai.core.client.models import RunResponse
from workflowai.core.client.models import RunResponse, RunStreamChunk
from workflowai.core.domain.task import Task
from workflowai.core.domain.task_run import Run, RunChunk


@pytest.mark.parametrize(
Expand All @@ -14,3 +19,74 @@ def test_task_run_response(fixture: str):
txt = fixture_text(fixture)
task_run = RunResponse.model_validate_json(txt)
assert task_run


class _TaskOutput(BaseModel):
a: int
b: str


class _TaskOutputOpt(BaseModel):
a: Optional[int] = None
b: Optional[str] = None


class _Task(Task[_TaskOutput, _TaskOutput]):
id: str = "test-task"
schema_id: int = 1
input_class: type[_TaskOutput] = _TaskOutput
output_class: type[_TaskOutput] = _TaskOutput


class _TaskOpt(Task[_TaskOutputOpt, _TaskOutputOpt]):
id: str = "test-task"
schema_id: int = 1
input_class: type[_TaskOutputOpt] = _TaskOutputOpt
output_class: type[_TaskOutputOpt] = _TaskOutputOpt


class TestRunStreamChunkToDomain:
def test_no_version_not_optional(self):
# Check that partial model is ok
chunk = RunStreamChunk.model_validate_json('{"id": "1", "task_output": {"a": 1}}')
assert chunk.task_output == {"a": 1}

with pytest.raises(ValidationError): # sanity
_TaskOutput.model_validate({"a": 1})

parsed = chunk.to_domain(_Task())
assert isinstance(parsed, RunChunk)
assert parsed.task_output.a == 1
# b is not defined
with pytest.raises(AttributeError):
assert parsed.task_output.b

def test_no_version_optional(self):
chunk = RunStreamChunk.model_validate_json('{"id": "1", "task_output": {"a": 1}}')
assert chunk

parsed = chunk.to_domain(_TaskOpt())
assert isinstance(parsed, RunChunk)
assert parsed.task_output.a == 1
assert parsed.task_output.b is None

def test_with_version(self):
chunk = RunStreamChunk.model_validate_json(
'{"id": "1", "task_output": {"a": 1, "b": "test"}, "cost_usd": 0.1, "duration_seconds": 1, "version": {"properties": {"a": 1, "b": "test"}}}', # noqa: E501
)
assert chunk

parsed = chunk.to_domain(_Task())
assert isinstance(parsed, Run)
assert parsed.task_output.a == 1
assert parsed.task_output.b == "test"

assert parsed.cost_usd == 0.1
assert parsed.duration_seconds == 1

def test_with_version_validation_fails(self):
chunk = RunStreamChunk.model_validate_json(
'{"id": "1", "task_output": {"a": 1}, "version": {"properties": {"a": 1, "b": "test"}}}',
)
with pytest.raises(ValidationError):
chunk.to_domain(_Task())
Loading