Skip to content

Commit 364088f

Browse files
authored
Merge pull request #75 from WorkflowAI/guillaume/max_turns
Add configurable max turns
2 parents 6506e09 + 80b85c7 commit 364088f

File tree

11 files changed

+296
-40
lines changed

11 files changed

+296
-40
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "workflowai"
3-
version = "0.6.1"
3+
version = "0.6.2-dev1"
44
description = "Python SDK for WorkflowAI"
55
authors = ["Guillaume Aquilina <[email protected]>"]
66
readme = "README.md"
@@ -63,7 +63,7 @@ unfixable = []
6363
[tool.ruff.lint.per-file-ignores]
6464
# in bin we use rich.print
6565
"bin/*" = ["T201"]
66-
"*_test.py" = ["S101"]
66+
"*_test.py" = ["S101", "S106"]
6767
"conftest.py" = ["S101"]
6868
"examples/*" = ["INP001", "T201", "ERA001"]
6969

workflowai/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Callable, Iterable
2-
from typing import Any, Optional
2+
from typing import Any, Literal, Optional
33

44
from typing_extensions import deprecated
55

@@ -82,3 +82,12 @@ def agent(
8282
model=model,
8383
tools=tools,
8484
)
85+
86+
87+
def send_feedback(
88+
feedback_token: str,
89+
outcome: Literal["positive", "negative"],
90+
comment: Optional[str] = None,
91+
user_id: Optional[str] = None,
92+
):
93+
return shared_client.send_feedback(feedback_token, outcome, comment, user_id)

workflowai/core/_common_types.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,18 @@ class VersionRunParams(TypedDict):
4040
temperature: NotRequired[Optional[float]]
4141

4242

43-
class BaseRunParams(VersionRunParams):
43+
class OtherRunParams(TypedDict):
4444
use_cache: NotRequired["CacheUsage"]
45-
metadata: NotRequired[Optional[dict[str, Any]]]
46-
labels: NotRequired[Optional[set[str]]]
45+
4746
max_retry_delay: NotRequired[float]
4847
max_retry_count: NotRequired[float]
4948

50-
max_tool_iterations: NotRequired[int] # 10 by default
49+
max_turns: NotRequired[int] # 10 by default
50+
max_turns_raises: NotRequired[bool] # True by default
51+
52+
53+
class BaseRunParams(VersionRunParams, OtherRunParams):
54+
metadata: NotRequired[Optional[dict[str, Any]]]
5155

5256

5357
class RunParams(BaseRunParams, Generic[AgentOutput]):

workflowai/core/client/_fn_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pydantic import BaseModel, ValidationError
1616
from typing_extensions import Unpack
1717

18+
from workflowai.core._common_types import OtherRunParams
1819
from workflowai.core.client._api import APIClient
1920
from workflowai.core.client._models import RunResponse
2021
from workflowai.core.client._types import (
@@ -131,7 +132,7 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp
131132
milliseconds. Defaults to 60000.
132133
max_retry_count (Optional[float], optional): The maximum number of retry attempts.
133134
Defaults to 1.
134-
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles.
135+
max_turns (Optional[int], optional): Maximum number of tool iteration cycles.
135136
Defaults to 10.
136137
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the
137138
output.
@@ -194,7 +195,7 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp
194195
milliseconds. Defaults to 60000.
195196
max_retry_count (Optional[float], optional): The maximum number of retry attempts.
196197
Defaults to 1.
197-
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles.
198+
max_turns (Optional[int], optional): Maximum number of tool iteration cycles.
198199
Defaults to 10.
199200
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the
200201
output.
@@ -234,7 +235,7 @@ def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]):
234235
milliseconds. Defaults to 60000.
235236
max_retry_count (Optional[float], optional): The maximum number of retry attempts.
236237
Defaults to 1.
237-
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles.
238+
max_turns (Optional[int], optional): Maximum number of tool iteration cycles.
238239
Defaults to 10.
239240
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the
240241
output.
@@ -276,7 +277,7 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp
276277
milliseconds. Defaults to 60000.
277278
max_retry_count (Optional[float], optional): The maximum number of retry attempts.
278279
Defaults to 1.
279-
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles.
280+
max_turns (Optional[int], optional): Maximum number of tool iteration cycles.
280281
Defaults to 10.
281282
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the
282283
output.
@@ -318,6 +319,7 @@ def wrap_run_template(
318319
model: Optional[ModelOrStr],
319320
fn: RunTemplate[AgentInput, AgentOutput],
320321
tools: Optional[Iterable[Callable[..., Any]]] = None,
322+
run_params: Optional[OtherRunParams] = None,
321323
) -> Union[
322324
_RunnableAgent[AgentInput, AgentOutput],
323325
_RunnableOutputOnlyAgent[AgentInput, AgentOutput],
@@ -344,6 +346,7 @@ def wrap_run_template(
344346
schema_id=schema_id,
345347
version=version,
346348
tools=tools,
349+
**(run_params or {}),
347350
)
348351

349352

@@ -358,13 +361,14 @@ def agent_wrapper(
358361
version: Optional[VersionReference] = None,
359362
model: Optional[ModelOrStr] = None,
360363
tools: Optional[Iterable[Callable[..., Any]]] = None,
364+
**kwargs: Unpack[OtherRunParams],
361365
) -> AgentDecorator:
362366
def wrap(fn: RunTemplate[AgentInput, AgentOutput]):
363367
tid = agent_id or agent_id_from_fn_name(fn)
364368
# TODO[types]: Not sure why a cast is needed here
365369
agent = cast(
366370
FinalRunTemplate[AgentInput, AgentOutput],
367-
wrap_run_template(client, tid, schema_id, version, model, fn, tools),
371+
wrap_run_template(client, tid, schema_id, version, model, fn, tools, kwargs),
368372
)
369373

370374
agent.__doc__ = """A class representing an AI agent that can process inputs and generate outputs.

workflowai/core/client/_fn_utils_test.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
get_generic_args,
2020
is_async_iterator,
2121
)
22-
from workflowai.core.client._models import RunResponse
22+
from workflowai.core.client._models import RunRequest, RunResponse
2323
from workflowai.core.domain.run import Run
2424

2525

@@ -80,6 +80,29 @@ async def test_fn_run(self, mock_api_client: Mock):
8080
assert run.id == "1"
8181
assert run.output == HelloTaskOutput(message="Hello, World!")
8282

83+
async def test_fn_run_with_default_cache(self, mock_api_client: Mock):
84+
wrapped = agent_wrapper(lambda: mock_api_client, schema_id=1, agent_id="hello", use_cache="never")(self.fn_run)
85+
assert isinstance(wrapped, _RunnableAgent)
86+
87+
mock_api_client.post.return_value = RunResponse(id="1", task_output={"message": "Hello, World!"})
88+
run = await wrapped(HelloTaskInput(name="World"))
89+
assert isinstance(run, Run)
90+
91+
mock_api_client.post.assert_called_once()
92+
req = mock_api_client.post.call_args.args[1]
93+
assert isinstance(req, RunRequest)
94+
assert req.use_cache == "never"
95+
96+
mock_api_client.post.reset_mock()
97+
98+
# Check that it can be overridden
99+
_ = await wrapped(HelloTaskInput(name="World"), use_cache="always")
100+
101+
mock_api_client.post.assert_called_once()
102+
req = mock_api_client.post.call_args.args[1]
103+
assert isinstance(req, RunRequest)
104+
assert req.use_cache == "always"
105+
83106
def fn_stream(self, task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ...
84107

85108
async def test_fn_stream(self, mock_api_client: Mock):

workflowai/core/client/_models.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,13 @@ def tool_call_to_domain(tool_call: ToolCall) -> DToolCall:
103103
)
104104

105105

106-
class ToolCallRequest(TypedDict):
106+
class ToolCallRequestDict(TypedDict):
107107
id: str
108108
name: str
109109
input: dict[str, Any]
110110

111111

112-
def tool_call_request_to_domain(tool_call_request: ToolCallRequest) -> DToolCallRequest:
112+
def tool_call_request_to_domain(tool_call_request: ToolCallRequestDict) -> DToolCallRequest:
113113
return DToolCallRequest(
114114
id=tool_call_request["id"],
115115
name=tool_call_request["name"],
@@ -119,15 +119,15 @@ def tool_call_request_to_domain(tool_call_request: ToolCallRequest) -> DToolCall
119119

120120
class RunResponse(BaseModel):
121121
id: str
122-
task_output: dict[str, Any]
122+
task_output: Optional[dict[str, Any]] = None
123123

124124
version: Optional[Version] = None
125125
duration_seconds: Optional[float] = None
126126
cost_usd: Optional[float] = None
127127
metadata: Optional[dict[str, Any]] = None
128128

129129
tool_calls: Optional[list[ToolCall]] = None
130-
tool_call_requests: Optional[list[ToolCallRequest]] = None
130+
tool_call_requests: Optional[list[ToolCallRequestDict]] = None
131131

132132
feedback_token: Optional[str] = None
133133

@@ -147,7 +147,7 @@ def to_domain(
147147
id=self.id,
148148
agent_id=task_id,
149149
schema_id=task_schema_id,
150-
output=validator(self.task_output, partial),
150+
output=validator(self.task_output or {}, partial),
151151
version=self.version and self.version.to_domain(),
152152
duration_seconds=self.duration_seconds,
153153
cost_usd=self.cost_usd,
@@ -220,3 +220,10 @@ class CompletionsResponse(BaseModel):
220220
"""Response from the completions API endpoint."""
221221

222222
completions: list[Completion]
223+
224+
225+
class CreateFeedbackRequest(BaseModel):
226+
feedback_token: str
227+
outcome: Literal["positive", "negative"]
228+
comment: Optional[str]
229+
user_id: Optional[str]

workflowai/core/client/agent.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
from pydantic import BaseModel, ValidationError
77
from typing_extensions import Unpack
88

9-
from workflowai.core._common_types import BaseRunParams, OutputValidator, VersionRunParams
9+
from workflowai.core._common_types import (
10+
BaseRunParams,
11+
OtherRunParams,
12+
OutputValidator,
13+
VersionRunParams,
14+
)
1015
from workflowai.core.client._api import APIClient
1116
from workflowai.core.client._models import (
1217
CompletionsResponse,
@@ -27,7 +32,7 @@
2732
global_default_version_reference,
2833
)
2934
from workflowai.core.domain.completion import Completion
30-
from workflowai.core.domain.errors import BaseError, WorkflowAIError
35+
from workflowai.core.domain.errors import BaseError, MaxTurnsReachedError, WorkflowAIError
3136
from workflowai.core.domain.run import Run
3237
from workflowai.core.domain.task import AgentInput, AgentOutput
3338
from workflowai.core.domain.tool import Tool
@@ -83,7 +88,7 @@ class MyOutput(BaseModel):
8388
```
8489
"""
8590

86-
_DEFAULT_MAX_ITERATIONS = 10
91+
_DEFAULT_MAX_TURNS = 10
8792

8893
def __init__(
8994
self,
@@ -94,6 +99,7 @@ def __init__(
9499
schema_id: Optional[int] = None,
95100
version: Optional[VersionReference] = None,
96101
tools: Optional[Iterable[Callable[..., Any]]] = None,
102+
**kwargs: Unpack[OtherRunParams],
97103
):
98104
self.agent_id = agent_id
99105
self.schema_id = schema_id
@@ -104,6 +110,7 @@ def __init__(
104110
self._tools = self.build_tools(tools) if tools else None
105111

106112
self._default_validator = default_validator(output_cls)
113+
self._other_run_params = kwargs
107114

108115
@classmethod
109116
def build_tools(cls, tools: Iterable[Callable[..., Any]]):
@@ -180,6 +187,13 @@ def _sanitize_version(self, params: VersionRunParams) -> Union[str, int, dict[st
180187
dumped["temperature"] = combined.temperature
181188
return dumped
182189

190+
def _get_run_param(self, key: str, params: OtherRunParams, default: Any = None) -> Any:
191+
if key in params:
192+
return params[key] # pyright: ignore [reportUnknownVariableType]
193+
if key in self._other_run_params:
194+
return self._other_run_params[key] # pyright: ignore [reportUnknownVariableType]
195+
return default
196+
183197
async def _prepare_run(self, agent_input: AgentInput, stream: bool, **kwargs: Unpack[RunParams[AgentOutput]]):
184198
schema_id = self.schema_id
185199
if not schema_id:
@@ -192,15 +206,14 @@ async def _prepare_run(self, agent_input: AgentInput, stream: bool, **kwargs: Un
192206
task_input=agent_input.model_dump(by_alias=True),
193207
version=version,
194208
stream=stream,
195-
use_cache=kwargs.get("use_cache"),
209+
use_cache=self._get_run_param("use_cache", kwargs),
196210
metadata=kwargs.get("metadata"),
197-
labels=kwargs.get("labels"),
198211
)
199212

200213
route = f"/v1/_/agents/{self.agent_id}/schemas/{self.schema_id}/run"
201214
should_retry, wait_for_exception = build_retryable_wait(
202-
kwargs.get("max_retry_delay", 60),
203-
kwargs.get("max_retry_count", 1),
215+
self._get_run_param("max_retry_delay", kwargs, 60),
216+
self._get_run_param("max_retry_count", kwargs, 1),
204217
)
205218
return self._PreparedRun(request, route, should_retry, wait_for_exception, schema_id)
206219

@@ -227,8 +240,8 @@ async def _prepare_reply(
227240
)
228241
route = f"/v1/_/agents/{self.agent_id}/runs/{run_id}/reply"
229242
should_retry, wait_for_exception = build_retryable_wait(
230-
kwargs.get("max_retry_delay", 60),
231-
kwargs.get("max_retry_count", 1),
243+
self._get_run_param("max_retry_delay", kwargs, 60),
244+
self._get_run_param("max_retry_count", kwargs, 1),
232245
)
233246

234247
return self._PreparedRun(request, route, should_retry, wait_for_exception, self.schema_id)
@@ -324,8 +337,14 @@ async def _build_run(
324337
run = self._build_run_no_tools(chunk, schema_id, validator)
325338

326339
if run.tool_call_requests:
327-
if current_iteration >= kwargs.get("max_iterations", self._DEFAULT_MAX_ITERATIONS):
328-
raise WorkflowAIError(error=BaseError(message="max tool iterations reached"), response=None)
340+
if current_iteration >= self._get_run_param("max_turns", kwargs, self._DEFAULT_MAX_TURNS):
341+
if self._get_run_param("max_turns_raises", kwargs, default=True):
342+
raise MaxTurnsReachedError(
343+
error=BaseError(message="max tool iterations reached"),
344+
response=None,
345+
tool_call_requests=run.tool_call_requests,
346+
)
347+
return run
329348
with_reply = await self._execute_tools(
330349
run_id=run.id,
331350
tool_call_requests=run.tool_call_requests,
@@ -368,7 +387,9 @@ async def run(
368387
max_retry_delay (Optional[float], optional): The maximum delay between retries in milliseconds.
369388
Defaults to 60000.
370389
max_retry_count (Optional[float], optional): The maximum number of retry attempts. Defaults to 1.
371-
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10.
390+
max_turns (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10.
391+
max_turns_raises (Optional[bool], optional): Whether to raise an error when the maximum number of turns is
392+
reached. Defaults to True.
372393
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the output.
373394
374395
Returns:
@@ -385,7 +406,7 @@ async def run(
385406
res,
386407
prepared_run.schema_id,
387408
validator,
388-
current_iteration=0,
409+
current_iteration=1,
389410
# TODO[test]: add test with custom validator
390411
**new_kwargs,
391412
)
@@ -424,7 +445,7 @@ async def stream(
424445
max_retry_delay (Optional[float], optional): The maximum delay between retries in milliseconds.
425446
Defaults to 60000.
426447
max_retry_count (Optional[float], optional): The maximum number of retry attempts. Defaults to 1.
427-
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10.
448+
max_turns (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10.
428449
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the output.
429450
430451
Returns:

0 commit comments

Comments
 (0)