diff --git a/README.md b/README.md index 61b5696..d7f8233 100644 --- a/README.md +++ b/README.md @@ -219,6 +219,27 @@ def analyze_call_feedback(input: CallFeedbackInput) -> AsyncIterator[Run[CallFee ... ``` +#### The Agent class + +Any agent function (aka a function decorated with `@workflowai.agent()`) is in fact an instance +of the `Agent` class. Which means that any defined agent can access the underlying agent functions, mainly +`run`, `stream` and `reply`. The `__call__` method of the agent is overriden for convenience to match the original +function signature. + +```python +# Any agent definition would also work +@workflowai.agent() +def analyze_call_feedback(input: CallFeedbackInput) -> CallFeedbackOutput: + ... + +# It is possible to call the run function directly to get a run object if needed +run = await agent.run(CallFeedbackInput(...)) +# Or the stream function to get a stream of run objects (see below) +chunks = [chunk async for chunk in agent.stream(CallFeedbackInput(...)) +# Or reply to manually to a given run id (see reply below) +run = await agent.reply(run_id="...", user_message="...", tool_results=...) +``` + ### The Run object Although having an agent only return the run output covers most use cases, some use cases require having more @@ -438,6 +459,9 @@ construction of a final output. > have default values for most of their fields. Otherwise the agent will throw a WorkflowAIError on missing fields > and the run chain will be broken. +> Under the hood, `run.reply` calls the `say_hello.reply` method as described in the +> [Agent class](#the-agent-class) section. + ### Tools Tools allow enhancing an agent's capabilities by allowing it to call external functions. diff --git a/pyproject.toml b/pyproject.toml index 45b7b79..97f75ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.6.0.dev10" +version = "0.6.0.dev11" description = "" authors = ["Guillaume Aquilina "] readme = "README.md" diff --git a/tests/e2e/run_test.py b/tests/e2e/run_test.py index 0249683..82b0a4d 100644 --- a/tests/e2e/run_test.py +++ b/tests/e2e/run_test.py @@ -52,7 +52,7 @@ async def test_run_task( ], ): task_input = ExtractProductReviewSentimentTaskInput(review_text="This product is amazing!") - run = await extract_product_review_sentiment_agent.run(task_input=task_input, use_cache="never") + run = await extract_product_review_sentiment_agent.run(task_input, use_cache="never") assert run.output.sentiment == Sentiment.POSITIVE @@ -67,7 +67,7 @@ async def test_stream_task( ) streamed = extract_product_review_sentiment_agent.stream( - task_input=task_input, + task_input, use_cache="never", ) chunks = [chunk async for chunk in streamed] diff --git a/workflowai/core/client/_fn_utils.py b/workflowai/core/client/_fn_utils.py index c56a8db..97495f9 100644 --- a/workflowai/core/client/_fn_utils.py +++ b/workflowai/core/client/_fn_utils.py @@ -1,4 +1,3 @@ -import functools import inspect from collections.abc import AsyncIterator, Callable, Iterable, Sequence from typing import ( @@ -7,6 +6,7 @@ NamedTuple, Optional, Union, + cast, get_args, get_origin, get_type_hints, @@ -104,7 +104,41 @@ def extract_fn_spec(fn: RunTemplate[AgentInput, AgentOutput]) -> RunFunctionSpec class _RunnableAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002 - """An agent that returns a run object. Handles recoverable errors when possible""" + """Run the agent and return the full run object. Handles recoverable errors when possible + + Args: + _ (AgentInput): The input to the task. + id (Optional[str]): A user defined ID for the run. The ID must be a UUID7, ordered by + creation time. If not provided, a UUID7 will be assigned by the server. + model (Optional[str]): The model to use for this run. Overrides the version's model if + provided. + version (Optional[VersionReference]): The version of the task to run. If not provided, + the version defined in the task is used. + instructions (Optional[str]): Custom instructions for this run. Overrides the version's + instructions if provided. + temperature (Optional[float]): The temperature to use for this run. Overrides the + version's temperature if provided. + use_cache (CacheUsage, optional): How to use the cache. Defaults to "auto". + "auto" (default): if a previous run exists with the same version and input, and if + the temperature is 0, the cached output is returned + "always": the cached output is returned when available, regardless + of the temperature value + "never": the cache is never used + labels (Optional[set[str]], optional): Labels are deprecated, please use metadata instead. + metadata (Optional[dict[str, Any]], optional): A dictionary of metadata to attach to the + run. + max_retry_delay (Optional[float], optional): The maximum delay between retries in + milliseconds. Defaults to 60000. + max_retry_count (Optional[float], optional): The maximum number of retry attempts. + Defaults to 1. + max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. + Defaults to 10. + validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the + output. + + Returns: + Run[AgentOutput]: The task run object. + """ try: return await self.run(input, **kwargs) except InvalidGenerationError as e: @@ -130,16 +164,125 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp class _RunnableOutputOnlyAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002 + """Run the agent + + This variant returns only the output, without the run metadata. + + Args: + _ (AgentInput): The input to the task. + id (Optional[str]): A user defined ID for the run. The ID must be a UUID7, ordered by + creation time. If not provided, a UUID7 will be assigned by the server. + model (Optional[str]): The model to use for this run. Overrides the version's model if + provided. + version (Optional[VersionReference]): The version of the task to run. If not provided, + the version defined in the task is used. + instructions (Optional[str]): Custom instructions for this run. Overrides the version's + instructions if provided. + temperature (Optional[float]): The temperature to use for this run. Overrides the + version's temperature if provided. + use_cache (CacheUsage, optional): How to use the cache. Defaults to "auto". + "auto" (default): if a previous run exists with the same version and input, and if + the temperature is 0, the cached output is returned + "always": the cached output is returned when available, regardless + of the temperature value + "never": the cache is never used + labels (Optional[set[str]], optional): Labels are deprecated, please use metadata instead. + metadata (Optional[dict[str, Any]], optional): A dictionary of metadata to attach to the + run. + max_retry_delay (Optional[float], optional): The maximum delay between retries in + milliseconds. Defaults to 60000. + max_retry_count (Optional[float], optional): The maximum number of retry attempts. + Defaults to 1. + max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. + Defaults to 10. + validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the + output. + + Returns: + AgentOutput: The output of the task. + """ return (await self.run(input, **kwargs)).output class _RunnableStreamAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002 + """Stream the output of the agent + + Args: + _ (AgentInput): The input to the task. + id (Optional[str]): A user defined ID for the run. The ID must be a UUID7, ordered by + creation time. If not provided, a UUID7 will be assigned by the server. + model (Optional[str]): The model to use for this run. Overrides the version's model if + provided. + version (Optional[VersionReference]): The version of the task to run. If not provided, + the version defined in the task is used. + instructions (Optional[str]): Custom instructions for this run. Overrides the version's + instructions if provided. + temperature (Optional[float]): The temperature to use for this run. Overrides the + version's temperature if provided. + use_cache (CacheUsage, optional): How to use the cache. Defaults to "auto". + "auto" (default): if a previous run exists with the same version and input, and if + the temperature is 0, the cached output is returned + "always": the cached output is returned when available, regardless + of the temperature value + "never": the cache is never used + labels (Optional[set[str]], optional): Labels are deprecated, please use metadata instead. + metadata (Optional[dict[str, Any]], optional): A dictionary of metadata to attach to the + run. + max_retry_delay (Optional[float], optional): The maximum delay between retries in + milliseconds. Defaults to 60000. + max_retry_count (Optional[float], optional): The maximum number of retry attempts. + Defaults to 1. + max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. + Defaults to 10. + validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the + output. + + Returns: + AsyncIterator[Run[AgentOutput]]: An async iterator yielding task run objects. + """ return self.stream(input, **kwargs) class _RunnableStreamOutputOnlyAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002 + """Stream the output of the agent + + This variant yields only the output, without the run metadata. + + Args: + _ (AgentInput): The input to the task. + id (Optional[str]): A user defined ID for the run. The ID must be a UUID7, ordered by + creation time. If not provided, a UUID7 will be assigned by the server. + model (Optional[str]): The model to use for this run. Overrides the version's model if + provided. + version (Optional[VersionReference]): The version of the task to run. If not provided, + the version defined in the task is used. + instructions (Optional[str]): Custom instructions for this run. Overrides the version's + instructions if provided. + temperature (Optional[float]): The temperature to use for this run. Overrides the + version's temperature if provided. + use_cache (CacheUsage, optional): How to use the cache. Defaults to "auto". + "auto" (default): if a previous run exists with the same version and input, and if + the temperature is 0, the cached output is returned + "always": the cached output is returned when available, regardless + of the temperature value + "never": the cache is never used + labels (Optional[set[str]], optional): Labels are deprecated, please use metadata instead. + metadata (Optional[dict[str, Any]], optional): A dictionary of metadata to attach to the + run. + max_retry_delay (Optional[float], optional): The maximum delay between retries in + milliseconds. Defaults to 60000. + max_retry_count (Optional[float], optional): The maximum number of retry attempts. + Defaults to 1. + max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. + Defaults to 10. + validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the + output. + + Returns: + AsyncIterator[AgentOutput]: An async iterator yielding task outputs. + """ async for chunk in self.stream(input, **kwargs): yield chunk.output @@ -215,9 +358,22 @@ def agent_wrapper( model: Optional[ModelOrStr] = None, tools: Optional[Iterable[Callable[..., Any]]] = None, ) -> AgentDecorator: - def wrap(fn: RunTemplate[AgentInput, AgentOutput]) -> FinalRunTemplate[AgentInput, AgentOutput]: + def wrap(fn: RunTemplate[AgentInput, AgentOutput]): tid = agent_id or agent_id_from_fn_name(fn) - return functools.wraps(fn)(wrap_run_template(client, tid, schema_id, version, model, fn, tools)) # pyright: ignore [reportReturnType] + # TODO[types]: Not sure why a cast is needed here + agent = cast( + FinalRunTemplate[AgentInput, AgentOutput], + wrap_run_template(client, tid, schema_id, version, model, fn, tools), + ) + + agent.__doc__ = """A class representing an AI agent that can process inputs and generate outputs. + + The Agent class provides functionality to run AI-powered tasks with support for streaming, + tool execution, and version management. +""" + agent.__name__ = fn.__name__ + + return agent # pyright: ignore [reportReturnType] - # pyright is unhappy with generics + # TODO[types]: pyright is unhappy with generics return wrap # pyright: ignore [reportReturnType] diff --git a/workflowai/core/client/_fn_utils_test.py b/workflowai/core/client/_fn_utils_test.py index 90c4dbc..f04cfab 100644 --- a/workflowai/core/client/_fn_utils_test.py +++ b/workflowai/core/client/_fn_utils_test.py @@ -116,6 +116,29 @@ async def test_fn_stream_output_only(self, mock_api_client: Mock): assert isinstance(chunks[0], HelloTaskOutput) assert chunks[0] == HelloTaskOutput(message="Hello, World!") + async def test_agent_functions_and_doc(self, mock_api_client: Mock): + wrapped = agent_wrapper(lambda: mock_api_client, schema_id=1, agent_id="hello")(self.fn_run_output_only) + assert wrapped.__doc__ + + mock_api_client.post.return_value = RunResponse(id="1", task_output={"message": "Hello, World!"}) + output = await wrapped(HelloTaskInput(name="World")) + assert isinstance(output, HelloTaskOutput) + + mock_api_client.post.return_value = RunResponse(id="1", task_output={"message": "Hello, World!"}) + run = await wrapped.run(HelloTaskInput(name="World"), model="gpt-4o") + assert isinstance(run, Run) + + mock_api_client.stream.return_value = mock_aiter(RunResponse(id="1", task_output={"message": "Hello, World!"})) + chunks = [c async for c in wrapped.stream(HelloTaskInput(name="World"))] + assert len(chunks) == 1 + assert isinstance(chunks[0], Run) + + assert wrapped.run.__doc__ + assert wrapped.stream.__doc__ + assert wrapped.reply.__doc__ + assert wrapped.register.__doc__ + assert wrapped.__call__.__doc__ + @pytest.mark.parametrize( ("value", "expected"), @@ -123,37 +146,42 @@ async def test_fn_stream_output_only(self, mock_api_client: Mock): # Empty docstrings ("", ""), (None, ""), - # Single line docstrings ("Hello world", "Hello world"), (" Hello world ", "Hello world"), - # Docstring with empty lines at start/end - (""" + ( + """ Hello world - """, "Hello world"), - + """, + "Hello world", + ), # Multi-line docstring with indentation - (""" + ( + """ First line Second line Indented line Last line - """, "First line\nSecond line\n Indented line\nLast line"), - + """, + "First line\nSecond line\n Indented line\nLast line", + ), # Docstring with empty lines in between - (""" + ( + """ First line Second line Third line - """, "First line\n\nSecond line\n\nThird line"), - + """, + "First line\n\nSecond line\n\nThird line", + ), # Real-world example - (""" + ( + """ Find the capital city of the country where the input city is located. Guidelines: @@ -163,13 +191,14 @@ async def test_fn_stream_output_only(self, mock_api_client: Mock): 4. Be accurate and precise with geographical information 5. If the input city is itself the capital, still provide the information """, - "Find the capital city of the country where the input city is located.\n\n" - "Guidelines:\n" - "1. First identify the country where the input city is located\n" - "2. Then provide the capital city of that country\n" - "3. Include an interesting historical or cultural fact about the capital\n" - "4. Be accurate and precise with geographical information\n" - "5. If the input city is itself the capital, still provide the information"), + "Find the capital city of the country where the input city is located.\n\n" + "Guidelines:\n" + "1. First identify the country where the input city is located\n" + "2. Then provide the capital city of that country\n" + "3. Include an interesting historical or cultural fact about the capital\n" + "4. Be accurate and precise with geographical information\n" + "5. If the input city is itself the capital, still provide the information", + ), ], ) def test_clean_docstring(value: Union[str, None], expected: str): diff --git a/workflowai/core/client/_types.py b/workflowai/core/client/_types.py index 27334c5..9802db2 100644 --- a/workflowai/core/client/_types.py +++ b/workflowai/core/client/_types.py @@ -1,4 +1,4 @@ -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterable from typing import ( Any, Generic, @@ -13,21 +13,31 @@ from workflowai.core._common_types import AgentInputContra, AgentOutputCov, RunParams from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentInput, AgentOutput +from workflowai.core.domain.tool_call import ToolCallResult -class RunFn(Protocol, Generic[AgentInputContra, AgentOutput]): +class _BaseObject(Protocol): + __name__: str + __doc__: Optional[str] + __module__: str + __qualname__: str + __annotations__: dict[str, Any] + __defaults__: Optional[tuple[Any, ...]] + + +class RunFn(_BaseObject, Generic[AgentInputContra, AgentOutput], Protocol): async def __call__(self, _: AgentInputContra, /) -> Run[AgentOutput]: ... -class RunFnOutputOnly(Protocol, Generic[AgentInputContra, AgentOutputCov]): +class RunFnOutputOnly(_BaseObject, Generic[AgentInputContra, AgentOutputCov], Protocol): async def __call__(self, _: AgentInputContra, /) -> AgentOutputCov: ... -class StreamRunFn(Protocol, Generic[AgentInputContra, AgentOutput]): +class StreamRunFn(_BaseObject, Generic[AgentInputContra, AgentOutput], Protocol): def __call__(self, _: AgentInputContra, /) -> AsyncIterator[Run[AgentOutput]]: ... -class StreamRunFnOutputOnly(Protocol, Generic[AgentInputContra, AgentOutputCov]): +class StreamRunFnOutputOnly(_BaseObject, Generic[AgentInputContra, AgentOutputCov], Protocol): def __call__(self, _: AgentInputContra, /) -> AsyncIterator[AgentOutputCov]: ... @@ -39,18 +49,34 @@ def __call__(self, _: AgentInputContra, /) -> AsyncIterator[AgentOutputCov]: ... ] -class _BaseProtocol(Protocol): - __name__: str - __doc__: Optional[str] - __module__: str - __qualname__: str - __annotations__: dict[str, Any] - __defaults__: Optional[tuple[Any, ...]] +class _BaseProtocol(_BaseObject, Generic[AgentInputContra, AgentOutput], Protocol): __kwdefaults__: Optional[dict[str, Any]] __code__: Any + async def run( + self, + agent_input: AgentInputContra, + **kwargs: Unpack[RunParams[AgentOutput]], + ) -> Run[AgentOutput]: ... + + def stream( + self, + agent_input: AgentInputContra, + **kwargs: Unpack[RunParams[AgentOutput]], + ) -> AsyncIterator[Run[AgentOutput]]: ... + + async def register(self): ... + + async def reply( + self, + run_id: str, + user_message: Optional[str] = None, + tool_results: Optional[Iterable[ToolCallResult]] = None, + **kwargs: Unpack[RunParams[AgentOutput]], + ): ... + -class FinalRunFn(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutput]): +class RunnableAgent(_BaseProtocol[AgentInputContra, AgentOutput], Protocol): async def __call__( self, _: AgentInputContra, @@ -59,7 +85,7 @@ async def __call__( ) -> Run[AgentOutput]: ... -class FinalRunFnOutputOnly(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutput]): +class RunnableOutputAgent(_BaseProtocol[AgentInputContra, AgentOutput], Protocol): async def __call__( self, _: AgentInputContra, @@ -68,7 +94,7 @@ async def __call__( ) -> AgentOutput: ... -class FinalStreamRunFn(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutput]): +class StreamableAgent(_BaseProtocol[AgentInputContra, AgentOutput], Protocol): def __call__( self, _: AgentInputContra, @@ -77,40 +103,40 @@ def __call__( ) -> AsyncIterator[Run[AgentOutput]]: ... -class FinalStreamRunFnOutputOnly(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutputCov]): +class StreamableOutputAgent(_BaseProtocol[AgentInputContra, AgentOutput], Protocol): def __call__( self, _: AgentInputContra, /, **kwargs: Unpack[RunParams[AgentOutput]], - ) -> AsyncIterator[AgentOutputCov]: ... + ) -> AsyncIterator[AgentOutput]: ... FinalRunTemplate = Union[ - FinalRunFn[AgentInput, AgentOutput], - FinalRunFnOutputOnly[AgentInput, AgentOutput], - FinalStreamRunFn[AgentInput, AgentOutput], - FinalStreamRunFnOutputOnly[AgentInput, AgentOutput], + RunnableAgent[AgentInput, AgentOutput], + RunnableOutputAgent[AgentInput, AgentOutput], + StreamableAgent[AgentInput, AgentOutput], + StreamableOutputAgent[AgentInput, AgentOutput], ] class AgentDecorator(Protocol): @overload - def __call__(self, fn: RunFn[AgentInput, AgentOutput]) -> FinalRunFn[AgentInput, AgentOutput]: ... + def __call__(self, fn: RunFn[AgentInput, AgentOutput]) -> RunnableAgent[AgentInput, AgentOutput]: ... @overload def __call__( self, fn: RunFnOutputOnly[AgentInput, AgentOutput], - ) -> FinalRunFnOutputOnly[AgentInput, AgentOutput]: ... + ) -> RunnableOutputAgent[AgentInput, AgentOutput]: ... @overload - def __call__(self, fn: StreamRunFn[AgentInput, AgentOutput]) -> FinalStreamRunFn[AgentInput, AgentOutput]: ... + def __call__(self, fn: StreamRunFn[AgentInput, AgentOutput]) -> StreamableAgent[AgentInput, AgentOutput]: ... @overload def __call__( self, fn: StreamRunFnOutputOnly[AgentInput, AgentOutput], - ) -> FinalStreamRunFnOutputOnly[AgentInput, AgentOutput]: ... + ) -> StreamableOutputAgent[AgentInput, AgentOutput]: ... def __call__(self, fn: RunTemplate[AgentInput, AgentOutput]) -> FinalRunTemplate[AgentInput, AgentOutput]: ... diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index eea91d7..837139e 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -33,6 +33,51 @@ class Agent(Generic[AgentInput, AgentOutput]): + """A class representing an AI agent that can process inputs and generate outputs. + + The Agent class provides functionality to run AI-powered tasks with support for streaming, + tool execution, and version management. This class is not intended to be used directly, + instead use the `agent` decorator to create an agent. + + Args: + agent_id (str): Unique identifier for the agent. + input_cls (type[AgentInput]): The Pydantic model class defining the expected input structure. + output_cls (type[AgentOutput]): The Pydantic model class defining the expected output structure. + api (Union[APIClient, Callable[[], APIClient]]): The API client instance or factory function. + schema_id (Optional[int], optional): The schema ID for the agent. Defaults to None. + version (Optional[VersionReference], optional): The version reference for the agent. + If not provided, uses the global default version. Defaults to None. + tools (Optional[Iterable[Callable[..., Any]]], optional): Collection of tool functions + that the agent can use. Defaults to None. + + Attributes: + agent_id (str): The agent's unique identifier. + schema_id (Optional[int]): The schema ID associated with the agent. + input_cls (type[AgentInput]): The input model class. + output_cls (type[AgentOutput]): The output model class. + version (VersionReference): The version reference for the agent. + + Example: + ```python + from pydantic import BaseModel + + class MyInput(BaseModel): + query: str + + class MyOutput(BaseModel): + response: str + + agent = Agent( + agent_id="my-agent", + input_cls=MyInput, + output_cls=MyOutput, + api=api_client + ) + + result = await agent.run(MyInput(query="Hello")) + ``` + """ + _DEFAULT_MAX_ITERATIONS = 10 def __init__( @@ -115,7 +160,7 @@ def _sanitize_version(self, params: VersionRunParams) -> Union[str, int, dict[st dumped["temperature"] = temperature return dumped - async def _prepare_run(self, task_input: AgentInput, stream: bool, **kwargs: Unpack[RunParams[AgentOutput]]): + async def _prepare_run(self, agent_input: AgentInput, stream: bool, **kwargs: Unpack[RunParams[AgentOutput]]): schema_id = self.schema_id if not schema_id: schema_id = await self.register() @@ -124,7 +169,7 @@ async def _prepare_run(self, task_input: AgentInput, stream: bool, **kwargs: Unp request = RunRequest( id=kwargs.get("id"), - task_input=task_input.model_dump(by_alias=True), + task_input=agent_input.model_dump(by_alias=True), version=version, stream=stream, use_cache=kwargs.get("use_cache"), @@ -169,7 +214,10 @@ async def _prepare_reply( return self._PreparedRun(request, route, should_retry, wait_for_exception, self.schema_id) async def register(self): - """Registers the agent and returns the schema id""" + """ + Registers the agent and returns the schema id. This function is called + when the agent is first used and the result is cached in the agent's definition. + """ res = await self.api.post( "/v1/_/agents", CreateAgentRequest( @@ -272,34 +320,40 @@ async def _build_run( async def run( self, - task_input: AgentInput, + agent_input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]], ) -> Run[AgentOutput]: """Run the agent Args: - task_input (AgentInput): the input to the task - version (Optional[TaskVersionReference], optional): the version of the task to run. If not provided, - the version defined in the task is used. Defaults to None. - use_cache (CacheUsage, optional): how to use the cache. Defaults to "auto". + agent_input (AgentInput): The input to the task. + id (Optional[str]): A user defined ID for the run. The ID must be a UUID7, ordered by creation time. + If not provided, a UUID7 will be assigned by the server. + model (Optional[str]): The model to use for this run. Overrides the version's model if provided. + version (Optional[VersionReference]): The version of the task to run. If not provided, + the version defined in the task is used. + instructions (Optional[str]): Custom instructions for this run. Overrides the version's instructions if + provided. + temperature (Optional[float]): The temperature to use for this run. Overrides the version's temperature if + provided. + use_cache (CacheUsage, optional): How to use the cache. Defaults to "auto". "auto" (default): if a previous run exists with the same version and input, and if the temperature is 0, the cached output is returned "always": the cached output is returned when available, regardless of the temperature value "never": the cache is never used - labels (Optional[set[str]], optional): a set of labels to attach to the run. - Labels are indexed and searchable. Defaults to None. - metadata (Optional[dict[str, Any]], optional): a dictionary of metadata to attach to the run. - Defaults to None. - retry_delay (int, optional): The initial delay between retries in milliseconds. Defaults to 5000. - max_retry_delay (int, optional): The maximum delay between retries in milliseconds. Defaults to 60000. - max_retry_count (int, optional): The maximum number of retry attempts. Defaults to 1. + labels (Optional[set[str]], optional): Labels are deprecated, please use metadata instead. + metadata (Optional[dict[str, Any]], optional): A dictionary of metadata to attach to the run. + max_retry_delay (Optional[float], optional): The maximum delay between retries in milliseconds. + Defaults to 60000. + max_retry_count (Optional[float], optional): The maximum number of retry attempts. Defaults to 1. + max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10. + validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the output. Returns: - Union[TaskRun[AgentInput, AgentOutput], AsyncIterator[AgentOutput]]: the task run object - or an async iterator of output objects + Run[AgentOutput]: The task run object. """ - prepared_run = await self._prepare_run(task_input, stream=False, **kwargs) + prepared_run = await self._prepare_run(agent_input, stream=False, **kwargs) validator, new_kwargs = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls)) last_error = None @@ -323,34 +377,40 @@ async def run( async def stream( self, - task_input: AgentInput, + agent_input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]], ): """Stream the output of the agent Args: - task_input (AgentInput): the input to the task - version (Optional[TaskVersionReference], optional): the version of the task to run. If not provided, - the version defined in the task is used. Defaults to None. - use_cache (CacheUsage, optional): how to use the cache. Defaults to "auto". + agent_input (AgentInput): The input to the task. + id (Optional[str]): A user defined ID for the run. The ID must be a UUID7, ordered by creation time. + If not provided, a UUID7 will be assigned by the server. + model (Optional[str]): The model to use for this run. Overrides the version's model if provided. + version (Optional[VersionReference]): The version of the task to run. If not provided, + the version defined in the task is used. + instructions (Optional[str]): Custom instructions for this run. + Overrides the version's instructions if provided. + temperature (Optional[float]): The temperature to use for this run. + Overrides the version's temperature if provided. + use_cache (CacheUsage, optional): How to use the cache. Defaults to "auto". "auto" (default): if a previous run exists with the same version and input, and if the temperature is 0, the cached output is returned "always": the cached output is returned when available, regardless of the temperature value "never": the cache is never used - labels (Optional[set[str]], optional): a set of labels to attach to the run. - Labels are indexed and searchable. Defaults to None. - metadata (Optional[dict[str, Any]], optional): a dictionary of metadata to attach to the run. - Defaults to None. - retry_delay (int, optional): The initial delay between retries in milliseconds. Defaults to 5000. - max_retry_delay (int, optional): The maximum delay between retries in milliseconds. Defaults to 60000. - max_retry_count (int, optional): The maximum number of retry attempts. Defaults to 1. + labels (Optional[set[str]], optional): Labels are deprecated, please use metadata instead. + metadata (Optional[dict[str, Any]], optional): A dictionary of metadata to attach to the run. + max_retry_delay (Optional[float], optional): The maximum delay between retries in milliseconds. + Defaults to 60000. + max_retry_count (Optional[float], optional): The maximum number of retry attempts. Defaults to 1. + max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10. + validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the output. Returns: - Union[TaskRun[AgentInput, AgentOutput], AsyncIterator[AgentOutput]]: the task run object - or an async iterator of output objects + AsyncIterator[Run[AgentOutput]]: An async iterator yielding task run objects. """ - prepared_run = await self._prepare_run(task_input, stream=True, **kwargs) + prepared_run = await self._prepare_run(agent_input, stream=True, **kwargs) validator, new_kwargs = self._sanitize_validator(kwargs, tolerant_validator(self.output_cls)) while prepared_run.should_retry(): @@ -381,6 +441,18 @@ async def reply( current_iteration: int = 0, **kwargs: Unpack[RunParams[AgentOutput]], ): + """Reply to a run to provide additional information or context. + + Args: + run_id (str): The id of the run to reply to. + user_message (Optional[str]): The message to reply with. + tool_results (Optional[Iterable[ToolCallResult]]): The results of the tool calls. + **kwargs: Additional keyword arguments. + + Returns: + Run[AgentOutput]: The task run object. + """ + prepared_run = await self._prepare_reply(run_id, user_message, tool_results, stream=False, **kwargs) validator, new_kwargs = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls)) diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index 2ae6dbd..559f41d 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -57,7 +57,7 @@ class TestRun: async def test_success(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): httpx_mock.add_response(json=fixtures_json("task_run.json")) - task_run = await agent.run(task_input=HelloTaskInput(name="Alice")) + task_run = await agent.run(HelloTaskInput(name="Alice")) assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" assert task_run.agent_id == "123" @@ -85,7 +85,7 @@ async def test_stream(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, ), ) - chunks = [chunk async for chunk in agent.stream(task_input=HelloTaskInput(name="Alice"))] + chunks = [chunk async for chunk in agent.stream(HelloTaskInput(name="Alice"))] outputs = [chunk.output for chunk in chunks] assert outputs == [ @@ -119,7 +119,7 @@ async def test_stream_not_optional( ), ) - chunks = [chunk async for chunk in agent_not_optional.stream(task_input=HelloTaskInput(name="Alice"))] + chunks = [chunk async for chunk in agent_not_optional.stream(HelloTaskInput(name="Alice"))] messages = [chunk.output.message for chunk in chunks] assert messages == ["", "hel", "hello", "hello"] @@ -142,7 +142,7 @@ async def test_run_with_env(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskI httpx_mock.add_response(json=fixtures_json("task_run.json")) await agent.run( - task_input=HelloTaskInput(name="Alice"), + HelloTaskInput(name="Alice"), version="dev", ) @@ -160,7 +160,7 @@ async def test_run_with_env(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskI async def test_success_with_headers(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): httpx_mock.add_response(json=fixtures_json("task_run.json")) - task_run = await agent.run(task_input=HelloTaskInput(name="Alice")) + task_run = await agent.run(HelloTaskInput(name="Alice")) assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" @@ -198,7 +198,7 @@ async def test_run_retries_on_too_many_requests( json=fixtures_json("task_run.json"), ) - task_run = await agent.run(task_input=HelloTaskInput(name="Alice"), max_retry_count=5) + task_run = await agent.run(HelloTaskInput(name="Alice"), max_retry_count=5) assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" @@ -214,7 +214,7 @@ async def test_run_retries_on_connection_error( httpx_mock.add_exception(httpx.ConnectError("arg")) httpx_mock.add_response(json=fixtures_json("task_run.json")) - task_run = await agent.run(task_input=HelloTaskInput(name="Alice"), max_retry_count=5) + task_run = await agent.run(HelloTaskInput(name="Alice"), max_retry_count=5) assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" async def test_max_retries(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): @@ -222,7 +222,7 @@ async def test_max_retries(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskIn httpx_mock.add_exception(httpx.ConnectError("arg"), is_reusable=True) with pytest.raises(WorkflowAIError): - await agent.run(task_input=HelloTaskInput(name="Alice"), max_retry_count=5) + await agent.run(HelloTaskInput(name="Alice"), max_retry_count=5) reqs = httpx_mock.get_requests() assert len(reqs) == 5 @@ -241,7 +241,7 @@ async def test_auto_register(self, httpx_mock: HTTPXMock, agent_no_schema: Agent json=run_response, ) - out = await agent_no_schema.run(task_input=HelloTaskInput(name="Alice")) + out = await agent_no_schema.run(HelloTaskInput(name="Alice")) assert out.id == "8f635b73-f403-47ee-bff9-18320616c6cc" run_response["id"] = "8f635b73-f403-47ee-bff9-18320616c6cc" @@ -250,7 +250,7 @@ async def test_auto_register(self, httpx_mock: HTTPXMock, agent_no_schema: Agent url="http://localhost:8000/v1/_/agents/123/schemas/2/run", json=run_response, ) - out2 = await agent_no_schema.run(task_input=HelloTaskInput(name="Alice")) + out2 = await agent_no_schema.run(HelloTaskInput(name="Alice")) assert out2.id == "8f635b73-f403-47ee-bff9-18320616c6cc" reqs = httpx_mock.get_requests() diff --git a/workflowai/core/client/client_test.py b/workflowai/core/client/client_test.py index f152628..96c6f67 100644 --- a/workflowai/core/client/client_test.py +++ b/workflowai/core/client/client_test.py @@ -31,7 +31,7 @@ def test_fn_name(self, workflowai: WorkflowAI): async def fn(task_input: HelloTaskInput) -> HelloTaskOutput: ... assert fn.__name__ == "fn" - assert fn.__doc__ is None + assert fn.__doc__ is not None assert callable(fn) async def test_run_output_only(self, workflowai: WorkflowAI, mock_run_fn: Mock):