Skip to content

Commit 8229c4f

Browse files
authored
Merge pull request #23 from WorkflowAI/guillaume/task-annotation
Task annotation
2 parents 1712f61 + 5f9d6a9 commit 8229c4f

23 files changed

+813
-246
lines changed

examples/city_to_capital_task.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
from asyncio import run as aiorun
2+
3+
import typer
14
from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType]
5+
from rich import print as rprint
26

3-
from workflowai import Task, VersionReference
7+
import workflowai
48

59

610
class CityToCapitalTaskInput(BaseModel):
@@ -17,10 +21,19 @@ class CityToCapitalTaskOutput(BaseModel):
1721
)
1822

1923

20-
class CityToCapitalTask(Task[CityToCapitalTaskInput, CityToCapitalTaskOutput]):
21-
id: str = "citytocapital"
22-
schema_id: int = 1
23-
input_class: type[CityToCapitalTaskInput] = CityToCapitalTaskInput
24-
output_class: type[CityToCapitalTaskOutput] = CityToCapitalTaskOutput
24+
@workflowai.task(schema_id=1)
25+
async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: ...
26+
27+
28+
def main(city: str) -> None:
29+
async def _inner() -> None:
30+
task_input = CityToCapitalTaskInput(city=city)
31+
task_output = await city_to_capital(task_input)
32+
33+
rprint(task_output)
34+
35+
aiorun(_inner())
36+
2537

26-
version: VersionReference = 4
38+
if __name__ == "__main__":
39+
typer.run(main)

examples/run_task.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "workflowai"
3-
version = "0.4.2"
3+
version = "0.5.0a0"
44
description = ""
55
authors = ["Guillaume Aquilina <[email protected]>"]
66
readme = "README.md"
@@ -13,7 +13,7 @@ httpx = "^0.27.0"
1313

1414

1515
[tool.poetry.group.dev.dependencies]
16-
pyright = "^1.1.389"
16+
pyright = "^1.1.390"
1717
pytest = "^8.2.2"
1818
pytest-asyncio = "^0.24.0"
1919
ruff = "^0.7.4"

tests/e2e/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import pytest
44
from dotenv import load_dotenv
55

6-
from workflowai import Client, start
6+
from workflowai import Client
7+
from workflowai.core.client._client import WorkflowAIClient
78

89
load_dotenv()
910

1011

1112
@pytest.fixture(scope="session")
1213
def wai() -> Client:
13-
return start(
14-
url=os.environ["WORKFLOWAI_TEST_API_URL"],
14+
return WorkflowAIClient(
15+
endpoint=os.environ["WORKFLOWAI_TEST_API_URL"],
1516
api_key=os.environ["WORKFLOWAI_TEST_API_KEY"],
1617
)

tests/e2e/run_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import Optional
2+
from typing import AsyncIterator, Optional
33

44
from pydantic import BaseModel
55

@@ -23,6 +23,12 @@ class ExtractProductReviewSentimentTaskOutput(BaseModel):
2323
sentiment: Optional[Sentiment] = None
2424

2525

26+
@workflowai.task(schema_id=1)
27+
def extract_product_review_sentiment(
28+
task_input: ExtractProductReviewSentimentTaskInput,
29+
) -> AsyncIterator[ExtractProductReviewSentimentTaskOutput]: ...
30+
31+
2632
class ExtractProductReviewSentimentTask(
2733
Task[ExtractProductReviewSentimentTaskInput, ExtractProductReviewSentimentTaskOutput],
2834
):

tests/integration/__init__.py

Whitespace-only changes.

tests/integration/run_test.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import json
2+
from typing import Any, AsyncIterator, Optional
3+
4+
from httpx import Request
5+
from pydantic import BaseModel
6+
from pytest_httpx import HTTPXMock, IteratorStream
7+
8+
import workflowai
9+
from workflowai.core.domain.task_run import Run
10+
11+
12+
class CityToCapitalTaskInput(BaseModel):
13+
city: str
14+
15+
16+
class CityToCapitalTaskOutput(BaseModel):
17+
capital: str
18+
19+
20+
workflowai.init(api_key="test", url="http://localhost:8000")
21+
22+
23+
def _mock_response(httpx_mock: HTTPXMock, task_id: str = "city-to-capital"):
24+
httpx_mock.add_response(
25+
method="POST",
26+
url=f"http://localhost:8000/v1/_/tasks/{task_id}/schemas/1/run",
27+
json={"id": "123", "task_output": {"capital": "Tokyo"}},
28+
)
29+
30+
31+
def _mock_stream(httpx_mock: HTTPXMock, task_id: str = "city-to-capital"):
32+
httpx_mock.add_response(
33+
url=f"http://localhost:8000/v1/_/tasks/{task_id}/schemas/1/run",
34+
stream=IteratorStream(
35+
[
36+
b'data: {"id":"1","task_output":{"capital":""}}\n\n',
37+
b'data: {"id":"1","task_output":{"capital":"Tok"}}\n\ndata: {"id":"1","task_output":{"capital":"Tokyo"}}\n\n', # noqa: E501
38+
b'data: {"id":"1","task_output":{"capital":"Tokyo"},"cost_usd":0.01,"duration_seconds":10.1}\n\n',
39+
],
40+
),
41+
)
42+
43+
44+
def _check_request(request: Optional[Request], version: Any = "production", task_id: str = "city-to-capital"):
45+
assert request is not None
46+
assert request.url == f"http://localhost:8000/v1/_/tasks/{task_id}/schemas/1/run"
47+
body = json.loads(request.content)
48+
assert body == {
49+
"task_input": {"city": "Hello"},
50+
"version": version,
51+
"stream": False,
52+
}
53+
assert request.headers["Authorization"] == "Bearer test"
54+
assert request.headers["Content-Type"] == "application/json"
55+
assert request.headers["x-workflowai-source"] == "sdk"
56+
assert request.headers["x-workflowai-language"] == "python"
57+
58+
59+
async def test_run_task(httpx_mock: HTTPXMock) -> None:
60+
@workflowai.task(schema_id=1)
61+
async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: ...
62+
63+
_mock_response(httpx_mock)
64+
65+
task_input = CityToCapitalTaskInput(city="Hello")
66+
task_output = await city_to_capital(task_input)
67+
68+
assert task_output.capital == "Tokyo"
69+
70+
_check_request(httpx_mock.get_request())
71+
72+
73+
async def test_run_task_run(httpx_mock: HTTPXMock) -> None:
74+
@workflowai.task(schema_id=1)
75+
async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapitalTaskOutput]: ...
76+
77+
_mock_response(httpx_mock)
78+
79+
task_input = CityToCapitalTaskInput(city="Hello")
80+
with_run = await city_to_capital(task_input)
81+
82+
assert with_run.id == "123"
83+
assert with_run.task_output.capital == "Tokyo"
84+
85+
_check_request(httpx_mock.get_request())
86+
87+
88+
async def test_run_task_run_version(httpx_mock: HTTPXMock) -> None:
89+
@workflowai.task(schema_id=1, version="staging")
90+
async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapitalTaskOutput]: ...
91+
92+
_mock_response(httpx_mock)
93+
94+
task_input = CityToCapitalTaskInput(city="Hello")
95+
with_run = await city_to_capital(task_input)
96+
97+
assert with_run.id == "123"
98+
assert with_run.task_output.capital == "Tokyo"
99+
100+
_check_request(httpx_mock.get_request(), version="staging")
101+
102+
103+
async def test_stream_task_run(httpx_mock: HTTPXMock) -> None:
104+
@workflowai.task(schema_id=1)
105+
def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToCapitalTaskOutput]: ...
106+
107+
_mock_stream(httpx_mock)
108+
109+
task_input = CityToCapitalTaskInput(city="Hello")
110+
chunks = [chunk async for chunk in city_to_capital(task_input)]
111+
112+
assert chunks == [
113+
CityToCapitalTaskOutput(capital=""),
114+
CityToCapitalTaskOutput(capital="Tok"),
115+
CityToCapitalTaskOutput(capital="Tokyo"),
116+
CityToCapitalTaskOutput(capital="Tokyo"),
117+
]
118+
119+
120+
async def test_stream_task_run_custom_id(httpx_mock: HTTPXMock) -> None:
121+
@workflowai.task(schema_id=1, task_id="custom-id")
122+
def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToCapitalTaskOutput]: ...
123+
124+
_mock_stream(httpx_mock, task_id="custom-id")
125+
126+
task_input = CityToCapitalTaskInput(city="Hello")
127+
chunks = [chunk async for chunk in city_to_capital(task_input)]
128+
129+
assert chunks == [
130+
CityToCapitalTaskOutput(capital=""),
131+
CityToCapitalTaskOutput(capital="Tok"),
132+
CityToCapitalTaskOutput(capital="Tokyo"),
133+
CityToCapitalTaskOutput(capital="Tokyo"),
134+
]

workflowai/__init__.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
1+
import os
12
from typing import Optional
23

34
from workflowai.core.client import Client as Client
5+
from workflowai.core.client._client import DEFAULT_VERSION_REFERENCE
6+
from workflowai.core.client._client import WorkflowAIClient as WorkflowAIClient
7+
from workflowai.core.client._types import TaskDecorator
48
from workflowai.core.domain.cache_usage import CacheUsage as CacheUsage
59
from workflowai.core.domain.errors import WorkflowAIError as WorkflowAIError
610
from workflowai.core.domain.task import Task as Task
11+
from workflowai.core.domain.task_run import Run as Run
712
from workflowai.core.domain.task_version import TaskVersion as TaskVersion
813
from workflowai.core.domain.task_version_reference import (
914
VersionReference as VersionReference,
1015
)
1116

17+
# By default the shared client is created using the default environment variables
18+
_shared_client = WorkflowAIClient(
19+
endpoint=os.getenv("WORKFLOWAI_API_URL"),
20+
api_key=os.getenv("WORKFLOWAI_API_KEY", ""),
21+
)
22+
1223

13-
def start(url: Optional[str] = None, api_key: Optional[str] = None) -> Client:
24+
def init(api_key: str, url: Optional[str] = None):
1425
"""Create a new workflowai client
1526
1627
Args:
@@ -21,6 +32,14 @@ def start(url: Optional[str] = None, api_key: Optional[str] = None) -> Client:
2132
Returns:
2233
client.Client: a client instance
2334
"""
24-
from workflowai.core.client.client import WorkflowAIClient
2535

26-
return WorkflowAIClient(url, api_key)
36+
global _shared_client # noqa: PLW0603
37+
_shared_client = WorkflowAIClient(endpoint=url, api_key=api_key)
38+
39+
40+
def task(
41+
schema_id: int,
42+
task_id: Optional[str] = None,
43+
version: VersionReference = DEFAULT_VERSION_REFERENCE,
44+
) -> TaskDecorator:
45+
return _shared_client.task(schema_id, task_id, version)

workflowai/core/client/__init__.py

Lines changed: 1 addition & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1 @@
1-
from typing import Any, AsyncIterator, Literal, Optional, Protocol, Union, overload
2-
3-
from workflowai.core.domain.cache_usage import CacheUsage
4-
from workflowai.core.domain.task import Task, TaskInput, TaskOutput
5-
from workflowai.core.domain.task_run import Run, RunChunk
6-
from workflowai.core.domain.task_version_reference import VersionReference
7-
8-
9-
class Client(Protocol):
10-
"""A client to interact with the WorkflowAI API"""
11-
12-
@overload
13-
async def run(
14-
self,
15-
task: Task[TaskInput, TaskOutput],
16-
task_input: TaskInput,
17-
stream: Literal[False] = False,
18-
version: Optional[VersionReference] = None,
19-
use_cache: CacheUsage = "when_available",
20-
metadata: Optional[dict[str, Any]] = None,
21-
max_retry_delay: float = 60,
22-
max_retry_count: float = 1,
23-
) -> Run[TaskOutput]: ...
24-
25-
@overload
26-
async def run(
27-
self,
28-
task: Task[TaskInput, TaskOutput],
29-
task_input: TaskInput,
30-
stream: Literal[True] = True,
31-
version: Optional[VersionReference] = None,
32-
use_cache: CacheUsage = "when_available",
33-
metadata: Optional[dict[str, Any]] = None,
34-
max_retry_delay: float = 60,
35-
max_retry_count: float = 1,
36-
) -> AsyncIterator[Union[RunChunk[TaskOutput], Run[TaskOutput]]]: ...
37-
38-
async def run(
39-
self,
40-
task: Task[TaskInput, TaskOutput],
41-
task_input: TaskInput,
42-
stream: bool = False,
43-
version: Optional[VersionReference] = None,
44-
use_cache: CacheUsage = "when_available",
45-
metadata: Optional[dict[str, Any]] = None,
46-
max_retry_delay: float = 60,
47-
max_retry_count: float = 1,
48-
) -> Union[Run[TaskOutput], AsyncIterator[Union[RunChunk[TaskOutput], Run[TaskOutput]]]]:
49-
"""Run a task
50-
51-
Args:
52-
task (Task[TaskInput, TaskOutput]): the task to run
53-
task_input (TaskInput): the input to the task
54-
version (Optional[TaskVersionReference], optional): the version of the task to run. If not provided,
55-
the version defined in the task is used. Defaults to None.
56-
environment (Optional[str], optional): the environment to run the task in. If not provided, the environment
57-
defined in the task is used. Defaults to None.
58-
iteration (Optional[int], optional): the iteration of the task to run. If not provided, the iteration
59-
defined in the task is used. Defaults to None.
60-
stream (bool, optional): whether to stream the output. If True, the function returns an async iterator of
61-
partial output objects. Defaults to False.
62-
use_cache (CacheUsage, optional): how to use the cache. Defaults to "when_available".
63-
labels (Optional[set[str]], optional): a set of labels to attach to the run.
64-
Labels are indexed and searchable. Defaults to None.
65-
metadata (Optional[dict[str, Any]], optional): a dictionary of metadata to attach to the run.
66-
Defaults to None.
67-
retry_delay (int, optional): The initial delay between retries in milliseconds. Defaults to 5000.
68-
max_retry_delay (int, optional): The maximum delay between retries in milliseconds. Defaults to 60000.
69-
max_retry_count (int, optional): The maximum number of retry attempts. Defaults to 1.
70-
71-
Returns:
72-
Union[TaskRun[TaskInput, TaskOutput], AsyncIterator[TaskOutput]]: the task run object
73-
or an async iterator of output objects
74-
"""
75-
...
1+
from ._types import Client as Client

0 commit comments

Comments
 (0)