From ad6ee6d04623d17493937714e945a346994d6efb Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Tue, 4 Mar 2025 15:27:20 -0500 Subject: [PATCH 1/8] test: add test_same_schema --- tests/e2e/no_schema_test.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/e2e/no_schema_test.py b/tests/e2e/no_schema_test.py index 3cc8784..111977f 100644 --- a/tests/e2e/no_schema_test.py +++ b/tests/e2e/no_schema_test.py @@ -1,8 +1,9 @@ from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field import workflowai +from workflowai.core.client.agent import Agent class SummarizeTaskInput(BaseModel): @@ -28,3 +29,31 @@ async def test_summarize(): use_cache="never", ) assert summarized.summary_points + + +async def test_same_schema(): + class InputWithNullableList(BaseModel): + opt_list: Optional[list[str]] = None + + class InputWithNonNullableList(BaseModel): + opt_list: list[str] = Field(default_factory=list) + + agent1 = Agent( + agent_id="summarize", + input_cls=InputWithNullableList, + output_cls=SummarizeTaskOutput, + api=lambda: workflowai.shared_client.api, + ) + + schema_id1 = await agent1.register() + + agent2 = Agent( + agent_id="summarize", + input_cls=InputWithNonNullableList, + output_cls=SummarizeTaskOutput, + api=lambda: workflowai.shared_client.api, + ) + + schema_id2 = await agent2.register() + + assert schema_id1 == schema_id2 From 55e988158fb44a821f5d023cfb53d42e5c799437 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Tue, 4 Mar 2025 16:37:07 -0500 Subject: [PATCH 2/8] fix: nested model construct on streams --- pyproject.toml | 2 +- workflowai/core/client/_utils.py | 5 +- workflowai/core/client/_utils_test.py | 29 +++++- workflowai/core/utils/_pydantic.py | 79 +++++++++++++++ workflowai/core/utils/_pydantic_test.py | 123 ++++++++++++++++++++++++ 5 files changed, 234 insertions(+), 4 deletions(-) create mode 100644 workflowai/core/utils/_pydantic.py create mode 100644 workflowai/core/utils/_pydantic_test.py diff --git a/pyproject.toml b/pyproject.toml index 37b210e..dfd99ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.6.0.dev19" +version = "0.6.0.dev20" description = "" authors = ["Guillaume Aquilina "] readme = "README.md" diff --git a/workflowai/core/client/_utils.py b/workflowai/core/client/_utils.py index c846ef8..12bcb61 100644 --- a/workflowai/core/client/_utils.py +++ b/workflowai/core/client/_utils.py @@ -13,6 +13,7 @@ from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.version_reference import VersionReference +from workflowai.core.utils._pydantic import construct_model_recursive delimiter = re.compile(r'\}\n\ndata: \{"') @@ -88,7 +89,7 @@ async def _wait_for_exception(e: WorkflowAIError): def tolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput: # noqa: ARG001 - return m.model_construct(None, **data) + return construct_model_recursive(m, data) return _validator @@ -97,7 +98,7 @@ def intolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput: # When we have tool call requests, the output can be empty if has_tool_call_requests: - return m.model_construct(None, **data) + return tolerant_validator(m)(data, has_tool_call_requests) return m.model_validate(data) diff --git a/workflowai/core/client/_utils_test.py b/workflowai/core/client/_utils_test.py index 350584c..dcf4f79 100644 --- a/workflowai/core/client/_utils_test.py +++ b/workflowai/core/client/_utils_test.py @@ -2,8 +2,14 @@ from unittest.mock import Mock, patch import pytest +from pydantic import BaseModel -from workflowai.core.client._utils import build_retryable_wait, global_default_version_reference, split_chunks +from workflowai.core.client._utils import ( + build_retryable_wait, + global_default_version_reference, + split_chunks, + tolerant_validator, +) from workflowai.core.domain.errors import BaseError, WorkflowAIError @@ -42,3 +48,24 @@ async def test_should_retry_count(self, request_error: WorkflowAIError): def test_global_default_version_reference(env_var: str, expected: Any): with patch.dict("os.environ", {"WORKFLOWAI_DEFAULT_VERSION": env_var}): assert global_default_version_reference() == expected + + +# Create a nested object with only required properties +class Recipe(BaseModel): + class Ingredient(BaseModel): + name: str + quantity: int + + ingredients: list[Ingredient] + + +class TestTolerantValidator: + def test_tolerant_validator_nested_object(self): + validated = tolerant_validator(Recipe)( + { + "ingredients": [{"name": "salt"}], + }, + has_tool_call_requests=False, + ) + for ingredient in validated.ingredients: + assert isinstance(ingredient, Recipe.Ingredient) diff --git a/workflowai/core/utils/_pydantic.py b/workflowai/core/utils/_pydantic.py new file mode 100644 index 0000000..c57d6b5 --- /dev/null +++ b/workflowai/core/utils/_pydantic.py @@ -0,0 +1,79 @@ +from collections.abc import Collection, Mapping +from typing import Any, Optional, get_args, get_origin + +from pydantic import BaseModel + +from workflowai.core.utils._vars import BM + + +def _safe_issubclass(cls: type[Any], other: type[Any]) -> bool: + try: + return issubclass(cls, other) + except TypeError: + return False + + +def _construct_list(annotation: type[Any], payload: list[Any]) -> Collection[Any]: + origin = get_origin(annotation) + # If the type annotation is not a collection, the we can't handle it so we return as is + if not origin or not _safe_issubclass(origin, Collection): + return payload + + args = get_args(annotation) + if not args: + return payload + + constructor = set if _safe_issubclass(origin, set) else list + + if len(args) == 1 and _safe_issubclass(args[0], BaseModel): + return constructor([construct_model_recursive(args[0], item) for item in payload]) # pyright: ignore [reportUnknownVariableType, reportCallIssue] + + return payload + + +def _construct_object(annotation: type[Any], payload: dict[str, Any]) -> Any: + if _safe_issubclass(annotation, BaseModel): + return construct_model_recursive(annotation, payload) # pyright: ignore [reportUnknownVariableType, reportArgumentType] + + # Try to map dict of objects + origin = get_origin(annotation) + if not origin or not _safe_issubclass(origin, Mapping): + return payload + + args = get_args(annotation) + if len(args) != 2: + return payload + + key_type, value_type = args + if key_type is not str: + return payload + return {k: _construct_for_annotation(value_type, v) for k, v in payload.items()} + + +def _construct_for_annotation(annotation: Optional[type[Any]], payload: Any) -> Any: + if annotation is None: + return payload + + if isinstance(payload, dict): + return _construct_object(annotation, payload) # pyright: ignore [reportUnknownArgumentType] + if isinstance(payload, list): + return _construct_list(annotation, payload) # pyright: ignore [reportUnknownArgumentType] + + return payload + + +# It does not look like there is an easy way to construct models from partial json objects +# - https://github.com/team23/pydantic-partial uses a heavy approach by constructing a new dynamic +# model class with non required fields +# - partial validation https://docs.pydantic.dev/latest/concepts/experimental/#partial-validation +# handles partial jsons but still validates that each field is present so it fails in our case +# where we just want to handle missing fields +def construct_model_recursive(model: type[BM], payload: dict[str, Any]) -> BM: + """ + Recursively calls model construct to build a model from partial json object + """ + mapped: dict[str, Any] = {} + for k, v in payload.items(): + field = model.model_fields[k] + mapped[k] = _construct_for_annotation(field.annotation, v) + return model.model_construct(None, **mapped) diff --git a/workflowai/core/utils/_pydantic_test.py b/workflowai/core/utils/_pydantic_test.py new file mode 100644 index 0000000..dcb78e3 --- /dev/null +++ b/workflowai/core/utils/_pydantic_test.py @@ -0,0 +1,123 @@ +from typing import Any + +import pytest +from pydantic import BaseModel + +from workflowai.core.utils._pydantic import construct_model_recursive + + +class TestConstructModelRecursive: + def test_simple_model(self): + class SimpleModel(BaseModel): + name1: str + name2: str + + constructed = construct_model_recursive(SimpleModel, {"name1": "John"}) + assert isinstance(constructed, SimpleModel) + assert constructed.name1 == "John" + with pytest.raises(AttributeError): + _ = constructed.name2 + + def test_list_of_strings(self): + class ListOfStrings(BaseModel): + strings: list[str] + + constructed = construct_model_recursive(ListOfStrings, {"strings": ["a", "b"]}) + assert isinstance(constructed, ListOfStrings) + assert constructed.strings == ["a", "b"] + + @pytest.mark.parametrize( + "payload", + [ + {}, + {"field1": "hello"}, + {"nested": {"name": "hello", "field2": "world"}}, + {"nested": {"name": "hello"}}, + ], + ) + def test_nested_model(self, payload: dict[str, Any]): + class NestedModel(BaseModel): + name: str + field2: str + + class OuterModel(BaseModel): + field1: str + nested: NestedModel + + constructed = construct_model_recursive(OuterModel, payload) + assert isinstance(constructed, OuterModel), "constructed is not an instance of OuterModel" + if "field1" in payload: + assert isinstance(constructed.field1, str), "field1 is not a string" + else: + with pytest.raises(AttributeError): + _ = constructed.field1 + + if "nested" in payload: + assert isinstance(constructed.nested, NestedModel), "nested is not an instance of NestedModel" + if "name" in payload["nested"]: + assert isinstance(constructed.nested.name, str) + else: + with pytest.raises(AttributeError): + _ = constructed.nested.name + + if "field2" in payload["nested"]: + assert isinstance(constructed.nested.field2, str) + else: + with pytest.raises(AttributeError): + _ = constructed.nested.field2 + else: + with pytest.raises(AttributeError): + _ = constructed.nested + + def test_list_of_models(self): + class NestedModel(BaseModel): + name: str + field2: str + + class ListOfModels(BaseModel): + models: list[NestedModel] + + constructed = construct_model_recursive( + ListOfModels, + {"models": [{"name": "hello", "field2": "world"}, {"name": "hello"}]}, + ) + assert isinstance(constructed, ListOfModels) + assert isinstance(constructed.models, list) + assert isinstance(constructed.models[0], NestedModel) + assert isinstance(constructed.models[1], NestedModel) + + def test_set_of_models(self): + class NestedModel(BaseModel): + name: str = "1" + field2: str + + def __hash__(self) -> int: + return hash(self.name) + + class SetOfModels(BaseModel): + models: set[NestedModel] + + constructed = construct_model_recursive( + SetOfModels, + {"models": [{"name": "hello", "field2": "world"}, {"name": "hello"}]}, + ) + assert isinstance(constructed, SetOfModels) + assert isinstance(constructed.models, set) + assert all(isinstance(model, NestedModel) for model in constructed.models) + + def test_dict_of_models(self): + class NestedModel(BaseModel): + name: str + field2: str + + class DictOfModels(BaseModel): + models: dict[str, NestedModel] + + constructed = construct_model_recursive( + DictOfModels, + {"models": {"hello": {"name": "hello", "field2": "world"}, "hello2": {"name": "hello"}}}, + ) + assert isinstance(constructed, DictOfModels) + assert isinstance(constructed.models, dict) + assert isinstance(constructed.models["hello"], NestedModel) + assert isinstance(constructed.models["hello2"], NestedModel) From d0ea9e78578841757becfb1513e81546f6a77fb5 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Tue, 4 Mar 2025 21:08:42 -0500 Subject: [PATCH 3/8] fix: issue with streaming partial nested objects --- workflowai/core/_common_types.py | 13 ++++++++++++- workflowai/core/client/_fn_utils.py | 5 +++-- workflowai/core/client/_models.py | 8 +++++++- workflowai/core/client/_models_test.py | 19 +++++++++++-------- workflowai/core/client/_utils.py | 19 ++++++++----------- workflowai/core/client/_utils_test.py | 8 ++++---- workflowai/core/client/agent.py | 14 +++++++------- 7 files changed, 52 insertions(+), 34 deletions(-) diff --git a/workflowai/core/_common_types.py b/workflowai/core/_common_types.py index ad090c4..ea49c85 100644 --- a/workflowai/core/_common_types.py +++ b/workflowai/core/_common_types.py @@ -19,7 +19,18 @@ class OutputValidator(Protocol, Generic[AgentOutputCov]): - def __call__(self, data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutputCov: ... + def __call__(self, data: dict[str, Any], partial: bool) -> AgentOutputCov: + """A way to convert a json object into an AgentOutput + + Args: + data (dict[str, Any]): The json object to convert + partial (bool): Whether the json is partial, meaning that + it may not contain all the fields required by the AgentOutput model. + + Returns: + AgentOutputCov: The converted AgentOutput + """ + ... class VersionRunParams(TypedDict): diff --git a/workflowai/core/client/_fn_utils.py b/workflowai/core/client/_fn_utils.py index 97495f9..3930669 100644 --- a/workflowai/core/client/_fn_utils.py +++ b/workflowai/core/client/_fn_utils.py @@ -23,7 +23,7 @@ RunParams, RunTemplate, ) -from workflowai.core.client._utils import intolerant_validator +from workflowai.core.client._utils import default_validator from workflowai.core.client.agent import Agent from workflowai.core.domain.errors import InvalidGenerationError from workflowai.core.domain.model import ModelOrStr @@ -144,7 +144,7 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp except InvalidGenerationError as e: if e.partial_output and e.run_id: try: - validator, _ = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls)) + validator, _ = self._sanitize_validator(kwargs, default_validator(self.output_cls)) run = self._build_run_no_tools( chunk=RunResponse( id=e.run_id, @@ -152,6 +152,7 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp ), schema_id=self.schema_id or 0, validator=validator, + partial=False, ) run.error = e.error return run diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 8cf1bf1..5b61628 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -134,12 +134,18 @@ def to_domain( task_id: str, task_schema_id: int, validator: OutputValidator[AgentOutput], + partial: Optional[bool] = None, ) -> Run[AgentOutput]: + # We do partial validation if either: + # - there are tool call requests, which means that the output can be empty + # - the run has not yet finished, for exmaple when streaming, in which case the duration_seconds is None + if partial is None: + partial = bool(self.tool_call_requests) or self.duration_seconds is None return Run( id=self.id, agent_id=task_id, schema_id=task_schema_id, - output=validator(self.task_output, self.tool_call_requests is not None), + output=validator(self.task_output, partial), version=self.version and self.version.to_domain(), duration_seconds=self.duration_seconds, cost_usd=self.cost_usd, diff --git a/workflowai/core/client/_models_test.py b/workflowai/core/client/_models_test.py index e12bdbf..2363fa7 100644 --- a/workflowai/core/client/_models_test.py +++ b/workflowai/core/client/_models_test.py @@ -5,7 +5,7 @@ from tests.utils import fixture_text from workflowai.core.client._models import RunResponse -from workflowai.core.client._utils import intolerant_validator, tolerant_validator +from workflowai.core.client._utils import default_validator from workflowai.core.domain.run import Run from workflowai.core.domain.tool_call import ToolCallRequest @@ -41,7 +41,7 @@ def test_no_version_not_optional(self): with pytest.raises(ValidationError): # sanity _TaskOutput.model_validate({"a": 1}) - parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput)) + parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput)) assert isinstance(parsed, Run) assert parsed.output.a == 1 # b is not defined @@ -52,18 +52,19 @@ def test_no_version_optional(self): chunk = RunResponse.model_validate_json('{"id": "1", "task_output": {"a": 1}}') assert chunk - parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutputOpt)) + parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutputOpt)) assert isinstance(parsed, Run) assert parsed.output.a == 1 assert parsed.output.b is None def test_with_version(self): + """Full output is validated since the duration is passed and there are no tool calls""" chunk = RunResponse.model_validate_json( '{"id": "1", "task_output": {"a": 1, "b": "test"}, "cost_usd": 0.1, "duration_seconds": 1, "version": {"properties": {"a": 1, "b": "test"}}}', # noqa: E501 ) assert chunk - parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput)) + parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput)) assert isinstance(parsed, Run) assert parsed.output.a == 1 assert parsed.output.b == "test" @@ -73,17 +74,19 @@ def test_with_version(self): def test_with_version_validation_fails(self): chunk = RunResponse.model_validate_json( - '{"id": "1", "task_output": {"a": 1}, "version": {"properties": {"a": 1, "b": "test"}}}', + """{"id": "1", "task_output": {"a": 1}, + "version": {"properties": {"a": 1, "b": "test"}}, "duration_seconds": 1}""", ) with pytest.raises(ValidationError): - chunk.to_domain(task_id="1", task_schema_id=1, validator=intolerant_validator(_TaskOutput)) + chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput)) def test_with_tool_calls(self): chunk = RunResponse.model_validate_json( - '{"id": "1", "task_output": {}, "tool_call_requests": [{"id": "1", "name": "test", "input": {"a": 1}}]}', + """{"id": "1", "task_output": {}, + "tool_call_requests": [{"id": "1", "name": "test", "input": {"a": 1}}], "duration_seconds": 1}""", ) assert chunk - parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput)) + parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput)) assert isinstance(parsed, Run) assert parsed.tool_call_requests == [ToolCallRequest(id="1", name="test", input={"a": 1})] diff --git a/workflowai/core/client/_utils.py b/workflowai/core/client/_utils.py index 12bcb61..8390727 100644 --- a/workflowai/core/client/_utils.py +++ b/workflowai/core/client/_utils.py @@ -87,18 +87,15 @@ async def _wait_for_exception(e: WorkflowAIError): return _should_retry, _wait_for_exception -def tolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: - def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput: # noqa: ARG001 - return construct_model_recursive(m, data) - - return _validator - - -def intolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: - def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput: +def default_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: + def _validator(data: dict[str, Any], partial: bool) -> AgentOutput: # When we have tool call requests, the output can be empty - if has_tool_call_requests: - return tolerant_validator(m)(data, has_tool_call_requests) + if partial: + try: + return construct_model_recursive(m, data) + except Exception: # noqa: BLE001 + logger.warning("Failed to validate partial data: %s", data) + return m.model_construct(None, **data) return m.model_validate(data) diff --git a/workflowai/core/client/_utils_test.py b/workflowai/core/client/_utils_test.py index dcf4f79..3c76bd3 100644 --- a/workflowai/core/client/_utils_test.py +++ b/workflowai/core/client/_utils_test.py @@ -6,9 +6,9 @@ from workflowai.core.client._utils import ( build_retryable_wait, + default_validator, global_default_version_reference, split_chunks, - tolerant_validator, ) from workflowai.core.domain.errors import BaseError, WorkflowAIError @@ -59,13 +59,13 @@ class Ingredient(BaseModel): ingredients: list[Ingredient] -class TestTolerantValidator: +class TestValidator: def test_tolerant_validator_nested_object(self): - validated = tolerant_validator(Recipe)( + validated = default_validator(Recipe)( { "ingredients": [{"name": "salt"}], }, - has_tool_call_requests=False, + partial=True, ) for ingredient in validated.ingredients: assert isinstance(ingredient, Recipe.Ingredient) diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index dcfefb2..ef96aa6 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -22,9 +22,8 @@ from workflowai.core.client._types import RunParams from workflowai.core.client._utils import ( build_retryable_wait, + default_validator, global_default_version_reference, - intolerant_validator, - tolerant_validator, ) from workflowai.core.domain.completion import Completion from workflowai.core.domain.errors import BaseError, WorkflowAIError @@ -295,8 +294,9 @@ def _build_run_no_tools( chunk: RunResponse, schema_id: int, validator: OutputValidator[AgentOutput], + partial: Optional[bool] = None, ) -> Run[AgentOutput]: - run = chunk.to_domain(self.agent_id, schema_id, validator) + run = chunk.to_domain(self.agent_id, schema_id, validator, partial) run._agent = self # pyright: ignore [reportPrivateUsage] return run @@ -362,7 +362,7 @@ async def run( Run[AgentOutput]: The task run object. """ prepared_run = await self._prepare_run(agent_input, stream=False, **kwargs) - validator, new_kwargs = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls)) + validator, new_kwargs = self._sanitize_validator(kwargs, default_validator(self.output_cls)) last_error = None while prepared_run.should_retry(): @@ -374,7 +374,6 @@ async def run( validator, current_iteration=0, # TODO[test]: add test with custom validator - # We popped validator above **new_kwargs, ) except WorkflowAIError as e: # noqa: PERF203 @@ -419,10 +418,11 @@ async def stream( AsyncIterator[Run[AgentOutput]]: An async iterator yielding task run objects. """ prepared_run = await self._prepare_run(agent_input, stream=True, **kwargs) - validator, new_kwargs = self._sanitize_validator(kwargs, tolerant_validator(self.output_cls)) + validator, new_kwargs = self._sanitize_validator(kwargs, default_validator(self.output_cls)) while prepared_run.should_retry(): try: + chunk: Optional[RunResponse] = None async for chunk in self.api.stream( method="POST", path=prepared_run.route, @@ -462,7 +462,7 @@ async def reply( """ prepared_run = await self._prepare_reply(run_id, user_message, tool_results, stream=False, **kwargs) - validator, new_kwargs = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls)) + validator, new_kwargs = self._sanitize_validator(kwargs, default_validator(self.output_cls)) res = await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True) return await self._build_run( From d794f3a7dfc7adec78eafa17e8169c2041521a6d Mon Sep 17 00:00:00 2001 From: guillaq Date: Wed, 5 Mar 2025 08:46:05 -0500 Subject: [PATCH 4/8] Fix typo in _models Co-authored-by: Yann BURY --- workflowai/core/client/_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 5b61628..ac888ba 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -138,7 +138,7 @@ def to_domain( ) -> Run[AgentOutput]: # We do partial validation if either: # - there are tool call requests, which means that the output can be empty - # - the run has not yet finished, for exmaple when streaming, in which case the duration_seconds is None + # - the run has not yet finished, for example when streaming, in which case the duration_seconds is None if partial is None: partial = bool(self.tool_call_requests) or self.duration_seconds is None return Run( From 27d907098d40583817095776b83a4a0a03a7e57d Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 5 Mar 2025 10:33:40 -0500 Subject: [PATCH 5/8] feat: use full partial model validation --- tests/integration/run_test.py | 20 ++-- workflowai/core/client/_api.py | 6 +- workflowai/core/client/_models_test.py | 4 +- workflowai/core/client/_utils.py | 15 +-- workflowai/core/client/agent.py | 45 ++++++-- workflowai/core/client/agent_test.py | 4 +- workflowai/core/utils/_pydantic.py | 143 ++++++++++++++---------- workflowai/core/utils/_pydantic_test.py | 106 ++++++++++++------ 8 files changed, 206 insertions(+), 137 deletions(-) diff --git a/tests/integration/run_test.py b/tests/integration/run_test.py index 73d2fe0..f0944c9 100644 --- a/tests/integration/run_test.py +++ b/tests/integration/run_test.py @@ -77,11 +77,11 @@ def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToC task_input = CityToCapitalTaskInput(city="Hello") chunks = [chunk async for chunk in city_to_capital(task_input)] - assert chunks == [ - CityToCapitalTaskOutput(capital=""), - CityToCapitalTaskOutput(capital="Tok"), - CityToCapitalTaskOutput(capital="Tokyo"), - CityToCapitalTaskOutput(capital="Tokyo"), + assert [chunk.capital for chunk in chunks] == [ + "", + "Tok", + "Tokyo", + "Tokyo", ] @@ -94,11 +94,11 @@ def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToC task_input = CityToCapitalTaskInput(city="Hello") chunks = [chunk async for chunk in city_to_capital(task_input)] - assert chunks == [ - CityToCapitalTaskOutput(capital=""), - CityToCapitalTaskOutput(capital="Tok"), - CityToCapitalTaskOutput(capital="Tokyo"), - CityToCapitalTaskOutput(capital="Tokyo"), + assert [chunk.capital for chunk in chunks] == [ + "", + "Tok", + "Tokyo", + "Tokyo", ] diff --git a/workflowai/core/client/_api.py b/workflowai/core/client/_api.py index 5f09641..8228ab5 100644 --- a/workflowai/core/client/_api.py +++ b/workflowai/core/client/_api.py @@ -1,4 +1,3 @@ -import logging from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any, Literal, Optional, TypeVar, Union, overload @@ -6,14 +5,13 @@ import httpx from pydantic import BaseModel, TypeAdapter, ValidationError +from workflowai.core._logger import logger from workflowai.core.domain.errors import BaseError, ErrorResponse, WorkflowAIError # A type for return values _R = TypeVar("_R") _M = TypeVar("_M", bound=BaseModel) -_logger = logging.getLogger("WorkflowAI") - class APIClient: def __init__(self, url: str, api_key: str, source_headers: Optional[dict[str, str]] = None): @@ -154,7 +152,7 @@ async def _wrap_sse(self, raw: AsyncIterator[bytes], termination_chars: bytes = in_data = False if data: - _logger.warning("Data left after processing", extra={"data": data}) + logger.warning("Data left after processing", extra={"data": data}) async def stream( self, diff --git a/workflowai/core/client/_models_test.py b/workflowai/core/client/_models_test.py index 2363fa7..661459c 100644 --- a/workflowai/core/client/_models_test.py +++ b/workflowai/core/client/_models_test.py @@ -45,8 +45,8 @@ def test_no_version_not_optional(self): assert isinstance(parsed, Run) assert parsed.output.a == 1 # b is not defined - with pytest.raises(AttributeError): - assert parsed.output.b + + assert parsed.output.b == "" def test_no_version_optional(self): chunk = RunResponse.model_validate_json('{"id": "1", "task_output": {"a": 1}}') diff --git a/workflowai/core/client/_utils.py b/workflowai/core/client/_utils.py index 8390727..0894082 100644 --- a/workflowai/core/client/_utils.py +++ b/workflowai/core/client/_utils.py @@ -13,7 +13,7 @@ from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.version_reference import VersionReference -from workflowai.core.utils._pydantic import construct_model_recursive +from workflowai.core.utils._pydantic import partial_model delimiter = re.compile(r'\}\n\ndata: \{"') @@ -88,16 +88,11 @@ async def _wait_for_exception(e: WorkflowAIError): def default_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: - def _validator(data: dict[str, Any], partial: bool) -> AgentOutput: - # When we have tool call requests, the output can be empty - if partial: - try: - return construct_model_recursive(m, data) - except Exception: # noqa: BLE001 - logger.warning("Failed to validate partial data: %s", data) - return m.model_construct(None, **data) + partial_cls = partial_model(m) - return m.model_validate(data) + def _validator(data: dict[str, Any], partial: bool) -> AgentOutput: + model_cls = partial_cls if partial else m + return model_cls.model_validate(data) return _validator diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index ef96aa6..f6d5db1 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -3,7 +3,7 @@ from collections.abc import Awaitable, Callable, Iterable from typing import Any, Generic, NamedTuple, Optional, Union, cast -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from typing_extensions import Unpack from workflowai.core._common_types import BaseRunParams, OutputValidator, VersionRunParams @@ -102,6 +102,8 @@ def __init__( self._api = (lambda: api) if isinstance(api, APIClient) else api self._tools = self.build_tools(tools) if tools else None + self._default_validator = default_validator(output_cls) + @classmethod def build_tools(cls, tools: Iterable[Callable[..., Any]]): # TODO: we should be more tolerant with errors ? @@ -362,7 +364,7 @@ async def run( Run[AgentOutput]: The task run object. """ prepared_run = await self._prepare_run(agent_input, stream=False, **kwargs) - validator, new_kwargs = self._sanitize_validator(kwargs, default_validator(self.output_cls)) + validator, new_kwargs = self._sanitize_validator(kwargs, self._default_validator) last_error = None while prepared_run.should_retry(): @@ -418,10 +420,12 @@ async def stream( AsyncIterator[Run[AgentOutput]]: An async iterator yielding task run objects. """ prepared_run = await self._prepare_run(agent_input, stream=True, **kwargs) - validator, new_kwargs = self._sanitize_validator(kwargs, default_validator(self.output_cls)) + validator, new_kwargs = self._sanitize_validator(kwargs, self._default_validator) while prepared_run.should_retry(): try: + # Will store the error if the final payload fails to validate + final_error: Optional[Exception] = None chunk: Optional[RunResponse] = None async for chunk in self.api.stream( method="POST", @@ -430,15 +434,32 @@ async def stream( returns=RunResponse, run=True, ): - yield await self._build_run( - chunk, - prepared_run.schema_id, - validator, - current_iteration=0, - **new_kwargs, - ) + try: + yield await self._build_run( + chunk, + prepared_run.schema_id, + validator, + current_iteration=0, + **new_kwargs, + ) + final_error = None + except ValidationError as e: + logger.debug("Validation error in stream", exc_info=e) + final_error = e + continue + if final_error: + raise WorkflowAIError( + error=BaseError( + message="Client side validation error in stream. This should not " + "happen is the payload is already validated by the server. This probably" + "means that there is an issue with the validator or the model.", + ), + response=None, + partial_output=chunk.task_output if chunk else None, + run_id=chunk.id if chunk else None, + ) from final_error return - except WorkflowAIError as e: # noqa: PERF203 + except WorkflowAIError as e: await prepared_run.wait_for_exception(e) async def reply( @@ -462,7 +483,7 @@ async def reply( """ prepared_run = await self._prepare_reply(run_id, user_message, tool_results, stream=False, **kwargs) - validator, new_kwargs = self._sanitize_validator(kwargs, default_validator(self.output_cls)) + validator, new_kwargs = self._sanitize_validator(kwargs, self._default_validator) res = await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True) return await self._build_run( diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index 8b4f4ab..864c6d2 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -171,9 +171,7 @@ async def test_stream_not_optional( assert messages == ["", "hel", "hello", "hello"] for chunk in chunks[:-1]: - with pytest.raises(AttributeError): - # Since the field is not optional, it will raise an attribute error - assert chunk.output.another_field + assert chunk.output.another_field == "" assert chunks[-1].output.another_field == "test" last_message = chunks[-1] diff --git a/workflowai/core/utils/_pydantic.py b/workflowai/core/utils/_pydantic.py index c57d6b5..f8c733f 100644 --- a/workflowai/core/utils/_pydantic.py +++ b/workflowai/core/utils/_pydantic.py @@ -1,79 +1,104 @@ -from collections.abc import Collection, Mapping -from typing import Any, Optional, get_args, get_origin +from collections.abc import Mapping, Sequence +from typing import Any, TypeVar, get_args, get_origin -from pydantic import BaseModel +from pydantic import BaseModel, Field, create_model +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined +from typing_extensions import TypeGuard +from workflowai.core._logger import logger from workflowai.core.utils._vars import BM +_T = TypeVar("_T") -def _safe_issubclass(cls: type[Any], other: type[Any]) -> bool: + +def _safe_issubclass(val: type[_T], cls: type[_T]) -> TypeGuard[type[_T]]: try: - return issubclass(cls, other) + return issubclass(val, cls) except TypeError: return False -def _construct_list(annotation: type[Any], payload: list[Any]) -> Collection[Any]: - origin = get_origin(annotation) - # If the type annotation is not a collection, the we can't handle it so we return as is - if not origin or not _safe_issubclass(origin, Collection): - return payload +def _copy_field_info(field_info: FieldInfo, **overrides: Any): + """ + Return a copy of a pydantic FieldInfo object, allow to override + certain values. + """ - args = get_args(annotation) - if not args: - return payload + kwargs = overrides + for k, v in field_info.__repr_args__(): + if k in kwargs or not k: + continue + kwargs[k] = v - constructor = set if _safe_issubclass(origin, set) else list + return Field(**kwargs) - if len(args) == 1 and _safe_issubclass(args[0], BaseModel): - return constructor([construct_model_recursive(args[0], item) for item in payload]) # pyright: ignore [reportUnknownVariableType, reportCallIssue] - return payload +def _default_value_from_annotation(annotation: type[Any]) -> Any: + try: + # Trying to see if the object is instantiable with no value + return annotation() + except Exception: # noqa: BLE001 + logger.debug("Failed to get default value from annotation", exc_info=True) + # Fallback to None + return None -def _construct_object(annotation: type[Any], payload: dict[str, Any]) -> Any: +def _optional_annotation(annotation: type[Any]) -> type[Any]: if _safe_issubclass(annotation, BaseModel): - return construct_model_recursive(annotation, payload) # pyright: ignore [reportUnknownVariableType, reportArgumentType] + return partial_model(annotation) - # Try to map dict of objects origin = get_origin(annotation) - if not origin or not _safe_issubclass(origin, Mapping): - return payload - args = get_args(annotation) - if len(args) != 2: - return payload - - key_type, value_type = args - if key_type is not str: - return payload - return {k: _construct_for_annotation(value_type, v) for k, v in payload.items()} - - -def _construct_for_annotation(annotation: Optional[type[Any]], payload: Any) -> Any: - if annotation is None: - return payload - - if isinstance(payload, dict): - return _construct_object(annotation, payload) # pyright: ignore [reportUnknownArgumentType] - if isinstance(payload, list): - return _construct_list(annotation, payload) # pyright: ignore [reportUnknownArgumentType] - - return payload - - -# It does not look like there is an easy way to construct models from partial json objects -# - https://github.com/team23/pydantic-partial uses a heavy approach by constructing a new dynamic -# model class with non required fields -# - partial validation https://docs.pydantic.dev/latest/concepts/experimental/#partial-validation -# handles partial jsons but still validates that each field is present so it fails in our case -# where we just want to handle missing fields -def construct_model_recursive(model: type[BM], payload: dict[str, Any]) -> BM: - """ - Recursively calls model construct to build a model from partial json object - """ - mapped: dict[str, Any] = {} - for k, v in payload.items(): - field = model.model_fields[k] - mapped[k] = _construct_for_annotation(field.annotation, v) - return model.model_construct(None, **mapped) + if not origin or not args: + return annotation + + if _safe_issubclass(origin, Sequence) or _safe_issubclass(origin, set): + if not len(args) == 1: + raise ValueError("Sequence must have exactly one argument") + return origin[_optional_annotation(args[0])] + # No need to do anything here ? + + if _safe_issubclass(origin, Mapping): + if not len(args) == 2: + raise ValueError("Mapping must have exactly two arguments") + if args[0] is not str: + raise ValueError("Mapping key must be a string") + + return origin[args[0], _optional_annotation(args[1])] + return annotation + + +def partial_model(base: type[BM]) -> type[BM]: + default_fields: dict[str, tuple[type[Any], FieldInfo]] = {} + for name, field in base.model_fields.items(): + if field.default != PydanticUndefined or field.default_factory or not field.annotation: + # No need to do anything here, the field is already optional + continue + + overrides: dict[str, Any] = {} + try: + annotation = _optional_annotation(field.annotation) + overrides["annotation"] = annotation + overrides["default"] = _default_value_from_annotation(annotation) + except Exception: # noqa: BLE001 + logger.debug("Failed to make annotation optional", exc_info=True) + continue + + default_fields[name] = (annotation, _copy_field_info(field, **overrides)) + + if not default_fields: + return base + + def custom_eq(o1: BM, o2: Any): + if not isinstance(o2, base): + return False + return o1.model_dump() == o2.model_dump() + + return create_model( # pyright: ignore [reportCallIssue, reportUnknownVariableType] + f"Partial{base.__name__}", + __base__=base, + __eq__=custom_eq, + __hash__=base.__hash__, + **default_fields, # pyright: ignore [reportArgumentType] + ) diff --git a/workflowai/core/utils/_pydantic_test.py b/workflowai/core/utils/_pydantic_test.py index dcb78e3..858a607 100644 --- a/workflowai/core/utils/_pydantic_test.py +++ b/workflowai/core/utils/_pydantic_test.py @@ -1,31 +1,69 @@ -from typing import Any +from typing import Any, Optional import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field, ValidationError -from workflowai.core.utils._pydantic import construct_model_recursive +from workflowai.core.utils._pydantic import partial_model -class TestConstructModelRecursive: +class TestPartialModel: + def test_partial_model_equals(self): + class SimpleModel(BaseModel): + name: str + + partial = partial_model(SimpleModel) + assert partial.model_validate({"name": "John"}) == SimpleModel(name="John") + + assert SimpleModel(name="John") == partial.model_validate({"name": "John"}) + def test_simple_model(self): class SimpleModel(BaseModel): name1: str name2: str + name3: int + name4: float + bla: dict[str, Any] + opt: Optional[str] - constructed = construct_model_recursive(SimpleModel, {"name1": "John"}) + constructed = partial_model(SimpleModel).model_validate({"name1": "John"}) assert isinstance(constructed, SimpleModel) assert constructed.name1 == "John" - with pytest.raises(AttributeError): - _ = constructed.name2 + assert constructed.name2 == "" + assert constructed.name3 == 0 + assert constructed.name4 == 0.0 + assert constructed.bla == {} + assert constructed.opt is None + + # Check that we do not raise on an empty payload + partial_model(SimpleModel).model_validate({}) + + # Check that we do raise when a type is wrong + with pytest.raises(ValidationError): + partial_model(SimpleModel).model_validate({"name1": 1, "name2": "2"}) + + def test_with_some_optional_fields(self): + class SomeOptionalFields(BaseModel): + name1: str + name2: str = "blibly" + list1: list[str] = Field(default_factory=lambda: ["1"]) + + constructed = partial_model(SomeOptionalFields).model_validate({}) + assert isinstance(constructed, SomeOptionalFields) + assert constructed.name1 == "" + assert constructed.name2 == "blibly" + assert constructed.list1 == ["1"] def test_list_of_strings(self): class ListOfStrings(BaseModel): strings: list[str] - constructed = construct_model_recursive(ListOfStrings, {"strings": ["a", "b"]}) + constructed = partial_model(ListOfStrings).model_validate({"strings": ["a", "b"]}) assert isinstance(constructed, ListOfStrings) assert constructed.strings == ["a", "b"] + # Check that we do not raise on an empty payload + partial_model(ListOfStrings).model_validate({}) + @pytest.mark.parametrize( "payload", [ @@ -44,30 +82,13 @@ class OuterModel(BaseModel): field1: str nested: NestedModel - constructed = construct_model_recursive(OuterModel, payload) + constructed = partial_model(OuterModel).model_validate(payload) assert isinstance(constructed, OuterModel), "constructed is not an instance of OuterModel" - if "field1" in payload: - assert isinstance(constructed.field1, str), "field1 is not a string" - else: - with pytest.raises(AttributeError): - _ = constructed.field1 - - if "nested" in payload: - assert isinstance(constructed.nested, NestedModel), "nested is not an instance of NestedModel" - if "name" in payload["nested"]: - assert isinstance(constructed.nested.name, str) - else: - with pytest.raises(AttributeError): - _ = constructed.nested.name - - if "field2" in payload["nested"]: - assert isinstance(constructed.nested.field2, str) - else: - with pytest.raises(AttributeError): - _ = constructed.nested.field2 - else: - with pytest.raises(AttributeError): - _ = constructed.nested + assert constructed.field1 == payload.get("field1", "") + assert isinstance(constructed.nested, NestedModel), "nested is not an instance of NestedModel" + + assert constructed.nested.name == payload.get("nested", {}).get("name", "") + assert constructed.nested.field2 == payload.get("nested", {}).get("field2", "") def test_list_of_models(self): class NestedModel(BaseModel): @@ -77,8 +98,7 @@ class NestedModel(BaseModel): class ListOfModels(BaseModel): models: list[NestedModel] - constructed = construct_model_recursive( - ListOfModels, + constructed = partial_model(ListOfModels).model_validate( {"models": [{"name": "hello", "field2": "world"}, {"name": "hello"}]}, ) assert isinstance(constructed, ListOfModels) @@ -97,8 +117,7 @@ def __hash__(self) -> int: class SetOfModels(BaseModel): models: set[NestedModel] - constructed = construct_model_recursive( - SetOfModels, + constructed = partial_model(SetOfModels).model_validate( {"models": [{"name": "hello", "field2": "world"}, {"name": "hello"}]}, ) assert isinstance(constructed, SetOfModels) @@ -113,11 +132,24 @@ class NestedModel(BaseModel): class DictOfModels(BaseModel): models: dict[str, NestedModel] - constructed = construct_model_recursive( - DictOfModels, + constructed = partial_model(DictOfModels).model_validate( {"models": {"hello": {"name": "hello", "field2": "world"}, "hello2": {"name": "hello"}}}, ) assert isinstance(constructed, DictOfModels) assert isinstance(constructed.models, dict) assert isinstance(constructed.models["hello"], NestedModel) assert isinstance(constructed.models["hello2"], NestedModel) + + def test_with_aliases(self): + class AliasModel(BaseModel): + message: str = Field(alias="message_alias") + aliased_ser: str = Field(serialization_alias="aliased_ser_alias") + aliased_val: str = Field(validation_alias="aliased_val_alias") + + partial = partial_model(AliasModel) + + payload = {"message_alias": "hello", "aliased_ser": "world", "aliased_val_alias": "!"} + v1 = AliasModel.model_validate(payload) + v2 = partial.model_validate(payload) + + assert v1 == v2 From 6624df1280bf36f64e588d195ba2841677146e21 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 5 Mar 2025 10:35:12 -0500 Subject: [PATCH 6/8] chore: group TestListModels --- workflowai/core/client/agent_test.py | 1057 +++++++++++++------------- 1 file changed, 527 insertions(+), 530 deletions(-) diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index 864c6d2..8904dc3 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -447,558 +447,555 @@ def test_only_model_privider(self, agent: Agent[HelloTaskInput, HelloTaskOutput] } -@pytest.mark.asyncio -async def test_list_models(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock): - """Test that list_models correctly fetches and returns available models.""" - # Mock the HTTP response instead of the API client method - httpx_mock.add_response( - url="http://localhost:8000/v1/_/agents/123/schemas/1/models", - json={ - "items": [ - { - "id": "gpt-4", - "name": "GPT-4", - "icon_url": "https://example.com/gpt4.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.01, - "is_latest": True, - "metadata": { - "provider_name": "OpenAI", - "price_per_input_token_usd": 0.0001, - "price_per_output_token_usd": 0.0002, - "release_date": "2024-01-01", - "context_window_tokens": 128000, - "quality_index": 0.95, +class TestListModels: + async def test_list_models(self, agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], }, - "is_default": True, - "providers": ["openai"], - }, - { - "id": "claude-3", - "name": "Claude 3", - "icon_url": "https://example.com/claude3.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.015, - "is_latest": True, - "metadata": { - "provider_name": "Anthropic", - "price_per_input_token_usd": 0.00015, - "price_per_output_token_usd": 0.00025, - "release_date": "2024-03-01", - "context_window_tokens": 200000, - "quality_index": 0.98, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], }, - "is_default": False, - "providers": ["anthropic"], - }, - ], - "count": 2, - }, - ) + ], + "count": 2, + }, + ) - # Call the method - models = await agent.list_models() - - # Verify the HTTP request was made correctly - request = httpx_mock.get_request() - assert request is not None, "Expected an HTTP request to be made" - assert request.method == "POST" - assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" - assert json.loads(request.content) == {} - - # Verify we get back the full ModelInfo objects - assert len(models) == 2 - assert isinstance(models[0], ModelInfo) - assert models[0].id == "gpt-4" - assert models[0].name == "GPT-4" - assert models[0].modes == ["chat"] - assert models[0].metadata is not None - assert models[0].metadata.provider_name == "OpenAI" - - assert isinstance(models[1], ModelInfo) - assert models[1].id == "claude-3" - assert models[1].name == "Claude 3" - assert models[1].modes == ["chat"] - assert models[1].metadata is not None - assert models[1].metadata.provider_name == "Anthropic" - - -@pytest.mark.asyncio -async def test_list_models_with_params_override(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock): - """Test that list_models correctly fetches and returns available models.""" - # Mock the HTTP response instead of the API client method - httpx_mock.add_response( - url="http://localhost:8000/v1/_/agents/123/schemas/1/models", - json={ - "items": [ - { - "id": "gpt-4", - "name": "GPT-4", - "icon_url": "https://example.com/gpt4.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.01, - "is_latest": True, - "metadata": { - "provider_name": "OpenAI", - "price_per_input_token_usd": 0.0001, - "price_per_output_token_usd": 0.0002, - "release_date": "2024-01-01", - "context_window_tokens": 128000, - "quality_index": 0.95, + # Call the method + models = await agent.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + async def test_list_models_with_params_override( + self, + agent: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + ): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], }, - "is_default": True, - "providers": ["openai"], - }, - { - "id": "claude-3", - "name": "Claude 3", - "icon_url": "https://example.com/claude3.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.015, - "is_latest": True, - "metadata": { - "provider_name": "Anthropic", - "price_per_input_token_usd": 0.00015, - "price_per_output_token_usd": 0.00025, - "release_date": "2024-03-01", - "context_window_tokens": 200000, - "quality_index": 0.98, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], }, - "is_default": False, - "providers": ["anthropic"], - }, - ], - "count": 2, - }, - ) + ], + "count": 2, + }, + ) + + # Call the method + models = await agent.list_models(instructions="Some override instructions", requires_tools=True) + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == { + "instructions": "Some override instructions", + "requires_tools": True, + } - # Call the method - models = await agent.list_models(instructions="Some override instructions", requires_tools=True) - - # Verify the HTTP request was made correctly - request = httpx_mock.get_request() - assert request is not None, "Expected an HTTP request to be made" - assert request.method == "POST" - assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" - assert json.loads(request.content) == { - "instructions": "Some override instructions", - "requires_tools": True, - } - - # Verify we get back the full ModelInfo objects - assert len(models) == 2 - assert isinstance(models[0], ModelInfo) - assert models[0].id == "gpt-4" - assert models[0].name == "GPT-4" - assert models[0].modes == ["chat"] - assert models[0].metadata is not None - assert models[0].metadata.provider_name == "OpenAI" - - assert isinstance(models[1], ModelInfo) - assert models[1].id == "claude-3" - assert models[1].name == "Claude 3" - assert models[1].modes == ["chat"] - assert models[1].metadata is not None - assert models[1].metadata.provider_name == "Anthropic" - - -@pytest.mark.asyncio -async def test_list_models_with_params_override_and_agent_with_tools_and_instructions( - agent_with_tools_and_instructions: Agent[HelloTaskInput, HelloTaskOutput], - httpx_mock: HTTPXMock, -): - """Test that list_models correctly fetches and returns available models.""" - # Mock the HTTP response instead of the API client method - httpx_mock.add_response( - url="http://localhost:8000/v1/_/agents/123/schemas/1/models", - json={ - "items": [ - { - "id": "gpt-4", - "name": "GPT-4", - "icon_url": "https://example.com/gpt4.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.01, - "is_latest": True, - "metadata": { - "provider_name": "OpenAI", - "price_per_input_token_usd": 0.0001, - "price_per_output_token_usd": 0.0002, - "release_date": "2024-01-01", - "context_window_tokens": 128000, - "quality_index": 0.95, + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + async def test_list_models_with_params_override_and_agent_with_tools_and_instructions( + self, + agent_with_tools_and_instructions: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + ): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], }, - "is_default": True, - "providers": ["openai"], - }, - { - "id": "claude-3", - "name": "Claude 3", - "icon_url": "https://example.com/claude3.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.015, - "is_latest": True, - "metadata": { - "provider_name": "Anthropic", - "price_per_input_token_usd": 0.00015, - "price_per_output_token_usd": 0.00025, - "release_date": "2024-03-01", - "context_window_tokens": 200000, - "quality_index": 0.98, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], }, - "is_default": False, - "providers": ["anthropic"], - }, - ], - "count": 2, - }, - ) + ], + "count": 2, + }, + ) - # Call the method - models = await agent_with_tools_and_instructions.list_models( - instructions="Some override instructions", - requires_tools=False, - ) + # Call the method + models = await agent_with_tools_and_instructions.list_models( + instructions="Some override instructions", + requires_tools=False, + ) - # Verify the HTTP request was made correctly - request = httpx_mock.get_request() - assert request is not None, "Expected an HTTP request to be made" - assert request.method == "POST" - assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" - assert json.loads(request.content) == { - "instructions": "Some override instructions", - "requires_tools": False, - } - - # Verify we get back the full ModelInfo objects - assert len(models) == 2 - assert isinstance(models[0], ModelInfo) - assert models[0].id == "gpt-4" - assert models[0].name == "GPT-4" - assert models[0].modes == ["chat"] - assert models[0].metadata is not None - assert models[0].metadata.provider_name == "OpenAI" - - assert isinstance(models[1], ModelInfo) - assert models[1].id == "claude-3" - assert models[1].name == "Claude 3" - assert models[1].modes == ["chat"] - assert models[1].metadata is not None - assert models[1].metadata.provider_name == "Anthropic" - - -@pytest.mark.asyncio -async def test_list_models_registers_if_needed( - agent_no_schema: Agent[HelloTaskInput, HelloTaskOutput], - httpx_mock: HTTPXMock, -): - """Test that list_models registers the agent if it hasn't been registered yet.""" - # Mock the registration response - httpx_mock.add_response( - url="http://localhost:8000/v1/_/agents", - json={"id": "123", "schema_id": 2}, - ) + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == { + "instructions": "Some override instructions", + "requires_tools": False, + } + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + async def test_list_models_registers_if_needed( + self, + agent_no_schema: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + ): + """Test that list_models registers the agent if it hasn't been registered yet.""" + # Mock the registration response + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents", + json={"id": "123", "schema_id": 2}, + ) - # Mock the models response with the new structure - httpx_mock.add_response( - url="http://localhost:8000/v1/_/agents/123/schemas/2/models", - json={ - "items": [ - { - "id": "gpt-4", - "name": "GPT-4", - "icon_url": "https://example.com/gpt4.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.01, - "is_latest": True, - "metadata": { - "provider_name": "OpenAI", - "price_per_input_token_usd": 0.0001, - "price_per_output_token_usd": 0.0002, - "release_date": "2024-01-01", - "context_window_tokens": 128000, - "quality_index": 0.95, + # Mock the models response with the new structure + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/2/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], }, - "is_default": True, - "providers": ["openai"], - }, - ], - "count": 1, - }, - ) + ], + "count": 1, + }, + ) + + # Call the method + models = await agent_no_schema.list_models() - # Call the method - models = await agent_no_schema.list_models() - - # Verify both API calls were made - reqs = httpx_mock.get_requests() - assert len(reqs) == 2 - assert reqs[0].url == "http://localhost:8000/v1/_/agents" - assert reqs[1].method == "POST" - assert reqs[1].url == "http://localhost:8000/v1/_/agents/123/schemas/2/models" - assert json.loads(reqs[1].content) == {} - - # Verify we get back the full ModelInfo object - assert len(models) == 1 - assert isinstance(models[0], ModelInfo) - assert models[0].id == "gpt-4" - assert models[0].name == "GPT-4" - assert models[0].modes == ["chat"] - assert models[0].metadata is not None - assert models[0].metadata.provider_name == "OpenAI" - - -@pytest.mark.asyncio -async def test_list_models_with_instructions( - agent_with_instructions: Agent[HelloTaskInput, HelloTaskOutput], - httpx_mock: HTTPXMock, -): - """Test that list_models correctly fetches and returns available models.""" - # Mock the HTTP response instead of the API client method - httpx_mock.add_response( - method="POST", - url="http://localhost:8000/v1/_/agents/123/schemas/1/models", - json={ - "items": [ - { - "id": "gpt-4", - "name": "GPT-4", - "icon_url": "https://example.com/gpt4.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.01, - "is_latest": True, - "metadata": { - "provider_name": "OpenAI", - "price_per_input_token_usd": 0.0001, - "price_per_output_token_usd": 0.0002, - "release_date": "2024-01-01", - "context_window_tokens": 128000, - "quality_index": 0.95, + # Verify both API calls were made + reqs = httpx_mock.get_requests() + assert len(reqs) == 2 + assert reqs[0].url == "http://localhost:8000/v1/_/agents" + assert reqs[1].method == "POST" + assert reqs[1].url == "http://localhost:8000/v1/_/agents/123/schemas/2/models" + assert json.loads(reqs[1].content) == {} + + # Verify we get back the full ModelInfo object + assert len(models) == 1 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + async def test_list_models_with_instructions( + self, + agent_with_instructions: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + ): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + method="POST", + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], }, - "is_default": True, - "providers": ["openai"], - }, - { - "id": "claude-3", - "name": "Claude 3", - "icon_url": "https://example.com/claude3.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.015, - "is_latest": True, - "metadata": { - "provider_name": "Anthropic", - "price_per_input_token_usd": 0.00015, - "price_per_output_token_usd": 0.00025, - "release_date": "2024-03-01", - "context_window_tokens": 200000, - "quality_index": 0.98, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], }, - "is_default": False, - "providers": ["anthropic"], - }, - ], - "count": 2, - }, - ) + ], + "count": 2, + }, + ) - # Call the method - models = await agent_with_instructions.list_models() - - # Verify the HTTP request was made correctly - request = httpx_mock.get_request() - assert request is not None, "Expected an HTTP request to be made" - assert request.method == "POST" - assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" - assert json.loads(request.content) == {"instructions": "Some instructions"} - - # Verify we get back the full ModelInfo objects - assert len(models) == 2 - assert isinstance(models[0], ModelInfo) - assert models[0].id == "gpt-4" - assert models[0].name == "GPT-4" - assert models[0].modes == ["chat"] - assert models[0].metadata is not None - assert models[0].metadata.provider_name == "OpenAI" - - assert isinstance(models[1], ModelInfo) - assert models[1].id == "claude-3" - assert models[1].name == "Claude 3" - assert models[1].modes == ["chat"] - assert models[1].metadata is not None - assert models[1].metadata.provider_name == "Anthropic" - - -@pytest.mark.asyncio -async def test_list_models_with_tools( - agent_with_tools: Agent[HelloTaskInput, HelloTaskOutput], - httpx_mock: HTTPXMock, -): - """Test that list_models correctly fetches and returns available models.""" - # Mock the HTTP response instead of the API client method - httpx_mock.add_response( - method="POST", - url="http://localhost:8000/v1/_/agents/123/schemas/1/models", - json={ - "items": [ - { - "id": "gpt-4", - "name": "GPT-4", - "icon_url": "https://example.com/gpt4.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.01, - "is_latest": True, - "metadata": { - "provider_name": "OpenAI", - "price_per_input_token_usd": 0.0001, - "price_per_output_token_usd": 0.0002, - "release_date": "2024-01-01", - "context_window_tokens": 128000, - "quality_index": 0.95, + # Call the method + models = await agent_with_instructions.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {"instructions": "Some instructions"} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + async def test_list_models_with_tools( + self, + agent_with_tools: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + ): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + method="POST", + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], }, - "is_default": True, - "providers": ["openai"], - }, - { - "id": "claude-3", - "name": "Claude 3", - "icon_url": "https://example.com/claude3.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.015, - "is_latest": True, - "metadata": { - "provider_name": "Anthropic", - "price_per_input_token_usd": 0.00015, - "price_per_output_token_usd": 0.00025, - "release_date": "2024-03-01", - "context_window_tokens": 200000, - "quality_index": 0.98, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], }, - "is_default": False, - "providers": ["anthropic"], - }, - ], - "count": 2, - }, - ) + ], + "count": 2, + }, + ) - # Call the method - models = await agent_with_tools.list_models() - - # Verify the HTTP request was made correctly - request = httpx_mock.get_request() - assert request is not None, "Expected an HTTP request to be made" - assert request.method == "POST" - assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" - assert json.loads(request.content) == {"requires_tools": True} - - # Verify we get back the full ModelInfo objects - assert len(models) == 2 - assert isinstance(models[0], ModelInfo) - assert models[0].id == "gpt-4" - assert models[0].name == "GPT-4" - assert models[0].modes == ["chat"] - assert models[0].metadata is not None - assert models[0].metadata.provider_name == "OpenAI" - - assert isinstance(models[1], ModelInfo) - assert models[1].id == "claude-3" - assert models[1].name == "Claude 3" - assert models[1].modes == ["chat"] - assert models[1].metadata is not None - assert models[1].metadata.provider_name == "Anthropic" - - -@pytest.mark.asyncio -async def test_list_models_with_instructions_and_tools( - agent_with_tools_and_instructions: Agent[HelloTaskInput, HelloTaskOutput], - httpx_mock: HTTPXMock, -): - """Test that list_models correctly fetches and returns available models.""" - # Mock the HTTP response instead of the API client method - httpx_mock.add_response( - method="POST", - url="http://localhost:8000/v1/_/agents/123/schemas/1/models", - json={ - "items": [ - { - "id": "gpt-4", - "name": "GPT-4", - "icon_url": "https://example.com/gpt4.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.01, - "is_latest": True, - "metadata": { - "provider_name": "OpenAI", - "price_per_input_token_usd": 0.0001, - "price_per_output_token_usd": 0.0002, - "release_date": "2024-01-01", - "context_window_tokens": 128000, - "quality_index": 0.95, + # Call the method + models = await agent_with_tools.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {"requires_tools": True} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + async def test_list_models_with_instructions_and_tools( + self, + agent_with_tools_and_instructions: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + ): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + method="POST", + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], }, - "is_default": True, - "providers": ["openai"], - }, - { - "id": "claude-3", - "name": "Claude 3", - "icon_url": "https://example.com/claude3.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.015, - "is_latest": True, - "metadata": { - "provider_name": "Anthropic", - "price_per_input_token_usd": 0.00015, - "price_per_output_token_usd": 0.00025, - "release_date": "2024-03-01", - "context_window_tokens": 200000, - "quality_index": 0.98, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], }, - "is_default": False, - "providers": ["anthropic"], - }, - ], - "count": 2, - }, - ) + ], + "count": 2, + }, + ) - # Call the method - models = await agent_with_tools_and_instructions.list_models() - - # Verify the HTTP request was made correctly - request = httpx_mock.get_request() - assert request is not None, "Expected an HTTP request to be made" - assert request.method == "POST" - assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" - assert json.loads(request.content) == {"instructions": "Some instructions", "requires_tools": True} - - # Verify we get back the full ModelInfo objects - assert len(models) == 2 - assert isinstance(models[0], ModelInfo) - assert models[0].id == "gpt-4" - assert models[0].name == "GPT-4" - assert models[0].modes == ["chat"] - assert models[0].metadata is not None - assert models[0].metadata.provider_name == "OpenAI" - - assert isinstance(models[1], ModelInfo) - assert models[1].id == "claude-3" - assert models[1].name == "Claude 3" - assert models[1].modes == ["chat"] - assert models[1].metadata is not None - assert models[1].metadata.provider_name == "Anthropic" + # Call the method + models = await agent_with_tools_and_instructions.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {"instructions": "Some instructions", "requires_tools": True} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" class TestFetchCompletions: From 2e838c962f27d6b901996371b14c1e7305147717 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 5 Mar 2025 10:45:17 -0500 Subject: [PATCH 7/8] test: add test for validation errors during streaming --- workflowai/core/client/agent.py | 6 +- workflowai/core/client/agent_test.py | 178 +++++++++++++++++---------- 2 files changed, 121 insertions(+), 63 deletions(-) diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index f6d5db1..e0c29e4 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -444,7 +444,11 @@ async def stream( ) final_error = None except ValidationError as e: - logger.debug("Validation error in stream", exc_info=e) + logger.debug( + "Client side validation error in stream. There is likely an " + "issue with the validator or the model.", + exc_info=e, + ) final_error = e continue if final_error: diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index 8904dc3..8779b47 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -1,5 +1,6 @@ import importlib.metadata import json +import logging from unittest.mock import Mock, patch import httpx @@ -120,68 +121,6 @@ async def test_success(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, "stream": False, } - async def test_stream(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): - httpx_mock.add_response( - stream=IteratorStream( - [ - b'data: {"id":"1","task_output":{"message":""}}\n\n', - b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 - b'data: {"id":"1","task_output":{"message":"hello"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 - ], - ), - ) - - chunks = [chunk async for chunk in agent.stream(HelloTaskInput(name="Alice"))] - - outputs = [chunk.output for chunk in chunks] - assert outputs == [ - HelloTaskOutput(message=""), - HelloTaskOutput(message="hel"), - HelloTaskOutput(message="hello"), - HelloTaskOutput(message="hello"), - ] - last_message = chunks[-1] - assert isinstance(last_message, Run) - assert last_message.version - assert last_message.version.properties.model == "gpt-4o" - assert last_message.version.properties.temperature == 0.5 - assert last_message.cost_usd == 0.01 - assert last_message.duration_seconds == 10.1 - - async def test_stream_not_optional( - self, - httpx_mock: HTTPXMock, - agent_not_optional: Agent[HelloTaskInput, HelloTaskOutputNotOptional], - ): - # Checking that streaming works even with non optional fields - # The first two chunks are missing a required key but the last one has it - httpx_mock.add_response( - stream=IteratorStream( - [ - b'data: {"id":"1","task_output":{"message":""}}\n\n', - b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 - b'data: {"id":"1","task_output":{"message":"hello", "another_field": "test"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 - ], - ), - ) - - chunks = [chunk async for chunk in agent_not_optional.stream(HelloTaskInput(name="Alice"))] - - messages = [chunk.output.message for chunk in chunks] - assert messages == ["", "hel", "hello", "hello"] - - for chunk in chunks[:-1]: - assert chunk.output.another_field == "" - assert chunks[-1].output.another_field == "test" - - last_message = chunks[-1] - assert isinstance(last_message, Run) - assert last_message.version - assert last_message.version.properties.model == "gpt-4o" - assert last_message.version.properties.temperature == 0.5 - assert last_message.cost_usd == 0.01 - assert last_message.duration_seconds == 10.1 - async def test_run_with_env(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): httpx_mock.add_response(json=fixtures_json("task_run.json")) @@ -1024,3 +963,118 @@ async def test_fetch_completions(self, agent: Agent[HelloTaskInput, HelloTaskOut ), ), ] + + +class TestStream: + async def test_stream(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): + httpx_mock.add_response( + stream=IteratorStream( + [ + b'data: {"id":"1","task_output":{"message":""}}\n\n', + b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 + b'data: {"id":"1","task_output":{"message":"hello"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 + ], + ), + ) + + chunks = [chunk async for chunk in agent.stream(HelloTaskInput(name="Alice"))] + + outputs = [chunk.output for chunk in chunks] + assert outputs == [ + HelloTaskOutput(message=""), + HelloTaskOutput(message="hel"), + HelloTaskOutput(message="hello"), + HelloTaskOutput(message="hello"), + ] + last_message = chunks[-1] + assert isinstance(last_message, Run) + assert last_message.version + assert last_message.version.properties.model == "gpt-4o" + assert last_message.version.properties.temperature == 0.5 + assert last_message.cost_usd == 0.01 + assert last_message.duration_seconds == 10.1 + + async def test_stream_not_optional( + self, + httpx_mock: HTTPXMock, + agent_not_optional: Agent[HelloTaskInput, HelloTaskOutputNotOptional], + ): + # Checking that streaming works even with non optional fields + # The first two chunks are missing a required key but the last one has it + httpx_mock.add_response( + stream=IteratorStream( + [ + b'data: {"id":"1","task_output":{"message":""}}\n\n', + b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 + b'data: {"id":"1","task_output":{"message":"hello", "another_field": "test"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 + ], + ), + ) + + chunks = [chunk async for chunk in agent_not_optional.stream(HelloTaskInput(name="Alice"))] + + messages = [chunk.output.message for chunk in chunks] + assert messages == ["", "hel", "hello", "hello"] + + for chunk in chunks[:-1]: + assert chunk.output.another_field == "" + assert chunks[-1].output.another_field == "test" + + last_message = chunks[-1] + assert isinstance(last_message, Run) + assert last_message.version + assert last_message.version.properties.model == "gpt-4o" + assert last_message.version.properties.temperature == 0.5 + assert last_message.cost_usd == 0.01 + assert last_message.duration_seconds == 10.1 + + async def test_stream_validation_errors( + self, + agent: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + caplog: pytest.LogCaptureFixture, + ): + """Test that validation errors are properly skipped and logged during streaming""" + httpx_mock.add_response( + stream=IteratorStream( + [ + b'data: {"id":"1","task_output":{"message":""}}\n\n', + # Middle chunk is passed + b'data: {"id":"1","task_output":{"message":1}}\n\n', + b'data: {"id":"1","task_output":{"message":"hello"}}\n\n', + ], + ), + ) + + with caplog.at_level(logging.DEBUG): + chunks = [chunk async for chunk in agent.stream(HelloTaskInput(name="Alice"))] + + assert len(chunks) == 2 + assert chunks[0].output.message == "" + assert chunks[1].output.message == "hello" + logs = [record for record in caplog.records if "Client side validation error in stream" in record.message] + + assert len(logs) == 1 + assert logs[0].levelname == "DEBUG" + assert logs[0].exc_info is not None + + async def test_stream_validation_final_error( + self, + agent: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + ): + """Check that we properly raise an error if the final payload fails to validate.""" + httpx_mock.add_response( + stream=IteratorStream( + [ + # Stream a single chunk that fails to validate + b'data: {"id":"1","task_output":{"message":1}}\n\n', + ], + ), + ) + + with pytest.raises(WorkflowAIError) as e: + _ = [c async for c in agent.stream(HelloTaskInput(name="Alice"))] + + assert e.value.partial_output == {"message": 1} + assert e.value.run_id == "1" From b4c57531a498caae20729837a96c14eb738a8ec0 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 5 Mar 2025 11:12:56 -0500 Subject: [PATCH 8/8] feat: add missing gpt 45 preview --- workflowai/core/domain/model.py | 1 + workflowai/core/domain/model_test.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+) create mode 100644 workflowai/core/domain/model_test.py diff --git a/workflowai/core/domain/model.py b/workflowai/core/domain/model.py index 061398c..5b166ed 100644 --- a/workflowai/core/domain/model.py +++ b/workflowai/core/domain/model.py @@ -34,6 +34,7 @@ class Model(str, Enum): O1_PREVIEW_2024_09_12 = "o1-preview-2024-09-12" O1_MINI_LATEST = "o1-mini-latest" O1_MINI_2024_09_12 = "o1-mini-2024-09-12" + GPT_45_PREVIEW_2025_02_27 = "gpt-4.5-preview-2025-02-27" GPT_4O_AUDIO_PREVIEW_2024_12_17 = "gpt-4o-audio-preview-2024-12-17" GPT_40_AUDIO_PREVIEW_2024_10_01 = "gpt-4o-audio-preview-2024-10-01" GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09" diff --git a/workflowai/core/domain/model_test.py b/workflowai/core/domain/model_test.py new file mode 100644 index 0000000..d23bacb --- /dev/null +++ b/workflowai/core/domain/model_test.py @@ -0,0 +1,15 @@ +import httpx + +from workflowai.core.domain.model import Model + + +async def test_model_exhaustive(): + """Make sure the list of models is synchronized with the prod API""" + async with httpx.AsyncClient() as client: + response = await client.get("https://run.workflowai.com/v1/models") + response.raise_for_status() + models: list[str] = response.json() + + # Converting to a set of strings should not be needed + # but it makes pytest errors prettier + assert set(models) == {m.value for m in Model}