diff --git a/examples/city_to_capital_task.py b/examples/city_to_capital_task.py index 45880a0..e0045dd 100644 --- a/examples/city_to_capital_task.py +++ b/examples/city_to_capital_task.py @@ -1,6 +1,10 @@ +from asyncio import run as aiorun + +import typer from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType] +from rich import print as rprint -from workflowai import Task, VersionReference +import workflowai class CityToCapitalTaskInput(BaseModel): @@ -17,10 +21,19 @@ class CityToCapitalTaskOutput(BaseModel): ) -class CityToCapitalTask(Task[CityToCapitalTaskInput, CityToCapitalTaskOutput]): - id: str = "citytocapital" - schema_id: int = 1 - input_class: type[CityToCapitalTaskInput] = CityToCapitalTaskInput - output_class: type[CityToCapitalTaskOutput] = CityToCapitalTaskOutput +@workflowai.task(schema_id=1) +async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: ... + + +def main(city: str) -> None: + async def _inner() -> None: + task_input = CityToCapitalTaskInput(city=city) + task_output = await city_to_capital(task_input) + + rprint(task_output) + + aiorun(_inner()) + - version: VersionReference = 4 +if __name__ == "__main__": + typer.run(main) diff --git a/examples/run_task.py b/examples/run_task.py deleted file mode 100644 index 8d2df90..0000000 --- a/examples/run_task.py +++ /dev/null @@ -1,24 +0,0 @@ -from asyncio import run as aiorun - -import typer -from rich import print as rprint - -import workflowai -from examples.city_to_capital_task import CityToCapitalTask, CityToCapitalTaskInput - - -def main(city: str) -> None: - client = workflowai.start() - task = CityToCapitalTask() - - async def _inner() -> None: - task_input = CityToCapitalTaskInput(city=city) - task_run = await client.run(task, task_input) - - rprint(task_run.task_output) - - aiorun(_inner()) - - -if __name__ == "__main__": - typer.run(main) diff --git a/poetry.lock b/poetry.lock index 564cdb1..d9fb309 100644 --- a/poetry.lock +++ b/poetry.lock @@ -485,13 +485,13 @@ windows-terminal = ["colorama (>=0.4.6)"] [[package]] name = "pyright" -version = "1.1.389" +version = "1.1.390" description = "Command line wrapper for pyright" optional = false python-versions = ">=3.7" files = [ - {file = "pyright-1.1.389-py3-none-any.whl", hash = "sha256:41e9620bba9254406dc1f621a88ceab5a88af4c826feb4f614d95691ed243a60"}, - {file = "pyright-1.1.389.tar.gz", hash = "sha256:716bf8cc174ab8b4dcf6828c3298cac05c5ed775dda9910106a5dcfe4c7fe220"}, + {file = "pyright-1.1.390-py3-none-any.whl", hash = "sha256:ecebfba5b6b50af7c1a44c2ba144ba2ab542c227eb49bc1f16984ff714e0e110"}, + {file = "pyright-1.1.390.tar.gz", hash = "sha256:aad7f160c49e0fbf8209507a15e17b781f63a86a1facb69ca877c71ef2e9538d"}, ] [package.dependencies] @@ -792,4 +792,4 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "81f3624451cc3d6628d5bc836a01afbc8b79bc64ff4700e824fa92a16a64ffb5" +content-hash = "443ebf66cd815e9cec42d1f95cf15810a564acbb10cebdccfff1846839a382aa" diff --git a/pyproject.toml b/pyproject.toml index a3d9039..ff27e4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.4.2" +version = "0.5.0a0" description = "" authors = ["Guillaume Aquilina "] readme = "README.md" @@ -13,7 +13,7 @@ httpx = "^0.27.0" [tool.poetry.group.dev.dependencies] -pyright = "^1.1.389" +pyright = "^1.1.390" pytest = "^8.2.2" pytest-asyncio = "^0.24.0" ruff = "^0.7.4" diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 24e60cb..69c109a 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -3,14 +3,15 @@ import pytest from dotenv import load_dotenv -from workflowai import Client, start +from workflowai import Client +from workflowai.core.client._client import WorkflowAIClient load_dotenv() @pytest.fixture(scope="session") def wai() -> Client: - return start( - url=os.environ["WORKFLOWAI_TEST_API_URL"], + return WorkflowAIClient( + endpoint=os.environ["WORKFLOWAI_TEST_API_URL"], api_key=os.environ["WORKFLOWAI_TEST_API_KEY"], ) diff --git a/tests/e2e/run_test.py b/tests/e2e/run_test.py index 90b2002..c623def 100644 --- a/tests/e2e/run_test.py +++ b/tests/e2e/run_test.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional +from typing import AsyncIterator, Optional from pydantic import BaseModel @@ -23,6 +23,12 @@ class ExtractProductReviewSentimentTaskOutput(BaseModel): sentiment: Optional[Sentiment] = None +@workflowai.task(schema_id=1) +def extract_product_review_sentiment( + task_input: ExtractProductReviewSentimentTaskInput, +) -> AsyncIterator[ExtractProductReviewSentimentTaskOutput]: ... + + class ExtractProductReviewSentimentTask( Task[ExtractProductReviewSentimentTaskInput, ExtractProductReviewSentimentTaskOutput], ): diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/run_test.py b/tests/integration/run_test.py new file mode 100644 index 0000000..dc3ab92 --- /dev/null +++ b/tests/integration/run_test.py @@ -0,0 +1,134 @@ +import json +from typing import Any, AsyncIterator, Optional + +from httpx import Request +from pydantic import BaseModel +from pytest_httpx import HTTPXMock, IteratorStream + +import workflowai +from workflowai.core.domain.task_run import Run + + +class CityToCapitalTaskInput(BaseModel): + city: str + + +class CityToCapitalTaskOutput(BaseModel): + capital: str + + +workflowai.init(api_key="test", url="http://localhost:8000") + + +def _mock_response(httpx_mock: HTTPXMock, task_id: str = "city-to-capital"): + httpx_mock.add_response( + method="POST", + url=f"http://localhost:8000/v1/_/tasks/{task_id}/schemas/1/run", + json={"id": "123", "task_output": {"capital": "Tokyo"}}, + ) + + +def _mock_stream(httpx_mock: HTTPXMock, task_id: str = "city-to-capital"): + httpx_mock.add_response( + url=f"http://localhost:8000/v1/_/tasks/{task_id}/schemas/1/run", + stream=IteratorStream( + [ + b'data: {"id":"1","task_output":{"capital":""}}\n\n', + b'data: {"id":"1","task_output":{"capital":"Tok"}}\n\ndata: {"id":"1","task_output":{"capital":"Tokyo"}}\n\n', # noqa: E501 + b'data: {"id":"1","task_output":{"capital":"Tokyo"},"cost_usd":0.01,"duration_seconds":10.1}\n\n', + ], + ), + ) + + +def _check_request(request: Optional[Request], version: Any = "production", task_id: str = "city-to-capital"): + assert request is not None + assert request.url == f"http://localhost:8000/v1/_/tasks/{task_id}/schemas/1/run" + body = json.loads(request.content) + assert body == { + "task_input": {"city": "Hello"}, + "version": version, + "stream": False, + } + assert request.headers["Authorization"] == "Bearer test" + assert request.headers["Content-Type"] == "application/json" + assert request.headers["x-workflowai-source"] == "sdk" + assert request.headers["x-workflowai-language"] == "python" + + +async def test_run_task(httpx_mock: HTTPXMock) -> None: + @workflowai.task(schema_id=1) + async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: ... + + _mock_response(httpx_mock) + + task_input = CityToCapitalTaskInput(city="Hello") + task_output = await city_to_capital(task_input) + + assert task_output.capital == "Tokyo" + + _check_request(httpx_mock.get_request()) + + +async def test_run_task_run(httpx_mock: HTTPXMock) -> None: + @workflowai.task(schema_id=1) + async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapitalTaskOutput]: ... + + _mock_response(httpx_mock) + + task_input = CityToCapitalTaskInput(city="Hello") + with_run = await city_to_capital(task_input) + + assert with_run.id == "123" + assert with_run.task_output.capital == "Tokyo" + + _check_request(httpx_mock.get_request()) + + +async def test_run_task_run_version(httpx_mock: HTTPXMock) -> None: + @workflowai.task(schema_id=1, version="staging") + async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapitalTaskOutput]: ... + + _mock_response(httpx_mock) + + task_input = CityToCapitalTaskInput(city="Hello") + with_run = await city_to_capital(task_input) + + assert with_run.id == "123" + assert with_run.task_output.capital == "Tokyo" + + _check_request(httpx_mock.get_request(), version="staging") + + +async def test_stream_task_run(httpx_mock: HTTPXMock) -> None: + @workflowai.task(schema_id=1) + def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToCapitalTaskOutput]: ... + + _mock_stream(httpx_mock) + + task_input = CityToCapitalTaskInput(city="Hello") + chunks = [chunk async for chunk in city_to_capital(task_input)] + + assert chunks == [ + CityToCapitalTaskOutput(capital=""), + CityToCapitalTaskOutput(capital="Tok"), + CityToCapitalTaskOutput(capital="Tokyo"), + CityToCapitalTaskOutput(capital="Tokyo"), + ] + + +async def test_stream_task_run_custom_id(httpx_mock: HTTPXMock) -> None: + @workflowai.task(schema_id=1, task_id="custom-id") + def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToCapitalTaskOutput]: ... + + _mock_stream(httpx_mock, task_id="custom-id") + + task_input = CityToCapitalTaskInput(city="Hello") + chunks = [chunk async for chunk in city_to_capital(task_input)] + + assert chunks == [ + CityToCapitalTaskOutput(capital=""), + CityToCapitalTaskOutput(capital="Tok"), + CityToCapitalTaskOutput(capital="Tokyo"), + CityToCapitalTaskOutput(capital="Tokyo"), + ] diff --git a/workflowai/__init__.py b/workflowai/__init__.py index 9d96059..ad21e64 100644 --- a/workflowai/__init__.py +++ b/workflowai/__init__.py @@ -1,16 +1,27 @@ +import os from typing import Optional from workflowai.core.client import Client as Client +from workflowai.core.client._client import DEFAULT_VERSION_REFERENCE +from workflowai.core.client._client import WorkflowAIClient as WorkflowAIClient +from workflowai.core.client._types import TaskDecorator from workflowai.core.domain.cache_usage import CacheUsage as CacheUsage from workflowai.core.domain.errors import WorkflowAIError as WorkflowAIError from workflowai.core.domain.task import Task as Task +from workflowai.core.domain.task_run import Run as Run from workflowai.core.domain.task_version import TaskVersion as TaskVersion from workflowai.core.domain.task_version_reference import ( VersionReference as VersionReference, ) +# By default the shared client is created using the default environment variables +_shared_client = WorkflowAIClient( + endpoint=os.getenv("WORKFLOWAI_API_URL"), + api_key=os.getenv("WORKFLOWAI_API_KEY", ""), +) + -def start(url: Optional[str] = None, api_key: Optional[str] = None) -> Client: +def init(api_key: str, url: Optional[str] = None): """Create a new workflowai client Args: @@ -21,6 +32,14 @@ def start(url: Optional[str] = None, api_key: Optional[str] = None) -> Client: Returns: client.Client: a client instance """ - from workflowai.core.client.client import WorkflowAIClient - return WorkflowAIClient(url, api_key) + global _shared_client # noqa: PLW0603 + _shared_client = WorkflowAIClient(endpoint=url, api_key=api_key) + + +def task( + schema_id: int, + task_id: Optional[str] = None, + version: VersionReference = DEFAULT_VERSION_REFERENCE, +) -> TaskDecorator: + return _shared_client.task(schema_id, task_id, version) diff --git a/workflowai/core/client/__init__.py b/workflowai/core/client/__init__.py index 9865f9f..50758d4 100644 --- a/workflowai/core/client/__init__.py +++ b/workflowai/core/client/__init__.py @@ -1,75 +1 @@ -from typing import Any, AsyncIterator, Literal, Optional, Protocol, Union, overload - -from workflowai.core.domain.cache_usage import CacheUsage -from workflowai.core.domain.task import Task, TaskInput, TaskOutput -from workflowai.core.domain.task_run import Run, RunChunk -from workflowai.core.domain.task_version_reference import VersionReference - - -class Client(Protocol): - """A client to interact with the WorkflowAI API""" - - @overload - async def run( - self, - task: Task[TaskInput, TaskOutput], - task_input: TaskInput, - stream: Literal[False] = False, - version: Optional[VersionReference] = None, - use_cache: CacheUsage = "when_available", - metadata: Optional[dict[str, Any]] = None, - max_retry_delay: float = 60, - max_retry_count: float = 1, - ) -> Run[TaskOutput]: ... - - @overload - async def run( - self, - task: Task[TaskInput, TaskOutput], - task_input: TaskInput, - stream: Literal[True] = True, - version: Optional[VersionReference] = None, - use_cache: CacheUsage = "when_available", - metadata: Optional[dict[str, Any]] = None, - max_retry_delay: float = 60, - max_retry_count: float = 1, - ) -> AsyncIterator[Union[RunChunk[TaskOutput], Run[TaskOutput]]]: ... - - async def run( - self, - task: Task[TaskInput, TaskOutput], - task_input: TaskInput, - stream: bool = False, - version: Optional[VersionReference] = None, - use_cache: CacheUsage = "when_available", - metadata: Optional[dict[str, Any]] = None, - max_retry_delay: float = 60, - max_retry_count: float = 1, - ) -> Union[Run[TaskOutput], AsyncIterator[Union[RunChunk[TaskOutput], Run[TaskOutput]]]]: - """Run a task - - Args: - task (Task[TaskInput, TaskOutput]): the task to run - task_input (TaskInput): the input to the task - version (Optional[TaskVersionReference], optional): the version of the task to run. If not provided, - the version defined in the task is used. Defaults to None. - environment (Optional[str], optional): the environment to run the task in. If not provided, the environment - defined in the task is used. Defaults to None. - iteration (Optional[int], optional): the iteration of the task to run. If not provided, the iteration - defined in the task is used. Defaults to None. - stream (bool, optional): whether to stream the output. If True, the function returns an async iterator of - partial output objects. Defaults to False. - use_cache (CacheUsage, optional): how to use the cache. Defaults to "when_available". - labels (Optional[set[str]], optional): a set of labels to attach to the run. - Labels are indexed and searchable. Defaults to None. - metadata (Optional[dict[str, Any]], optional): a dictionary of metadata to attach to the run. - Defaults to None. - retry_delay (int, optional): The initial delay between retries in milliseconds. Defaults to 5000. - max_retry_delay (int, optional): The maximum delay between retries in milliseconds. Defaults to 60000. - max_retry_count (int, optional): The maximum number of retry attempts. Defaults to 1. - - Returns: - Union[TaskRun[TaskInput, TaskOutput], AsyncIterator[TaskOutput]]: the task run object - or an async iterator of output objects - """ - ... +from ._types import Client as Client diff --git a/workflowai/core/client/api.py b/workflowai/core/client/_api.py similarity index 96% rename from workflowai/core/client/api.py rename to workflowai/core/client/_api.py index 5416387..1634636 100644 --- a/workflowai/core/client/api.py +++ b/workflowai/core/client/_api.py @@ -3,7 +3,7 @@ import httpx from pydantic import BaseModel, TypeAdapter, ValidationError -from workflowai.core.client.utils import split_chunks +from workflowai.core.client._utils import split_chunks from workflowai.core.domain.errors import BaseError, ErrorResponse, WorkflowAIError # A type for return values @@ -17,9 +17,6 @@ def __init__(self, endpoint: str, api_key: str, source_headers: Optional[dict[st self.api_key = api_key self.source_headers = source_headers or {} - if not self.endpoint or not self.api_key: - raise ValueError("Missing API URL or key") - def _client(self) -> httpx.AsyncClient: source_headers = self.source_headers or {} client = httpx.AsyncClient( diff --git a/workflowai/core/client/api_test.py b/workflowai/core/client/_api_test.py similarity index 98% rename from workflowai/core/client/api_test.py rename to workflowai/core/client/_api_test.py index b346369..4b57eb0 100644 --- a/workflowai/core/client/api_test.py +++ b/workflowai/core/client/_api_test.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from pytest_httpx import HTTPXMock -from workflowai.core.client.api import APIClient +from workflowai.core.client._api import APIClient from workflowai.core.domain.errors import WorkflowAIError diff --git a/workflowai/core/client/client.py b/workflowai/core/client/_client.py similarity index 53% rename from workflowai/core/client/client.py rename to workflowai/core/client/_client.py index 4bdb912..81d158c 100644 --- a/workflowai/core/client/client.py +++ b/workflowai/core/client/_client.py @@ -1,8 +1,9 @@ +import functools import importlib.metadata +import logging import os from collections.abc import Awaitable, Callable from typing import ( - Any, AsyncIterator, Literal, Optional, @@ -10,33 +11,60 @@ overload, ) +from typing_extensions import Unpack + from workflowai.core.client import Client -from workflowai.core.client.api import APIClient -from workflowai.core.client.models import ( +from workflowai.core.client._api import APIClient +from workflowai.core.client._fn_utils import task_id_from_fn_name, wrap_run_template +from workflowai.core.client._models import ( RunRequest, RunResponse, - RunStreamChunk, ) -from workflowai.core.client.utils import build_retryable_wait -from workflowai.core.domain.cache_usage import CacheUsage +from workflowai.core.client._types import ( + FinalRunTemplate, + OutputValidator, + RunParams, + RunTemplate, + TaskDecorator, +) +from workflowai.core.client._utils import build_retryable_wait, tolerant_validator from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.task import Task, TaskInput, TaskOutput -from workflowai.core.domain.task_run import Run, RunChunk +from workflowai.core.domain.task_run import Run from workflowai.core.domain.task_version_reference import VersionReference +_logger = logging.getLogger("WorkflowAI") + + +def _compute_default_version_reference() -> VersionReference: + version = os.getenv("WORKFLOWAI_DEFAULT_VERSION") + if not version: + return "production" + + if version in {"dev", "staging", "production"}: + return version # pyright: ignore [reportReturnType] + + try: + return int(version) + except ValueError: + pass + + _logger.warning("Invalid default version: %s", version) + + return "production" + + +DEFAULT_VERSION_REFERENCE = _compute_default_version_reference() + class WorkflowAIClient(Client): - def __init__(self, endpoint: Optional[str] = None, api_key: Optional[str] = None): + def __init__(self, api_key: str, endpoint: Optional[str] = None): self.additional_headers = { "x-workflowai-source": "sdk", "x-workflowai-language": "python", "x-workflowai-version": importlib.metadata.version("workflowai"), } - self.api = APIClient( - endpoint or os.getenv("WORKFLOWAI_API_URL", "https://run.workflowai.com"), - api_key or os.getenv("WORKFLOWAI_API_KEY", ""), - self.additional_headers, - ) + self.api = APIClient(endpoint or "https://run.workflowai.com", api_key, self.additional_headers) @overload async def run( @@ -44,11 +72,7 @@ async def run( task: Task[TaskInput, TaskOutput], task_input: TaskInput, stream: Literal[False] = False, - version: Optional[VersionReference] = None, - use_cache: CacheUsage = "when_available", - metadata: Optional[dict[str, Any]] = None, - max_retry_delay: float = 60, - max_retry_count: float = 1, + **kwargs: Unpack[RunParams[TaskOutput]], ) -> Run[TaskOutput]: ... @overload @@ -57,65 +81,60 @@ async def run( task: Task[TaskInput, TaskOutput], task_input: TaskInput, stream: Literal[True] = True, - version: Optional[VersionReference] = None, - use_cache: CacheUsage = "when_available", - metadata: Optional[dict[str, Any]] = None, - max_retry_delay: float = 60, - max_retry_count: float = 1, - ) -> AsyncIterator[Union[RunChunk[TaskOutput], Run[TaskOutput]]]: ... + **kwargs: Unpack[RunParams[TaskOutput]], + ) -> AsyncIterator[Run[TaskOutput]]: ... async def run( self, task: Task[TaskInput, TaskOutput], task_input: TaskInput, stream: bool = False, - version: Optional[VersionReference] = None, - use_cache: CacheUsage = "when_available", - metadata: Optional[dict[str, Any]] = None, - max_retry_delay: float = 60, - max_retry_count: float = 1, - ) -> Union[Run[TaskOutput], AsyncIterator[Union[RunChunk[TaskOutput], Run[TaskOutput]]]]: + **kwargs: Unpack[RunParams[TaskOutput]], + ) -> Union[Run[TaskOutput], AsyncIterator[Run[TaskOutput]]]: request = RunRequest( task_input=task_input.model_dump(by_alias=True), - version=version or task.version, + version=kwargs.get("version") or task.version, stream=stream, - use_cache=use_cache, - metadata=metadata, + use_cache=kwargs.get("use_cache"), + metadata=kwargs.get("metadata"), ) route = f"/v1/_/tasks/{task.id}/schemas/{task.schema_id}/run" - should_retry, wait_for_exception = build_retryable_wait(max_retry_delay, max_retry_count) + should_retry, wait_for_exception = build_retryable_wait( + kwargs.get("max_retry_delay", 60), + kwargs.get("max_retry_count", 1), + ) if not stream: return await self._retriable_run( route, request, - task, should_retry=should_retry, wait_for_exception=wait_for_exception, + validator=kwargs.get("validator") or task.output_class.model_validate, ) return self._retriable_stream( route, request, - task, should_retry=should_retry, wait_for_exception=wait_for_exception, + validator=kwargs.get("validator") or tolerant_validator(task.output_class), ) async def _retriable_run( self, route: str, request: RunRequest, - task: Task[TaskInput, TaskOutput], should_retry: Callable[[], bool], wait_for_exception: Callable[[WorkflowAIError], Awaitable[None]], + validator: OutputValidator[TaskOutput], ): last_error = None while should_retry(): try: res = await self.api.post(route, request, returns=RunResponse) - return res.to_domain(task) + return res.to_domain(validator) except WorkflowAIError as e: # noqa: PERF203 last_error = e await wait_for_exception(e) @@ -126,9 +145,9 @@ async def _retriable_stream( self, route: str, request: RunRequest, - task: Task[TaskInput, TaskOutput], should_retry: Callable[[], bool], wait_for_exception: Callable[[WorkflowAIError], Awaitable[None]], + validator: OutputValidator[TaskOutput], ): while should_retry(): try: @@ -136,9 +155,21 @@ async def _retriable_stream( method="POST", path=route, data=request, - returns=RunStreamChunk, + returns=RunResponse, ): - yield chunk.to_domain(task) + yield chunk.to_domain(validator) return except WorkflowAIError as e: # noqa: PERF203 await wait_for_exception(e) + + def task( + self, + schema_id: int, + task_id: Optional[str] = None, + version: VersionReference = DEFAULT_VERSION_REFERENCE, + ) -> TaskDecorator: + def wrap(fn: RunTemplate[TaskInput, TaskOutput]) -> FinalRunTemplate[TaskInput, TaskOutput]: + tid = task_id or task_id_from_fn_name(fn) + return functools.wraps(fn)(wrap_run_template(self, tid, schema_id, version, fn)) # pyright: ignore [reportReturnType] + + return wrap # pyright: ignore [reportReturnType] diff --git a/workflowai/core/client/client_test.py b/workflowai/core/client/_client_test.py similarity index 60% rename from workflowai/core/client/client_test.py rename to workflowai/core/client/_client_test.py index 7407a67..ea84039 100644 --- a/workflowai/core/client/client_test.py +++ b/workflowai/core/client/_client_test.py @@ -1,5 +1,7 @@ import importlib.metadata import json +from typing import Any, AsyncIterator +from unittest.mock import AsyncMock, patch import pytest from pytest_httpx import HTTPXMock, IteratorStream @@ -7,7 +9,10 @@ from tests.models.hello_task import HelloTask, HelloTaskInput, HelloTaskNotOptional, HelloTaskOutput from tests.utils import fixtures_json from workflowai.core.client import Client -from workflowai.core.client.client import WorkflowAIClient +from workflowai.core.client._client import ( + WorkflowAIClient, + _compute_default_version_reference, # pyright: ignore [reportPrivateUsage] +) from workflowai.core.domain.task_run import Run @@ -34,7 +39,6 @@ async def test_success(self, httpx_mock: HTTPXMock, client: Client): "task_input": {"name": "Alice"}, "version": "production", "stream": False, - "use_cache": "when_available", } async def test_stream(self, httpx_mock: HTTPXMock, client: Client): @@ -65,6 +69,7 @@ async def test_stream(self, httpx_mock: HTTPXMock, client: Client): ] last_message = chunks[-1] assert isinstance(last_message, Run) + assert last_message.version assert last_message.version.properties.model == "gpt-4o" assert last_message.version.properties.temperature == 0.5 assert last_message.cost_usd == 0.01 @@ -102,6 +107,7 @@ async def test_stream_not_optional(self, httpx_mock: HTTPXMock, client: Client): last_message = chunks[-1] assert isinstance(last_message, Run) + assert last_message.version assert last_message.version.properties.model == "gpt-4o" assert last_message.version.properties.temperature == 0.5 assert last_message.cost_usd == 0.01 @@ -126,7 +132,6 @@ async def test_run_with_env(self, httpx_mock: HTTPXMock, client: Client): "task_input": {"name": "Alice"}, "version": "dev", "stream": False, - "use_cache": "when_available", } async def test_success_with_headers(self, httpx_mock: HTTPXMock, client: Client): @@ -151,7 +156,6 @@ async def test_success_with_headers(self, httpx_mock: HTTPXMock, client: Client) "task_input": {"name": "Alice"}, "version": "production", "stream": False, - "use_cache": "when_available", } # Check for additional headers for key, value in headers.items(): @@ -171,3 +175,101 @@ async def test_run_retries_on_too_many_requests(self, httpx_mock: HTTPXMock, cli assert len(reqs) == 2 assert reqs[0].url == "http://localhost:8000/v1/_/tasks/123/schemas/1/run" assert reqs[1].url == "http://localhost:8000/v1/_/tasks/123/schemas/1/run" + + +class TestTask: + @pytest.fixture + def patched_run_fn(self, client: Client): + with patch.object(client, "run", spec=client.run) as run_mock: + yield run_mock + + def test_fn_name(self, client: Client): + @client.task(schema_id=1, task_id="123") + async def fn(task_input: HelloTaskInput) -> HelloTaskOutput: ... + + assert fn.__name__ == "fn" + assert fn.__doc__ is None + assert callable(fn) + + async def test_run_output_only(self, client: Client, patched_run_fn: AsyncMock): + @client.task(schema_id=1, task_id="123") + async def fn(task_input: HelloTaskInput) -> HelloTaskOutput: ... + + patched_run_fn.return_value = Run(task_output=HelloTaskOutput(message="hello")) + + output = await fn(HelloTaskInput(name="Alice")) + + assert output == HelloTaskOutput(message="hello") + + async def test_run_with_version(self, client: Client, patched_run_fn: AsyncMock): + @client.task(schema_id=1, task_id="123") + async def fn(task_input: HelloTaskInput) -> Run[HelloTaskOutput]: ... + + patched_run_fn.return_value = Run(id="1", task_output=HelloTaskOutput(message="hello")) + + output = await fn(HelloTaskInput(name="Alice")) + + assert output.id == "1" + assert output.task_output == HelloTaskOutput(message="hello") + assert isinstance(output, Run) + + async def test_stream(self, client: Client, httpx_mock: HTTPXMock): + # We avoid mocking the run fn directly here, python does weird things with + # having to await async iterators depending on how they are defined so instead we mock + # the underlying api call to check that we don't need the extra await + + @client.task(schema_id=1, task_id="123") + def fn(task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ... + + httpx_mock.add_response( + stream=IteratorStream( + [ + b'data: {"id":"1","task_output":{"message":""}}\n\n', + b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 + b'data: {"id":"1","task_output":{"message":"hello"},"cost_usd":0.01,"duration_seconds":10.1}\n\n', + ], + ), + ) + + chunks = [chunk async for chunk in fn(HelloTaskInput(name="Alice"))] + + assert chunks == [ + Run(id="1", task_output=HelloTaskOutput(message="")), + Run(id="1", task_output=HelloTaskOutput(message="hel")), + Run(id="1", task_output=HelloTaskOutput(message="hello")), + Run(id="1", task_output=HelloTaskOutput(message="hello"), duration_seconds=10.1, cost_usd=0.01), + ] + + async def test_stream_output_only(self, client: Client, httpx_mock: HTTPXMock): + @client.task(schema_id=1) + def fn(task_input: HelloTaskInput) -> AsyncIterator[HelloTaskOutput]: ... + + httpx_mock.add_response( + stream=IteratorStream( + [ + b'data: {"id":"1","task_output":{"message":""}}\n\n', + b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 + b'data: {"id":"1","task_output":{"message":"hello"},"cost_usd":0.01,"duration_seconds":10.1}\n\n', + ], + ), + ) + + chunks = [chunk async for chunk in fn(HelloTaskInput(name="Alice"))] + + # We could remove duplicates but it would add a condition for everyone and every chunk + # that might not be useful. + assert chunks == [ + HelloTaskOutput(message=""), + HelloTaskOutput(message="hel"), + HelloTaskOutput(message="hello"), + HelloTaskOutput(message="hello"), + ] + + +@pytest.mark.parametrize( + ("env_var", "expected"), + [("p", "production"), ("production", "production"), ("dev", "dev"), ("staging", "staging"), ("1", 1)], +) +def test_compute_default_version_reference(env_var: str, expected: Any): + with patch.dict("os.environ", {"WORKFLOWAI_DEFAULT_VERSION": env_var}): + assert _compute_default_version_reference() == expected diff --git a/workflowai/core/client/_fn_utils.py b/workflowai/core/client/_fn_utils.py new file mode 100644 index 0000000..d95cf39 --- /dev/null +++ b/workflowai/core/client/_fn_utils.py @@ -0,0 +1,155 @@ +from typing import ( + Any, + AsyncIterator, + NamedTuple, + Sequence, + Type, + Union, + get_args, + get_origin, + get_type_hints, +) + +from pydantic import BaseModel +from typing_extensions import Unpack + +from workflowai.core.client._types import ( + Client, + FinalRunFn, + FinalRunFnOutputOnly, + FinalStreamRunFn, + FinalStreamRunFnOutputOnly, + RunParams, + RunTemplate, +) +from workflowai.core.domain.task import Task, TaskInput, TaskOutput +from workflowai.core.domain.task_run import Run +from workflowai.core.domain.task_version_reference import VersionReference + +# TODO: add sync support + + +def get_generic_args(t: type[BaseModel]) -> Union[Sequence[type], None]: + return t.__pydantic_generic_metadata__.get("args") + + +def check_return_type(return_type_hint: Type[Any]) -> tuple[bool, Type[BaseModel]]: + if issubclass(return_type_hint, Run): + args = get_generic_args(return_type_hint) # pyright: ignore [reportUnknownArgumentType] + if not args: + raise ValueError("Run must have a generic argument") + output_cls = args[0] + if not issubclass(output_cls, BaseModel): + raise ValueError("Run generic argument must be a subclass of BaseModel") + return False, output_cls + if issubclass(return_type_hint, BaseModel): + return True, return_type_hint + raise ValueError("Function must have a return type hint that is a subclass of Pydantic's 'BaseModel' or 'Run'") + + +class ExtractFnData(NamedTuple): + stream: bool + output_only: bool + input_cls: Type[BaseModel] + output_cls: Type[BaseModel] + + +def is_async_iterator(t: type[Any]) -> bool: + ori: Any = get_origin(t) + if not ori: + return False + return issubclass(ori, AsyncIterator) + + +def extract_fn_data(fn: RunTemplate[TaskInput, TaskOutput]) -> ExtractFnData: + hints = get_type_hints(fn) + if "return" not in hints: + raise ValueError("Function must have a return type hint") + if "task_input" not in hints: + raise ValueError("Function must have a task_input parameter") + + return_type_hint = hints["return"] + input_cls = hints["task_input"] + if not issubclass(input_cls, BaseModel): + raise ValueError("task_input must be a subclass of BaseModel") + + output_cls = None + + if is_async_iterator(return_type_hint): + stream = True + output_only, output_cls = check_return_type(get_args(return_type_hint)[0]) + else: + stream = False + output_only, output_cls = check_return_type(return_type_hint) + + return ExtractFnData(stream, output_only, input_cls, output_cls) + + +def _wrap_run(client: Client, task: Task[TaskInput, TaskOutput]) -> FinalRunFn[TaskInput, TaskOutput]: + async def wrap(task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]) -> Run[TaskOutput]: + return await client.run(task, task_input, stream=False, **kwargs) + + return wrap + + +def _wrap_run_output_only( + client: Client, + task: Task[TaskInput, TaskOutput], +) -> FinalRunFnOutputOnly[TaskInput, TaskOutput]: + async def wrap(task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]) -> TaskOutput: + run = await client.run(task, task_input, stream=False, **kwargs) + return run.task_output + + return wrap + + +def _wrap_stream_run(client: Client, task: Task[TaskInput, TaskOutput]) -> FinalStreamRunFn[TaskInput, TaskOutput]: + async def wrap(task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]) -> AsyncIterator[Run[TaskOutput]]: + s = await client.run(task, task_input, stream=True, **kwargs) + async for chunk in s: + yield chunk + + return wrap + + +def _wrap_stream_run_output_only( + client: Client, + task: Task[TaskInput, TaskOutput], +) -> FinalStreamRunFnOutputOnly[TaskInput, TaskOutput]: + async def wrap(task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]) -> AsyncIterator[TaskOutput]: + s = await client.run(task, task_input, stream=True, **kwargs) + async for chunk in s: + yield chunk.task_output + + # TODO: not sure what's going on here... + return wrap # pyright: ignore [reportReturnType] + + +def wrap_run_template( + client: Client, + task_id: str, + task_schema_id: int, + version: VersionReference, + fn: RunTemplate[TaskInput, TaskOutput], +): + stream, output_only, input_cls, output_cls = extract_fn_data(fn) + # There is some co / contravariant issue here... + task: Task[TaskInput, TaskOutput] = Task( # pyright: ignore [reportAssignmentType] + id=task_id, + schema_id=task_schema_id, + input_class=input_cls, + output_class=output_cls, + version=version, + ) + + if stream: + if output_only: + return _wrap_stream_run_output_only(client, task) + return _wrap_stream_run(client, task) + if output_only: + return _wrap_run_output_only(client, task) + return _wrap_run(client, task) + + +def task_id_from_fn_name(fn: Any) -> str: + return fn.__name__.replace("_", "-").lower() diff --git a/workflowai/core/client/_fn_utils_test.py b/workflowai/core/client/_fn_utils_test.py new file mode 100644 index 0000000..4cb9acf --- /dev/null +++ b/workflowai/core/client/_fn_utils_test.py @@ -0,0 +1,45 @@ +from typing import AsyncIterator + +from pydantic import BaseModel + +from tests.models.hello_task import HelloTaskInput, HelloTaskOutput +from workflowai.core.client._fn_utils import extract_fn_data, get_generic_args, is_async_iterator +from workflowai.core.domain.task_run import Run + + +async def say_hello(task_input: HelloTaskInput) -> HelloTaskOutput: ... + + +async def say_hello_run(task_input: HelloTaskInput) -> Run[HelloTaskOutput]: ... + + +def stream_hello(task_input: HelloTaskInput) -> AsyncIterator[HelloTaskOutput]: ... + + +def stream_hello_run(task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ... + + +class TestGetGenericArgs: + def test_get_generic_arg(self): + assert get_generic_args(Run[HelloTaskOutput]) == (HelloTaskOutput,) + + +class TestIsAsyncIterator: + def test_is_async_iterator(self): + assert is_async_iterator(AsyncIterator[HelloTaskOutput]) + assert not is_async_iterator(HelloTaskOutput) + assert not is_async_iterator(BaseModel) + + +class TestExtractFnData: + def test_run_output_only(self): + assert extract_fn_data(say_hello) == (False, True, HelloTaskInput, HelloTaskOutput) + + def test_run(self): + assert extract_fn_data(say_hello_run) == (False, False, HelloTaskInput, HelloTaskOutput) + + def test_stream_output_only(self): + assert extract_fn_data(stream_hello) == (True, True, HelloTaskInput, HelloTaskOutput) + + def test_stream(self): + assert extract_fn_data(stream_hello_run) == (True, False, HelloTaskInput, HelloTaskOutput) diff --git a/workflowai/core/client/models.py b/workflowai/core/client/_models.py similarity index 53% rename from workflowai/core/client/models.py rename to workflowai/core/client/_models.py index 509a8fe..7a11e4f 100644 --- a/workflowai/core/client/models.py +++ b/workflowai/core/client/_models.py @@ -3,9 +3,10 @@ from pydantic import BaseModel from typing_extensions import NotRequired, TypedDict +from workflowai.core.client._types import OutputValidator from workflowai.core.domain.cache_usage import CacheUsage -from workflowai.core.domain.task import Task, TaskInput, TaskOutput -from workflowai.core.domain.task_run import Run, RunChunk +from workflowai.core.domain.task import TaskOutput +from workflowai.core.domain.task_run import Run from workflowai.core.domain.task_version import TaskVersion from workflowai.core.domain.task_version_properties import TaskVersionProperties @@ -40,44 +41,16 @@ class RunResponse(BaseModel): id: str task_output: dict[str, Any] - version: Version - duration_seconds: Optional[float] = None - cost_usd: Optional[float] = None - - def to_domain(self, task: Task[TaskInput, TaskOutput]) -> Run[TaskOutput]: - return Run( - id=self.id, - task_output=task.output_class.model_validate(self.task_output), - version=TaskVersion( - properties=TaskVersionProperties.model_construct( - None, - **self.version.properties, - ), - ), - duration_seconds=self.duration_seconds, - cost_usd=self.cost_usd, - ) - - -class RunStreamChunk(BaseModel): - id: str - task_output: dict[str, Any] - version: Optional[Version] = None duration_seconds: Optional[float] = None cost_usd: Optional[float] = None - def to_domain(self, task: Task[TaskInput, TaskOutput]) -> Union[Run[TaskOutput], RunChunk[TaskOutput]]: - if self.version is None: - return RunChunk( - id=self.id, - task_output=task.output_class.model_construct(None, **self.task_output), - ) - + def to_domain(self, validator: OutputValidator[TaskOutput]) -> Run[TaskOutput]: return Run( id=self.id, - task_output=task.output_class.model_validate(self.task_output), - version=TaskVersion( + task_output=validator(self.task_output), + version=self.version + and TaskVersion( properties=TaskVersionProperties.model_construct( None, **self.version.properties, diff --git a/workflowai/core/client/models_test.py b/workflowai/core/client/_models_test.py similarity index 59% rename from workflowai/core/client/models_test.py rename to workflowai/core/client/_models_test.py index 330b5bf..2eba341 100644 --- a/workflowai/core/client/models_test.py +++ b/workflowai/core/client/_models_test.py @@ -4,9 +4,9 @@ from pydantic import BaseModel, ValidationError from tests.utils import fixture_text -from workflowai.core.client.models import RunResponse, RunStreamChunk -from workflowai.core.domain.task import Task -from workflowai.core.domain.task_run import Run, RunChunk +from workflowai.core.client._models import RunResponse +from workflowai.core.client._utils import tolerant_validator +from workflowai.core.domain.task_run import Run @pytest.mark.parametrize( @@ -31,52 +31,38 @@ class _TaskOutputOpt(BaseModel): b: Optional[str] = None -class _Task(Task[_TaskOutput, _TaskOutput]): - id: str = "test-task" - schema_id: int = 1 - input_class: type[_TaskOutput] = _TaskOutput - output_class: type[_TaskOutput] = _TaskOutput - - -class _TaskOpt(Task[_TaskOutputOpt, _TaskOutputOpt]): - id: str = "test-task" - schema_id: int = 1 - input_class: type[_TaskOutputOpt] = _TaskOutputOpt - output_class: type[_TaskOutputOpt] = _TaskOutputOpt - - -class TestRunStreamChunkToDomain: +class TestRunResponseToDomain: def test_no_version_not_optional(self): # Check that partial model is ok - chunk = RunStreamChunk.model_validate_json('{"id": "1", "task_output": {"a": 1}}') + chunk = RunResponse.model_validate_json('{"id": "1", "task_output": {"a": 1}}') assert chunk.task_output == {"a": 1} with pytest.raises(ValidationError): # sanity _TaskOutput.model_validate({"a": 1}) - parsed = chunk.to_domain(_Task()) - assert isinstance(parsed, RunChunk) + parsed = chunk.to_domain(tolerant_validator(_TaskOutput)) + assert isinstance(parsed, Run) assert parsed.task_output.a == 1 # b is not defined with pytest.raises(AttributeError): assert parsed.task_output.b def test_no_version_optional(self): - chunk = RunStreamChunk.model_validate_json('{"id": "1", "task_output": {"a": 1}}') + chunk = RunResponse.model_validate_json('{"id": "1", "task_output": {"a": 1}}') assert chunk - parsed = chunk.to_domain(_TaskOpt()) - assert isinstance(parsed, RunChunk) + parsed = chunk.to_domain(tolerant_validator(_TaskOutputOpt)) + assert isinstance(parsed, Run) assert parsed.task_output.a == 1 assert parsed.task_output.b is None def test_with_version(self): - chunk = RunStreamChunk.model_validate_json( + chunk = RunResponse.model_validate_json( '{"id": "1", "task_output": {"a": 1, "b": "test"}, "cost_usd": 0.1, "duration_seconds": 1, "version": {"properties": {"a": 1, "b": "test"}}}', # noqa: E501 ) assert chunk - parsed = chunk.to_domain(_Task()) + parsed = chunk.to_domain(tolerant_validator(_TaskOutput)) assert isinstance(parsed, Run) assert parsed.task_output.a == 1 assert parsed.task_output.b == "test" @@ -85,8 +71,8 @@ def test_with_version(self): assert parsed.duration_seconds == 1 def test_with_version_validation_fails(self): - chunk = RunStreamChunk.model_validate_json( + chunk = RunResponse.model_validate_json( '{"id": "1", "task_output": {"a": 1}, "version": {"properties": {"a": 1, "b": "test"}}}', ) with pytest.raises(ValidationError): - chunk.to_domain(_Task()) + chunk.to_domain(_TaskOutput.model_validate) diff --git a/workflowai/core/client/_types.py b/workflowai/core/client/_types.py new file mode 100644 index 0000000..d57d97e --- /dev/null +++ b/workflowai/core/client/_types.py @@ -0,0 +1,198 @@ +from collections.abc import Callable +from typing import ( + Any, + AsyncIterator, + Generic, + Literal, + Optional, + Protocol, + TypeVar, + Union, + overload, +) + +from pydantic import BaseModel +from typing_extensions import NotRequired, TypedDict, Unpack + +from workflowai.core.domain.cache_usage import CacheUsage +from workflowai.core.domain.task import Task, TaskInput, TaskOutput +from workflowai.core.domain.task_run import Run +from workflowai.core.domain.task_version_reference import VersionReference + +TaskInputContra = TypeVar("TaskInputContra", bound=BaseModel, contravariant=True) +TaskOutputCov = TypeVar("TaskOutputCov", bound=BaseModel, covariant=True) + +OutputValidator = Callable[[dict[str, Any]], TaskOutput] + + +class RunParams(TypedDict, Generic[TaskOutput]): + version: NotRequired[Optional[VersionReference]] + use_cache: NotRequired[CacheUsage] + metadata: NotRequired[Optional[dict[str, Any]]] + max_retry_delay: NotRequired[float] + max_retry_count: NotRequired[float] + validator: NotRequired[OutputValidator[TaskOutput]] + + +class RunFn(Protocol, Generic[TaskInputContra, TaskOutput]): + async def __call__(self, task_input: TaskInputContra) -> Run[TaskOutput]: ... + + +class RunFnOutputOnly(Protocol, Generic[TaskInputContra, TaskOutputCov]): + async def __call__(self, task_input: TaskInputContra) -> TaskOutputCov: ... + + +class StreamRunFn(Protocol, Generic[TaskInputContra, TaskOutput]): + def __call__( + self, + task_input: TaskInputContra, + ) -> AsyncIterator[Run[TaskOutput]]: ... + + +class StreamRunFnOutputOnly(Protocol, Generic[TaskInputContra, TaskOutputCov]): + def __call__( + self, + task_input: TaskInputContra, + ) -> AsyncIterator[TaskOutputCov]: ... + + +RunTemplate = Union[ + RunFn[TaskInput, TaskOutput], + RunFnOutputOnly[TaskInput, TaskOutput], + StreamRunFn[TaskInput, TaskOutput], + StreamRunFnOutputOnly[TaskInput, TaskOutput], +] + + +class _BaseProtocol(Protocol): + __name__: str + __doc__: Optional[str] + __module__: str + __qualname__: str + __annotations__: dict[str, Any] + __defaults__: Optional[tuple[Any, ...]] + __kwdefaults__: Optional[dict[str, Any]] + __code__: Any + + +class FinalRunFn(_BaseProtocol, Protocol, Generic[TaskInputContra, TaskOutput]): + async def __call__( + self, + task_input: TaskInputContra, + **kwargs: Unpack[RunParams[TaskOutput]], + ) -> Run[TaskOutput]: ... + + +class FinalRunFnOutputOnly(_BaseProtocol, Protocol, Generic[TaskInputContra, TaskOutput]): + async def __call__( + self, + task_input: TaskInputContra, + **kwargs: Unpack[RunParams[TaskOutput]], + ) -> TaskOutput: ... + + +class FinalStreamRunFn(_BaseProtocol, Protocol, Generic[TaskInputContra, TaskOutput]): + def __call__( + self, + task_input: TaskInputContra, + **kwargs: Unpack[RunParams[TaskOutput]], + ) -> AsyncIterator[Run[TaskOutput]]: ... + + +class FinalStreamRunFnOutputOnly(_BaseProtocol, Protocol, Generic[TaskInputContra, TaskOutputCov]): + def __call__( + self, + task_input: TaskInputContra, + **kwargs: Unpack[RunParams[TaskOutput]], + ) -> AsyncIterator[TaskOutputCov]: ... + + +FinalRunTemplate = Union[ + FinalRunFn[TaskInput, TaskOutput], + FinalRunFnOutputOnly[TaskInput, TaskOutput], + FinalStreamRunFn[TaskInput, TaskOutput], + FinalStreamRunFnOutputOnly[TaskInput, TaskOutput], +] + + +class TaskDecorator(Protocol): + @overload + def __call__(self, fn: RunFn[TaskInput, TaskOutput]) -> FinalRunFn[TaskInput, TaskOutput]: ... + + @overload + def __call__(self, fn: RunFnOutputOnly[TaskInput, TaskOutput]) -> FinalRunFnOutputOnly[TaskInput, TaskOutput]: ... + + @overload + def __call__(self, fn: StreamRunFn[TaskInput, TaskOutput]) -> FinalStreamRunFn[TaskInput, TaskOutput]: ... + + @overload + def __call__( + self, + fn: StreamRunFnOutputOnly[TaskInput, TaskOutput], + ) -> FinalStreamRunFnOutputOnly[TaskInput, TaskOutput]: ... + + def __call__(self, fn: RunTemplate[TaskInput, TaskOutput]) -> FinalRunTemplate[TaskInput, TaskOutput]: ... + + +class Client(Protocol): + """A client to interact with the WorkflowAI API""" + + @overload + async def run( + self, + task: Task[TaskInput, TaskOutput], + task_input: TaskInput, + stream: Literal[False] = False, + **kwargs: Unpack[RunParams[TaskOutput]], + ) -> Run[TaskOutput]: ... + + @overload + async def run( + self, + task: Task[TaskInput, TaskOutput], + task_input: TaskInput, + stream: Literal[True] = True, + **kwargs: Unpack[RunParams[TaskOutput]], + ) -> AsyncIterator[Run[TaskOutput]]: ... + + async def run( + self, + task: Task[TaskInput, TaskOutput], + task_input: TaskInput, + stream: bool = False, + **kwargs: Unpack[RunParams[TaskOutput]], + ) -> Union[Run[TaskOutput], AsyncIterator[Run[TaskOutput]]]: + """Run a task + + Args: + task (Task[TaskInput, TaskOutput]): the task to run + task_input (TaskInput): the input to the task + version (Optional[TaskVersionReference], optional): the version of the task to run. If not provided, + the version defined in the task is used. Defaults to None. + environment (Optional[str], optional): the environment to run the task in. If not provided, the environment + defined in the task is used. Defaults to None. + iteration (Optional[int], optional): the iteration of the task to run. If not provided, the iteration + defined in the task is used. Defaults to None. + stream (bool, optional): whether to stream the output. If True, the function returns an async iterator of + partial output objects. Defaults to False. + use_cache (CacheUsage, optional): how to use the cache. Defaults to "when_available". + labels (Optional[set[str]], optional): a set of labels to attach to the run. + Labels are indexed and searchable. Defaults to None. + metadata (Optional[dict[str, Any]], optional): a dictionary of metadata to attach to the run. + Defaults to None. + retry_delay (int, optional): The initial delay between retries in milliseconds. Defaults to 5000. + max_retry_delay (int, optional): The maximum delay between retries in milliseconds. Defaults to 60000. + max_retry_count (int, optional): The maximum number of retry attempts. Defaults to 1. + + Returns: + Union[TaskRun[TaskInput, TaskOutput], AsyncIterator[TaskOutput]]: the task run object + or an async iterator of output objects + """ + ... + + def task( + self, + schema_id: int, + task_id: Optional[str] = None, + version: VersionReference = "production", + ) -> TaskDecorator: ... diff --git a/workflowai/core/client/utils.py b/workflowai/core/client/_utils.py similarity index 93% rename from workflowai/core/client/utils.py rename to workflowai/core/client/_utils.py index ceed7df..08a16e5 100644 --- a/workflowai/core/client/utils.py +++ b/workflowai/core/client/_utils.py @@ -8,7 +8,9 @@ from time import time from typing import Any, Optional +from workflowai.core.client._types import OutputValidator from workflowai.core.domain.errors import BaseError, WorkflowAIError +from workflowai.core.domain.task import TaskOutput delimiter = re.compile(r'\}\n\ndata: \{"') @@ -93,3 +95,7 @@ async def _wait_for_exception(e: WorkflowAIError): retry_count += 1 return _should_retry, _wait_for_exception + + +def tolerant_validator(m: type[TaskOutput]) -> OutputValidator[TaskOutput]: + return lambda payload: m.model_construct(None, **payload) diff --git a/workflowai/core/client/utils_test.py b/workflowai/core/client/_utils_test.py similarity index 93% rename from workflowai/core/client/utils_test.py rename to workflowai/core/client/_utils_test.py index 61e06ee..e399971 100644 --- a/workflowai/core/client/utils_test.py +++ b/workflowai/core/client/_utils_test.py @@ -5,7 +5,7 @@ from freezegun import freeze_time from httpx import HTTPStatusError -from workflowai.core.client.utils import build_retryable_wait, retry_after_to_delay_seconds, split_chunks +from workflowai.core.client._utils import build_retryable_wait, retry_after_to_delay_seconds, split_chunks from workflowai.core.domain.errors import WorkflowAIError diff --git a/workflowai/core/domain/task.py b/workflowai/core/domain/task.py index 97eacc8..cb8ca0e 100644 --- a/workflowai/core/domain/task.py +++ b/workflowai/core/domain/task.py @@ -1,5 +1,4 @@ -from datetime import datetime -from typing import Generic, Optional, TypeVar +from typing import Generic, TypeVar from pydantic import BaseModel @@ -24,5 +23,3 @@ class Task(BaseModel, Generic[TaskInput, TaskOutput]): input_class: type[TaskInput] = BaseModel # pyright: ignore [reportAssignmentType] output_class: type[TaskOutput] = BaseModel # pyright: ignore [reportAssignmentType] - - created_at: Optional[datetime] = None diff --git a/workflowai/core/domain/task_run.py b/workflowai/core/domain/task_run.py index d7ed71d..704c055 100644 --- a/workflowai/core/domain/task_run.py +++ b/workflowai/core/domain/task_run.py @@ -7,15 +7,7 @@ from workflowai.core.domain.task_version import TaskVersion -class RunChunk(BaseModel, Generic[TaskOutput]): - id: str = Field( - default_factory=lambda: str(uuid.uuid4()), - description="The unique identifier of the task run", - ) - task_output: TaskOutput - - -class Run(RunChunk[TaskOutput]): +class Run(BaseModel, Generic[TaskOutput]): """ A task run is an instance of a task with a specific input and output. @@ -23,9 +15,19 @@ class Run(RunChunk[TaskOutput]): been evaluated """ + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="The unique identifier of the task run. This is a UUIDv7.", + ) + task_output: TaskOutput + duration_seconds: Optional[float] = None cost_usd: Optional[float] = None - version: TaskVersion + version: Optional[TaskVersion] = Field( + default=None, + description="The version of the task that was run. Only provided if the version differs from the version" + " specified in the request, for example in case of a model fallback", + ) metadata: Optional[dict[str, Any]] = None