Skip to content

Task annotation #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions examples/city_to_capital_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from asyncio import run as aiorun

import typer
from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType]
from rich import print as rprint

from workflowai import Task, VersionReference
import workflowai


class CityToCapitalTaskInput(BaseModel):
Expand All @@ -17,10 +21,19 @@ class CityToCapitalTaskOutput(BaseModel):
)


class CityToCapitalTask(Task[CityToCapitalTaskInput, CityToCapitalTaskOutput]):
id: str = "citytocapital"
schema_id: int = 1
input_class: type[CityToCapitalTaskInput] = CityToCapitalTaskInput
output_class: type[CityToCapitalTaskOutput] = CityToCapitalTaskOutput
@workflowai.task(schema_id=1)
async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: ...


def main(city: str) -> None:
async def _inner() -> None:
task_input = CityToCapitalTaskInput(city=city)
task_output = await city_to_capital(task_input)

rprint(task_output)

aiorun(_inner())


version: VersionReference = 4
if __name__ == "__main__":
typer.run(main)
24 changes: 0 additions & 24 deletions examples/run_task.py

This file was deleted.

8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "workflowai"
version = "0.4.2"
version = "0.5.0a0"
description = ""
authors = ["Guillaume Aquilina <[email protected]>"]
readme = "README.md"
Expand All @@ -13,7 +13,7 @@ httpx = "^0.27.0"


[tool.poetry.group.dev.dependencies]
pyright = "^1.1.389"
pyright = "^1.1.390"
pytest = "^8.2.2"
pytest-asyncio = "^0.24.0"
ruff = "^0.7.4"
Expand Down
7 changes: 4 additions & 3 deletions tests/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import pytest
from dotenv import load_dotenv

from workflowai import Client, start
from workflowai import Client
from workflowai.core.client._client import WorkflowAIClient

load_dotenv()


@pytest.fixture(scope="session")
def wai() -> Client:
return start(
url=os.environ["WORKFLOWAI_TEST_API_URL"],
return WorkflowAIClient(
endpoint=os.environ["WORKFLOWAI_TEST_API_URL"],
api_key=os.environ["WORKFLOWAI_TEST_API_KEY"],
)
8 changes: 7 additions & 1 deletion tests/e2e/run_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Optional
from typing import AsyncIterator, Optional

from pydantic import BaseModel

Expand All @@ -23,6 +23,12 @@ class ExtractProductReviewSentimentTaskOutput(BaseModel):
sentiment: Optional[Sentiment] = None


@workflowai.task(schema_id=1)
def extract_product_review_sentiment(
task_input: ExtractProductReviewSentimentTaskInput,
) -> AsyncIterator[ExtractProductReviewSentimentTaskOutput]: ...


class ExtractProductReviewSentimentTask(
Task[ExtractProductReviewSentimentTaskInput, ExtractProductReviewSentimentTaskOutput],
):
Expand Down
Empty file added tests/integration/__init__.py
Empty file.
134 changes: 134 additions & 0 deletions tests/integration/run_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import json
from typing import Any, AsyncIterator, Optional

from httpx import Request
from pydantic import BaseModel
from pytest_httpx import HTTPXMock, IteratorStream

import workflowai
from workflowai.core.domain.task_run import Run


class CityToCapitalTaskInput(BaseModel):
city: str


class CityToCapitalTaskOutput(BaseModel):
capital: str


workflowai.init(api_key="test", url="http://localhost:8000")


def _mock_response(httpx_mock: HTTPXMock, task_id: str = "city-to-capital"):
httpx_mock.add_response(
method="POST",
url=f"http://localhost:8000/v1/_/tasks/{task_id}/schemas/1/run",
json={"id": "123", "task_output": {"capital": "Tokyo"}},
)


def _mock_stream(httpx_mock: HTTPXMock, task_id: str = "city-to-capital"):
httpx_mock.add_response(
url=f"http://localhost:8000/v1/_/tasks/{task_id}/schemas/1/run",
stream=IteratorStream(
[
b'data: {"id":"1","task_output":{"capital":""}}\n\n',
b'data: {"id":"1","task_output":{"capital":"Tok"}}\n\ndata: {"id":"1","task_output":{"capital":"Tokyo"}}\n\n', # noqa: E501
b'data: {"id":"1","task_output":{"capital":"Tokyo"},"cost_usd":0.01,"duration_seconds":10.1}\n\n',
],
),
)


def _check_request(request: Optional[Request], version: Any = "production", task_id: str = "city-to-capital"):
assert request is not None
assert request.url == f"http://localhost:8000/v1/_/tasks/{task_id}/schemas/1/run"
body = json.loads(request.content)
assert body == {
"task_input": {"city": "Hello"},
"version": version,
"stream": False,
}
assert request.headers["Authorization"] == "Bearer test"
assert request.headers["Content-Type"] == "application/json"
assert request.headers["x-workflowai-source"] == "sdk"
assert request.headers["x-workflowai-language"] == "python"


async def test_run_task(httpx_mock: HTTPXMock) -> None:
@workflowai.task(schema_id=1)
async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: ...

_mock_response(httpx_mock)

task_input = CityToCapitalTaskInput(city="Hello")
task_output = await city_to_capital(task_input)

assert task_output.capital == "Tokyo"

_check_request(httpx_mock.get_request())


async def test_run_task_run(httpx_mock: HTTPXMock) -> None:
@workflowai.task(schema_id=1)
async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapitalTaskOutput]: ...

_mock_response(httpx_mock)

task_input = CityToCapitalTaskInput(city="Hello")
with_run = await city_to_capital(task_input)

assert with_run.id == "123"
assert with_run.task_output.capital == "Tokyo"

_check_request(httpx_mock.get_request())


async def test_run_task_run_version(httpx_mock: HTTPXMock) -> None:
@workflowai.task(schema_id=1, version="staging")
async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapitalTaskOutput]: ...

_mock_response(httpx_mock)

task_input = CityToCapitalTaskInput(city="Hello")
with_run = await city_to_capital(task_input)

assert with_run.id == "123"
assert with_run.task_output.capital == "Tokyo"

_check_request(httpx_mock.get_request(), version="staging")


async def test_stream_task_run(httpx_mock: HTTPXMock) -> None:
@workflowai.task(schema_id=1)
def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToCapitalTaskOutput]: ...

_mock_stream(httpx_mock)

task_input = CityToCapitalTaskInput(city="Hello")
chunks = [chunk async for chunk in city_to_capital(task_input)]

assert chunks == [
CityToCapitalTaskOutput(capital=""),
CityToCapitalTaskOutput(capital="Tok"),
CityToCapitalTaskOutput(capital="Tokyo"),
CityToCapitalTaskOutput(capital="Tokyo"),
]


async def test_stream_task_run_custom_id(httpx_mock: HTTPXMock) -> None:
@workflowai.task(schema_id=1, task_id="custom-id")
def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToCapitalTaskOutput]: ...

_mock_stream(httpx_mock, task_id="custom-id")

task_input = CityToCapitalTaskInput(city="Hello")
chunks = [chunk async for chunk in city_to_capital(task_input)]

assert chunks == [
CityToCapitalTaskOutput(capital=""),
CityToCapitalTaskOutput(capital="Tok"),
CityToCapitalTaskOutput(capital="Tokyo"),
CityToCapitalTaskOutput(capital="Tokyo"),
]
25 changes: 22 additions & 3 deletions workflowai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
import os
from typing import Optional

from workflowai.core.client import Client as Client
from workflowai.core.client._client import DEFAULT_VERSION_REFERENCE
from workflowai.core.client._client import WorkflowAIClient as WorkflowAIClient
from workflowai.core.client._types import TaskDecorator
from workflowai.core.domain.cache_usage import CacheUsage as CacheUsage
from workflowai.core.domain.errors import WorkflowAIError as WorkflowAIError
from workflowai.core.domain.task import Task as Task
from workflowai.core.domain.task_run import Run as Run
from workflowai.core.domain.task_version import TaskVersion as TaskVersion
from workflowai.core.domain.task_version_reference import (
VersionReference as VersionReference,
)

# By default the shared client is created using the default environment variables
_shared_client = WorkflowAIClient(
endpoint=os.getenv("WORKFLOWAI_API_URL"),
api_key=os.getenv("WORKFLOWAI_API_KEY", ""),
)


def start(url: Optional[str] = None, api_key: Optional[str] = None) -> Client:
def init(api_key: str, url: Optional[str] = None):
"""Create a new workflowai client

Args:
Expand All @@ -21,6 +32,14 @@ def start(url: Optional[str] = None, api_key: Optional[str] = None) -> Client:
Returns:
client.Client: a client instance
"""
from workflowai.core.client.client import WorkflowAIClient

return WorkflowAIClient(url, api_key)
global _shared_client # noqa: PLW0603
_shared_client = WorkflowAIClient(endpoint=url, api_key=api_key)


def task(
schema_id: int,
task_id: Optional[str] = None,
version: VersionReference = DEFAULT_VERSION_REFERENCE,
) -> TaskDecorator:
return _shared_client.task(schema_id, task_id, version)
76 changes: 1 addition & 75 deletions workflowai/core/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,75 +1 @@
from typing import Any, AsyncIterator, Literal, Optional, Protocol, Union, overload

from workflowai.core.domain.cache_usage import CacheUsage
from workflowai.core.domain.task import Task, TaskInput, TaskOutput
from workflowai.core.domain.task_run import Run, RunChunk
from workflowai.core.domain.task_version_reference import VersionReference


class Client(Protocol):
"""A client to interact with the WorkflowAI API"""

@overload
async def run(
self,
task: Task[TaskInput, TaskOutput],
task_input: TaskInput,
stream: Literal[False] = False,
version: Optional[VersionReference] = None,
use_cache: CacheUsage = "when_available",
metadata: Optional[dict[str, Any]] = None,
max_retry_delay: float = 60,
max_retry_count: float = 1,
) -> Run[TaskOutput]: ...

@overload
async def run(
self,
task: Task[TaskInput, TaskOutput],
task_input: TaskInput,
stream: Literal[True] = True,
version: Optional[VersionReference] = None,
use_cache: CacheUsage = "when_available",
metadata: Optional[dict[str, Any]] = None,
max_retry_delay: float = 60,
max_retry_count: float = 1,
) -> AsyncIterator[Union[RunChunk[TaskOutput], Run[TaskOutput]]]: ...

async def run(
self,
task: Task[TaskInput, TaskOutput],
task_input: TaskInput,
stream: bool = False,
version: Optional[VersionReference] = None,
use_cache: CacheUsage = "when_available",
metadata: Optional[dict[str, Any]] = None,
max_retry_delay: float = 60,
max_retry_count: float = 1,
) -> Union[Run[TaskOutput], AsyncIterator[Union[RunChunk[TaskOutput], Run[TaskOutput]]]]:
"""Run a task

Args:
task (Task[TaskInput, TaskOutput]): the task to run
task_input (TaskInput): the input to the task
version (Optional[TaskVersionReference], optional): the version of the task to run. If not provided,
the version defined in the task is used. Defaults to None.
environment (Optional[str], optional): the environment to run the task in. If not provided, the environment
defined in the task is used. Defaults to None.
iteration (Optional[int], optional): the iteration of the task to run. If not provided, the iteration
defined in the task is used. Defaults to None.
stream (bool, optional): whether to stream the output. If True, the function returns an async iterator of
partial output objects. Defaults to False.
use_cache (CacheUsage, optional): how to use the cache. Defaults to "when_available".
labels (Optional[set[str]], optional): a set of labels to attach to the run.
Labels are indexed and searchable. Defaults to None.
metadata (Optional[dict[str, Any]], optional): a dictionary of metadata to attach to the run.
Defaults to None.
retry_delay (int, optional): The initial delay between retries in milliseconds. Defaults to 5000.
max_retry_delay (int, optional): The maximum delay between retries in milliseconds. Defaults to 60000.
max_retry_count (int, optional): The maximum number of retry attempts. Defaults to 1.

Returns:
Union[TaskRun[TaskInput, TaskOutput], AsyncIterator[TaskOutput]]: the task run object
or an async iterator of output objects
"""
...
from ._types import Client as Client
Loading
Loading