Skip to content

Fix model construct #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "workflowai"
version = "0.6.0.dev19"
version = "0.6.0.dev20"
description = ""
authors = ["Guillaume Aquilina <[email protected]>"]
readme = "README.md"
Expand Down
31 changes: 30 additions & 1 deletion tests/e2e/no_schema_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Optional

from pydantic import BaseModel
from pydantic import BaseModel, Field

import workflowai
from workflowai.core.client.agent import Agent


class SummarizeTaskInput(BaseModel):
Expand All @@ -28,3 +29,31 @@ async def test_summarize():
use_cache="never",
)
assert summarized.summary_points


async def test_same_schema():
class InputWithNullableList(BaseModel):
opt_list: Optional[list[str]] = None

class InputWithNonNullableList(BaseModel):
opt_list: list[str] = Field(default_factory=list)

agent1 = Agent(
agent_id="summarize",
input_cls=InputWithNullableList,
output_cls=SummarizeTaskOutput,
api=lambda: workflowai.shared_client.api,
)

schema_id1 = await agent1.register()

agent2 = Agent(
agent_id="summarize",
input_cls=InputWithNonNullableList,
output_cls=SummarizeTaskOutput,
api=lambda: workflowai.shared_client.api,
)

schema_id2 = await agent2.register()

assert schema_id1 == schema_id2
20 changes: 10 additions & 10 deletions tests/integration/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToC
task_input = CityToCapitalTaskInput(city="Hello")
chunks = [chunk async for chunk in city_to_capital(task_input)]

assert chunks == [
CityToCapitalTaskOutput(capital=""),
CityToCapitalTaskOutput(capital="Tok"),
CityToCapitalTaskOutput(capital="Tokyo"),
CityToCapitalTaskOutput(capital="Tokyo"),
assert [chunk.capital for chunk in chunks] == [
"",
"Tok",
"Tokyo",
"Tokyo",
]


Expand All @@ -94,11 +94,11 @@ def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToC
task_input = CityToCapitalTaskInput(city="Hello")
chunks = [chunk async for chunk in city_to_capital(task_input)]

assert chunks == [
CityToCapitalTaskOutput(capital=""),
CityToCapitalTaskOutput(capital="Tok"),
CityToCapitalTaskOutput(capital="Tokyo"),
CityToCapitalTaskOutput(capital="Tokyo"),
assert [chunk.capital for chunk in chunks] == [
"",
"Tok",
"Tokyo",
"Tokyo",
]


Expand Down
13 changes: 12 additions & 1 deletion workflowai/core/_common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,18 @@


class OutputValidator(Protocol, Generic[AgentOutputCov]):
def __call__(self, data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutputCov: ...
def __call__(self, data: dict[str, Any], partial: bool) -> AgentOutputCov:
"""A way to convert a json object into an AgentOutput

Args:
data (dict[str, Any]): The json object to convert
partial (bool): Whether the json is partial, meaning that
it may not contain all the fields required by the AgentOutput model.

Returns:
AgentOutputCov: The converted AgentOutput
"""
...


class VersionRunParams(TypedDict):
Expand Down
6 changes: 2 additions & 4 deletions workflowai/core/client/_api.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, Literal, Optional, TypeVar, Union, overload

import httpx
from pydantic import BaseModel, TypeAdapter, ValidationError

from workflowai.core._logger import logger
from workflowai.core.domain.errors import BaseError, ErrorResponse, WorkflowAIError

# A type for return values
_R = TypeVar("_R")
_M = TypeVar("_M", bound=BaseModel)

_logger = logging.getLogger("WorkflowAI")


class APIClient:
def __init__(self, url: str, api_key: str, source_headers: Optional[dict[str, str]] = None):
Expand Down Expand Up @@ -154,7 +152,7 @@ async def _wrap_sse(self, raw: AsyncIterator[bytes], termination_chars: bytes =
in_data = False

if data:
_logger.warning("Data left after processing", extra={"data": data})
logger.warning("Data left after processing", extra={"data": data})

async def stream(
self,
Expand Down
5 changes: 3 additions & 2 deletions workflowai/core/client/_fn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
RunParams,
RunTemplate,
)
from workflowai.core.client._utils import intolerant_validator
from workflowai.core.client._utils import default_validator
from workflowai.core.client.agent import Agent
from workflowai.core.domain.errors import InvalidGenerationError
from workflowai.core.domain.model import ModelOrStr
Expand Down Expand Up @@ -144,14 +144,15 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp
except InvalidGenerationError as e:
if e.partial_output and e.run_id:
try:
validator, _ = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls))
validator, _ = self._sanitize_validator(kwargs, default_validator(self.output_cls))
run = self._build_run_no_tools(
chunk=RunResponse(
id=e.run_id,
task_output=e.partial_output,
),
schema_id=self.schema_id or 0,
validator=validator,
partial=False,
)
run.error = e.error
return run
Expand Down
8 changes: 7 additions & 1 deletion workflowai/core/client/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,18 @@ def to_domain(
task_id: str,
task_schema_id: int,
validator: OutputValidator[AgentOutput],
partial: Optional[bool] = None,
) -> Run[AgentOutput]:
# We do partial validation if either:
# - there are tool call requests, which means that the output can be empty
# - the run has not yet finished, for example when streaming, in which case the duration_seconds is None
if partial is None:
partial = bool(self.tool_call_requests) or self.duration_seconds is None
return Run(
id=self.id,
agent_id=task_id,
schema_id=task_schema_id,
output=validator(self.task_output, self.tool_call_requests is not None),
output=validator(self.task_output, partial),
version=self.version and self.version.to_domain(),
duration_seconds=self.duration_seconds,
cost_usd=self.cost_usd,
Expand Down
23 changes: 13 additions & 10 deletions workflowai/core/client/_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from tests.utils import fixture_text
from workflowai.core.client._models import RunResponse
from workflowai.core.client._utils import intolerant_validator, tolerant_validator
from workflowai.core.client._utils import default_validator
from workflowai.core.domain.run import Run
from workflowai.core.domain.tool_call import ToolCallRequest

Expand Down Expand Up @@ -41,29 +41,30 @@ def test_no_version_not_optional(self):
with pytest.raises(ValidationError): # sanity
_TaskOutput.model_validate({"a": 1})

parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput))
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput))
assert isinstance(parsed, Run)
assert parsed.output.a == 1
# b is not defined
with pytest.raises(AttributeError):
assert parsed.output.b

assert parsed.output.b == ""

def test_no_version_optional(self):
chunk = RunResponse.model_validate_json('{"id": "1", "task_output": {"a": 1}}')
assert chunk

parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutputOpt))
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutputOpt))
assert isinstance(parsed, Run)
assert parsed.output.a == 1
assert parsed.output.b is None

def test_with_version(self):
"""Full output is validated since the duration is passed and there are no tool calls"""
chunk = RunResponse.model_validate_json(
'{"id": "1", "task_output": {"a": 1, "b": "test"}, "cost_usd": 0.1, "duration_seconds": 1, "version": {"properties": {"a": 1, "b": "test"}}}', # noqa: E501
)
assert chunk

parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput))
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput))
assert isinstance(parsed, Run)
assert parsed.output.a == 1
assert parsed.output.b == "test"
Expand All @@ -73,17 +74,19 @@ def test_with_version(self):

def test_with_version_validation_fails(self):
chunk = RunResponse.model_validate_json(
'{"id": "1", "task_output": {"a": 1}, "version": {"properties": {"a": 1, "b": "test"}}}',
"""{"id": "1", "task_output": {"a": 1},
"version": {"properties": {"a": 1, "b": "test"}}, "duration_seconds": 1}""",
)
with pytest.raises(ValidationError):
chunk.to_domain(task_id="1", task_schema_id=1, validator=intolerant_validator(_TaskOutput))
chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput))

def test_with_tool_calls(self):
chunk = RunResponse.model_validate_json(
'{"id": "1", "task_output": {}, "tool_call_requests": [{"id": "1", "name": "test", "input": {"a": 1}}]}',
"""{"id": "1", "task_output": {},
"tool_call_requests": [{"id": "1", "name": "test", "input": {"a": 1}}], "duration_seconds": 1}""",
)
assert chunk

parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput))
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput))
assert isinstance(parsed, Run)
assert parsed.tool_call_requests == [ToolCallRequest(id="1", name="test", input={"a": 1})]
19 changes: 6 additions & 13 deletions workflowai/core/client/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from workflowai.core.domain.errors import BaseError, WorkflowAIError
from workflowai.core.domain.task import AgentOutput
from workflowai.core.domain.version_reference import VersionReference
from workflowai.core.utils._pydantic import partial_model

delimiter = re.compile(r'\}\n\ndata: \{"')

Expand Down Expand Up @@ -86,20 +87,12 @@ async def _wait_for_exception(e: WorkflowAIError):
return _should_retry, _wait_for_exception


def tolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]:
def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput: # noqa: ARG001
return m.model_construct(None, **data)
def default_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]:
partial_cls = partial_model(m)

return _validator


def intolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]:
def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput:
# When we have tool call requests, the output can be empty
if has_tool_call_requests:
return m.model_construct(None, **data)

return m.model_validate(data)
def _validator(data: dict[str, Any], partial: bool) -> AgentOutput:
model_cls = partial_cls if partial else m
return model_cls.model_validate(data)

return _validator

Expand Down
29 changes: 28 additions & 1 deletion workflowai/core/client/_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
from unittest.mock import Mock, patch

import pytest
from pydantic import BaseModel

from workflowai.core.client._utils import build_retryable_wait, global_default_version_reference, split_chunks
from workflowai.core.client._utils import (
build_retryable_wait,
default_validator,
global_default_version_reference,
split_chunks,
)
from workflowai.core.domain.errors import BaseError, WorkflowAIError


Expand Down Expand Up @@ -42,3 +48,24 @@ async def test_should_retry_count(self, request_error: WorkflowAIError):
def test_global_default_version_reference(env_var: str, expected: Any):
with patch.dict("os.environ", {"WORKFLOWAI_DEFAULT_VERSION": env_var}):
assert global_default_version_reference() == expected


# Create a nested object with only required properties
class Recipe(BaseModel):
class Ingredient(BaseModel):
name: str
quantity: int

ingredients: list[Ingredient]


class TestValidator:
def test_tolerant_validator_nested_object(self):
validated = default_validator(Recipe)(
{
"ingredients": [{"name": "salt"}],
},
partial=True,
)
for ingredient in validated.ingredients:
assert isinstance(ingredient, Recipe.Ingredient)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to check ingredient.name == "salt" ?

Loading