From 67b06630e5ebebafe4441c96557865ee4be2a3cc Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Fri, 22 Nov 2024 09:34:53 -0500 Subject: [PATCH] fix: non optional model construct in streams --- pyproject.toml | 2 +- tests/models/hello_task.py | 10 ++++ workflowai/core/client/client_test.py | 39 +++++++++++++- workflowai/core/client/models.py | 2 +- workflowai/core/client/models_test.py | 78 ++++++++++++++++++++++++++- 5 files changed, 127 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7428143..bb79da5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.4.0" +version = "0.4.1" description = "" authors = ["Guillaume Aquilina "] readme = "README.md" diff --git a/tests/models/hello_task.py b/tests/models/hello_task.py index 55e44ed..5c42bdd 100644 --- a/tests/models/hello_task.py +++ b/tests/models/hello_task.py @@ -14,3 +14,13 @@ class HelloTaskOutput(BaseModel): class HelloTask(Task[HelloTaskInput, HelloTaskOutput]): input_class: type[HelloTaskInput] = HelloTaskInput output_class: type[HelloTaskOutput] = HelloTaskOutput + + +class HelloTaskOutputNotOptional(HelloTaskOutput): + message: str + another_field: str + + +class HelloTaskNotOptional(Task[HelloTaskInput, HelloTaskOutputNotOptional]): + input_class: type[HelloTaskInput] = HelloTaskInput + output_class: type[HelloTaskOutputNotOptional] = HelloTaskOutputNotOptional diff --git a/workflowai/core/client/client_test.py b/workflowai/core/client/client_test.py index d5b1ca6..7407a67 100644 --- a/workflowai/core/client/client_test.py +++ b/workflowai/core/client/client_test.py @@ -4,7 +4,7 @@ import pytest from pytest_httpx import HTTPXMock, IteratorStream -from tests.models.hello_task import HelloTask, HelloTaskInput, HelloTaskOutput +from tests.models.hello_task import HelloTask, HelloTaskInput, HelloTaskNotOptional, HelloTaskOutput from tests.utils import fixtures_json from workflowai.core.client import Client from workflowai.core.client.client import WorkflowAIClient @@ -70,6 +70,43 @@ async def test_stream(self, httpx_mock: HTTPXMock, client: Client): assert last_message.cost_usd == 0.01 assert last_message.duration_seconds == 10.1 + async def test_stream_not_optional(self, httpx_mock: HTTPXMock, client: Client): + # Checking that streaming works even with non optional fields + # The first two chunks are missing a required key but the last one has it + httpx_mock.add_response( + stream=IteratorStream( + [ + b'data: {"id":"1","task_output":{"message":""}}\n\n', + b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 + b'data: {"id":"1","task_output":{"message":"hello", "another_field": "test"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 + ], + ), + ) + task = HelloTaskNotOptional(id="123", schema_id=1) + + streamed = await client.run( + task, + task_input=HelloTaskInput(name="Alice"), + stream=True, + ) + chunks = [chunk async for chunk in streamed] + + messages = [chunk.task_output.message for chunk in chunks] + assert messages == ["", "hel", "hello", "hello"] + + for chunk in chunks[:-1]: + with pytest.raises(AttributeError): + # Since the field is not optional, it will raise an attribute error + assert chunk.task_output.another_field + assert chunks[-1].task_output.another_field == "test" + + last_message = chunks[-1] + assert isinstance(last_message, Run) + assert last_message.version.properties.model == "gpt-4o" + assert last_message.version.properties.temperature == 0.5 + assert last_message.cost_usd == 0.01 + assert last_message.duration_seconds == 10.1 + async def test_run_with_env(self, httpx_mock: HTTPXMock, client: Client): httpx_mock.add_response(json=fixtures_json("task_run.json")) task = HelloTask(id="123", schema_id=1) diff --git a/workflowai/core/client/models.py b/workflowai/core/client/models.py index 0be02d3..509a8fe 100644 --- a/workflowai/core/client/models.py +++ b/workflowai/core/client/models.py @@ -71,7 +71,7 @@ def to_domain(self, task: Task[TaskInput, TaskOutput]) -> Union[Run[TaskOutput], if self.version is None: return RunChunk( id=self.id, - task_output=task.output_class.model_validate(self.task_output), + task_output=task.output_class.model_construct(None, **self.task_output), ) return Run( diff --git a/workflowai/core/client/models_test.py b/workflowai/core/client/models_test.py index 192d0eb..330b5bf 100644 --- a/workflowai/core/client/models_test.py +++ b/workflowai/core/client/models_test.py @@ -1,7 +1,12 @@ +from typing import Optional + import pytest +from pydantic import BaseModel, ValidationError from tests.utils import fixture_text -from workflowai.core.client.models import RunResponse +from workflowai.core.client.models import RunResponse, RunStreamChunk +from workflowai.core.domain.task import Task +from workflowai.core.domain.task_run import Run, RunChunk @pytest.mark.parametrize( @@ -14,3 +19,74 @@ def test_task_run_response(fixture: str): txt = fixture_text(fixture) task_run = RunResponse.model_validate_json(txt) assert task_run + + +class _TaskOutput(BaseModel): + a: int + b: str + + +class _TaskOutputOpt(BaseModel): + a: Optional[int] = None + b: Optional[str] = None + + +class _Task(Task[_TaskOutput, _TaskOutput]): + id: str = "test-task" + schema_id: int = 1 + input_class: type[_TaskOutput] = _TaskOutput + output_class: type[_TaskOutput] = _TaskOutput + + +class _TaskOpt(Task[_TaskOutputOpt, _TaskOutputOpt]): + id: str = "test-task" + schema_id: int = 1 + input_class: type[_TaskOutputOpt] = _TaskOutputOpt + output_class: type[_TaskOutputOpt] = _TaskOutputOpt + + +class TestRunStreamChunkToDomain: + def test_no_version_not_optional(self): + # Check that partial model is ok + chunk = RunStreamChunk.model_validate_json('{"id": "1", "task_output": {"a": 1}}') + assert chunk.task_output == {"a": 1} + + with pytest.raises(ValidationError): # sanity + _TaskOutput.model_validate({"a": 1}) + + parsed = chunk.to_domain(_Task()) + assert isinstance(parsed, RunChunk) + assert parsed.task_output.a == 1 + # b is not defined + with pytest.raises(AttributeError): + assert parsed.task_output.b + + def test_no_version_optional(self): + chunk = RunStreamChunk.model_validate_json('{"id": "1", "task_output": {"a": 1}}') + assert chunk + + parsed = chunk.to_domain(_TaskOpt()) + assert isinstance(parsed, RunChunk) + assert parsed.task_output.a == 1 + assert parsed.task_output.b is None + + def test_with_version(self): + chunk = RunStreamChunk.model_validate_json( + '{"id": "1", "task_output": {"a": 1, "b": "test"}, "cost_usd": 0.1, "duration_seconds": 1, "version": {"properties": {"a": 1, "b": "test"}}}', # noqa: E501 + ) + assert chunk + + parsed = chunk.to_domain(_Task()) + assert isinstance(parsed, Run) + assert parsed.task_output.a == 1 + assert parsed.task_output.b == "test" + + assert parsed.cost_usd == 0.1 + assert parsed.duration_seconds == 1 + + def test_with_version_validation_fails(self): + chunk = RunStreamChunk.model_validate_json( + '{"id": "1", "task_output": {"a": 1}, "version": {"properties": {"a": 1, "b": "test"}}}', + ) + with pytest.raises(ValidationError): + chunk.to_domain(_Task())