Skip to content

Add documentation for agents #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,27 @@ def analyze_call_feedback(input: CallFeedbackInput) -> AsyncIterator[Run[CallFee
...
```

#### The Agent class

Any agent function (aka a function decorated with `@workflowai.agent()`) is in fact an instance
of the `Agent` class. Which means that any defined agent can access the underlying agent functions, mainly
`run`, `stream` and `reply`. The `__call__` method of the agent is overriden for convenience to match the original
function signature.

```python
# Any agent definition would also work
@workflowai.agent()
def analyze_call_feedback(input: CallFeedbackInput) -> CallFeedbackOutput:
...

# It is possible to call the run function directly to get a run object if needed
run = await agent.run(CallFeedbackInput(...))
# Or the stream function to get a stream of run objects (see below)
chunks = [chunk async for chunk in agent.stream(CallFeedbackInput(...))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guillaq could you please clarify how to print the chunk as they arrive?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, I think I've got it. Let me try! :)

# Or reply to manually to a given run id (see reply below)
run = await agent.reply(run_id="...", user_message="...", tool_results=...)
```

### The Run object

Although having an agent only return the run output covers most use cases, some use cases require having more
Expand Down Expand Up @@ -438,6 +459,9 @@ construction of a final output.
> have default values for most of their fields. Otherwise the agent will throw a WorkflowAIError on missing fields
> and the run chain will be broken.

> Under the hood, `run.reply` calls the `say_hello.reply` method as described in the
> [Agent class](#the-agent-class) section.

### Tools

Tools allow enhancing an agent's capabilities by allowing it to call external functions.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "workflowai"
version = "0.6.0.dev10"
version = "0.6.0.dev11"
description = ""
authors = ["Guillaume Aquilina <[email protected]>"]
readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def test_run_task(
],
):
task_input = ExtractProductReviewSentimentTaskInput(review_text="This product is amazing!")
run = await extract_product_review_sentiment_agent.run(task_input=task_input, use_cache="never")
run = await extract_product_review_sentiment_agent.run(task_input, use_cache="never")
assert run.output.sentiment == Sentiment.POSITIVE


Expand All @@ -67,7 +67,7 @@ async def test_stream_task(
)

streamed = extract_product_review_sentiment_agent.stream(
task_input=task_input,
task_input,
use_cache="never",
)
chunks = [chunk async for chunk in streamed]
Expand Down
166 changes: 161 additions & 5 deletions workflowai/core/client/_fn_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
import inspect
from collections.abc import AsyncIterator, Callable, Iterable, Sequence
from typing import (
Expand All @@ -7,6 +6,7 @@
NamedTuple,
Optional,
Union,
cast,
get_args,
get_origin,
get_type_hints,
Expand Down Expand Up @@ -104,7 +104,41 @@ def extract_fn_spec(fn: RunTemplate[AgentInput, AgentOutput]) -> RunFunctionSpec

class _RunnableAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]):
async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002
"""An agent that returns a run object. Handles recoverable errors when possible"""
"""Run the agent and return the full run object. Handles recoverable errors when possible

Args:
_ (AgentInput): The input to the task.
id (Optional[str]): A user defined ID for the run. The ID must be a UUID7, ordered by
creation time. If not provided, a UUID7 will be assigned by the server.
model (Optional[str]): The model to use for this run. Overrides the version's model if
provided.
version (Optional[VersionReference]): The version of the task to run. If not provided,
the version defined in the task is used.
instructions (Optional[str]): Custom instructions for this run. Overrides the version's
instructions if provided.
temperature (Optional[float]): The temperature to use for this run. Overrides the
version's temperature if provided.
use_cache (CacheUsage, optional): How to use the cache. Defaults to "auto".
"auto" (default): if a previous run exists with the same version and input, and if
the temperature is 0, the cached output is returned
"always": the cached output is returned when available, regardless
of the temperature value
"never": the cache is never used
labels (Optional[set[str]], optional): Labels are deprecated, please use metadata instead.
metadata (Optional[dict[str, Any]], optional): A dictionary of metadata to attach to the
run.
max_retry_delay (Optional[float], optional): The maximum delay between retries in
milliseconds. Defaults to 60000.
max_retry_count (Optional[float], optional): The maximum number of retry attempts.
Defaults to 1.
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles.
Defaults to 10.
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the
output.

Returns:
Run[AgentOutput]: The task run object.
"""
try:
return await self.run(input, **kwargs)
except InvalidGenerationError as e:
Expand All @@ -130,16 +164,125 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp

class _RunnableOutputOnlyAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]):
async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002
"""Run the agent

This variant returns only the output, without the run metadata.

Args:
_ (AgentInput): The input to the task.
id (Optional[str]): A user defined ID for the run. The ID must be a UUID7, ordered by
creation time. If not provided, a UUID7 will be assigned by the server.
model (Optional[str]): The model to use for this run. Overrides the version's model if
provided.
version (Optional[VersionReference]): The version of the task to run. If not provided,
the version defined in the task is used.
instructions (Optional[str]): Custom instructions for this run. Overrides the version's
instructions if provided.
temperature (Optional[float]): The temperature to use for this run. Overrides the
version's temperature if provided.
use_cache (CacheUsage, optional): How to use the cache. Defaults to "auto".
"auto" (default): if a previous run exists with the same version and input, and if
the temperature is 0, the cached output is returned
"always": the cached output is returned when available, regardless
of the temperature value
"never": the cache is never used
labels (Optional[set[str]], optional): Labels are deprecated, please use metadata instead.
metadata (Optional[dict[str, Any]], optional): A dictionary of metadata to attach to the
run.
max_retry_delay (Optional[float], optional): The maximum delay between retries in
milliseconds. Defaults to 60000.
max_retry_count (Optional[float], optional): The maximum number of retry attempts.
Defaults to 1.
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles.
Defaults to 10.
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the
output.

Returns:
AgentOutput: The output of the task.
"""
return (await self.run(input, **kwargs)).output


class _RunnableStreamAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]):
def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002
"""Stream the output of the agent

Args:
_ (AgentInput): The input to the task.
id (Optional[str]): A user defined ID for the run. The ID must be a UUID7, ordered by
creation time. If not provided, a UUID7 will be assigned by the server.
model (Optional[str]): The model to use for this run. Overrides the version's model if
provided.
version (Optional[VersionReference]): The version of the task to run. If not provided,
the version defined in the task is used.
instructions (Optional[str]): Custom instructions for this run. Overrides the version's
instructions if provided.
temperature (Optional[float]): The temperature to use for this run. Overrides the
version's temperature if provided.
use_cache (CacheUsage, optional): How to use the cache. Defaults to "auto".
"auto" (default): if a previous run exists with the same version and input, and if
the temperature is 0, the cached output is returned
"always": the cached output is returned when available, regardless
of the temperature value
"never": the cache is never used
labels (Optional[set[str]], optional): Labels are deprecated, please use metadata instead.
metadata (Optional[dict[str, Any]], optional): A dictionary of metadata to attach to the
run.
max_retry_delay (Optional[float], optional): The maximum delay between retries in
milliseconds. Defaults to 60000.
max_retry_count (Optional[float], optional): The maximum number of retry attempts.
Defaults to 1.
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles.
Defaults to 10.
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the
output.

Returns:
AsyncIterator[Run[AgentOutput]]: An async iterator yielding task run objects.
"""
return self.stream(input, **kwargs)


class _RunnableStreamOutputOnlyAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]):
async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002
"""Stream the output of the agent

This variant yields only the output, without the run metadata.

Args:
_ (AgentInput): The input to the task.
id (Optional[str]): A user defined ID for the run. The ID must be a UUID7, ordered by
creation time. If not provided, a UUID7 will be assigned by the server.
model (Optional[str]): The model to use for this run. Overrides the version's model if
provided.
version (Optional[VersionReference]): The version of the task to run. If not provided,
the version defined in the task is used.
instructions (Optional[str]): Custom instructions for this run. Overrides the version's
instructions if provided.
temperature (Optional[float]): The temperature to use for this run. Overrides the
version's temperature if provided.
use_cache (CacheUsage, optional): How to use the cache. Defaults to "auto".
"auto" (default): if a previous run exists with the same version and input, and if
the temperature is 0, the cached output is returned
"always": the cached output is returned when available, regardless
of the temperature value
"never": the cache is never used
labels (Optional[set[str]], optional): Labels are deprecated, please use metadata instead.
metadata (Optional[dict[str, Any]], optional): A dictionary of metadata to attach to the
run.
max_retry_delay (Optional[float], optional): The maximum delay between retries in
milliseconds. Defaults to 60000.
max_retry_count (Optional[float], optional): The maximum number of retry attempts.
Defaults to 1.
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles.
Defaults to 10.
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the
output.

Returns:
AsyncIterator[AgentOutput]: An async iterator yielding task outputs.
"""
async for chunk in self.stream(input, **kwargs):
yield chunk.output

Expand Down Expand Up @@ -215,9 +358,22 @@ def agent_wrapper(
model: Optional[ModelOrStr] = None,
tools: Optional[Iterable[Callable[..., Any]]] = None,
) -> AgentDecorator:
def wrap(fn: RunTemplate[AgentInput, AgentOutput]) -> FinalRunTemplate[AgentInput, AgentOutput]:
def wrap(fn: RunTemplate[AgentInput, AgentOutput]):
tid = agent_id or agent_id_from_fn_name(fn)
return functools.wraps(fn)(wrap_run_template(client, tid, schema_id, version, model, fn, tools)) # pyright: ignore [reportReturnType]
# TODO[types]: Not sure why a cast is needed here
agent = cast(
FinalRunTemplate[AgentInput, AgentOutput],
wrap_run_template(client, tid, schema_id, version, model, fn, tools),
)

agent.__doc__ = """A class representing an AI agent that can process inputs and generate outputs.

The Agent class provides functionality to run AI-powered tasks with support for streaming,
tool execution, and version management.
"""
agent.__name__ = fn.__name__

return agent # pyright: ignore [reportReturnType]

# pyright is unhappy with generics
# TODO[types]: pyright is unhappy with generics
return wrap # pyright: ignore [reportReturnType]
67 changes: 48 additions & 19 deletions workflowai/core/client/_fn_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,44 +116,72 @@ async def test_fn_stream_output_only(self, mock_api_client: Mock):
assert isinstance(chunks[0], HelloTaskOutput)
assert chunks[0] == HelloTaskOutput(message="Hello, World!")

async def test_agent_functions_and_doc(self, mock_api_client: Mock):
wrapped = agent_wrapper(lambda: mock_api_client, schema_id=1, agent_id="hello")(self.fn_run_output_only)
assert wrapped.__doc__

mock_api_client.post.return_value = RunResponse(id="1", task_output={"message": "Hello, World!"})
output = await wrapped(HelloTaskInput(name="World"))
assert isinstance(output, HelloTaskOutput)

mock_api_client.post.return_value = RunResponse(id="1", task_output={"message": "Hello, World!"})
run = await wrapped.run(HelloTaskInput(name="World"), model="gpt-4o")
assert isinstance(run, Run)

mock_api_client.stream.return_value = mock_aiter(RunResponse(id="1", task_output={"message": "Hello, World!"}))
chunks = [c async for c in wrapped.stream(HelloTaskInput(name="World"))]
assert len(chunks) == 1
assert isinstance(chunks[0], Run)

assert wrapped.run.__doc__
assert wrapped.stream.__doc__
assert wrapped.reply.__doc__
assert wrapped.register.__doc__
assert wrapped.__call__.__doc__


@pytest.mark.parametrize(
("value", "expected"),
[
# Empty docstrings
("", ""),
(None, ""),

# Single line docstrings
("Hello world", "Hello world"),
(" Hello world ", "Hello world"),

# Docstring with empty lines at start/end
("""
(
"""

Hello world

""", "Hello world"),

""",
"Hello world",
),
# Multi-line docstring with indentation
("""
(
"""
First line
Second line
Indented line
Last line
""", "First line\nSecond line\n Indented line\nLast line"),

""",
"First line\nSecond line\n Indented line\nLast line",
),
# Docstring with empty lines in between
("""
(
"""
First line

Second line

Third line
""", "First line\n\nSecond line\n\nThird line"),

""",
"First line\n\nSecond line\n\nThird line",
),
# Real-world example
("""
(
"""
Find the capital city of the country where the input city is located.

Guidelines:
Expand All @@ -163,13 +191,14 @@ async def test_fn_stream_output_only(self, mock_api_client: Mock):
4. Be accurate and precise with geographical information
5. If the input city is itself the capital, still provide the information
""",
"Find the capital city of the country where the input city is located.\n\n"
"Guidelines:\n"
"1. First identify the country where the input city is located\n"
"2. Then provide the capital city of that country\n"
"3. Include an interesting historical or cultural fact about the capital\n"
"4. Be accurate and precise with geographical information\n"
"5. If the input city is itself the capital, still provide the information"),
"Find the capital city of the country where the input city is located.\n\n"
"Guidelines:\n"
"1. First identify the country where the input city is located\n"
"2. Then provide the capital city of that country\n"
"3. Include an interesting historical or cultural fact about the capital\n"
"4. Be accurate and precise with geographical information\n"
"5. If the input city is itself the capital, still provide the information",
),
],
)
def test_clean_docstring(value: Union[str, None], expected: str):
Expand Down
Loading
Loading