From 4b235ec1b964ecbea5a458782d806ae756effd10 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Thu, 30 Jan 2025 21:15:54 -0500 Subject: [PATCH 1/3] fix: better tool type introspection --- .vscode/launch.json | 14 +++++++ workflowai/core/utils/_schema_generator.py | 19 +++++++++ .../core/utils/_schema_generator_test.py | 12 ++++++ workflowai/core/utils/_tools.py | 34 ++++++++------- workflowai/core/utils/_tools_test.py | 42 +++++++++++++++++++ 5 files changed, 106 insertions(+), 15 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 workflowai/core/utils/_schema_generator.py create mode 100644 workflowai/core/utils/_schema_generator_test.py diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..e51fec7 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,14 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Debug Tests", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "purpose": ["debug-test"], + "console": "integratedTerminal", + "justMyCode": false + } + ] +} diff --git a/workflowai/core/utils/_schema_generator.py b/workflowai/core/utils/_schema_generator.py new file mode 100644 index 0000000..fb8eccf --- /dev/null +++ b/workflowai/core/utils/_schema_generator.py @@ -0,0 +1,19 @@ +from typing import Any + +from pydantic.json_schema import GenerateJsonSchema +from typing_extensions import override + + +class JsonSchemaGenerator(GenerateJsonSchema): + """A schema generator that simplifies the schemas generated by pydantic.""" + + @override + def generate(self, *args: Any, **kwargs: Any): + generated = super().generate(*args, **kwargs) + # Remove the title from the schema + generated.pop("title", None) + return generated + + @override + def field_title_should_be_set(self, *args: Any, **kwargs: Any) -> bool: + return False diff --git a/workflowai/core/utils/_schema_generator_test.py b/workflowai/core/utils/_schema_generator_test.py new file mode 100644 index 0000000..6998fee --- /dev/null +++ b/workflowai/core/utils/_schema_generator_test.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + +from workflowai.core.utils._schema_generator import JsonSchemaGenerator + + +class TestJsonSchemaGenerator: + def test_generate(self): + class TestModel(BaseModel): + name: str + + schema = TestModel.model_json_schema(schema_generator=JsonSchemaGenerator) + assert schema == {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]} diff --git a/workflowai/core/utils/_tools.py b/workflowai/core/utils/_tools.py index e7be95d..7521c12 100644 --- a/workflowai/core/utils/_tools.py +++ b/workflowai/core/utils/_tools.py @@ -4,6 +4,8 @@ from pydantic import BaseModel +from workflowai.core.utils._schema_generator import JsonSchemaGenerator + ToolFunction = Callable[..., Any] @@ -57,12 +59,25 @@ def _get_type_schema(param_type: type) -> dict[str, Any]: if param_type is bool: return {"type": "boolean"} - if isinstance(param_type, BaseModel): - return param_type.model_json_schema() + if issubclass(param_type, BaseModel): + return param_type.model_json_schema(by_alias=True, schema_generator=JsonSchemaGenerator) raise ValueError(f"Unsupported type: {param_type}") +def _schema_from_type_hint(param_type_hint: Any) -> dict[str, Any]: + param_type = param_type_hint.__origin__ if hasattr(param_type_hint, "__origin__") else param_type_hint + if not isinstance(param_type, type): + raise ValueError(f"Unsupported type: {param_type}") + + param_description = param_type_hint.__metadata__[0] if hasattr(param_type_hint, "__metadata__") else None + param_schema = _get_type_schema(param_type) + if param_description: + param_schema["description"] = param_description + + return param_schema + + def _build_input_schema(sig: inspect.Signature, type_hints: dict[str, Any]) -> dict[str, Any]: input_schema: dict[str, Any] = {"type": "object", "properties": {}, "required": []} @@ -70,13 +85,7 @@ def _build_input_schema(sig: inspect.Signature, type_hints: dict[str, Any]) -> d if param_name == "self": continue - param_type_hint = type_hints[param_name] - param_type = param_type_hint.__origin__ if hasattr(param_type_hint, "__origin__") else param_type_hint - param_description = param_type_hint.__metadata__[0] if hasattr(param_type_hint, "__metadata__") else None - - param_schema = _get_type_schema(param_type) if isinstance(param_type, type) else {"type": "string"} - if param_description is not None: - param_schema["description"] = param_description + param_schema = _schema_from_type_hint(type_hints[param_name]) if param.default is inspect.Parameter.empty: input_schema["required"].append(param_name) @@ -91,9 +100,4 @@ def _build_output_schema(type_hints: dict[str, Any]) -> dict[str, Any]: if not return_type: raise ValueError("Return type annotation is required") - return_type_base = return_type.__origin__ if hasattr(return_type, "__origin__") else return_type - - if not isinstance(return_type_base, type): - raise ValueError(f"Unsupported return type: {return_type_base}") - - return _get_type_schema(return_type_base) + return _schema_from_type_hint(return_type) diff --git a/workflowai/core/utils/_tools_test.py b/workflowai/core/utils/_tools_test.py index 6dc7628..607f3a2 100644 --- a/workflowai/core/utils/_tools_test.py +++ b/workflowai/core/utils/_tools_test.py @@ -1,6 +1,8 @@ from enum import Enum from typing import Annotated +from pydantic import BaseModel + from workflowai.core.utils._tools import tool_schema @@ -70,3 +72,43 @@ def sample_method(self, value: int) -> str: assert schema.output_schema == { "type": "string", } + + def test_with_base_model_in_input(self): + class TestModel(BaseModel): + name: str + + def sample_func(model: TestModel) -> str: ... + + schema = tool_schema(sample_func) + + assert schema.input_schema == { + "type": "object", + "properties": { + "model": { + "properties": { + "name": { + "type": "string", + }, + }, + "required": [ + "name", + ], + "type": "object", + }, + }, + "required": ["model"], + } + + def test_with_base_model_in_output(self): + class TestModel(BaseModel): + val: int + + def sample_func() -> TestModel: ... + + schema = tool_schema(sample_func) + + assert schema.output_schema == { + "type": "object", + "properties": {"val": {"type": "integer"}}, + "required": ["val"], + } From a2206fc91a79f832bf51ab3268a9a4116f76638b Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Thu, 30 Jan 2025 22:12:29 -0500 Subject: [PATCH 2/3] feat: add max iteration check --- examples/city_to_capital_task.py | 2 +- workflowai/core/client/agent.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/city_to_capital_task.py b/examples/city_to_capital_task.py index e095e38..1dccb05 100644 --- a/examples/city_to_capital_task.py +++ b/examples/city_to_capital_task.py @@ -21,7 +21,7 @@ class CityToCapitalTaskOutput(BaseModel): ) -@workflowai.task(schema_id=1) +@workflowai.agent(schema_id=1) async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: ... diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index cbe06f5..528f175 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -32,6 +32,8 @@ class Agent(Generic[AgentInput, AgentOutput]): + _DEFAULT_MAX_ITERATIONS = 10 + def __init__( self, agent_id: str, @@ -216,6 +218,8 @@ async def _build_run( run._agent = self # pyright: ignore [reportPrivateUsage] if run.tool_call_requests: + if current_iteration >= kwargs.get("max_iterations", self._DEFAULT_MAX_ITERATIONS): + raise WorkflowAIError(error=BaseError(message="max tool iterations reached"), response=None) with_reply = await self._execute_tools( run_id=run.id, tool_call_requests=run.tool_call_requests, From 31831d27d21fee2e8a879fe48fdd55dd31f02aa1 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Thu, 30 Jan 2025 22:12:51 -0500 Subject: [PATCH 3/3] chore: bump version dev4 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 77ac7c5..1fed874 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.6.0.dev3" +version = "0.6.0.dev4" description = "" authors = ["Guillaume Aquilina "] readme = "README.md"