Skip to content

Tools feedback #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Python: Debug Tests",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"purpose": ["debug-test"],
"console": "integratedTerminal",
"justMyCode": false
}
]
}
2 changes: 1 addition & 1 deletion examples/city_to_capital_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class CityToCapitalTaskOutput(BaseModel):
)


@workflowai.task(schema_id=1)
@workflowai.agent(schema_id=1)
async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: ...


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "workflowai"
version = "0.6.0.dev3"
version = "0.6.0.dev4"
description = ""
authors = ["Guillaume Aquilina <[email protected]>"]
readme = "README.md"
Expand Down
4 changes: 4 additions & 0 deletions workflowai/core/client/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@


class Agent(Generic[AgentInput, AgentOutput]):
_DEFAULT_MAX_ITERATIONS = 10

def __init__(
self,
agent_id: str,
Expand Down Expand Up @@ -216,6 +218,8 @@ async def _build_run(
run._agent = self # pyright: ignore [reportPrivateUsage]

if run.tool_call_requests:
if current_iteration >= kwargs.get("max_iterations", self._DEFAULT_MAX_ITERATIONS):
raise WorkflowAIError(error=BaseError(message="max tool iterations reached"), response=None)
with_reply = await self._execute_tools(
run_id=run.id,
tool_call_requests=run.tool_call_requests,
Expand Down
19 changes: 19 additions & 0 deletions workflowai/core/utils/_schema_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Any

from pydantic.json_schema import GenerateJsonSchema
from typing_extensions import override


class JsonSchemaGenerator(GenerateJsonSchema):
"""A schema generator that simplifies the schemas generated by pydantic."""

@override
def generate(self, *args: Any, **kwargs: Any):
generated = super().generate(*args, **kwargs)
# Remove the title from the schema
generated.pop("title", None)
return generated

@override
def field_title_should_be_set(self, *args: Any, **kwargs: Any) -> bool:
return False
12 changes: 12 additions & 0 deletions workflowai/core/utils/_schema_generator_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pydantic import BaseModel

from workflowai.core.utils._schema_generator import JsonSchemaGenerator


class TestJsonSchemaGenerator:
def test_generate(self):
class TestModel(BaseModel):
name: str

schema = TestModel.model_json_schema(schema_generator=JsonSchemaGenerator)
assert schema == {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}
34 changes: 19 additions & 15 deletions workflowai/core/utils/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from pydantic import BaseModel

from workflowai.core.utils._schema_generator import JsonSchemaGenerator

ToolFunction = Callable[..., Any]


Expand Down Expand Up @@ -57,26 +59,33 @@ def _get_type_schema(param_type: type) -> dict[str, Any]:
if param_type is bool:
return {"type": "boolean"}

if isinstance(param_type, BaseModel):
return param_type.model_json_schema()
if issubclass(param_type, BaseModel):
return param_type.model_json_schema(by_alias=True, schema_generator=JsonSchemaGenerator)

raise ValueError(f"Unsupported type: {param_type}")


def _schema_from_type_hint(param_type_hint: Any) -> dict[str, Any]:
param_type = param_type_hint.__origin__ if hasattr(param_type_hint, "__origin__") else param_type_hint
if not isinstance(param_type, type):
raise ValueError(f"Unsupported type: {param_type}")

param_description = param_type_hint.__metadata__[0] if hasattr(param_type_hint, "__metadata__") else None
param_schema = _get_type_schema(param_type)
if param_description:
param_schema["description"] = param_description

return param_schema


def _build_input_schema(sig: inspect.Signature, type_hints: dict[str, Any]) -> dict[str, Any]:
input_schema: dict[str, Any] = {"type": "object", "properties": {}, "required": []}

for param_name, param in sig.parameters.items():
if param_name == "self":
continue

param_type_hint = type_hints[param_name]
param_type = param_type_hint.__origin__ if hasattr(param_type_hint, "__origin__") else param_type_hint
param_description = param_type_hint.__metadata__[0] if hasattr(param_type_hint, "__metadata__") else None

param_schema = _get_type_schema(param_type) if isinstance(param_type, type) else {"type": "string"}
if param_description is not None:
param_schema["description"] = param_description
param_schema = _schema_from_type_hint(type_hints[param_name])

if param.default is inspect.Parameter.empty:
input_schema["required"].append(param_name)
Expand All @@ -91,9 +100,4 @@ def _build_output_schema(type_hints: dict[str, Any]) -> dict[str, Any]:
if not return_type:
raise ValueError("Return type annotation is required")

return_type_base = return_type.__origin__ if hasattr(return_type, "__origin__") else return_type

if not isinstance(return_type_base, type):
raise ValueError(f"Unsupported return type: {return_type_base}")

return _get_type_schema(return_type_base)
return _schema_from_type_hint(return_type)
42 changes: 42 additions & 0 deletions workflowai/core/utils/_tools_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from enum import Enum
from typing import Annotated

from pydantic import BaseModel

from workflowai.core.utils._tools import tool_schema


Expand Down Expand Up @@ -70,3 +72,43 @@ def sample_method(self, value: int) -> str:
assert schema.output_schema == {
"type": "string",
}

def test_with_base_model_in_input(self):
class TestModel(BaseModel):
name: str

def sample_func(model: TestModel) -> str: ...

schema = tool_schema(sample_func)

assert schema.input_schema == {
"type": "object",
"properties": {
"model": {
"properties": {
"name": {
"type": "string",
},
},
"required": [
"name",
],
"type": "object",
},
},
"required": ["model"],
}

def test_with_base_model_in_output(self):
class TestModel(BaseModel):
val: int

def sample_func() -> TestModel: ...

schema = tool_schema(sample_func)

assert schema.output_schema == {
"type": "object",
"properties": {"val": {"type": "integer"}},
"required": ["val"],
}