diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..5c95d80 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,7 @@ +{ + "recommendations": [ + "charliermarsh.ruff", + "njpwerner.autodocstring", + "editorconfig.editorconfig" + ] +} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..9fa825e --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,69 @@ +# Contributing to WorkflowAI + +## Setup + +### Prerequisites + +- [Poetry](https://python-poetry.org/docs/#installation) for dependency management and publishing + +### Getting started + +```bash +# We recomment configuring the virtual envs in project with poetry so that +# it can easily be picked up by IDEs + +# poetry config virtualenvs.in-project true +poetry install --all-extras + +# Install the pre-commit hooks +poetry run pre-commit install +# or `make install` to install the pre-commit hooks and the dependencies + +# Check the code quality +# Run ruff +poetry run ruff check . +# Run pyright +poetry run pyright +# or `make lint` to run ruff and pyright + +# Run the unit and integration tests +# They do not require any configuration +poetry run pytest --ignore=tests/e2e # make test + +# Run the end to end tests +# They require the `WORKFLOWAI_TEST_API_URL` and `WORKFLOWAI_TEST_API_KEY` environment variables to be set +# If they are present in the `.env` file, they will be picked up automatically +poetry run pytest tests/e2e +``` + +#### Configuring VSCode + +Suggested extensions are available in the [.vscode/extensions.json](.vscode/extensions.json) file. + +### Dependencies + +#### Ruff + +[Ruff](https://github.com/astral-sh/ruff) is a very fast Python code linter and formatter. + +```sh +ruff check . # check the entire project +ruff check src/workflowai/core # check a specific file +ruff check . --fix # fix linting errors automatically in the entire project +``` + +#### Pyright + +[Pyright](https://github.com/microsoft/pyright) is a static type checker for Python. + +> We preferred it to `mypy` because it is faster and easier to configure. + +#### Pydantic + +[Pydantic](https://docs.pydantic.dev/) is a data validation library for Python. +It provides very convenient methods to serialize and deserialize data, introspect its structure, set validation +rules, etc. + +#### HTTPX + +[HTTPX](https://www.python-httpx.org/) is a modern HTTP library for Python. diff --git a/README.md b/README.md index 6f7671f..4710e6b 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,8 @@ WorkflowAI supports a long list of models. The source of truth for models we sup You can set the model explicitly in the agent decorator: ```python +from workflowai import Model + @workflowai.agent(model=Model.GPT_4O_LATEST) def say_hello(input: Input) -> Output: ... @@ -151,16 +153,31 @@ def say_hello(input: Input) -> AsyncIterator[Run[Output]]: ... ``` -### Streaming and advanced usage +### The Run object + +Although having an agent only return the run output covers most use cases, some use cases require having more +information about the run. -You can configure the agent function to stream or return the full run object, simply by changing the type annotation. +By changing the type annotation of the agent function to `Run[Output]`, the generated function will return +the full run object. ```python -# Return the full run object, useful if you want to extract metadata like cost or duration @workflowai.agent() -async def say_hello(input: Input) -> Run[Output]: - ... +async def say_hello(input: Input) -> Run[Output]: ... + + +run = await say_hello(Input(name="John")) +print(run.output) # the output, as before +print(run.model) # the model used for the run +print(run.cost_usd) # the cost of the run in USD +print(run.duration_seconds) # the duration of the inference in seconds +``` +### Streaming + +You can configure the agent function to stream by changing the type annotation to an AsyncIterator. + +```python # Stream the output, the output is filled as it is generated @workflowai.agent() def say_hello(input: Input) -> AsyncIterator[Output]: @@ -172,6 +189,38 @@ def say_hello(input: Input) -> AsyncIterator[Run[Output]]: ... ``` +### Replying to a run + +Some use cases require the ability to have a back and forth between the client and the LLM. For example: + +- tools [see below](#tools) use the reply ability internally +- chatbots +- correcting the LLM output + +In WorkflowAI, this is done by replying to a run. A reply can contain: + +- a user response +- tool results + + + +```python +# Returning the full run object is required to use the reply feature +@workflowai.agent() +async def say_hello(input: Input) -> Run[Output]: + ... + +run = await say_hello(Input(name="John")) +run = await run.reply(user_response="Now say hello to his brother James") +``` + +The output of a reply to a run has the same type as the original run, which makes it easy to iterate towards the +construction of a final output. + +> To allow run iterations, it is very important to have outputs that are tolerant to missing fields, aka that +> have default values for most of their fields. Otherwise the agent will throw a WorkflowAIError on missing fields +> and the run chain will be broken. + ### Tools Tools allow enhancing an agent's capabilities by allowing it to call external functions. @@ -222,9 +271,16 @@ def get_current_time(timezone: Annotated[str, "The timezone to get the current t """Return the current time in the given timezone in iso format""" return datetime.now(ZoneInfo(timezone)).isoformat() +# Tools can also be async +async def fetch_webpage(url: str) -> str: + """Fetch the content of a webpage""" + async with httpx.AsyncClient() as client: + response = await client.get(url) + return response.text + @agent( id="answer-question", - tools=[get_current_time], + tools=[get_current_time, fetch_webpage], version=VersionProperties(model=Model.GPT_4O_LATEST), ) async def answer_question(_: AnswerQuestionInput) -> Run[AnswerQuestionOutput]: ... @@ -261,6 +317,29 @@ except WorkflowAIError as e: print(e.message) ``` +#### Recoverable errors + +Sometimes, the LLM outputs an object that is partially valid, good examples are: + +- the model context window was exceeded during the generation +- the model decided that a tool call result was a failure + +In this case, an agent that returns an output only will always raise an `InvalidGenerationError` which +subclasses `WorkflowAIError`. + +However, an agent that returns a full run object will try to recover from the error by using the partial output. + +```python + +run = await agent(input=Input(name="John")) + +# The run will have an error +assert run.error is not None + +# The run will have a partial output +assert run.output is not None +``` + ### Definining input and output types There are some important subtleties when defining input and output types. @@ -368,3 +447,32 @@ async for run in say_hello(Input(name="John")): print(run.output.greeting1) # will be empty if the model has not generated it yet ``` + +#### Field properties + +Pydantic allows a variety of other validation criteria for fields: minimum, maximum, pattern, etc. +This additional criteria are included the JSON Schema that is sent to WorkflowAI, and are sent to the model. + +```python +class Input(BaseModel): + name: str = Field(min_length=3, max_length=10) + age: int = Field(ge=18, le=100) + email: str = Field(pattern=r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$") +``` + +These arguments can be used to stir the model in the right direction. The caveat is have a +validation that is too strict can lead to invalid generations. In case of an invalid generation: + +- WorkflowAI retries the inference once by providing the model with the invalid output and the validation error +- if the model still fails to generate a valid output, the run will fail with an `InvalidGenerationError`. + the partial output is available in the `partial_output` attribute of the `InvalidGenerationError` + +```python + +@agent() +def my_agent(_: Input) -> :... +``` + +## Contributing + +See the [CONTRIBUTING.md](./CONTRIBUTING.md) file for more details. diff --git a/poetry.lock b/poetry.lock index d9fb309..8f0c0fd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -485,13 +485,13 @@ windows-terminal = ["colorama (>=0.4.6)"] [[package]] name = "pyright" -version = "1.1.390" +version = "1.1.393" description = "Command line wrapper for pyright" optional = false python-versions = ">=3.7" files = [ - {file = "pyright-1.1.390-py3-none-any.whl", hash = "sha256:ecebfba5b6b50af7c1a44c2ba144ba2ab542c227eb49bc1f16984ff714e0e110"}, - {file = "pyright-1.1.390.tar.gz", hash = "sha256:aad7f160c49e0fbf8209507a15e17b781f63a86a1facb69ca877c71ef2e9538d"}, + {file = "pyright-1.1.393-py3-none-any.whl", hash = "sha256:8320629bb7a44ca90944ba599390162bf59307f3d9fb6e27da3b7011b8c17ae5"}, + {file = "pyright-1.1.393.tar.gz", hash = "sha256:aeeb7ff4e0364775ef416a80111613f91a05c8e01e58ecfefc370ca0db7aed9c"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index 1fed874..0fbc0c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ unfixable = [] # in bin we use rich.print "bin/*" = ["T201"] "*_test.py" = ["S101"] +"conftest.py" = ["S101"] [tool.pyright] pythonVersion = "3.9" diff --git a/tests/e2e/tools_test.py b/tests/e2e/tools_test.py index 82869a5..6b07240 100644 --- a/tests/e2e/tools_test.py +++ b/tests/e2e/tools_test.py @@ -6,7 +6,7 @@ from workflowai import Run, agent from workflowai.core.domain.model import Model -from workflowai.core.domain.tool import Tool +from workflowai.core.domain.tool import ToolDefinition from workflowai.core.domain.tool_call import ToolCallResult from workflowai.core.domain.version_properties import VersionProperties @@ -20,7 +20,7 @@ class AnswerQuestionOutput(BaseModel): async def test_manual_tool(): - get_current_time_tool = Tool( + get_current_time_tool = ToolDefinition( name="get_current_time", description="Get the current time", input_schema={}, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..ac2cb40 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,90 @@ +import json +from typing import Any, Optional +from unittest.mock import patch + +import pytest +from pydantic import BaseModel +from pytest_httpx import HTTPXMock, IteratorStream + +from workflowai.core.client.client import WorkflowAI + + +@pytest.fixture(scope="module", autouse=True) +def init_client(): + with patch("workflowai.shared_client", new=WorkflowAI(api_key="test", endpoint="https://run.workflowai.dev")): + yield + + +class CityToCapitalTaskInput(BaseModel): + city: str + + +class CityToCapitalTaskOutput(BaseModel): + capital: str + + +class IntTestClient: + REGISTER_URL = "https://api.workflowai.dev/v1/_/agents" + + def __init__(self, httpx_mock: HTTPXMock): + self.httpx_mock = httpx_mock + + def mock_register(self, schema_id: int = 1, task_id: str = "city-to-capital", variant_id: str = "1"): + self.httpx_mock.add_response( + method="POST", + url=self.REGISTER_URL, + json={"schema_id": schema_id, "variant_id": variant_id, "id": task_id}, + ) + + def mock_response( + self, + task_id: str = "city-to-capital", + capital: str = "Tokyo", + json: Optional[dict[str, Any]] = None, + url: Optional[str] = None, + status_code: int = 200, + ): + self.httpx_mock.add_response( + method="POST", + url=url or f"https://run.workflowai.dev/v1/_/agents/{task_id}/schemas/1/run", + json=json or {"id": "123", "task_output": {"capital": capital}}, + status_code=status_code, + ) + + def mock_stream(self, task_id: str = "city-to-capital"): + self.httpx_mock.add_response( + url=f"https://run.workflowai.dev/v1/_/agents/{task_id}/schemas/1/run", + stream=IteratorStream( + [ + b'data: {"id":"1","task_output":{"capital":""}}\n\n', + b'data: {"id":"1","task_output":{"capital":"Tok"}}\n\ndata: {"id":"1","task_output":{"capital":"Tokyo"}}\n\n', # noqa: E501 + b'data: {"id":"1","task_output":{"capital":"Tokyo"},"cost_usd":0.01,"duration_seconds":10.1}\n\n', + ], + ), + ) + + def check_request( + self, + version: Any = "production", + task_id: str = "city-to-capital", + task_input: Optional[dict[str, Any]] = None, + **matchers: Any, + ): + request = self.httpx_mock.get_request(**matchers) + assert request is not None + assert request.url == f"https://run.workflowai.dev/v1/_/agents/{task_id}/schemas/1/run" + body = json.loads(request.content) + assert body == { + "task_input": task_input or {"city": "Hello"}, + "version": version, + "stream": False, + } + assert request.headers["Authorization"] == "Bearer test" + assert request.headers["Content-Type"] == "application/json" + assert request.headers["x-workflowai-source"] == "sdk" + assert request.headers["x-workflowai-language"] == "python" + + +@pytest.fixture +def test_client(httpx_mock: HTTPXMock) -> IntTestClient: + return IntTestClient(httpx_mock) diff --git a/tests/integration/run_fail_test.py b/tests/integration/run_fail_test.py new file mode 100644 index 0000000..c39a402 --- /dev/null +++ b/tests/integration/run_fail_test.py @@ -0,0 +1,61 @@ +from typing import Any, Optional + +import pytest + +import workflowai +from tests.integration.conftest import CityToCapitalTaskInput, CityToCapitalTaskOutput, IntTestClient +from workflowai.core.domain.errors import InvalidGenerationError +from workflowai.core.domain.run import Run + + +class TestRecoverableError: + def _mock_agent_run_failed(self, test_client: IntTestClient, output: Optional[dict[str, Any]] = None): + # The agent run + test_client.mock_response( + status_code=424, + json={ + "id": "123", + "task_output": output or {"capital": "Tokyo"}, + "error": { + "code": "agent_run_failed", + "message": "Test error message", + }, + }, + ) + + async def test_output_only(self, test_client: IntTestClient): + @workflowai.agent(schema_id=1) + async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: ... + + self._mock_agent_run_failed(test_client) + + with pytest.raises(InvalidGenerationError) as e: + await city_to_capital(CityToCapitalTaskInput(city="Hello")) + + assert e.value.run_id == "123" + assert e.value.partial_output == {"capital": "Tokyo"} + + async def test_recover(self, test_client: IntTestClient): + # When the return is a full run object we try and recover the error using a partial output + + @workflowai.agent(schema_id=1) + async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapitalTaskOutput]: ... + + self._mock_agent_run_failed(test_client) + + run = await city_to_capital(CityToCapitalTaskInput(city="Hello")) + assert run.id == "123" + assert run.output.capital == "Tokyo" + + assert run.error is not None + assert run.error.code == "agent_run_failed" + + async def test_unrecoverable_error(self, test_client: IntTestClient): + @workflowai.agent(schema_id=1) + async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapitalTaskOutput]: ... + + # Mocking with an invalid output, CityToCapitalTaskOutput requires the capital field + self._mock_agent_run_failed(test_client, {"capitale": "Tokyo"}) + + with pytest.raises(InvalidGenerationError): + await city_to_capital(CityToCapitalTaskInput(city="Hello")) diff --git a/tests/integration/run_test.py b/tests/integration/run_test.py index 6645225..b6e6f1d 100644 --- a/tests/integration/run_test.py +++ b/tests/integration/run_test.py @@ -1,90 +1,32 @@ import json -from typing import Any, AsyncIterator, Optional +from typing import AsyncIterator -from httpx import Request from pydantic import BaseModel -from pytest_httpx import HTTPXMock, IteratorStream import workflowai +from tests.integration.conftest import CityToCapitalTaskInput, CityToCapitalTaskOutput, IntTestClient from workflowai.core.domain.run import Run -class CityToCapitalTaskInput(BaseModel): - city: str - - -class CityToCapitalTaskOutput(BaseModel): - capital: str - - -workflowai.init(api_key="test", url="https://run.workflowai.dev") - -_REGISTER_URL = "https://api.workflowai.dev/v1/_/agents" - - -def _mock_register(httpx_mock: HTTPXMock, schema_id: int = 1, task_id: str = "city-to-capital", variant_id: str = "1"): - httpx_mock.add_response( - method="POST", - url=_REGISTER_URL, - json={"schema_id": schema_id, "variant_id": variant_id, "id": task_id}, - ) - - -def _mock_response(httpx_mock: HTTPXMock, task_id: str = "city-to-capital", capital: str = "Tokyo"): - httpx_mock.add_response( - method="POST", - url=f"https://run.workflowai.dev/v1/_/agents/{task_id}/schemas/1/run", - json={"id": "123", "task_output": {"capital": capital}}, - ) - - -def _mock_stream(httpx_mock: HTTPXMock, task_id: str = "city-to-capital"): - httpx_mock.add_response( - url=f"https://run.workflowai.dev/v1/_/agents/{task_id}/schemas/1/run", - stream=IteratorStream( - [ - b'data: {"id":"1","task_output":{"capital":""}}\n\n', - b'data: {"id":"1","task_output":{"capital":"Tok"}}\n\ndata: {"id":"1","task_output":{"capital":"Tokyo"}}\n\n', # noqa: E501 - b'data: {"id":"1","task_output":{"capital":"Tokyo"},"cost_usd":0.01,"duration_seconds":10.1}\n\n', - ], - ), - ) - - -def _check_request(request: Optional[Request], version: Any = "production", task_id: str = "city-to-capital"): - assert request is not None - assert request.url == f"https://run.workflowai.dev/v1/_/agents/{task_id}/schemas/1/run" - body = json.loads(request.content) - assert body == { - "task_input": {"city": "Hello"}, - "version": version, - "stream": False, - } - assert request.headers["Authorization"] == "Bearer test" - assert request.headers["Content-Type"] == "application/json" - assert request.headers["x-workflowai-source"] == "sdk" - assert request.headers["x-workflowai-language"] == "python" - - -async def test_run_task(httpx_mock: HTTPXMock) -> None: +async def test_run_task(test_client: IntTestClient) -> None: @workflowai.agent(schema_id=1) async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: ... - _mock_response(httpx_mock) + test_client.mock_response() task_input = CityToCapitalTaskInput(city="Hello") output = await city_to_capital(task_input) assert output.capital == "Tokyo" - _check_request(httpx_mock.get_request()) + test_client.check_request() -async def test_run_task_run(httpx_mock: HTTPXMock) -> None: +async def test_run_task_run(test_client: IntTestClient) -> None: @workflowai.agent(schema_id=1) async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapitalTaskOutput]: ... - _mock_response(httpx_mock) + test_client.mock_response() task_input = CityToCapitalTaskInput(city="Hello") with_run = await city_to_capital(task_input) @@ -92,14 +34,14 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapit assert with_run.id == "123" assert with_run.output.capital == "Tokyo" - _check_request(httpx_mock.get_request()) + test_client.check_request() -async def test_run_task_run_version(httpx_mock: HTTPXMock) -> None: +async def test_run_task_run_version(test_client: IntTestClient) -> None: @workflowai.agent(schema_id=1, version="staging") async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapitalTaskOutput]: ... - _mock_response(httpx_mock) + test_client.mock_response() task_input = CityToCapitalTaskInput(city="Hello") with_run = await city_to_capital(task_input) @@ -107,14 +49,14 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapit assert with_run.id == "123" assert with_run.output.capital == "Tokyo" - _check_request(httpx_mock.get_request(), version="staging") + test_client.check_request(version="staging") -async def test_stream_task_run(httpx_mock: HTTPXMock) -> None: +async def test_stream_task_run(test_client: IntTestClient) -> None: @workflowai.agent(schema_id=1) def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToCapitalTaskOutput]: ... - _mock_stream(httpx_mock) + test_client.mock_stream() task_input = CityToCapitalTaskInput(city="Hello") chunks = [chunk async for chunk in city_to_capital(task_input)] @@ -127,11 +69,11 @@ def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToC ] -async def test_stream_task_run_custom_id(httpx_mock: HTTPXMock) -> None: +async def test_stream_task_run_custom_id(test_client: IntTestClient) -> None: @workflowai.agent(schema_id=1, id="custom-id") def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToCapitalTaskOutput]: ... - _mock_stream(httpx_mock, task_id="custom-id") + test_client.mock_stream(task_id="custom-id") task_input = CityToCapitalTaskInput(city="Hello") chunks = [chunk async for chunk in city_to_capital(task_input)] @@ -144,25 +86,25 @@ def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToC ] -async def test_auto_register(httpx_mock: HTTPXMock): +async def test_auto_register(test_client: IntTestClient): @workflowai.agent() async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: ... - _mock_register(httpx_mock) + test_client.mock_register() - _mock_response(httpx_mock) + test_client.mock_response() res = await city_to_capital(CityToCapitalTaskInput(city="Hello")) assert res.capital == "Tokyo" - _mock_response(httpx_mock, capital="Paris") + test_client.mock_response(capital="Paris") # Run it a second time res = await city_to_capital(CityToCapitalTaskInput(city="Hello"), use_cache="never") assert res.capital == "Paris" - req = httpx_mock.get_requests() + req = test_client.httpx_mock.get_requests() assert len(req) == 3 - assert req[0].url == _REGISTER_URL + assert req[0].url == test_client.REGISTER_URL req_body = json.loads(req[0].read()) assert req_body == { @@ -194,3 +136,67 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTa "type": "object", }, } + + +async def test_run_with_tool(test_client: IntTestClient): + class _SayHelloToolInput(BaseModel): + name: str + + class _SayHelloToolOutput(BaseModel): + message: str + + def say_hello(tool_input: _SayHelloToolInput) -> _SayHelloToolOutput: + return _SayHelloToolOutput(message=f"Hello {tool_input.name}") + + @workflowai.agent(id="city-to-capital", tools=[say_hello]) + async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: + """Say hello to the user""" + ... + + test_client.mock_register() + + # First response will respond a tool call request + test_client.mock_response( + json={ + "id": "1234", + "task_output": {}, + "tool_call_requests": [ + { + "id": "say_hello_1", + "name": "say_hello", + "input": {"tool_input": {"name": "john"}}, + }, + ], + }, + ) + + # Second response will respond the final output + test_client.mock_response(url="https://run.workflowai.dev/v1/_/agents/city-to-capital/runs/1234/reply") + + task_input = CityToCapitalTaskInput(city="Hello") + output = await city_to_capital(task_input) + assert output.capital == "Tokyo" + + assert len(test_client.httpx_mock.get_requests()) == 3 + + run_req = test_client.httpx_mock.get_request( + url="https://run.workflowai.dev/v1/_/agents/city-to-capital/schemas/1/run", + ) + assert run_req is not None + run_req_body = json.loads(run_req.content) + assert run_req_body["task_input"] == {"city": "Hello"} + assert set(run_req_body["version"].keys()) == {"enabled_tools", "instructions", "model"} + assert len(run_req_body["version"]["enabled_tools"]) == 1 + assert run_req_body["version"]["enabled_tools"][0]["name"] == "say_hello" + + reply_req = test_client.httpx_mock.get_request( + url="https://run.workflowai.dev/v1/_/agents/city-to-capital/runs/1234/reply", + ) + assert reply_req is not None + reply_req_body = json.loads(reply_req.content) + assert reply_req_body["tool_results"] == [ + { + "id": "say_hello_1", + "output": {"message": "Hello john"}, + }, + ] diff --git a/workflowai/core/client/_api.py b/workflowai/core/client/_api.py index 8d12a5e..d8d737d 100644 --- a/workflowai/core/client/_api.py +++ b/workflowai/core/client/_api.py @@ -112,7 +112,7 @@ def _extract_error( ) -> WorkflowAIError: try: res = ErrorResponse.model_validate_json(data) - return WorkflowAIError(error=res.error, task_run_id=res.task_run_id, response=response) + return WorkflowAIError(error=res.error, run_id=res.id, response=response, partial_output=res.task_output) except ValidationError: raise WorkflowAIError( error=BaseError( diff --git a/workflowai/core/client/_api_test.py b/workflowai/core/client/_api_test.py index 78622c8..b13452d 100644 --- a/workflowai/core/client/_api_test.py +++ b/workflowai/core/client/_api_test.py @@ -21,7 +21,7 @@ def test_extract_error(self): "message": "Test error message", "details": {"key": "value"}, }, - "task_run_id": "test_task_123", + "id": "test_task_123", }, ) @@ -29,7 +29,31 @@ def test_extract_error(self): assert isinstance(error, WorkflowAIError) assert error.error.message == "Test error message" assert error.error.details == {"key": "value"} - assert error.task_run_id == "test_task_123" + assert error.run_id == "test_task_123" + assert error.response == response + + def test_extract_partial_output(self): + client = APIClient(endpoint="test_endpoint", api_key="test_api_key") + + # Test valid JSON error response + response = httpx.Response( + status_code=400, + json={ + "error": { + "message": "Test error message", + "details": {"key": "value"}, + }, + "id": "test_task_123", + "task_output": {"key": "value"}, + }, + ) + + error = client._extract_error(response, response.content) # pyright:ignore[reportPrivateUsage] + assert isinstance(error, WorkflowAIError) + assert error.error.message == "Test error message" + assert error.error.details == {"key": "value"} + assert error.run_id == "test_task_123" + assert error.partial_output == {"key": "value"} assert error.response == response def test_extract_error_invalid_json(self): diff --git a/workflowai/core/client/_fn_utils.py b/workflowai/core/client/_fn_utils.py index 0f097dd..6c29776 100644 --- a/workflowai/core/client/_fn_utils.py +++ b/workflowai/core/client/_fn_utils.py @@ -15,17 +15,20 @@ get_type_hints, ) -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from typing_extensions import Unpack from workflowai.core.client._api import APIClient +from workflowai.core.client._models import RunResponse from workflowai.core.client._types import ( AgentDecorator, FinalRunTemplate, RunParams, RunTemplate, ) +from workflowai.core.client._utils import intolerant_validator from workflowai.core.client.agent import Agent +from workflowai.core.domain.errors import InvalidGenerationError from workflowai.core.domain.model import ModelOrStr from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentInput, AgentOutput @@ -104,7 +107,28 @@ def extract_fn_spec(fn: RunTemplate[AgentInput, AgentOutput]) -> RunFunctionSpec class _RunnableAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002 - return await self.run(input, **kwargs) + """An agent that returns a run object. Handles recoverable errors when possible""" + try: + return await self.run(input, **kwargs) + except InvalidGenerationError as e: + if e.partial_output and e.run_id: + try: + validator, _ = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls)) + run = self._build_run_no_tools( + chunk=RunResponse( + id=e.run_id, + task_output=e.partial_output, + ), + schema_id=self.schema_id or 0, + validator=validator, + ) + run.error = e.error + return run + + except ValidationError: + # Error is not recoverable so not returning anything + pass + raise e class _RunnableOutputOnlyAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 369b9fe..593eb3f 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -1,6 +1,6 @@ from typing import Any, Literal, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType] from typing_extensions import NotRequired, TypedDict from workflowai.core._common_types import OutputValidator @@ -126,7 +126,12 @@ class RunResponse(BaseModel): tool_calls: Optional[list[ToolCall]] = None tool_call_requests: Optional[list[ToolCallRequest]] = None - def to_domain(self, task_id: str, task_schema_id: int, validator: OutputValidator[AgentOutput]) -> Run[AgentOutput]: + def to_domain( + self, + task_id: str, + task_schema_id: int, + validator: OutputValidator[AgentOutput], + ) -> Run[AgentOutput]: return Run( id=self.id, agent_id=task_id, diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index 528f175..f8643f1 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -28,7 +28,6 @@ from workflowai.core.domain.tool_call import ToolCallRequest, ToolCallResult from workflowai.core.domain.version_properties import VersionProperties from workflowai.core.domain.version_reference import VersionReference -from workflowai.core.utils._tools import tool_schema class Agent(Generic[AgentInput, AgentOutput]): @@ -53,9 +52,9 @@ def __init__( self._tools = self.build_tools(tools) if tools else None @classmethod - def build_tools(cls, tools: Iterable[Callable[..., Any]]) -> dict[str, tuple[Tool, Callable[..., Any]]]: + def build_tools(cls, tools: Iterable[Callable[..., Any]]): # TODO: we should be more tolerant with errors ? - return {tool.__name__: (tool_schema(tool), tool) for tool in tools} + return {tool.__name__: Tool.from_fn(tool) for tool in tools} @property def api(self) -> APIClient: @@ -75,7 +74,7 @@ def _sanitize_version(self, version: Optional[VersionReference]) -> Union[str, i if not isinstance(version, VersionProperties): return version - dumped = version.model_dump(by_alias=True) + dumped = version.model_dump(by_alias=True, exclude_unset=True) if not dumped.get("model"): import workflowai @@ -88,7 +87,7 @@ def _sanitize_version(self, version: Optional[VersionReference]) -> Union[str, i "input_schema": tool.input_schema, "output_schema": tool.output_schema, } - for tool, _ in self._tools.values() + for tool in self._tools.values() ] return dumped @@ -159,11 +158,9 @@ async def register(self): return res.schema_id @classmethod - async def _safe_execute_tool(cls, tool_call_request: ToolCallRequest, tool_func: Callable[..., Any]): + async def _safe_execute_tool(cls, tool_call_request: ToolCallRequest, tool: Tool): try: - output: Any = tool_func(**tool_call_request.input) - if isinstance(output, Awaitable): - output = await output + output = await tool(tool_call_request.input) return ToolCallResult( id=tool_call_request.id, output=output, @@ -184,13 +181,13 @@ async def _execute_tools( if not self._tools: return None - executions: list[tuple[ToolCallRequest, Callable[..., Any]]] = [] + executions: list[tuple[ToolCallRequest, Tool]] = [] for tool_call_request in tool_call_requests: if tool_call_request.name not in self._tools: continue - _, tool_func = self._tools[tool_call_request.name] - executions.append((tool_call_request, tool_func)) + tool = self._tools[tool_call_request.name] + executions.append((tool_call_request, tool)) if not executions: return None @@ -206,6 +203,16 @@ async def _execute_tools( **kwargs, ) + def _build_run_no_tools( + self, + chunk: RunResponse, + schema_id: int, + validator: OutputValidator[AgentOutput], + ) -> Run[AgentOutput]: + run = chunk.to_domain(self.agent_id, schema_id, validator) + run._agent = self # pyright: ignore [reportPrivateUsage] + return run + async def _build_run( self, chunk: RunResponse, @@ -214,8 +221,7 @@ async def _build_run( current_iteration: int, **kwargs: Unpack[BaseRunParams], ) -> Run[AgentOutput]: - run = chunk.to_domain(self.agent_id, schema_id, validator) - run._agent = self # pyright: ignore [reportPrivateUsage] + run = self._build_run_no_tools(chunk, schema_id, validator) if run.tool_call_requests: if current_iteration >= kwargs.get("max_iterations", self._DEFAULT_MAX_ITERATIONS): diff --git a/workflowai/core/domain/errors.py b/workflowai/core/domain/errors.py index 5c58606..06232c9 100644 --- a/workflowai/core/domain/errors.py +++ b/workflowai/core/domain/errors.py @@ -36,6 +36,20 @@ # The requested model does not support the requested generation mode # (e-g a model that does not support images generation was sent an image) "model_does_not_support_mode", + # Invalid file provided + "invalid_file", + # The maximum number of tool call iterations was reached + "max_tool_call_iteration", + # The current configuration does not support structured generation + "structured_generation_error", + # The content was moderated + "content_moderation", + # Task banned + "task_banned", + # The request timed out + "timeout", + # Agent run failed + "agent_run_failed", ] ErrorCode = Union[ @@ -70,7 +84,8 @@ class BaseError(BaseModel): class ErrorResponse(BaseModel): error: BaseError - task_run_id: Optional[str] = None + id: Optional[str] = None + task_output: Optional[dict[str, Any]] = None def _retry_after_to_delay_seconds(retry_after: Any) -> Optional[float]: @@ -94,17 +109,25 @@ def __init__( self, response: Optional[Response], error: BaseError, - task_run_id: Optional[str] = None, + run_id: Optional[str] = None, retry_after_delay_seconds: Optional[float] = None, + partial_output: Optional[dict[str, Any]] = None, ): self.error = error - self.task_run_id = task_run_id + self.run_id = run_id self.response = response self._retry_after_delay_seconds = retry_after_delay_seconds + self.partial_output = partial_output def __str__(self): return f"WorkflowAIError : [{self.error.code}] ({self.error.status_code}): [{self.error.message}]" + @classmethod + def error_cls(cls, code: str): + if code == "invalid_generation" or code == "failed_generation" or code == "agent_run_failed": + return InvalidGenerationError + return cls + @classmethod def from_response(cls, response: Response): try: @@ -114,15 +137,17 @@ def from_response(cls, response: Response): details = r_error.get("details", {}) error_code = r_error.get("code", "unknown_error") status_code = response.status_code - task_run_id = r_error.get("task_run_id", None) + run_id = response_json.get("id", None) + partial_output = response_json.get("task_output", None) except JSONDecodeError: error_message = "Unknown error" details = {"raw": response.content.decode()} error_code = "unknown_error" status_code = response.status_code - task_run_id = None + run_id = None + partial_output = None - return cls( + return cls.error_cls(error_code)( response=response, error=BaseError( message=error_message, @@ -130,7 +155,8 @@ def from_response(cls, response: Response): status_code=status_code, code=error_code, ), - task_run_id=task_run_id, + run_id=run_id, + partial_output=partial_output, ) @property @@ -142,3 +168,6 @@ def retry_after_delay_seconds(self) -> Optional[float]: return _retry_after_to_delay_seconds(self.response.headers.get("Retry-After")) return None + + +class InvalidGenerationError(WorkflowAIError): ... diff --git a/workflowai/core/domain/run.py b/workflowai/core/domain/run.py index 9affd7e..62c6269 100644 --- a/workflowai/core/domain/run.py +++ b/workflowai/core/domain/run.py @@ -7,6 +7,7 @@ from workflowai.core import _common_types from workflowai.core.client import _types +from workflowai.core.domain.errors import BaseError from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.tool_call import ToolCall, ToolCallRequest, ToolCallResult from workflowai.core.domain.version import Version @@ -42,6 +43,11 @@ class Run(BaseModel, Generic[AgentOutput]): tool_calls: Optional[list[ToolCall]] = None tool_call_requests: Optional[list[ToolCallRequest]] = None + error: Optional[BaseError] = Field( + default=None, + description="An error that occurred during the run. Only provided if the run failed.", + ) + _agent: Optional["_AgentBase[AgentOutput]"] = None def __eq__(self, other: object) -> bool: @@ -74,6 +80,12 @@ async def reply( **kwargs, ) + @property + def model(self): + if self.version is None: + return None + return self.version.properties.model + class _AgentBase(Protocol, Generic[AgentOutput]): async def reply( diff --git a/workflowai/core/domain/tool.py b/workflowai/core/domain/tool.py index 13b02bc..07443d8 100644 --- a/workflowai/core/domain/tool.py +++ b/workflowai/core/domain/tool.py @@ -1,11 +1,66 @@ -from typing import Any +import inspect +from collections.abc import Awaitable, Callable +from typing import Any, Optional from pydantic import BaseModel, Field +from workflowai.core.utils._tools import tool_schema -class Tool(BaseModel): + +class ToolDefinition(BaseModel): name: str = Field(description="The name of the tool") description: str = Field(default="", description="The description of the tool") input_schema: dict[str, Any] = Field(description="The input class of the tool") output_schema: dict[str, Any] = Field(description="The output class of the tool") + + +ToolFunction = Callable[..., Any] + + +class Tool(ToolDefinition): + input_deserializer: Optional[Callable[[Any], Any]] = Field( + default=None, + description="The deserializer for the input class of the tool", + ) + + output_serializer: Optional[Callable[[Any], Any]] = Field( + default=None, + description="The serializer for the output class of the tool", + ) + + tool_fn: Callable[..., Any] = Field(description="The function that implements the tool") + + @classmethod + def from_fn(cls, func: ToolFunction): + """Creates JSON schemas for function input parameters and return type. + + Args: + func (Callable[[Any], Any]): a Python callable with annotated types + + Returns: + FunctionJsonSchema: a FunctionJsonSchema object containing the function input/output JSON schemas + """ + + tool_description = inspect.getdoc(func) + input_schema, output_schema = tool_schema(func) + return cls( + name=func.__name__, + description=tool_description or "", + input_schema=input_schema.schema, + input_deserializer=input_schema.deserializer, + output_schema=output_schema.schema, + output_serializer=output_schema.serializer, + tool_fn=func, + ) + + async def __call__(self, tool_input: Any): + deserialized_input = self.input_deserializer(tool_input) if self.input_deserializer else tool_input + if not deserialized_input: + deserialized_input = {} + output: Any = self.tool_fn(**deserialized_input) + if isinstance(output, Awaitable): + output = await output + if self.output_serializer: + return self.output_serializer(output) + return output diff --git a/workflowai/core/domain/tool_test.py b/workflowai/core/domain/tool_test.py new file mode 100644 index 0000000..8260345 --- /dev/null +++ b/workflowai/core/domain/tool_test.py @@ -0,0 +1,65 @@ +from pydantic import BaseModel + +from workflowai.core.domain.tool import Tool + + +class TestToolDefinition: + def test_simple(self): + def sample_func(name: str, age: int) -> str: + """Hello I am a docstring""" + return f"Hello {name}, you are {age} years old" + + tool = Tool.from_fn(sample_func) + assert tool.name == "sample_func" + assert tool.description == "Hello I am a docstring" + assert tool.input_schema == { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name", "age"], + } + assert tool.output_schema == {"type": "string"} + assert tool.input_deserializer is None + assert tool.output_serializer is None + assert tool.tool_fn == sample_func + + +class TestSanity: + """Test that we can create a tool from a function and then call it from a serialized input""" + + async def test_simple_non_async(self): + def sample_func(name: str, age: int) -> str: + return f"Hello {name}, you are {age} years old" + + tool = Tool.from_fn(sample_func) + + assert await tool({"name": "John", "age": 30}) == "Hello John, you are 30 years old" + + async def test_simple_async(self): + async def sample_func(name: str, age: int) -> str: + return f"Hello {name}, you are {age} years old" + + tool = Tool.from_fn(sample_func) + + assert await tool({"name": "John", "age": 30}) == "Hello John, you are 30 years old" + + async def test_base_model_in_input(self): + class TestModel(BaseModel): + name: str + + def sample_func(model: TestModel) -> str: + return f"Hello {model.name}" + + tool = Tool.from_fn(sample_func) + + assert await tool({"model": {"name": "John"}}) == "Hello John" + + async def test_base_model_in_output(self): + class TestModel(BaseModel): + name: str + + def sample_func() -> TestModel: + return TestModel(name="John") + + tool = Tool.from_fn(sample_func) + + assert await tool({}) == {"name": "John"} diff --git a/workflowai/core/domain/version_properties.py b/workflowai/core/domain/version_properties.py index 8ba1736..7906136 100644 --- a/workflowai/core/domain/version_properties.py +++ b/workflowai/core/domain/version_properties.py @@ -1,9 +1,9 @@ from typing import Optional, Union -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field # pyright: ignore [reportUnknownVariableType] from workflowai.core.domain.model import ModelOrStr -from workflowai.core.domain.tool import Tool +from workflowai.core.domain.tool import ToolDefinition class VersionProperties(BaseModel): @@ -44,7 +44,7 @@ class VersionProperties(BaseModel): description="The version of the runner used", ) - enabled_tools: Optional[list[Union[str, Tool]]] = Field( + enabled_tools: Optional[list[Union[str, ToolDefinition]]] = Field( default=None, description="The tools enabled for the run. A string can be used to refer to a tool hosted by WorkflowAI", ) diff --git a/workflowai/core/utils/_tools.py b/workflowai/core/utils/_tools.py index 7521c12..f16632e 100644 --- a/workflowai/core/utils/_tools.py +++ b/workflowai/core/utils/_tools.py @@ -1,6 +1,6 @@ import inspect from enum import Enum -from typing import Any, Callable, get_type_hints +from typing import Any, Callable, NamedTuple, Optional, cast, get_type_hints from pydantic import BaseModel @@ -9,34 +9,13 @@ ToolFunction = Callable[..., Any] -def tool_schema(func: ToolFunction): - """Creates JSON schemas for function input parameters and return type. +class SchemaDeserializer(NamedTuple): + schema: dict[str, Any] + serializer: Optional[Callable[[Any], Any]] = None + deserializer: Optional[Callable[[Any], Any]] = None - Args: - func (Callable[[Any], Any]): a Python callable with annotated types - - Returns: - FunctionJsonSchema: a FunctionJsonSchema object containing the function input/output JSON schemas - """ - from workflowai.core.domain.tool import Tool - - sig = inspect.signature(func) - type_hints = get_type_hints(func, include_extras=True) - input_schema = _build_input_schema(sig, type_hints) - output_schema = _build_output_schema(type_hints) - - tool_description = inspect.getdoc(func) - - return Tool( - name=func.__name__, - description=tool_description or "", - input_schema=input_schema, - output_schema=output_schema, - ) - - -def _get_type_schema(param_type: type) -> dict[str, Any]: +def _get_type_schema(param_type: type): """Convert a Python type to its corresponding JSON schema type. Args: @@ -48,24 +27,31 @@ def _get_type_schema(param_type: type) -> dict[str, Any]: if issubclass(param_type, Enum): if not issubclass(param_type, str): raise ValueError(f"Non string enums are not supported: {param_type}") - return {"type": "string", "enum": [e.value for e in param_type]} + return SchemaDeserializer({"type": "string", "enum": [e.value for e in param_type]}) if param_type is str: - return {"type": "string"} + return SchemaDeserializer({"type": "string"}) + + if param_type is int: + return SchemaDeserializer({"type": "integer"}) - if param_type in (int, float): - return {"type": "number"} + if param_type is float: + return SchemaDeserializer({"type": "number"}) if param_type is bool: - return {"type": "boolean"} + return SchemaDeserializer({"type": "boolean"}) if issubclass(param_type, BaseModel): - return param_type.model_json_schema(by_alias=True, schema_generator=JsonSchemaGenerator) + return SchemaDeserializer( + schema=param_type.model_json_schema(by_alias=True, schema_generator=JsonSchemaGenerator), + serializer=lambda x: cast(BaseModel, x).model_dump(mode="json"), # pyright: ignore [reportUnknownLambdaType] + deserializer=param_type.model_validate, + ) raise ValueError(f"Unsupported type: {param_type}") -def _schema_from_type_hint(param_type_hint: Any) -> dict[str, Any]: +def _schema_from_type_hint(param_type_hint: Any): param_type = param_type_hint.__origin__ if hasattr(param_type_hint, "__origin__") else param_type_hint if not isinstance(param_type, type): raise ValueError(f"Unsupported type: {param_type}") @@ -73,31 +59,58 @@ def _schema_from_type_hint(param_type_hint: Any) -> dict[str, Any]: param_description = param_type_hint.__metadata__[0] if hasattr(param_type_hint, "__metadata__") else None param_schema = _get_type_schema(param_type) if param_description: - param_schema["description"] = param_description + param_schema.schema["description"] = param_description return param_schema -def _build_input_schema(sig: inspect.Signature, type_hints: dict[str, Any]) -> dict[str, Any]: +def _combine_deserializers(deserializers: dict[str, Callable[[Any], Any]]): + def deserializer(_input: dict[str, Any]) -> dict[str, Any]: + return {k: deserializers[k](v) if k in deserializers else v for k, v in _input.items()} + + return deserializer + + +def _build_input_schema(sig: inspect.Signature, type_hints: dict[str, Any]): input_schema: dict[str, Any] = {"type": "object", "properties": {}, "required": []} + deserializers: dict[str, Callable[[Any], Any]] = {} for param_name, param in sig.parameters.items(): if param_name == "self": continue - param_schema = _schema_from_type_hint(type_hints[param_name]) + out = _schema_from_type_hint(type_hints[param_name]) if param.default is inspect.Parameter.empty: input_schema["required"].append(param_name) - input_schema["properties"][param_name] = param_schema + input_schema["properties"][param_name] = out.schema + if out.deserializer: + deserializers[param_name] = out.deserializer + + if not input_schema["properties"]: + return SchemaDeserializer({}) - return input_schema + deserializer = _combine_deserializers(deserializers) if deserializers else None + # No need to serialize the input + return SchemaDeserializer(input_schema, deserializer=deserializer) -def _build_output_schema(type_hints: dict[str, Any]) -> dict[str, Any]: + +def _build_output_schema(type_hints: dict[str, Any]): return_type = type_hints.get("return") if not return_type: raise ValueError("Return type annotation is required") - return _schema_from_type_hint(return_type) + # No need to deserialize the output + return _schema_from_type_hint(return_type)._replace(deserializer=None) + + +def tool_schema(func: Callable[..., Any]): + sig = inspect.signature(func) + type_hints = get_type_hints(func, include_extras=True) + + input_schema = _build_input_schema(sig, type_hints) + output_schema = _build_output_schema(type_hints) + + return input_schema, output_schema diff --git a/workflowai/core/utils/_tools_test.py b/workflowai/core/utils/_tools_test.py index 607f3a2..d6c04da 100644 --- a/workflowai/core/utils/_tools_test.py +++ b/workflowai/core/utils/_tools_test.py @@ -22,10 +22,9 @@ def sample_func( """Sample function for testing""" ... - schema = tool_schema(sample_func) + input_schema, output_schema = tool_schema(sample_func) - assert schema.name == "sample_func" - assert schema.input_schema == { + assert input_schema.schema == { "type": "object", "properties": { "name": { @@ -33,7 +32,7 @@ def sample_func( "description": "The name parameter", }, "age": { - "type": "number", + "type": "integer", }, "height": { "type": "number", @@ -48,28 +47,27 @@ def sample_func( }, "required": ["name", "age", "height", "is_active"], # 'mode' is not required } - assert schema.output_schema == { + assert output_schema.schema == { "type": "boolean", } - assert schema.description == "Sample function for testing" def test_method_with_self(self): class TestClass: def sample_method(self, value: int) -> str: return str(value) - schema = tool_schema(TestClass.sample_method) + input_schema, output_schema = tool_schema(TestClass.sample_method) - assert schema.input_schema == { + assert input_schema.schema == { "type": "object", "properties": { "value": { - "type": "number", + "type": "integer", }, }, "required": ["value"], } - assert schema.output_schema == { + assert output_schema.schema == { "type": "string", } @@ -79,9 +77,9 @@ class TestModel(BaseModel): def sample_func(model: TestModel) -> str: ... - schema = tool_schema(sample_func) + input_schema, output_schema = tool_schema(sample_func) - assert schema.input_schema == { + assert input_schema.schema == { "type": "object", "properties": { "model": { @@ -99,16 +97,29 @@ def sample_func(model: TestModel) -> str: ... "required": ["model"], } + assert input_schema.deserializer is not None + assert input_schema.deserializer({"model": {"name": "John"}}) == {"model": TestModel(name="John")} + + assert output_schema.schema == { + "type": "string", + } + assert output_schema.deserializer is None + def test_with_base_model_in_output(self): class TestModel(BaseModel): val: int def sample_func() -> TestModel: ... - schema = tool_schema(sample_func) + input_schema, output_schema = tool_schema(sample_func) + + assert input_schema.schema == {} + assert input_schema.deserializer is None - assert schema.output_schema == { + assert output_schema.schema == { "type": "object", "properties": {"val": {"type": "integer"}}, "required": ["val"], } + assert output_schema.serializer is not None + assert output_schema.serializer(TestModel(val=10)) == {"val": 10}