Skip to content

Add tool support #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 161 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ run will be created. By default:

#### Using different models

WorkflowAI supports a long list of models. The source of truth for models we support is on [workflowai.com](https://workflowai.com). The [Model](./workflowai/core/domain/model.py) type is a good indication of what models are supported at the time of the sdk release, although it may be missing some models since new ones are added all the time.
WorkflowAI supports a long list of models. The source of truth for models we support is on [workflowai.com](https://workflowai.com). The [Model enum](./workflowai/core/domain/model.py) is a good indication of what models are supported at the time of the sdk release, although it may be missing some models since new ones are added all the time.

You can set the model explicitly in the agent decorator:

```python
@workflowai.agent(model="gpt-4o")
@workflowai.agent(model=Model.GPT_4O_LATEST)
def say_hello(input: Input) -> Output:
...
```
Expand Down Expand Up @@ -174,11 +174,20 @@ def say_hello(input: Input) -> AsyncIterator[Run[Output]]:

### Tools

WorkflowAI has a few tools that can be used to enhance the agent's capabilities:
Tools allow enhancing an agent's capabilities by allowing it to call external functions.

#### WorkflowAI Hosted tools

WorkflowAI hosts a few tools:

- `@browser-text` allows fetching the content of a web page
- `@search` allows performing a web search

Hosted tools tend to be faster because there is no back and forth between the client and the WorkflowAI API. Instead,
if a tool call is needed, the WorkflowAI API will call it within a single request.

A single run will be created for all tool iterations.

To use a tool, simply add it's handles to the instructions (the function docstring):

```python
Expand All @@ -190,6 +199,47 @@ def say_hello(input: Input) -> Output:
...
```

#### Custom tools

Custom tools allow using most functions within a single agent call. If an agent has custom tools, and the model
deems that tools are needed for a particular run, the agent will:

- call all tools in parallel
- wait for all tools to complete
- reply to the run with the tool outputs
- continue with the next step of the run, and re-execute tools if needed
- ...
- until either no tool calls are requested, the max iteration (10 by default) or the agent has run to completion

Tools are defined as regular python functions, and can be async or sync. Examples for tools are available in the [tools end 2 end test file](./tests/e2e/tools_test.py).

> **Important**: It must be possible to determine the schema of a tool from the function signature. This means that
> the function must have type annotations and use standard types or `BaseModel` only for now.

```python
# Annotations for parameters are passed as property descriptions in the tool schema
def get_current_time(timezone: Annotated[str, "The timezone to get the current time in. e-g Europe/Paris"]) -> str:
"""Return the current time in the given timezone in iso format"""
return datetime.now(ZoneInfo(timezone)).isoformat()

@agent(
id="answer-question",
tools=[get_current_time],
version=VersionProperties(model=Model.GPT_4O_LATEST),
)
async def answer_question(_: AnswerQuestionInput) -> Run[AnswerQuestionOutput]: ...

run = await answer_question(AnswerQuestionInput(question="What is the current time in Paris?"))
assert run.output.answer
```

> It's important to understand that there are actually two runs in a single agent call:
>
> - the first run returns an empty output with a tool call request with a timezone
> - the second run returns the current time in the given timezone
>
> Only the last run is returned to the caller.

### Error handling

Agents can raise errors, for example when the underlying model fails to generate a response or when
Expand All @@ -210,3 +260,111 @@ except WorkflowAIError as e:
print(e.code)
print(e.message)
```

### Definining input and output types

There are some important subtleties when defining input and output types.

#### Descriptions and examples

Field description and examples are passed to the model and can help stir the output in the right direction. A good
use case is to describe a format or style for a string field

```python
# summary has no examples or description so the model will likely return a block of text
class SummaryOutput(BaseModel):
summary: str

# passing the description will help the model return a summary formatted as bullet points
class SummaryOutput(BaseModel):
summary: str = Field(description="A summary, formatted as bullet points")

# passing examples can help as well
class SummaryOutput(BaseModel):
summary: str = Field(examples=["- Paris is a city in France\n- London is a city in England"])
```

Some notes:

- there are very little use cases for descriptions and examples in the **input** type. The model will most of the
infer from the value that is passed. An example use case is to use the description for fields that can be missing.
- adding examples that are too numerous or too specific can push the model to restrict the output value

#### Required versus optional fields

In short, we recommend using default values for most output fields.

Pydantic is by default rather strict on model validation. If there is no default value, the field must be provided.
Although the fact that a field is required is passed to the model, the generation can sometimes omit null or empty
values.

```python
class Input(BaseModel):
name: str

class OutputStrict(BaseModel):
greeting: str

@workflowai.agent()
async def say_hello_strict(_: Input) -> OutputStrict:
...

try:
run = await say_hello(Input(name="John"))
print(run.output.greeting) # "Hello, John!"
except WorkflowAIError as e:
print(e.code) # "invalid_generation" error code means that the generation did not match the schema

class OutputTolerant(BaseModel):
greeting: str = ""

@workflowai.agent()
async def say_hello_tolerant(_: Input) -> OutputTolerant:
...

# The invalid_generation is less likely
run = await say_hello_tolerant(Input(name="John"))
if not run.output.greeting:
print("No greeting was generated !")
print(run.output.greeting) # "Hello, John!"

```

> WorkflowAI automatically retries invalid generations once. If a model outputs an object that does not match the
> schema, a new generation is triggered with the previous response and the error message as context.

Another reason to prefer optional fields in the output is for streaming. Partial outputs are constructed using
`BaseModel.model_construct` when streaming. If a default value is not provided for a field, fields that are
absent will cause `AttributeError` when queried.

```python
class Input(BaseModel):
name: str

class OutputStrict(BaseModel):
greeting1: str
greeting2: str

@workflowai.agent()
def say_hello_strict(_: Input) -> AsyncIterator[Output]:
...

async for run in say_hello(Input(name="John")):
try:
print(run.output.greeting1)
except AttributeError:
# run.output.greeting1 has not been generated yet


class OutputTolerant(BaseModel):
greeting1: str = ""
greeting2: str = ""

@workflowai.agent()
def say_hello_tolerant(_: Input) -> AsyncIterator[OutputTolerant]:
...

async for run in say_hello(Input(name="John")):
print(run.output.greeting1) # will be empty if the model has not generated it yet

```
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "workflowai"
version = "0.6.0.dev2"
version = "0.6.0.dev3"
description = ""
authors = ["Guillaume Aquilina <[email protected]>"]
readme = "README.md"
Expand Down
64 changes: 64 additions & 0 deletions tests/e2e/tools_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from datetime import datetime
from typing import Annotated

from pydantic import BaseModel
from zoneinfo import ZoneInfo

from workflowai import Run, agent
from workflowai.core.domain.model import Model
from workflowai.core.domain.tool import Tool
from workflowai.core.domain.tool_call import ToolCallResult
from workflowai.core.domain.version_properties import VersionProperties


class AnswerQuestionInput(BaseModel):
question: str


class AnswerQuestionOutput(BaseModel):
answer: str = ""


async def test_manual_tool():
get_current_time_tool = Tool(
name="get_current_time",
description="Get the current time",
input_schema={},
output_schema={
"properties": {
"time": {"type": "string", "description": "The current time"},
},
},
)

@agent(
id="answer-question",
version=VersionProperties(model=Model.GPT_4O_LATEST, enabled_tools=[get_current_time_tool]),
)
async def answer_question(_: AnswerQuestionInput) -> Run[AnswerQuestionOutput]: ...

run = await answer_question(AnswerQuestionInput(question="What is the current time spelled out in French?"))
assert not run.output.answer

assert run.tool_call_requests
assert len(run.tool_call_requests) == 1
assert run.tool_call_requests[0].name == "get_current_time"

replied = await run.reply(tool_results=[ToolCallResult(id=run.tool_call_requests[0].id, output={"time": "12:00"})])
assert replied.output.answer


async def test_auto_tool():
def get_current_time(timezone: Annotated[str, "The timezone to get the current time in. e-g Europe/Paris"]) -> str:
"""Return the current time in the given timezone in iso format"""
return datetime.now(ZoneInfo(timezone)).isoformat()

@agent(
id="answer-question",
tools=[get_current_time],
version=VersionProperties(model=Model.GPT_4O_LATEST),
)
async def answer_question(_: AnswerQuestionInput) -> Run[AnswerQuestionOutput]: ...

run = await answer_question(AnswerQuestionInput(question="What is the current time in Paris?"))
assert run.output.answer
10 changes: 7 additions & 3 deletions workflowai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from typing import Optional
from collections.abc import Callable, Iterable
from typing import Any, Optional

from typing_extensions import deprecated

from workflowai.core.client._types import AgentDecorator
from workflowai.core.client.client import WorkflowAI as WorkflowAI
from workflowai.core.domain import model
from workflowai.core.domain.cache_usage import CacheUsage as CacheUsage
from workflowai.core.domain.errors import WorkflowAIError as WorkflowAIError
from workflowai.core.domain.model import Model as Model
Expand All @@ -31,7 +33,7 @@ def _build_client(
shared_client: WorkflowAI = _build_client()

# The default model to use when running agents without a deployment
DEFAULT_MODEL: Model = os.getenv("WORKFLOWAI_DEFAULT_MODEL", "gemini-1.5-pro-latest")
DEFAULT_MODEL: "model.ModelOrStr" = os.getenv("WORKFLOWAI_DEFAULT_MODEL", "gemini-1.5-pro-latest")


def init(api_key: Optional[str] = None, url: Optional[str] = None, default_version: Optional[VersionReference] = None):
Expand Down Expand Up @@ -66,7 +68,8 @@ def agent(
id: Optional[str] = None, # noqa: A002
schema_id: Optional[int] = None,
version: Optional[VersionReference] = None,
model: Optional[Model] = None,
model: Optional["model.ModelOrStr"] = None,
tools: Optional[Iterable[Callable[..., Any]]] = None,
) -> AgentDecorator:
from workflowai.core.client._fn_utils import agent_wrapper

Expand All @@ -76,4 +79,5 @@ def agent(
agent_id=id,
version=version,
model=model,
tools=tools,
)
36 changes: 36 additions & 0 deletions workflowai/core/_common_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import (
Any,
Generic,
Optional,
Protocol,
TypeVar,
)

from pydantic import BaseModel
from typing_extensions import NotRequired, TypedDict

from workflowai.core.domain.cache_usage import CacheUsage
from workflowai.core.domain.task import AgentOutput
from workflowai.core.domain.version_reference import VersionReference

AgentInputContra = TypeVar("AgentInputContra", bound=BaseModel, contravariant=True)
AgentOutputCov = TypeVar("AgentOutputCov", bound=BaseModel, covariant=True)


class OutputValidator(Protocol, Generic[AgentOutputCov]):
def __call__(self, data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutputCov: ...


class BaseRunParams(TypedDict):
version: NotRequired[Optional["VersionReference"]]
use_cache: NotRequired["CacheUsage"]
metadata: NotRequired[Optional[dict[str, Any]]]
labels: NotRequired[Optional[set[str]]]
max_retry_delay: NotRequired[float]
max_retry_count: NotRequired[float]

max_tool_iterations: NotRequired[int] # 10 by default


class RunParams(BaseRunParams, Generic[AgentOutput]):
validator: NotRequired[OutputValidator["AgentOutput"]]
File renamed without changes.
13 changes: 8 additions & 5 deletions workflowai/core/client/_fn_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
import inspect
from collections.abc import Callable
from collections.abc import Callable, Iterable
from typing import (
Any,
AsyncIterator,
Expand All @@ -26,7 +26,7 @@
RunTemplate,
)
from workflowai.core.client.agent import Agent
from workflowai.core.domain.model import Model
from workflowai.core.domain.model import ModelOrStr
from workflowai.core.domain.run import Run
from workflowai.core.domain.task import AgentInput, AgentOutput
from workflowai.core.domain.version_properties import VersionProperties
Expand Down Expand Up @@ -128,8 +128,9 @@ def wrap_run_template(
agent_id: str,
schema_id: Optional[int],
version: Optional[VersionReference],
model: Optional[Model],
model: Optional[ModelOrStr],
fn: RunTemplate[AgentInput, AgentOutput],
tools: Optional[Iterable[Callable[..., Any]]] = None,
) -> Union[
_RunnableAgent[AgentInput, AgentOutput],
_RunnableOutputOnlyAgent[AgentInput, AgentOutput],
Expand All @@ -155,6 +156,7 @@ def wrap_run_template(
api=client,
schema_id=schema_id,
version=version,
tools=tools,
)


Expand All @@ -167,11 +169,12 @@ def agent_wrapper(
schema_id: Optional[int] = None,
agent_id: Optional[str] = None,
version: Optional[VersionReference] = None,
model: Optional[Model] = None,
model: Optional[ModelOrStr] = None,
tools: Optional[Iterable[Callable[..., Any]]] = None,
) -> AgentDecorator:
def wrap(fn: RunTemplate[AgentInput, AgentOutput]) -> FinalRunTemplate[AgentInput, AgentOutput]:
tid = agent_id or agent_id_from_fn_name(fn)
return functools.wraps(fn)(wrap_run_template(client, tid, schema_id, version, model, fn)) # pyright: ignore [reportReturnType]
return functools.wraps(fn)(wrap_run_template(client, tid, schema_id, version, model, fn, tools)) # pyright: ignore [reportReturnType]

# pyright is unhappy with generics
return wrap # pyright: ignore [reportReturnType]
Loading