diff --git a/pyproject.toml b/pyproject.toml index 37b210e..dfd99ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.6.0.dev19" +version = "0.6.0.dev20" description = "" authors = ["Guillaume Aquilina "] readme = "README.md" diff --git a/tests/e2e/no_schema_test.py b/tests/e2e/no_schema_test.py index 3cc8784..111977f 100644 --- a/tests/e2e/no_schema_test.py +++ b/tests/e2e/no_schema_test.py @@ -1,8 +1,9 @@ from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field import workflowai +from workflowai.core.client.agent import Agent class SummarizeTaskInput(BaseModel): @@ -28,3 +29,31 @@ async def test_summarize(): use_cache="never", ) assert summarized.summary_points + + +async def test_same_schema(): + class InputWithNullableList(BaseModel): + opt_list: Optional[list[str]] = None + + class InputWithNonNullableList(BaseModel): + opt_list: list[str] = Field(default_factory=list) + + agent1 = Agent( + agent_id="summarize", + input_cls=InputWithNullableList, + output_cls=SummarizeTaskOutput, + api=lambda: workflowai.shared_client.api, + ) + + schema_id1 = await agent1.register() + + agent2 = Agent( + agent_id="summarize", + input_cls=InputWithNonNullableList, + output_cls=SummarizeTaskOutput, + api=lambda: workflowai.shared_client.api, + ) + + schema_id2 = await agent2.register() + + assert schema_id1 == schema_id2 diff --git a/tests/integration/run_test.py b/tests/integration/run_test.py index 73d2fe0..f0944c9 100644 --- a/tests/integration/run_test.py +++ b/tests/integration/run_test.py @@ -77,11 +77,11 @@ def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToC task_input = CityToCapitalTaskInput(city="Hello") chunks = [chunk async for chunk in city_to_capital(task_input)] - assert chunks == [ - CityToCapitalTaskOutput(capital=""), - CityToCapitalTaskOutput(capital="Tok"), - CityToCapitalTaskOutput(capital="Tokyo"), - CityToCapitalTaskOutput(capital="Tokyo"), + assert [chunk.capital for chunk in chunks] == [ + "", + "Tok", + "Tokyo", + "Tokyo", ] @@ -94,11 +94,11 @@ def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToC task_input = CityToCapitalTaskInput(city="Hello") chunks = [chunk async for chunk in city_to_capital(task_input)] - assert chunks == [ - CityToCapitalTaskOutput(capital=""), - CityToCapitalTaskOutput(capital="Tok"), - CityToCapitalTaskOutput(capital="Tokyo"), - CityToCapitalTaskOutput(capital="Tokyo"), + assert [chunk.capital for chunk in chunks] == [ + "", + "Tok", + "Tokyo", + "Tokyo", ] diff --git a/workflowai/core/_common_types.py b/workflowai/core/_common_types.py index ad090c4..ea49c85 100644 --- a/workflowai/core/_common_types.py +++ b/workflowai/core/_common_types.py @@ -19,7 +19,18 @@ class OutputValidator(Protocol, Generic[AgentOutputCov]): - def __call__(self, data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutputCov: ... + def __call__(self, data: dict[str, Any], partial: bool) -> AgentOutputCov: + """A way to convert a json object into an AgentOutput + + Args: + data (dict[str, Any]): The json object to convert + partial (bool): Whether the json is partial, meaning that + it may not contain all the fields required by the AgentOutput model. + + Returns: + AgentOutputCov: The converted AgentOutput + """ + ... class VersionRunParams(TypedDict): diff --git a/workflowai/core/client/_api.py b/workflowai/core/client/_api.py index 5f09641..8228ab5 100644 --- a/workflowai/core/client/_api.py +++ b/workflowai/core/client/_api.py @@ -1,4 +1,3 @@ -import logging from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any, Literal, Optional, TypeVar, Union, overload @@ -6,14 +5,13 @@ import httpx from pydantic import BaseModel, TypeAdapter, ValidationError +from workflowai.core._logger import logger from workflowai.core.domain.errors import BaseError, ErrorResponse, WorkflowAIError # A type for return values _R = TypeVar("_R") _M = TypeVar("_M", bound=BaseModel) -_logger = logging.getLogger("WorkflowAI") - class APIClient: def __init__(self, url: str, api_key: str, source_headers: Optional[dict[str, str]] = None): @@ -154,7 +152,7 @@ async def _wrap_sse(self, raw: AsyncIterator[bytes], termination_chars: bytes = in_data = False if data: - _logger.warning("Data left after processing", extra={"data": data}) + logger.warning("Data left after processing", extra={"data": data}) async def stream( self, diff --git a/workflowai/core/client/_fn_utils.py b/workflowai/core/client/_fn_utils.py index 97495f9..3930669 100644 --- a/workflowai/core/client/_fn_utils.py +++ b/workflowai/core/client/_fn_utils.py @@ -23,7 +23,7 @@ RunParams, RunTemplate, ) -from workflowai.core.client._utils import intolerant_validator +from workflowai.core.client._utils import default_validator from workflowai.core.client.agent import Agent from workflowai.core.domain.errors import InvalidGenerationError from workflowai.core.domain.model import ModelOrStr @@ -144,7 +144,7 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp except InvalidGenerationError as e: if e.partial_output and e.run_id: try: - validator, _ = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls)) + validator, _ = self._sanitize_validator(kwargs, default_validator(self.output_cls)) run = self._build_run_no_tools( chunk=RunResponse( id=e.run_id, @@ -152,6 +152,7 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp ), schema_id=self.schema_id or 0, validator=validator, + partial=False, ) run.error = e.error return run diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 8cf1bf1..ac888ba 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -134,12 +134,18 @@ def to_domain( task_id: str, task_schema_id: int, validator: OutputValidator[AgentOutput], + partial: Optional[bool] = None, ) -> Run[AgentOutput]: + # We do partial validation if either: + # - there are tool call requests, which means that the output can be empty + # - the run has not yet finished, for example when streaming, in which case the duration_seconds is None + if partial is None: + partial = bool(self.tool_call_requests) or self.duration_seconds is None return Run( id=self.id, agent_id=task_id, schema_id=task_schema_id, - output=validator(self.task_output, self.tool_call_requests is not None), + output=validator(self.task_output, partial), version=self.version and self.version.to_domain(), duration_seconds=self.duration_seconds, cost_usd=self.cost_usd, diff --git a/workflowai/core/client/_models_test.py b/workflowai/core/client/_models_test.py index e12bdbf..661459c 100644 --- a/workflowai/core/client/_models_test.py +++ b/workflowai/core/client/_models_test.py @@ -5,7 +5,7 @@ from tests.utils import fixture_text from workflowai.core.client._models import RunResponse -from workflowai.core.client._utils import intolerant_validator, tolerant_validator +from workflowai.core.client._utils import default_validator from workflowai.core.domain.run import Run from workflowai.core.domain.tool_call import ToolCallRequest @@ -41,29 +41,30 @@ def test_no_version_not_optional(self): with pytest.raises(ValidationError): # sanity _TaskOutput.model_validate({"a": 1}) - parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput)) + parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput)) assert isinstance(parsed, Run) assert parsed.output.a == 1 # b is not defined - with pytest.raises(AttributeError): - assert parsed.output.b + + assert parsed.output.b == "" def test_no_version_optional(self): chunk = RunResponse.model_validate_json('{"id": "1", "task_output": {"a": 1}}') assert chunk - parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutputOpt)) + parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutputOpt)) assert isinstance(parsed, Run) assert parsed.output.a == 1 assert parsed.output.b is None def test_with_version(self): + """Full output is validated since the duration is passed and there are no tool calls""" chunk = RunResponse.model_validate_json( '{"id": "1", "task_output": {"a": 1, "b": "test"}, "cost_usd": 0.1, "duration_seconds": 1, "version": {"properties": {"a": 1, "b": "test"}}}', # noqa: E501 ) assert chunk - parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput)) + parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput)) assert isinstance(parsed, Run) assert parsed.output.a == 1 assert parsed.output.b == "test" @@ -73,17 +74,19 @@ def test_with_version(self): def test_with_version_validation_fails(self): chunk = RunResponse.model_validate_json( - '{"id": "1", "task_output": {"a": 1}, "version": {"properties": {"a": 1, "b": "test"}}}', + """{"id": "1", "task_output": {"a": 1}, + "version": {"properties": {"a": 1, "b": "test"}}, "duration_seconds": 1}""", ) with pytest.raises(ValidationError): - chunk.to_domain(task_id="1", task_schema_id=1, validator=intolerant_validator(_TaskOutput)) + chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput)) def test_with_tool_calls(self): chunk = RunResponse.model_validate_json( - '{"id": "1", "task_output": {}, "tool_call_requests": [{"id": "1", "name": "test", "input": {"a": 1}}]}', + """{"id": "1", "task_output": {}, + "tool_call_requests": [{"id": "1", "name": "test", "input": {"a": 1}}], "duration_seconds": 1}""", ) assert chunk - parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput)) + parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput)) assert isinstance(parsed, Run) assert parsed.tool_call_requests == [ToolCallRequest(id="1", name="test", input={"a": 1})] diff --git a/workflowai/core/client/_utils.py b/workflowai/core/client/_utils.py index c846ef8..0894082 100644 --- a/workflowai/core/client/_utils.py +++ b/workflowai/core/client/_utils.py @@ -13,6 +13,7 @@ from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.version_reference import VersionReference +from workflowai.core.utils._pydantic import partial_model delimiter = re.compile(r'\}\n\ndata: \{"') @@ -86,20 +87,12 @@ async def _wait_for_exception(e: WorkflowAIError): return _should_retry, _wait_for_exception -def tolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: - def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput: # noqa: ARG001 - return m.model_construct(None, **data) +def default_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: + partial_cls = partial_model(m) - return _validator - - -def intolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: - def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput: - # When we have tool call requests, the output can be empty - if has_tool_call_requests: - return m.model_construct(None, **data) - - return m.model_validate(data) + def _validator(data: dict[str, Any], partial: bool) -> AgentOutput: + model_cls = partial_cls if partial else m + return model_cls.model_validate(data) return _validator diff --git a/workflowai/core/client/_utils_test.py b/workflowai/core/client/_utils_test.py index 350584c..3c76bd3 100644 --- a/workflowai/core/client/_utils_test.py +++ b/workflowai/core/client/_utils_test.py @@ -2,8 +2,14 @@ from unittest.mock import Mock, patch import pytest +from pydantic import BaseModel -from workflowai.core.client._utils import build_retryable_wait, global_default_version_reference, split_chunks +from workflowai.core.client._utils import ( + build_retryable_wait, + default_validator, + global_default_version_reference, + split_chunks, +) from workflowai.core.domain.errors import BaseError, WorkflowAIError @@ -42,3 +48,24 @@ async def test_should_retry_count(self, request_error: WorkflowAIError): def test_global_default_version_reference(env_var: str, expected: Any): with patch.dict("os.environ", {"WORKFLOWAI_DEFAULT_VERSION": env_var}): assert global_default_version_reference() == expected + + +# Create a nested object with only required properties +class Recipe(BaseModel): + class Ingredient(BaseModel): + name: str + quantity: int + + ingredients: list[Ingredient] + + +class TestValidator: + def test_tolerant_validator_nested_object(self): + validated = default_validator(Recipe)( + { + "ingredients": [{"name": "salt"}], + }, + partial=True, + ) + for ingredient in validated.ingredients: + assert isinstance(ingredient, Recipe.Ingredient) diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index dcfefb2..e0c29e4 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -3,7 +3,7 @@ from collections.abc import Awaitable, Callable, Iterable from typing import Any, Generic, NamedTuple, Optional, Union, cast -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from typing_extensions import Unpack from workflowai.core._common_types import BaseRunParams, OutputValidator, VersionRunParams @@ -22,9 +22,8 @@ from workflowai.core.client._types import RunParams from workflowai.core.client._utils import ( build_retryable_wait, + default_validator, global_default_version_reference, - intolerant_validator, - tolerant_validator, ) from workflowai.core.domain.completion import Completion from workflowai.core.domain.errors import BaseError, WorkflowAIError @@ -103,6 +102,8 @@ def __init__( self._api = (lambda: api) if isinstance(api, APIClient) else api self._tools = self.build_tools(tools) if tools else None + self._default_validator = default_validator(output_cls) + @classmethod def build_tools(cls, tools: Iterable[Callable[..., Any]]): # TODO: we should be more tolerant with errors ? @@ -295,8 +296,9 @@ def _build_run_no_tools( chunk: RunResponse, schema_id: int, validator: OutputValidator[AgentOutput], + partial: Optional[bool] = None, ) -> Run[AgentOutput]: - run = chunk.to_domain(self.agent_id, schema_id, validator) + run = chunk.to_domain(self.agent_id, schema_id, validator, partial) run._agent = self # pyright: ignore [reportPrivateUsage] return run @@ -362,7 +364,7 @@ async def run( Run[AgentOutput]: The task run object. """ prepared_run = await self._prepare_run(agent_input, stream=False, **kwargs) - validator, new_kwargs = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls)) + validator, new_kwargs = self._sanitize_validator(kwargs, self._default_validator) last_error = None while prepared_run.should_retry(): @@ -374,7 +376,6 @@ async def run( validator, current_iteration=0, # TODO[test]: add test with custom validator - # We popped validator above **new_kwargs, ) except WorkflowAIError as e: # noqa: PERF203 @@ -419,10 +420,13 @@ async def stream( AsyncIterator[Run[AgentOutput]]: An async iterator yielding task run objects. """ prepared_run = await self._prepare_run(agent_input, stream=True, **kwargs) - validator, new_kwargs = self._sanitize_validator(kwargs, tolerant_validator(self.output_cls)) + validator, new_kwargs = self._sanitize_validator(kwargs, self._default_validator) while prepared_run.should_retry(): try: + # Will store the error if the final payload fails to validate + final_error: Optional[Exception] = None + chunk: Optional[RunResponse] = None async for chunk in self.api.stream( method="POST", path=prepared_run.route, @@ -430,15 +434,36 @@ async def stream( returns=RunResponse, run=True, ): - yield await self._build_run( - chunk, - prepared_run.schema_id, - validator, - current_iteration=0, - **new_kwargs, - ) + try: + yield await self._build_run( + chunk, + prepared_run.schema_id, + validator, + current_iteration=0, + **new_kwargs, + ) + final_error = None + except ValidationError as e: + logger.debug( + "Client side validation error in stream. There is likely an " + "issue with the validator or the model.", + exc_info=e, + ) + final_error = e + continue + if final_error: + raise WorkflowAIError( + error=BaseError( + message="Client side validation error in stream. This should not " + "happen is the payload is already validated by the server. This probably" + "means that there is an issue with the validator or the model.", + ), + response=None, + partial_output=chunk.task_output if chunk else None, + run_id=chunk.id if chunk else None, + ) from final_error return - except WorkflowAIError as e: # noqa: PERF203 + except WorkflowAIError as e: await prepared_run.wait_for_exception(e) async def reply( @@ -462,7 +487,7 @@ async def reply( """ prepared_run = await self._prepare_reply(run_id, user_message, tool_results, stream=False, **kwargs) - validator, new_kwargs = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls)) + validator, new_kwargs = self._sanitize_validator(kwargs, self._default_validator) res = await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True) return await self._build_run( diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index 8b4f4ab..8779b47 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -1,5 +1,6 @@ import importlib.metadata import json +import logging from unittest.mock import Mock, patch import httpx @@ -120,70 +121,6 @@ async def test_success(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, "stream": False, } - async def test_stream(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): - httpx_mock.add_response( - stream=IteratorStream( - [ - b'data: {"id":"1","task_output":{"message":""}}\n\n', - b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 - b'data: {"id":"1","task_output":{"message":"hello"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 - ], - ), - ) - - chunks = [chunk async for chunk in agent.stream(HelloTaskInput(name="Alice"))] - - outputs = [chunk.output for chunk in chunks] - assert outputs == [ - HelloTaskOutput(message=""), - HelloTaskOutput(message="hel"), - HelloTaskOutput(message="hello"), - HelloTaskOutput(message="hello"), - ] - last_message = chunks[-1] - assert isinstance(last_message, Run) - assert last_message.version - assert last_message.version.properties.model == "gpt-4o" - assert last_message.version.properties.temperature == 0.5 - assert last_message.cost_usd == 0.01 - assert last_message.duration_seconds == 10.1 - - async def test_stream_not_optional( - self, - httpx_mock: HTTPXMock, - agent_not_optional: Agent[HelloTaskInput, HelloTaskOutputNotOptional], - ): - # Checking that streaming works even with non optional fields - # The first two chunks are missing a required key but the last one has it - httpx_mock.add_response( - stream=IteratorStream( - [ - b'data: {"id":"1","task_output":{"message":""}}\n\n', - b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 - b'data: {"id":"1","task_output":{"message":"hello", "another_field": "test"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 - ], - ), - ) - - chunks = [chunk async for chunk in agent_not_optional.stream(HelloTaskInput(name="Alice"))] - - messages = [chunk.output.message for chunk in chunks] - assert messages == ["", "hel", "hello", "hello"] - - for chunk in chunks[:-1]: - with pytest.raises(AttributeError): - # Since the field is not optional, it will raise an attribute error - assert chunk.output.another_field - assert chunks[-1].output.another_field == "test" - - last_message = chunks[-1] - assert isinstance(last_message, Run) - assert last_message.version - assert last_message.version.properties.model == "gpt-4o" - assert last_message.version.properties.temperature == 0.5 - assert last_message.cost_usd == 0.01 - assert last_message.duration_seconds == 10.1 - async def test_run_with_env(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): httpx_mock.add_response(json=fixtures_json("task_run.json")) @@ -449,558 +386,555 @@ def test_only_model_privider(self, agent: Agent[HelloTaskInput, HelloTaskOutput] } -@pytest.mark.asyncio -async def test_list_models(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock): - """Test that list_models correctly fetches and returns available models.""" - # Mock the HTTP response instead of the API client method - httpx_mock.add_response( - url="http://localhost:8000/v1/_/agents/123/schemas/1/models", - json={ - "items": [ - { - "id": "gpt-4", - "name": "GPT-4", - "icon_url": "https://example.com/gpt4.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.01, - "is_latest": True, - "metadata": { - "provider_name": "OpenAI", - "price_per_input_token_usd": 0.0001, - "price_per_output_token_usd": 0.0002, - "release_date": "2024-01-01", - "context_window_tokens": 128000, - "quality_index": 0.95, +class TestListModels: + async def test_list_models(self, agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], }, - "is_default": True, - "providers": ["openai"], - }, - { - "id": "claude-3", - "name": "Claude 3", - "icon_url": "https://example.com/claude3.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.015, - "is_latest": True, - "metadata": { - "provider_name": "Anthropic", - "price_per_input_token_usd": 0.00015, - "price_per_output_token_usd": 0.00025, - "release_date": "2024-03-01", - "context_window_tokens": 200000, - "quality_index": 0.98, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], }, - "is_default": False, - "providers": ["anthropic"], - }, - ], - "count": 2, - }, - ) + ], + "count": 2, + }, + ) - # Call the method - models = await agent.list_models() - - # Verify the HTTP request was made correctly - request = httpx_mock.get_request() - assert request is not None, "Expected an HTTP request to be made" - assert request.method == "POST" - assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" - assert json.loads(request.content) == {} - - # Verify we get back the full ModelInfo objects - assert len(models) == 2 - assert isinstance(models[0], ModelInfo) - assert models[0].id == "gpt-4" - assert models[0].name == "GPT-4" - assert models[0].modes == ["chat"] - assert models[0].metadata is not None - assert models[0].metadata.provider_name == "OpenAI" - - assert isinstance(models[1], ModelInfo) - assert models[1].id == "claude-3" - assert models[1].name == "Claude 3" - assert models[1].modes == ["chat"] - assert models[1].metadata is not None - assert models[1].metadata.provider_name == "Anthropic" - - -@pytest.mark.asyncio -async def test_list_models_with_params_override(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock): - """Test that list_models correctly fetches and returns available models.""" - # Mock the HTTP response instead of the API client method - httpx_mock.add_response( - url="http://localhost:8000/v1/_/agents/123/schemas/1/models", - json={ - "items": [ - { - "id": "gpt-4", - "name": "GPT-4", - "icon_url": "https://example.com/gpt4.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.01, - "is_latest": True, - "metadata": { - "provider_name": "OpenAI", - "price_per_input_token_usd": 0.0001, - "price_per_output_token_usd": 0.0002, - "release_date": "2024-01-01", - "context_window_tokens": 128000, - "quality_index": 0.95, + # Call the method + models = await agent.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + async def test_list_models_with_params_override( + self, + agent: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + ): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], }, - "is_default": True, - "providers": ["openai"], - }, - { - "id": "claude-3", - "name": "Claude 3", - "icon_url": "https://example.com/claude3.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.015, - "is_latest": True, - "metadata": { - "provider_name": "Anthropic", - "price_per_input_token_usd": 0.00015, - "price_per_output_token_usd": 0.00025, - "release_date": "2024-03-01", - "context_window_tokens": 200000, - "quality_index": 0.98, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], }, - "is_default": False, - "providers": ["anthropic"], - }, - ], - "count": 2, - }, - ) + ], + "count": 2, + }, + ) + + # Call the method + models = await agent.list_models(instructions="Some override instructions", requires_tools=True) + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == { + "instructions": "Some override instructions", + "requires_tools": True, + } - # Call the method - models = await agent.list_models(instructions="Some override instructions", requires_tools=True) - - # Verify the HTTP request was made correctly - request = httpx_mock.get_request() - assert request is not None, "Expected an HTTP request to be made" - assert request.method == "POST" - assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" - assert json.loads(request.content) == { - "instructions": "Some override instructions", - "requires_tools": True, - } - - # Verify we get back the full ModelInfo objects - assert len(models) == 2 - assert isinstance(models[0], ModelInfo) - assert models[0].id == "gpt-4" - assert models[0].name == "GPT-4" - assert models[0].modes == ["chat"] - assert models[0].metadata is not None - assert models[0].metadata.provider_name == "OpenAI" - - assert isinstance(models[1], ModelInfo) - assert models[1].id == "claude-3" - assert models[1].name == "Claude 3" - assert models[1].modes == ["chat"] - assert models[1].metadata is not None - assert models[1].metadata.provider_name == "Anthropic" - - -@pytest.mark.asyncio -async def test_list_models_with_params_override_and_agent_with_tools_and_instructions( - agent_with_tools_and_instructions: Agent[HelloTaskInput, HelloTaskOutput], - httpx_mock: HTTPXMock, -): - """Test that list_models correctly fetches and returns available models.""" - # Mock the HTTP response instead of the API client method - httpx_mock.add_response( - url="http://localhost:8000/v1/_/agents/123/schemas/1/models", - json={ - "items": [ - { - "id": "gpt-4", - "name": "GPT-4", - "icon_url": "https://example.com/gpt4.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.01, - "is_latest": True, - "metadata": { - "provider_name": "OpenAI", - "price_per_input_token_usd": 0.0001, - "price_per_output_token_usd": 0.0002, - "release_date": "2024-01-01", - "context_window_tokens": 128000, - "quality_index": 0.95, + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + async def test_list_models_with_params_override_and_agent_with_tools_and_instructions( + self, + agent_with_tools_and_instructions: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + ): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], }, - "is_default": True, - "providers": ["openai"], - }, - { - "id": "claude-3", - "name": "Claude 3", - "icon_url": "https://example.com/claude3.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.015, - "is_latest": True, - "metadata": { - "provider_name": "Anthropic", - "price_per_input_token_usd": 0.00015, - "price_per_output_token_usd": 0.00025, - "release_date": "2024-03-01", - "context_window_tokens": 200000, - "quality_index": 0.98, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], }, - "is_default": False, - "providers": ["anthropic"], - }, - ], - "count": 2, - }, - ) + ], + "count": 2, + }, + ) - # Call the method - models = await agent_with_tools_and_instructions.list_models( - instructions="Some override instructions", - requires_tools=False, - ) + # Call the method + models = await agent_with_tools_and_instructions.list_models( + instructions="Some override instructions", + requires_tools=False, + ) - # Verify the HTTP request was made correctly - request = httpx_mock.get_request() - assert request is not None, "Expected an HTTP request to be made" - assert request.method == "POST" - assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" - assert json.loads(request.content) == { - "instructions": "Some override instructions", - "requires_tools": False, - } - - # Verify we get back the full ModelInfo objects - assert len(models) == 2 - assert isinstance(models[0], ModelInfo) - assert models[0].id == "gpt-4" - assert models[0].name == "GPT-4" - assert models[0].modes == ["chat"] - assert models[0].metadata is not None - assert models[0].metadata.provider_name == "OpenAI" - - assert isinstance(models[1], ModelInfo) - assert models[1].id == "claude-3" - assert models[1].name == "Claude 3" - assert models[1].modes == ["chat"] - assert models[1].metadata is not None - assert models[1].metadata.provider_name == "Anthropic" - - -@pytest.mark.asyncio -async def test_list_models_registers_if_needed( - agent_no_schema: Agent[HelloTaskInput, HelloTaskOutput], - httpx_mock: HTTPXMock, -): - """Test that list_models registers the agent if it hasn't been registered yet.""" - # Mock the registration response - httpx_mock.add_response( - url="http://localhost:8000/v1/_/agents", - json={"id": "123", "schema_id": 2}, - ) + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == { + "instructions": "Some override instructions", + "requires_tools": False, + } + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + async def test_list_models_registers_if_needed( + self, + agent_no_schema: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + ): + """Test that list_models registers the agent if it hasn't been registered yet.""" + # Mock the registration response + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents", + json={"id": "123", "schema_id": 2}, + ) - # Mock the models response with the new structure - httpx_mock.add_response( - url="http://localhost:8000/v1/_/agents/123/schemas/2/models", - json={ - "items": [ - { - "id": "gpt-4", - "name": "GPT-4", - "icon_url": "https://example.com/gpt4.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.01, - "is_latest": True, - "metadata": { - "provider_name": "OpenAI", - "price_per_input_token_usd": 0.0001, - "price_per_output_token_usd": 0.0002, - "release_date": "2024-01-01", - "context_window_tokens": 128000, - "quality_index": 0.95, + # Mock the models response with the new structure + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/2/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], }, - "is_default": True, - "providers": ["openai"], - }, - ], - "count": 1, - }, - ) + ], + "count": 1, + }, + ) + + # Call the method + models = await agent_no_schema.list_models() - # Call the method - models = await agent_no_schema.list_models() - - # Verify both API calls were made - reqs = httpx_mock.get_requests() - assert len(reqs) == 2 - assert reqs[0].url == "http://localhost:8000/v1/_/agents" - assert reqs[1].method == "POST" - assert reqs[1].url == "http://localhost:8000/v1/_/agents/123/schemas/2/models" - assert json.loads(reqs[1].content) == {} - - # Verify we get back the full ModelInfo object - assert len(models) == 1 - assert isinstance(models[0], ModelInfo) - assert models[0].id == "gpt-4" - assert models[0].name == "GPT-4" - assert models[0].modes == ["chat"] - assert models[0].metadata is not None - assert models[0].metadata.provider_name == "OpenAI" - - -@pytest.mark.asyncio -async def test_list_models_with_instructions( - agent_with_instructions: Agent[HelloTaskInput, HelloTaskOutput], - httpx_mock: HTTPXMock, -): - """Test that list_models correctly fetches and returns available models.""" - # Mock the HTTP response instead of the API client method - httpx_mock.add_response( - method="POST", - url="http://localhost:8000/v1/_/agents/123/schemas/1/models", - json={ - "items": [ - { - "id": "gpt-4", - "name": "GPT-4", - "icon_url": "https://example.com/gpt4.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.01, - "is_latest": True, - "metadata": { - "provider_name": "OpenAI", - "price_per_input_token_usd": 0.0001, - "price_per_output_token_usd": 0.0002, - "release_date": "2024-01-01", - "context_window_tokens": 128000, - "quality_index": 0.95, + # Verify both API calls were made + reqs = httpx_mock.get_requests() + assert len(reqs) == 2 + assert reqs[0].url == "http://localhost:8000/v1/_/agents" + assert reqs[1].method == "POST" + assert reqs[1].url == "http://localhost:8000/v1/_/agents/123/schemas/2/models" + assert json.loads(reqs[1].content) == {} + + # Verify we get back the full ModelInfo object + assert len(models) == 1 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + async def test_list_models_with_instructions( + self, + agent_with_instructions: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + ): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + method="POST", + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], }, - "is_default": True, - "providers": ["openai"], - }, - { - "id": "claude-3", - "name": "Claude 3", - "icon_url": "https://example.com/claude3.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.015, - "is_latest": True, - "metadata": { - "provider_name": "Anthropic", - "price_per_input_token_usd": 0.00015, - "price_per_output_token_usd": 0.00025, - "release_date": "2024-03-01", - "context_window_tokens": 200000, - "quality_index": 0.98, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], }, - "is_default": False, - "providers": ["anthropic"], - }, - ], - "count": 2, - }, - ) + ], + "count": 2, + }, + ) - # Call the method - models = await agent_with_instructions.list_models() - - # Verify the HTTP request was made correctly - request = httpx_mock.get_request() - assert request is not None, "Expected an HTTP request to be made" - assert request.method == "POST" - assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" - assert json.loads(request.content) == {"instructions": "Some instructions"} - - # Verify we get back the full ModelInfo objects - assert len(models) == 2 - assert isinstance(models[0], ModelInfo) - assert models[0].id == "gpt-4" - assert models[0].name == "GPT-4" - assert models[0].modes == ["chat"] - assert models[0].metadata is not None - assert models[0].metadata.provider_name == "OpenAI" - - assert isinstance(models[1], ModelInfo) - assert models[1].id == "claude-3" - assert models[1].name == "Claude 3" - assert models[1].modes == ["chat"] - assert models[1].metadata is not None - assert models[1].metadata.provider_name == "Anthropic" - - -@pytest.mark.asyncio -async def test_list_models_with_tools( - agent_with_tools: Agent[HelloTaskInput, HelloTaskOutput], - httpx_mock: HTTPXMock, -): - """Test that list_models correctly fetches and returns available models.""" - # Mock the HTTP response instead of the API client method - httpx_mock.add_response( - method="POST", - url="http://localhost:8000/v1/_/agents/123/schemas/1/models", - json={ - "items": [ - { - "id": "gpt-4", - "name": "GPT-4", - "icon_url": "https://example.com/gpt4.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.01, - "is_latest": True, - "metadata": { - "provider_name": "OpenAI", - "price_per_input_token_usd": 0.0001, - "price_per_output_token_usd": 0.0002, - "release_date": "2024-01-01", - "context_window_tokens": 128000, - "quality_index": 0.95, + # Call the method + models = await agent_with_instructions.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {"instructions": "Some instructions"} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + async def test_list_models_with_tools( + self, + agent_with_tools: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + ): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + method="POST", + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], }, - "is_default": True, - "providers": ["openai"], - }, - { - "id": "claude-3", - "name": "Claude 3", - "icon_url": "https://example.com/claude3.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.015, - "is_latest": True, - "metadata": { - "provider_name": "Anthropic", - "price_per_input_token_usd": 0.00015, - "price_per_output_token_usd": 0.00025, - "release_date": "2024-03-01", - "context_window_tokens": 200000, - "quality_index": 0.98, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], }, - "is_default": False, - "providers": ["anthropic"], - }, - ], - "count": 2, - }, - ) + ], + "count": 2, + }, + ) - # Call the method - models = await agent_with_tools.list_models() - - # Verify the HTTP request was made correctly - request = httpx_mock.get_request() - assert request is not None, "Expected an HTTP request to be made" - assert request.method == "POST" - assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" - assert json.loads(request.content) == {"requires_tools": True} - - # Verify we get back the full ModelInfo objects - assert len(models) == 2 - assert isinstance(models[0], ModelInfo) - assert models[0].id == "gpt-4" - assert models[0].name == "GPT-4" - assert models[0].modes == ["chat"] - assert models[0].metadata is not None - assert models[0].metadata.provider_name == "OpenAI" - - assert isinstance(models[1], ModelInfo) - assert models[1].id == "claude-3" - assert models[1].name == "Claude 3" - assert models[1].modes == ["chat"] - assert models[1].metadata is not None - assert models[1].metadata.provider_name == "Anthropic" - - -@pytest.mark.asyncio -async def test_list_models_with_instructions_and_tools( - agent_with_tools_and_instructions: Agent[HelloTaskInput, HelloTaskOutput], - httpx_mock: HTTPXMock, -): - """Test that list_models correctly fetches and returns available models.""" - # Mock the HTTP response instead of the API client method - httpx_mock.add_response( - method="POST", - url="http://localhost:8000/v1/_/agents/123/schemas/1/models", - json={ - "items": [ - { - "id": "gpt-4", - "name": "GPT-4", - "icon_url": "https://example.com/gpt4.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.01, - "is_latest": True, - "metadata": { - "provider_name": "OpenAI", - "price_per_input_token_usd": 0.0001, - "price_per_output_token_usd": 0.0002, - "release_date": "2024-01-01", - "context_window_tokens": 128000, - "quality_index": 0.95, + # Call the method + models = await agent_with_tools.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {"requires_tools": True} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + async def test_list_models_with_instructions_and_tools( + self, + agent_with_tools_and_instructions: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + ): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + method="POST", + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], }, - "is_default": True, - "providers": ["openai"], - }, - { - "id": "claude-3", - "name": "Claude 3", - "icon_url": "https://example.com/claude3.png", - "modes": ["chat"], - "is_not_supported_reason": None, - "average_cost_per_run_usd": 0.015, - "is_latest": True, - "metadata": { - "provider_name": "Anthropic", - "price_per_input_token_usd": 0.00015, - "price_per_output_token_usd": 0.00025, - "release_date": "2024-03-01", - "context_window_tokens": 200000, - "quality_index": 0.98, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], }, - "is_default": False, - "providers": ["anthropic"], - }, - ], - "count": 2, - }, - ) + ], + "count": 2, + }, + ) - # Call the method - models = await agent_with_tools_and_instructions.list_models() - - # Verify the HTTP request was made correctly - request = httpx_mock.get_request() - assert request is not None, "Expected an HTTP request to be made" - assert request.method == "POST" - assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" - assert json.loads(request.content) == {"instructions": "Some instructions", "requires_tools": True} - - # Verify we get back the full ModelInfo objects - assert len(models) == 2 - assert isinstance(models[0], ModelInfo) - assert models[0].id == "gpt-4" - assert models[0].name == "GPT-4" - assert models[0].modes == ["chat"] - assert models[0].metadata is not None - assert models[0].metadata.provider_name == "OpenAI" - - assert isinstance(models[1], ModelInfo) - assert models[1].id == "claude-3" - assert models[1].name == "Claude 3" - assert models[1].modes == ["chat"] - assert models[1].metadata is not None - assert models[1].metadata.provider_name == "Anthropic" + # Call the method + models = await agent_with_tools_and_instructions.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {"instructions": "Some instructions", "requires_tools": True} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" class TestFetchCompletions: @@ -1029,3 +963,118 @@ async def test_fetch_completions(self, agent: Agent[HelloTaskInput, HelloTaskOut ), ), ] + + +class TestStream: + async def test_stream(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): + httpx_mock.add_response( + stream=IteratorStream( + [ + b'data: {"id":"1","task_output":{"message":""}}\n\n', + b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 + b'data: {"id":"1","task_output":{"message":"hello"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 + ], + ), + ) + + chunks = [chunk async for chunk in agent.stream(HelloTaskInput(name="Alice"))] + + outputs = [chunk.output for chunk in chunks] + assert outputs == [ + HelloTaskOutput(message=""), + HelloTaskOutput(message="hel"), + HelloTaskOutput(message="hello"), + HelloTaskOutput(message="hello"), + ] + last_message = chunks[-1] + assert isinstance(last_message, Run) + assert last_message.version + assert last_message.version.properties.model == "gpt-4o" + assert last_message.version.properties.temperature == 0.5 + assert last_message.cost_usd == 0.01 + assert last_message.duration_seconds == 10.1 + + async def test_stream_not_optional( + self, + httpx_mock: HTTPXMock, + agent_not_optional: Agent[HelloTaskInput, HelloTaskOutputNotOptional], + ): + # Checking that streaming works even with non optional fields + # The first two chunks are missing a required key but the last one has it + httpx_mock.add_response( + stream=IteratorStream( + [ + b'data: {"id":"1","task_output":{"message":""}}\n\n', + b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 + b'data: {"id":"1","task_output":{"message":"hello", "another_field": "test"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 + ], + ), + ) + + chunks = [chunk async for chunk in agent_not_optional.stream(HelloTaskInput(name="Alice"))] + + messages = [chunk.output.message for chunk in chunks] + assert messages == ["", "hel", "hello", "hello"] + + for chunk in chunks[:-1]: + assert chunk.output.another_field == "" + assert chunks[-1].output.another_field == "test" + + last_message = chunks[-1] + assert isinstance(last_message, Run) + assert last_message.version + assert last_message.version.properties.model == "gpt-4o" + assert last_message.version.properties.temperature == 0.5 + assert last_message.cost_usd == 0.01 + assert last_message.duration_seconds == 10.1 + + async def test_stream_validation_errors( + self, + agent: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + caplog: pytest.LogCaptureFixture, + ): + """Test that validation errors are properly skipped and logged during streaming""" + httpx_mock.add_response( + stream=IteratorStream( + [ + b'data: {"id":"1","task_output":{"message":""}}\n\n', + # Middle chunk is passed + b'data: {"id":"1","task_output":{"message":1}}\n\n', + b'data: {"id":"1","task_output":{"message":"hello"}}\n\n', + ], + ), + ) + + with caplog.at_level(logging.DEBUG): + chunks = [chunk async for chunk in agent.stream(HelloTaskInput(name="Alice"))] + + assert len(chunks) == 2 + assert chunks[0].output.message == "" + assert chunks[1].output.message == "hello" + logs = [record for record in caplog.records if "Client side validation error in stream" in record.message] + + assert len(logs) == 1 + assert logs[0].levelname == "DEBUG" + assert logs[0].exc_info is not None + + async def test_stream_validation_final_error( + self, + agent: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, + ): + """Check that we properly raise an error if the final payload fails to validate.""" + httpx_mock.add_response( + stream=IteratorStream( + [ + # Stream a single chunk that fails to validate + b'data: {"id":"1","task_output":{"message":1}}\n\n', + ], + ), + ) + + with pytest.raises(WorkflowAIError) as e: + _ = [c async for c in agent.stream(HelloTaskInput(name="Alice"))] + + assert e.value.partial_output == {"message": 1} + assert e.value.run_id == "1" diff --git a/workflowai/core/domain/model.py b/workflowai/core/domain/model.py index 061398c..5b166ed 100644 --- a/workflowai/core/domain/model.py +++ b/workflowai/core/domain/model.py @@ -34,6 +34,7 @@ class Model(str, Enum): O1_PREVIEW_2024_09_12 = "o1-preview-2024-09-12" O1_MINI_LATEST = "o1-mini-latest" O1_MINI_2024_09_12 = "o1-mini-2024-09-12" + GPT_45_PREVIEW_2025_02_27 = "gpt-4.5-preview-2025-02-27" GPT_4O_AUDIO_PREVIEW_2024_12_17 = "gpt-4o-audio-preview-2024-12-17" GPT_40_AUDIO_PREVIEW_2024_10_01 = "gpt-4o-audio-preview-2024-10-01" GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09" diff --git a/workflowai/core/domain/model_test.py b/workflowai/core/domain/model_test.py new file mode 100644 index 0000000..d23bacb --- /dev/null +++ b/workflowai/core/domain/model_test.py @@ -0,0 +1,15 @@ +import httpx + +from workflowai.core.domain.model import Model + + +async def test_model_exhaustive(): + """Make sure the list of models is synchronized with the prod API""" + async with httpx.AsyncClient() as client: + response = await client.get("https://run.workflowai.com/v1/models") + response.raise_for_status() + models: list[str] = response.json() + + # Converting to a set of strings should not be needed + # but it makes pytest errors prettier + assert set(models) == {m.value for m in Model} diff --git a/workflowai/core/utils/_pydantic.py b/workflowai/core/utils/_pydantic.py new file mode 100644 index 0000000..f8c733f --- /dev/null +++ b/workflowai/core/utils/_pydantic.py @@ -0,0 +1,104 @@ +from collections.abc import Mapping, Sequence +from typing import Any, TypeVar, get_args, get_origin + +from pydantic import BaseModel, Field, create_model +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined +from typing_extensions import TypeGuard + +from workflowai.core._logger import logger +from workflowai.core.utils._vars import BM + +_T = TypeVar("_T") + + +def _safe_issubclass(val: type[_T], cls: type[_T]) -> TypeGuard[type[_T]]: + try: + return issubclass(val, cls) + except TypeError: + return False + + +def _copy_field_info(field_info: FieldInfo, **overrides: Any): + """ + Return a copy of a pydantic FieldInfo object, allow to override + certain values. + """ + + kwargs = overrides + for k, v in field_info.__repr_args__(): + if k in kwargs or not k: + continue + kwargs[k] = v + + return Field(**kwargs) + + +def _default_value_from_annotation(annotation: type[Any]) -> Any: + try: + # Trying to see if the object is instantiable with no value + return annotation() + except Exception: # noqa: BLE001 + logger.debug("Failed to get default value from annotation", exc_info=True) + # Fallback to None + return None + + +def _optional_annotation(annotation: type[Any]) -> type[Any]: + if _safe_issubclass(annotation, BaseModel): + return partial_model(annotation) + + origin = get_origin(annotation) + args = get_args(annotation) + if not origin or not args: + return annotation + + if _safe_issubclass(origin, Sequence) or _safe_issubclass(origin, set): + if not len(args) == 1: + raise ValueError("Sequence must have exactly one argument") + return origin[_optional_annotation(args[0])] + # No need to do anything here ? + + if _safe_issubclass(origin, Mapping): + if not len(args) == 2: + raise ValueError("Mapping must have exactly two arguments") + if args[0] is not str: + raise ValueError("Mapping key must be a string") + + return origin[args[0], _optional_annotation(args[1])] + return annotation + + +def partial_model(base: type[BM]) -> type[BM]: + default_fields: dict[str, tuple[type[Any], FieldInfo]] = {} + for name, field in base.model_fields.items(): + if field.default != PydanticUndefined or field.default_factory or not field.annotation: + # No need to do anything here, the field is already optional + continue + + overrides: dict[str, Any] = {} + try: + annotation = _optional_annotation(field.annotation) + overrides["annotation"] = annotation + overrides["default"] = _default_value_from_annotation(annotation) + except Exception: # noqa: BLE001 + logger.debug("Failed to make annotation optional", exc_info=True) + continue + + default_fields[name] = (annotation, _copy_field_info(field, **overrides)) + + if not default_fields: + return base + + def custom_eq(o1: BM, o2: Any): + if not isinstance(o2, base): + return False + return o1.model_dump() == o2.model_dump() + + return create_model( # pyright: ignore [reportCallIssue, reportUnknownVariableType] + f"Partial{base.__name__}", + __base__=base, + __eq__=custom_eq, + __hash__=base.__hash__, + **default_fields, # pyright: ignore [reportArgumentType] + ) diff --git a/workflowai/core/utils/_pydantic_test.py b/workflowai/core/utils/_pydantic_test.py new file mode 100644 index 0000000..858a607 --- /dev/null +++ b/workflowai/core/utils/_pydantic_test.py @@ -0,0 +1,155 @@ +from typing import Any, Optional + +import pytest +from pydantic import BaseModel, Field, ValidationError + +from workflowai.core.utils._pydantic import partial_model + + +class TestPartialModel: + def test_partial_model_equals(self): + class SimpleModel(BaseModel): + name: str + + partial = partial_model(SimpleModel) + assert partial.model_validate({"name": "John"}) == SimpleModel(name="John") + + assert SimpleModel(name="John") == partial.model_validate({"name": "John"}) + + def test_simple_model(self): + class SimpleModel(BaseModel): + name1: str + name2: str + name3: int + name4: float + bla: dict[str, Any] + opt: Optional[str] + + constructed = partial_model(SimpleModel).model_validate({"name1": "John"}) + assert isinstance(constructed, SimpleModel) + assert constructed.name1 == "John" + assert constructed.name2 == "" + assert constructed.name3 == 0 + assert constructed.name4 == 0.0 + assert constructed.bla == {} + assert constructed.opt is None + + # Check that we do not raise on an empty payload + partial_model(SimpleModel).model_validate({}) + + # Check that we do raise when a type is wrong + with pytest.raises(ValidationError): + partial_model(SimpleModel).model_validate({"name1": 1, "name2": "2"}) + + def test_with_some_optional_fields(self): + class SomeOptionalFields(BaseModel): + name1: str + name2: str = "blibly" + list1: list[str] = Field(default_factory=lambda: ["1"]) + + constructed = partial_model(SomeOptionalFields).model_validate({}) + assert isinstance(constructed, SomeOptionalFields) + assert constructed.name1 == "" + assert constructed.name2 == "blibly" + assert constructed.list1 == ["1"] + + def test_list_of_strings(self): + class ListOfStrings(BaseModel): + strings: list[str] + + constructed = partial_model(ListOfStrings).model_validate({"strings": ["a", "b"]}) + assert isinstance(constructed, ListOfStrings) + assert constructed.strings == ["a", "b"] + + # Check that we do not raise on an empty payload + partial_model(ListOfStrings).model_validate({}) + + @pytest.mark.parametrize( + "payload", + [ + {}, + {"field1": "hello"}, + {"nested": {"name": "hello", "field2": "world"}}, + {"nested": {"name": "hello"}}, + ], + ) + def test_nested_model(self, payload: dict[str, Any]): + class NestedModel(BaseModel): + name: str + field2: str + + class OuterModel(BaseModel): + field1: str + nested: NestedModel + + constructed = partial_model(OuterModel).model_validate(payload) + assert isinstance(constructed, OuterModel), "constructed is not an instance of OuterModel" + assert constructed.field1 == payload.get("field1", "") + assert isinstance(constructed.nested, NestedModel), "nested is not an instance of NestedModel" + + assert constructed.nested.name == payload.get("nested", {}).get("name", "") + assert constructed.nested.field2 == payload.get("nested", {}).get("field2", "") + + def test_list_of_models(self): + class NestedModel(BaseModel): + name: str + field2: str + + class ListOfModels(BaseModel): + models: list[NestedModel] + + constructed = partial_model(ListOfModels).model_validate( + {"models": [{"name": "hello", "field2": "world"}, {"name": "hello"}]}, + ) + assert isinstance(constructed, ListOfModels) + assert isinstance(constructed.models, list) + assert isinstance(constructed.models[0], NestedModel) + assert isinstance(constructed.models[1], NestedModel) + + def test_set_of_models(self): + class NestedModel(BaseModel): + name: str = "1" + field2: str + + def __hash__(self) -> int: + return hash(self.name) + + class SetOfModels(BaseModel): + models: set[NestedModel] + + constructed = partial_model(SetOfModels).model_validate( + {"models": [{"name": "hello", "field2": "world"}, {"name": "hello"}]}, + ) + assert isinstance(constructed, SetOfModels) + assert isinstance(constructed.models, set) + assert all(isinstance(model, NestedModel) for model in constructed.models) + + def test_dict_of_models(self): + class NestedModel(BaseModel): + name: str + field2: str + + class DictOfModels(BaseModel): + models: dict[str, NestedModel] + + constructed = partial_model(DictOfModels).model_validate( + {"models": {"hello": {"name": "hello", "field2": "world"}, "hello2": {"name": "hello"}}}, + ) + assert isinstance(constructed, DictOfModels) + assert isinstance(constructed.models, dict) + assert isinstance(constructed.models["hello"], NestedModel) + assert isinstance(constructed.models["hello2"], NestedModel) + + def test_with_aliases(self): + class AliasModel(BaseModel): + message: str = Field(alias="message_alias") + aliased_ser: str = Field(serialization_alias="aliased_ser_alias") + aliased_val: str = Field(validation_alias="aliased_val_alias") + + partial = partial_model(AliasModel) + + payload = {"message_alias": "hello", "aliased_ser": "world", "aliased_val_alias": "!"} + v1 = AliasModel.model_validate(payload) + v2 = partial.model_validate(payload) + + assert v1 == v2