Skip to content

Commit 64570aa

Browse files
authored
Merge pull request #36 from WorkflowAI/guillaume/tool-enhancements
Tool enhancements
2 parents 57baccc + 6e6ac9d commit 64570aa

21 files changed

+764
-178
lines changed

.vscode/extensions.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"recommendations": [
3+
"charliermarsh.ruff",
4+
"njpwerner.autodocstring",
5+
"editorconfig.editorconfig"
6+
]
7+
}

CONTRIBUTING.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Contributing to WorkflowAI
2+
3+
## Setup
4+
5+
### Prerequisites
6+
7+
- [Poetry](https://python-poetry.org/docs/#installation) for dependency management and publishing
8+
9+
### Getting started
10+
11+
```bash
12+
# We recomment configuring the virtual envs in project with poetry so that
13+
# it can easily be picked up by IDEs
14+
15+
# poetry config virtualenvs.in-project true
16+
poetry install --all-extras
17+
18+
# Install the pre-commit hooks
19+
poetry run pre-commit install
20+
# or `make install` to install the pre-commit hooks and the dependencies
21+
22+
# Check the code quality
23+
# Run ruff
24+
poetry run ruff check .
25+
# Run pyright
26+
poetry run pyright
27+
# or `make lint` to run ruff and pyright
28+
29+
# Run the unit and integration tests
30+
# They do not require any configuration
31+
poetry run pytest --ignore=tests/e2e # make test
32+
33+
# Run the end to end tests
34+
# They require the `WORKFLOWAI_TEST_API_URL` and `WORKFLOWAI_TEST_API_KEY` environment variables to be set
35+
# If they are present in the `.env` file, they will be picked up automatically
36+
poetry run pytest tests/e2e
37+
```
38+
39+
#### Configuring VSCode
40+
41+
Suggested extensions are available in the [.vscode/extensions.json](.vscode/extensions.json) file.
42+
43+
### Dependencies
44+
45+
#### Ruff
46+
47+
[Ruff](https://github.com/astral-sh/ruff) is a very fast Python code linter and formatter.
48+
49+
```sh
50+
ruff check . # check the entire project
51+
ruff check src/workflowai/core # check a specific file
52+
ruff check . --fix # fix linting errors automatically in the entire project
53+
```
54+
55+
#### Pyright
56+
57+
[Pyright](https://github.com/microsoft/pyright) is a static type checker for Python.
58+
59+
> We preferred it to `mypy` because it is faster and easier to configure.
60+
61+
#### Pydantic
62+
63+
[Pydantic](https://docs.pydantic.dev/) is a data validation library for Python.
64+
It provides very convenient methods to serialize and deserialize data, introspect its structure, set validation
65+
rules, etc.
66+
67+
#### HTTPX
68+
69+
[HTTPX](https://www.python-httpx.org/) is a modern HTTP library for Python.

README.md

Lines changed: 114 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ WorkflowAI supports a long list of models. The source of truth for models we sup
121121
You can set the model explicitly in the agent decorator:
122122

123123
```python
124+
from workflowai import Model
125+
124126
@workflowai.agent(model=Model.GPT_4O_LATEST)
125127
def say_hello(input: Input) -> Output:
126128
...
@@ -151,16 +153,31 @@ def say_hello(input: Input) -> AsyncIterator[Run[Output]]:
151153
...
152154
```
153155

154-
### Streaming and advanced usage
156+
### The Run object
157+
158+
Although having an agent only return the run output covers most use cases, some use cases require having more
159+
information about the run.
155160

156-
You can configure the agent function to stream or return the full run object, simply by changing the type annotation.
161+
By changing the type annotation of the agent function to `Run[Output]`, the generated function will return
162+
the full run object.
157163

158164
```python
159-
# Return the full run object, useful if you want to extract metadata like cost or duration
160165
@workflowai.agent()
161-
async def say_hello(input: Input) -> Run[Output]:
162-
...
166+
async def say_hello(input: Input) -> Run[Output]: ...
167+
168+
169+
run = await say_hello(Input(name="John"))
170+
print(run.output) # the output, as before
171+
print(run.model) # the model used for the run
172+
print(run.cost_usd) # the cost of the run in USD
173+
print(run.duration_seconds) # the duration of the inference in seconds
174+
```
163175

176+
### Streaming
177+
178+
You can configure the agent function to stream by changing the type annotation to an AsyncIterator.
179+
180+
```python
164181
# Stream the output, the output is filled as it is generated
165182
@workflowai.agent()
166183
def say_hello(input: Input) -> AsyncIterator[Output]:
@@ -172,6 +189,38 @@ def say_hello(input: Input) -> AsyncIterator[Run[Output]]:
172189
...
173190
```
174191

192+
### Replying to a run
193+
194+
Some use cases require the ability to have a back and forth between the client and the LLM. For example:
195+
196+
- tools [see below](#tools) use the reply ability internally
197+
- chatbots
198+
- correcting the LLM output
199+
200+
In WorkflowAI, this is done by replying to a run. A reply can contain:
201+
202+
- a user response
203+
- tool results
204+
205+
<!-- TODO: find a better example for reply -->
206+
207+
```python
208+
# Returning the full run object is required to use the reply feature
209+
@workflowai.agent()
210+
async def say_hello(input: Input) -> Run[Output]:
211+
...
212+
213+
run = await say_hello(Input(name="John"))
214+
run = await run.reply(user_response="Now say hello to his brother James")
215+
```
216+
217+
The output of a reply to a run has the same type as the original run, which makes it easy to iterate towards the
218+
construction of a final output.
219+
220+
> To allow run iterations, it is very important to have outputs that are tolerant to missing fields, aka that
221+
> have default values for most of their fields. Otherwise the agent will throw a WorkflowAIError on missing fields
222+
> and the run chain will be broken.
223+
175224
### Tools
176225

177226
Tools allow enhancing an agent's capabilities by allowing it to call external functions.
@@ -222,9 +271,16 @@ def get_current_time(timezone: Annotated[str, "The timezone to get the current t
222271
"""Return the current time in the given timezone in iso format"""
223272
return datetime.now(ZoneInfo(timezone)).isoformat()
224273

274+
# Tools can also be async
275+
async def fetch_webpage(url: str) -> str:
276+
"""Fetch the content of a webpage"""
277+
async with httpx.AsyncClient() as client:
278+
response = await client.get(url)
279+
return response.text
280+
225281
@agent(
226282
id="answer-question",
227-
tools=[get_current_time],
283+
tools=[get_current_time, fetch_webpage],
228284
version=VersionProperties(model=Model.GPT_4O_LATEST),
229285
)
230286
async def answer_question(_: AnswerQuestionInput) -> Run[AnswerQuestionOutput]: ...
@@ -261,6 +317,29 @@ except WorkflowAIError as e:
261317
print(e.message)
262318
```
263319

320+
#### Recoverable errors
321+
322+
Sometimes, the LLM outputs an object that is partially valid, good examples are:
323+
324+
- the model context window was exceeded during the generation
325+
- the model decided that a tool call result was a failure
326+
327+
In this case, an agent that returns an output only will always raise an `InvalidGenerationError` which
328+
subclasses `WorkflowAIError`.
329+
330+
However, an agent that returns a full run object will try to recover from the error by using the partial output.
331+
332+
```python
333+
334+
run = await agent(input=Input(name="John"))
335+
336+
# The run will have an error
337+
assert run.error is not None
338+
339+
# The run will have a partial output
340+
assert run.output is not None
341+
```
342+
264343
### Definining input and output types
265344

266345
There are some important subtleties when defining input and output types.
@@ -368,3 +447,32 @@ async for run in say_hello(Input(name="John")):
368447
print(run.output.greeting1) # will be empty if the model has not generated it yet
369448

370449
```
450+
451+
#### Field properties
452+
453+
Pydantic allows a variety of other validation criteria for fields: minimum, maximum, pattern, etc.
454+
This additional criteria are included the JSON Schema that is sent to WorkflowAI, and are sent to the model.
455+
456+
```python
457+
class Input(BaseModel):
458+
name: str = Field(min_length=3, max_length=10)
459+
age: int = Field(ge=18, le=100)
460+
email: str = Field(pattern=r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$")
461+
```
462+
463+
These arguments can be used to stir the model in the right direction. The caveat is have a
464+
validation that is too strict can lead to invalid generations. In case of an invalid generation:
465+
466+
- WorkflowAI retries the inference once by providing the model with the invalid output and the validation error
467+
- if the model still fails to generate a valid output, the run will fail with an `InvalidGenerationError`.
468+
the partial output is available in the `partial_output` attribute of the `InvalidGenerationError`
469+
470+
```python
471+
472+
@agent()
473+
def my_agent(_: Input) -> :...
474+
```
475+
476+
## Contributing
477+
478+
See the [CONTRIBUTING.md](./CONTRIBUTING.md) file for more details.

poetry.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ unfixable = []
6464
# in bin we use rich.print
6565
"bin/*" = ["T201"]
6666
"*_test.py" = ["S101"]
67+
"conftest.py" = ["S101"]
6768

6869
[tool.pyright]
6970
pythonVersion = "3.9"

tests/e2e/tools_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from workflowai import Run, agent
88
from workflowai.core.domain.model import Model
9-
from workflowai.core.domain.tool import Tool
9+
from workflowai.core.domain.tool import ToolDefinition
1010
from workflowai.core.domain.tool_call import ToolCallResult
1111
from workflowai.core.domain.version_properties import VersionProperties
1212

@@ -20,7 +20,7 @@ class AnswerQuestionOutput(BaseModel):
2020

2121

2222
async def test_manual_tool():
23-
get_current_time_tool = Tool(
23+
get_current_time_tool = ToolDefinition(
2424
name="get_current_time",
2525
description="Get the current time",
2626
input_schema={},

tests/integration/conftest.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import json
2+
from typing import Any, Optional
3+
from unittest.mock import patch
4+
5+
import pytest
6+
from pydantic import BaseModel
7+
from pytest_httpx import HTTPXMock, IteratorStream
8+
9+
from workflowai.core.client.client import WorkflowAI
10+
11+
12+
@pytest.fixture(scope="module", autouse=True)
13+
def init_client():
14+
with patch("workflowai.shared_client", new=WorkflowAI(api_key="test", endpoint="https://run.workflowai.dev")):
15+
yield
16+
17+
18+
class CityToCapitalTaskInput(BaseModel):
19+
city: str
20+
21+
22+
class CityToCapitalTaskOutput(BaseModel):
23+
capital: str
24+
25+
26+
class IntTestClient:
27+
REGISTER_URL = "https://api.workflowai.dev/v1/_/agents"
28+
29+
def __init__(self, httpx_mock: HTTPXMock):
30+
self.httpx_mock = httpx_mock
31+
32+
def mock_register(self, schema_id: int = 1, task_id: str = "city-to-capital", variant_id: str = "1"):
33+
self.httpx_mock.add_response(
34+
method="POST",
35+
url=self.REGISTER_URL,
36+
json={"schema_id": schema_id, "variant_id": variant_id, "id": task_id},
37+
)
38+
39+
def mock_response(
40+
self,
41+
task_id: str = "city-to-capital",
42+
capital: str = "Tokyo",
43+
json: Optional[dict[str, Any]] = None,
44+
url: Optional[str] = None,
45+
status_code: int = 200,
46+
):
47+
self.httpx_mock.add_response(
48+
method="POST",
49+
url=url or f"https://run.workflowai.dev/v1/_/agents/{task_id}/schemas/1/run",
50+
json=json or {"id": "123", "task_output": {"capital": capital}},
51+
status_code=status_code,
52+
)
53+
54+
def mock_stream(self, task_id: str = "city-to-capital"):
55+
self.httpx_mock.add_response(
56+
url=f"https://run.workflowai.dev/v1/_/agents/{task_id}/schemas/1/run",
57+
stream=IteratorStream(
58+
[
59+
b'data: {"id":"1","task_output":{"capital":""}}\n\n',
60+
b'data: {"id":"1","task_output":{"capital":"Tok"}}\n\ndata: {"id":"1","task_output":{"capital":"Tokyo"}}\n\n', # noqa: E501
61+
b'data: {"id":"1","task_output":{"capital":"Tokyo"},"cost_usd":0.01,"duration_seconds":10.1}\n\n',
62+
],
63+
),
64+
)
65+
66+
def check_request(
67+
self,
68+
version: Any = "production",
69+
task_id: str = "city-to-capital",
70+
task_input: Optional[dict[str, Any]] = None,
71+
**matchers: Any,
72+
):
73+
request = self.httpx_mock.get_request(**matchers)
74+
assert request is not None
75+
assert request.url == f"https://run.workflowai.dev/v1/_/agents/{task_id}/schemas/1/run"
76+
body = json.loads(request.content)
77+
assert body == {
78+
"task_input": task_input or {"city": "Hello"},
79+
"version": version,
80+
"stream": False,
81+
}
82+
assert request.headers["Authorization"] == "Bearer test"
83+
assert request.headers["Content-Type"] == "application/json"
84+
assert request.headers["x-workflowai-source"] == "sdk"
85+
assert request.headers["x-workflowai-language"] == "python"
86+
87+
88+
@pytest.fixture
89+
def test_client(httpx_mock: HTTPXMock) -> IntTestClient:
90+
return IntTestClient(httpx_mock)

0 commit comments

Comments
 (0)