Skip to content

Commit fd8a163

Browse files
authored
Merge pull request #71 from WorkflowAI/guillaume/fix-model-construct
Fix model construct
2 parents 0a90daa + b4c5753 commit fd8a163

File tree

16 files changed

+1071
-654
lines changed

16 files changed

+1071
-654
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "workflowai"
3-
version = "0.6.0.dev19"
3+
version = "0.6.0.dev20"
44
description = ""
55
authors = ["Guillaume Aquilina <[email protected]>"]
66
readme = "README.md"

tests/e2e/no_schema_test.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from typing import Optional
22

3-
from pydantic import BaseModel
3+
from pydantic import BaseModel, Field
44

55
import workflowai
6+
from workflowai.core.client.agent import Agent
67

78

89
class SummarizeTaskInput(BaseModel):
@@ -28,3 +29,31 @@ async def test_summarize():
2829
use_cache="never",
2930
)
3031
assert summarized.summary_points
32+
33+
34+
async def test_same_schema():
35+
class InputWithNullableList(BaseModel):
36+
opt_list: Optional[list[str]] = None
37+
38+
class InputWithNonNullableList(BaseModel):
39+
opt_list: list[str] = Field(default_factory=list)
40+
41+
agent1 = Agent(
42+
agent_id="summarize",
43+
input_cls=InputWithNullableList,
44+
output_cls=SummarizeTaskOutput,
45+
api=lambda: workflowai.shared_client.api,
46+
)
47+
48+
schema_id1 = await agent1.register()
49+
50+
agent2 = Agent(
51+
agent_id="summarize",
52+
input_cls=InputWithNonNullableList,
53+
output_cls=SummarizeTaskOutput,
54+
api=lambda: workflowai.shared_client.api,
55+
)
56+
57+
schema_id2 = await agent2.register()
58+
59+
assert schema_id1 == schema_id2

tests/integration/run_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,11 @@ def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToC
7777
task_input = CityToCapitalTaskInput(city="Hello")
7878
chunks = [chunk async for chunk in city_to_capital(task_input)]
7979

80-
assert chunks == [
81-
CityToCapitalTaskOutput(capital=""),
82-
CityToCapitalTaskOutput(capital="Tok"),
83-
CityToCapitalTaskOutput(capital="Tokyo"),
84-
CityToCapitalTaskOutput(capital="Tokyo"),
80+
assert [chunk.capital for chunk in chunks] == [
81+
"",
82+
"Tok",
83+
"Tokyo",
84+
"Tokyo",
8585
]
8686

8787

@@ -94,11 +94,11 @@ def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToC
9494
task_input = CityToCapitalTaskInput(city="Hello")
9595
chunks = [chunk async for chunk in city_to_capital(task_input)]
9696

97-
assert chunks == [
98-
CityToCapitalTaskOutput(capital=""),
99-
CityToCapitalTaskOutput(capital="Tok"),
100-
CityToCapitalTaskOutput(capital="Tokyo"),
101-
CityToCapitalTaskOutput(capital="Tokyo"),
97+
assert [chunk.capital for chunk in chunks] == [
98+
"",
99+
"Tok",
100+
"Tokyo",
101+
"Tokyo",
102102
]
103103

104104

workflowai/core/_common_types.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,18 @@
1919

2020

2121
class OutputValidator(Protocol, Generic[AgentOutputCov]):
22-
def __call__(self, data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutputCov: ...
22+
def __call__(self, data: dict[str, Any], partial: bool) -> AgentOutputCov:
23+
"""A way to convert a json object into an AgentOutput
24+
25+
Args:
26+
data (dict[str, Any]): The json object to convert
27+
partial (bool): Whether the json is partial, meaning that
28+
it may not contain all the fields required by the AgentOutput model.
29+
30+
Returns:
31+
AgentOutputCov: The converted AgentOutput
32+
"""
33+
...
2334

2435

2536
class VersionRunParams(TypedDict):

workflowai/core/client/_api.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
1-
import logging
21
from collections.abc import AsyncIterator
32
from contextlib import asynccontextmanager
43
from typing import Any, Literal, Optional, TypeVar, Union, overload
54

65
import httpx
76
from pydantic import BaseModel, TypeAdapter, ValidationError
87

8+
from workflowai.core._logger import logger
99
from workflowai.core.domain.errors import BaseError, ErrorResponse, WorkflowAIError
1010

1111
# A type for return values
1212
_R = TypeVar("_R")
1313
_M = TypeVar("_M", bound=BaseModel)
1414

15-
_logger = logging.getLogger("WorkflowAI")
16-
1715

1816
class APIClient:
1917
def __init__(self, url: str, api_key: str, source_headers: Optional[dict[str, str]] = None):
@@ -154,7 +152,7 @@ async def _wrap_sse(self, raw: AsyncIterator[bytes], termination_chars: bytes =
154152
in_data = False
155153

156154
if data:
157-
_logger.warning("Data left after processing", extra={"data": data})
155+
logger.warning("Data left after processing", extra={"data": data})
158156

159157
async def stream(
160158
self,

workflowai/core/client/_fn_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
RunParams,
2424
RunTemplate,
2525
)
26-
from workflowai.core.client._utils import intolerant_validator
26+
from workflowai.core.client._utils import default_validator
2727
from workflowai.core.client.agent import Agent
2828
from workflowai.core.domain.errors import InvalidGenerationError
2929
from workflowai.core.domain.model import ModelOrStr
@@ -144,14 +144,15 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp
144144
except InvalidGenerationError as e:
145145
if e.partial_output and e.run_id:
146146
try:
147-
validator, _ = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls))
147+
validator, _ = self._sanitize_validator(kwargs, default_validator(self.output_cls))
148148
run = self._build_run_no_tools(
149149
chunk=RunResponse(
150150
id=e.run_id,
151151
task_output=e.partial_output,
152152
),
153153
schema_id=self.schema_id or 0,
154154
validator=validator,
155+
partial=False,
155156
)
156157
run.error = e.error
157158
return run

workflowai/core/client/_models.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,18 @@ def to_domain(
134134
task_id: str,
135135
task_schema_id: int,
136136
validator: OutputValidator[AgentOutput],
137+
partial: Optional[bool] = None,
137138
) -> Run[AgentOutput]:
139+
# We do partial validation if either:
140+
# - there are tool call requests, which means that the output can be empty
141+
# - the run has not yet finished, for example when streaming, in which case the duration_seconds is None
142+
if partial is None:
143+
partial = bool(self.tool_call_requests) or self.duration_seconds is None
138144
return Run(
139145
id=self.id,
140146
agent_id=task_id,
141147
schema_id=task_schema_id,
142-
output=validator(self.task_output, self.tool_call_requests is not None),
148+
output=validator(self.task_output, partial),
143149
version=self.version and self.version.to_domain(),
144150
duration_seconds=self.duration_seconds,
145151
cost_usd=self.cost_usd,

workflowai/core/client/_models_test.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from tests.utils import fixture_text
77
from workflowai.core.client._models import RunResponse
8-
from workflowai.core.client._utils import intolerant_validator, tolerant_validator
8+
from workflowai.core.client._utils import default_validator
99
from workflowai.core.domain.run import Run
1010
from workflowai.core.domain.tool_call import ToolCallRequest
1111

@@ -41,29 +41,30 @@ def test_no_version_not_optional(self):
4141
with pytest.raises(ValidationError): # sanity
4242
_TaskOutput.model_validate({"a": 1})
4343

44-
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput))
44+
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput))
4545
assert isinstance(parsed, Run)
4646
assert parsed.output.a == 1
4747
# b is not defined
48-
with pytest.raises(AttributeError):
49-
assert parsed.output.b
48+
49+
assert parsed.output.b == ""
5050

5151
def test_no_version_optional(self):
5252
chunk = RunResponse.model_validate_json('{"id": "1", "task_output": {"a": 1}}')
5353
assert chunk
5454

55-
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutputOpt))
55+
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutputOpt))
5656
assert isinstance(parsed, Run)
5757
assert parsed.output.a == 1
5858
assert parsed.output.b is None
5959

6060
def test_with_version(self):
61+
"""Full output is validated since the duration is passed and there are no tool calls"""
6162
chunk = RunResponse.model_validate_json(
6263
'{"id": "1", "task_output": {"a": 1, "b": "test"}, "cost_usd": 0.1, "duration_seconds": 1, "version": {"properties": {"a": 1, "b": "test"}}}', # noqa: E501
6364
)
6465
assert chunk
6566

66-
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput))
67+
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput))
6768
assert isinstance(parsed, Run)
6869
assert parsed.output.a == 1
6970
assert parsed.output.b == "test"
@@ -73,17 +74,19 @@ def test_with_version(self):
7374

7475
def test_with_version_validation_fails(self):
7576
chunk = RunResponse.model_validate_json(
76-
'{"id": "1", "task_output": {"a": 1}, "version": {"properties": {"a": 1, "b": "test"}}}',
77+
"""{"id": "1", "task_output": {"a": 1},
78+
"version": {"properties": {"a": 1, "b": "test"}}, "duration_seconds": 1}""",
7779
)
7880
with pytest.raises(ValidationError):
79-
chunk.to_domain(task_id="1", task_schema_id=1, validator=intolerant_validator(_TaskOutput))
81+
chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput))
8082

8183
def test_with_tool_calls(self):
8284
chunk = RunResponse.model_validate_json(
83-
'{"id": "1", "task_output": {}, "tool_call_requests": [{"id": "1", "name": "test", "input": {"a": 1}}]}',
85+
"""{"id": "1", "task_output": {},
86+
"tool_call_requests": [{"id": "1", "name": "test", "input": {"a": 1}}], "duration_seconds": 1}""",
8487
)
8588
assert chunk
8689

87-
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput))
90+
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput))
8891
assert isinstance(parsed, Run)
8992
assert parsed.tool_call_requests == [ToolCallRequest(id="1", name="test", input={"a": 1})]

workflowai/core/client/_utils.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from workflowai.core.domain.errors import BaseError, WorkflowAIError
1414
from workflowai.core.domain.task import AgentOutput
1515
from workflowai.core.domain.version_reference import VersionReference
16+
from workflowai.core.utils._pydantic import partial_model
1617

1718
delimiter = re.compile(r'\}\n\ndata: \{"')
1819

@@ -86,20 +87,12 @@ async def _wait_for_exception(e: WorkflowAIError):
8687
return _should_retry, _wait_for_exception
8788

8889

89-
def tolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]:
90-
def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput: # noqa: ARG001
91-
return m.model_construct(None, **data)
90+
def default_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]:
91+
partial_cls = partial_model(m)
9292

93-
return _validator
94-
95-
96-
def intolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]:
97-
def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput:
98-
# When we have tool call requests, the output can be empty
99-
if has_tool_call_requests:
100-
return m.model_construct(None, **data)
101-
102-
return m.model_validate(data)
93+
def _validator(data: dict[str, Any], partial: bool) -> AgentOutput:
94+
model_cls = partial_cls if partial else m
95+
return model_cls.model_validate(data)
10396

10497
return _validator
10598

workflowai/core/client/_utils_test.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@
22
from unittest.mock import Mock, patch
33

44
import pytest
5+
from pydantic import BaseModel
56

6-
from workflowai.core.client._utils import build_retryable_wait, global_default_version_reference, split_chunks
7+
from workflowai.core.client._utils import (
8+
build_retryable_wait,
9+
default_validator,
10+
global_default_version_reference,
11+
split_chunks,
12+
)
713
from workflowai.core.domain.errors import BaseError, WorkflowAIError
814

915

@@ -42,3 +48,24 @@ async def test_should_retry_count(self, request_error: WorkflowAIError):
4248
def test_global_default_version_reference(env_var: str, expected: Any):
4349
with patch.dict("os.environ", {"WORKFLOWAI_DEFAULT_VERSION": env_var}):
4450
assert global_default_version_reference() == expected
51+
52+
53+
# Create a nested object with only required properties
54+
class Recipe(BaseModel):
55+
class Ingredient(BaseModel):
56+
name: str
57+
quantity: int
58+
59+
ingredients: list[Ingredient]
60+
61+
62+
class TestValidator:
63+
def test_tolerant_validator_nested_object(self):
64+
validated = default_validator(Recipe)(
65+
{
66+
"ingredients": [{"name": "salt"}],
67+
},
68+
partial=True,
69+
)
70+
for ingredient in validated.ingredients:
71+
assert isinstance(ingredient, Recipe.Ingredient)

0 commit comments

Comments
 (0)