diff --git a/Makefile b/Makefile index 8eb40a1..b8295b2 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,7 @@ lint: .PHONY: test test: - pytest + pytest --ignore=tests/e2e .PHONY: lock lock: diff --git a/README.md b/README.md index 3889b4b..268bd09 100644 --- a/README.md +++ b/README.md @@ -12,63 +12,158 @@ pip install workflowai ## Usage -Usage examples are available in the [examples](./examples/) directory. +Usage examples are available in the [examples](./examples/) directory or end to [end test](./tests/e2e/) +directory. + +### Getting a workflowai api key + +Create an account on [workflowai.com](https://workflowai.com), generate an API key and set it as +an environment variable. + +``` +WORKFLOWAI_API_KEY=... +``` + +> You can also set the `WORKFLOWAI_API_URL` environment variable to point to your own WorkflowAI instance. + +> The current UI does not allow to generate an API key without creating a task. Take the opportunity to play +> around with the UI. When the task is created, you can generate an API key from the Code section ### Set up the workflowai client +If you have defined the api key using an environment variable, the shared workflowai client will be +correctly configured. + +You can override the shared client by calling the init function. + ```python import workflowai -wai = workflowai.start( +workflowai.init( url=..., # defaults to WORKFLOWAI_API_URL env var or https://api.workflowai.com api_key=..., # defaults to WORKFLOWAI_API_KEY env var ) ``` -### Define a task +#### Using multiple clients -We use pydantic for type definitions. +You might want to avoid using the shared client, for example if you are using multiple API keys or accounts. +It is possible to achieve this by manually creating client instances ```python -from pydantic import BaseModel, Field +from workflowai import WorkflowAI -from workflowai import Task, TaskVersionReference +client = WorkflowAI( + url=..., + api_key=..., +) + +# Use the client to create and run agents +@client.agent() +def my_agent(task_input: Input) -> Output: + ... +``` + +### Build agents -class CityToCapitalTaskInput(BaseModel): - city: str +An agent is in essence an async function with the added constraints that: +- it has a single argument that is a pydantic model +- it has a single return value that is a pydantic model +- it is decorated with the `@client.agent()` decorator -class CityToCapitalTaskOutput(BaseModel): - capital: str +> [Pydantic](https://docs.pydantic.dev/latest/) is a very popular and powerful library for data validation and +> parsing. It allows us to extract the input and output schema in a simple way + +Below is an agent that says hello: + +```python +import workflowai +from pydantic import BaseModel -class CityToCapitalTask(Task[CityToCapitalTaskInput, CityToCapitalTaskOutput]): - id: str = "citytocapital" - schema_id: int = 1 - input_class: type[CityToCapitalTaskInput] = CityToCapitalTaskInput - output_class: type[CityToCapitalTaskOutput] = CityToCapitalTaskOutput +class Input(BaseModel): + name: str - # The default version that should be used when running the task - version: TaskVersionReference = TaskVersionReference( - iteration=4, - ) +class Output(BaseModel): + greeting: str + +@workflowai.agent() +async def say_hello(input: Input) -> Output: + """Say hello""" + ... ``` -### Run a task +When you call that function, the associated agent will be created on workflowai.com if it does not exist yet and a +run will be created. By default: + +- the docstring will be used as instructions for the agent +- the default model (`workflowai.DEFAULT_MODEL`) is used to run the agent +- the agent id will be a slugified version of the function name (i-e `say-hello`) in this case + +> **What is "..." ?** +> +> The `...` is the ellipsis value in python. It is usually used as a placeholder. You could use "pass" here as well +> or anything really, the implementation of the function is handled by the decorator `@workflowai.agent()` and so +> the function body is not executed. +> `...` is usually the right choice because it signals type checkers that they should ignore the function body. + +> Having the agent id determined at runtime can lead to unexpected changes, since changing the function name will +> change the agent id. A good practice is to set the agent id explicitly, `@workflowai.agent(id="say-hello")`. + +#### Using different models + +WorkflowAI supports a long list of models. The source of truth for models we support is on [workflowai.com](https://workflowai.com). The [Model](./workflowai/core/domain/model.py) type is a good indication of what models are supported at the time of the sdk release, although it may be missing some models since new ones are added all the time. + +You can set the model explicitly in the agent decorator: ```python -task = CityToCapitalTask() -task_input = CityToCapitalTaskInput(city=city) -task_run = await wai.run(task, task_input) +@workflowai.agent(model="gpt-4o") +def say_hello(input: Input) -> Output: + ... +``` -print(task_run.task_output) +> Models do not become invalid on WorkflowAI. When a model is retired, it will be replaced dynamically by +> a newer version of the same model with the same or a lower price so calling the api with +> a retired model will always work. + +### Version from code or deployments + +Setting a docstring or a model in the agent decorator signals the client that the agent parameters are +fixed and configured via code. + +Handling the agent parameters in code is useful to get started but may be limited in the long run: + +- it is somewhat hard to measure the impact of different parameters +- moving to new models or instructions requires a deployment +- iterating on the agent parameters can be very tedious + +Deployments allow you to refer to a version of an agent's parameters from your code that's managed from the +workflowai.com UI. The following code will use the version of the agent named "production" which is a lot +more flexible than changing the function parameters when running in production. + +```python +@workflowai.agent(deployment="production") # or simply @workflowai.agent() +def say_hello(input: Input) -> AsyncIterator[Run[Output]]: + ... ``` -It is also possible to stream a task output +### Streaming and advanced usage + +You can configure the agent function to stream or return the full run object, simply by changing the type annotation. ```python -task = CityToCapitalTask() -task_input = CityToCapitalTaskInput(city=city) -iterator = await wai.run(task, task_input, stream=True) -async for chunk in iterator: - print(chunk) # chunk is a partial (non validated) CityToCapitalTaskOutput +# Return the full run object, useful if you want to extract metadata like cost or duration +@workflowai.agent() +async def say_hello(input: Input) -> Run[Output]: + ... + +# Stream the output, the output is filled as it is generated +@workflowai.agent() +def say_hello(input: Input) -> AsyncIterator[Output]: + ... + +# Stream the run object, the output is filled as it is generated +@workflowai.agent() +def say_hello(input: Input) -> AsyncIterator[Run[Output]]: + ... ``` diff --git a/pyproject.toml b/pyproject.toml index 30a2b98..82fb11e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.5.5" +version = "0.6.0.dev0" description = "" authors = ["Guillaume Aquilina "] readme = "README.md" @@ -52,7 +52,8 @@ ignore = [ "PYI051", "FIX002", "SLF001", #reportPrivateUsage - "PT017", # Do not force using pytest.raises + "PT017", # Do not force using pytest.raises + "PIE790", # ... are not unnecessary for empty functions with docstring ] # Allow fix for all enabled rules (when `--fix`) is provided. diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index b668723..921e66f 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -3,15 +3,15 @@ import pytest from dotenv import load_dotenv -from workflowai import Client -from workflowai.core.client.client import WorkflowAIClient +import workflowai -load_dotenv() +load_dotenv(override=True) -@pytest.fixture(scope="session") -def wai() -> Client: - return WorkflowAIClient( - endpoint=os.environ["WORKFLOWAI_TEST_API_URL"], +@pytest.fixture(scope="session", autouse=True) +def wai(): + workflowai.init( api_key=os.environ["WORKFLOWAI_TEST_API_KEY"], + url=os.environ["WORKFLOWAI_TEST_API_URL"], ) + return workflowai.shared_client diff --git a/tests/e2e/no_schema_test.py b/tests/e2e/no_schema_test.py new file mode 100644 index 0000000..b2a49b2 --- /dev/null +++ b/tests/e2e/no_schema_test.py @@ -0,0 +1,22 @@ +from typing import Optional + +from pydantic import BaseModel + +import workflowai + + +class SummarizeTaskInput(BaseModel): + text: Optional[str] = None + + +class SummarizeTaskOutput(BaseModel): + summary_points: Optional[list[str]] = None + + +@workflowai.agent(id="summarize") +async def summarize(task_input: SummarizeTaskInput) -> SummarizeTaskOutput: ... + + +async def test_summarize(): + summarized = await summarize(SummarizeTaskInput(text="Hello, world!")) + assert summarized.summary_points diff --git a/tests/e2e/run_test.py b/tests/e2e/run_test.py index c623def..7c098d6 100644 --- a/tests/e2e/run_test.py +++ b/tests/e2e/run_test.py @@ -1,10 +1,12 @@ from enum import Enum from typing import AsyncIterator, Optional +import pytest from pydantic import BaseModel import workflowai -from workflowai.core.domain.task import Task +from workflowai.core.client.agent import Agent +from workflowai.core.client.client import WorkflowAI class ExtractProductReviewSentimentTaskInput(BaseModel): @@ -23,36 +25,50 @@ class ExtractProductReviewSentimentTaskOutput(BaseModel): sentiment: Optional[Sentiment] = None -@workflowai.task(schema_id=1) +@workflowai.agent(id="extract-product-review-sentiment", schema_id=1) def extract_product_review_sentiment( task_input: ExtractProductReviewSentimentTaskInput, ) -> AsyncIterator[ExtractProductReviewSentimentTaskOutput]: ... -class ExtractProductReviewSentimentTask( - Task[ExtractProductReviewSentimentTaskInput, ExtractProductReviewSentimentTaskOutput], -): - id: str = "extract-product-review-sentiment" - schema_id: int = 1 - input_class: type[ExtractProductReviewSentimentTaskInput] = ExtractProductReviewSentimentTaskInput - output_class: type[ExtractProductReviewSentimentTaskOutput] = ExtractProductReviewSentimentTaskOutput +@pytest.fixture +def extract_product_review_sentiment_agent( + wai: WorkflowAI, +) -> Agent[ExtractProductReviewSentimentTaskInput, ExtractProductReviewSentimentTaskOutput]: + return Agent( + agent_id="extract-product-review-sentiment", + schema_id=1, + input_cls=ExtractProductReviewSentimentTaskInput, + output_cls=ExtractProductReviewSentimentTaskOutput, + api=wai.api, + ) -async def test_run_task(wai: workflowai.Client): - task = ExtractProductReviewSentimentTask() +async def test_run_task( + extract_product_review_sentiment_agent: Agent[ + ExtractProductReviewSentimentTaskInput, + ExtractProductReviewSentimentTaskOutput, + ], +): task_input = ExtractProductReviewSentimentTaskInput(review_text="This product is amazing!") - run = await wai.run(task, task_input=task_input, use_cache="never") + run = await extract_product_review_sentiment_agent.run(task_input=task_input, use_cache="never") assert run.task_output.sentiment == Sentiment.POSITIVE -async def test_stream_task(wai: workflowai.Client): - task = ExtractProductReviewSentimentTask() - +async def test_stream_task( + extract_product_review_sentiment_agent: Agent[ + ExtractProductReviewSentimentTaskInput, + ExtractProductReviewSentimentTaskOutput, + ], +): task_input = ExtractProductReviewSentimentTaskInput( review_text="This product is amazing!", ) - streamed = await wai.run(task, task_input=task_input, stream=True, use_cache="never") + streamed = extract_product_review_sentiment_agent.stream( + task_input=task_input, + use_cache="never", + ) chunks = [chunk async for chunk in streamed] assert len(chunks) > 1 diff --git a/tests/integration/run_test.py b/tests/integration/run_test.py index 3195da9..4e37ad5 100644 --- a/tests/integration/run_test.py +++ b/tests/integration/run_test.py @@ -19,18 +19,28 @@ class CityToCapitalTaskOutput(BaseModel): workflowai.init(api_key="test", url="http://localhost:8000") +_REGISTER_URL = "http://localhost:8000/v1/_/agents" -def _mock_response(httpx_mock: HTTPXMock, task_id: str = "city-to-capital"): + +def _mock_register(httpx_mock: HTTPXMock, schema_id: int = 1, task_id: str = "city-to-capital", variant_id: str = "1"): + httpx_mock.add_response( + method="POST", + url=_REGISTER_URL, + json={"schema_id": schema_id, "variant_id": variant_id, "id": task_id}, + ) + + +def _mock_response(httpx_mock: HTTPXMock, task_id: str = "city-to-capital", capital: str = "Tokyo"): httpx_mock.add_response( method="POST", - url=f"http://localhost:8000/v1/_/tasks/{task_id}/schemas/1/run", - json={"id": "123", "task_output": {"capital": "Tokyo"}}, + url=f"http://localhost:8000/v1/_/agents/{task_id}/schemas/1/run", + json={"id": "123", "task_output": {"capital": capital}}, ) def _mock_stream(httpx_mock: HTTPXMock, task_id: str = "city-to-capital"): httpx_mock.add_response( - url=f"http://localhost:8000/v1/_/tasks/{task_id}/schemas/1/run", + url=f"http://localhost:8000/v1/_/agents/{task_id}/schemas/1/run", stream=IteratorStream( [ b'data: {"id":"1","task_output":{"capital":""}}\n\n', @@ -43,7 +53,7 @@ def _mock_stream(httpx_mock: HTTPXMock, task_id: str = "city-to-capital"): def _check_request(request: Optional[Request], version: Any = "production", task_id: str = "city-to-capital"): assert request is not None - assert request.url == f"http://localhost:8000/v1/_/tasks/{task_id}/schemas/1/run" + assert request.url == f"http://localhost:8000/v1/_/agents/{task_id}/schemas/1/run" body = json.loads(request.content) assert body == { "task_input": {"city": "Hello"}, @@ -57,7 +67,7 @@ def _check_request(request: Optional[Request], version: Any = "production", task async def test_run_task(httpx_mock: HTTPXMock) -> None: - @workflowai.task(schema_id=1) + @workflowai.agent(schema_id=1) async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: ... _mock_response(httpx_mock) @@ -71,7 +81,7 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTa async def test_run_task_run(httpx_mock: HTTPXMock) -> None: - @workflowai.task(schema_id=1) + @workflowai.agent(schema_id=1) async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapitalTaskOutput]: ... _mock_response(httpx_mock) @@ -86,7 +96,7 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapit async def test_run_task_run_version(httpx_mock: HTTPXMock) -> None: - @workflowai.task(schema_id=1, version="staging") + @workflowai.agent(schema_id=1, version="staging") async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapitalTaskOutput]: ... _mock_response(httpx_mock) @@ -101,7 +111,7 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapit async def test_stream_task_run(httpx_mock: HTTPXMock) -> None: - @workflowai.task(schema_id=1) + @workflowai.agent(schema_id=1) def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToCapitalTaskOutput]: ... _mock_stream(httpx_mock) @@ -118,7 +128,7 @@ def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToC async def test_stream_task_run_custom_id(httpx_mock: HTTPXMock) -> None: - @workflowai.task(schema_id=1, task_id="custom-id") + @workflowai.agent(schema_id=1, id="custom-id") def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToCapitalTaskOutput]: ... _mock_stream(httpx_mock, task_id="custom-id") @@ -132,3 +142,55 @@ def city_to_capital(task_input: CityToCapitalTaskInput) -> AsyncIterator[CityToC CityToCapitalTaskOutput(capital="Tokyo"), CityToCapitalTaskOutput(capital="Tokyo"), ] + + +async def test_auto_register(httpx_mock: HTTPXMock): + @workflowai.agent() + async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: ... + + _mock_register(httpx_mock) + + _mock_response(httpx_mock) + + res = await city_to_capital(CityToCapitalTaskInput(city="Hello")) + assert res.capital == "Tokyo" + + _mock_response(httpx_mock, capital="Paris") + # Run it a second time + res = await city_to_capital(CityToCapitalTaskInput(city="Hello"), use_cache="never") + assert res.capital == "Paris" + + req = httpx_mock.get_requests() + assert len(req) == 3 + assert req[0].url == _REGISTER_URL + + req_body = json.loads(req[0].read()) + assert req_body == { + "id": "city-to-capital", + "input_schema": { + "properties": { + "city": { + "title": "City", + "type": "string", + }, + }, + "required": [ + "city", + ], + "title": "CityToCapitalTaskInput", + "type": "object", + }, + "output_schema": { + "properties": { + "capital": { + "title": "Capital", + "type": "string", + }, + }, + "required": [ + "capital", + ], + "title": "CityToCapitalTaskOutput", + "type": "object", + }, + } diff --git a/tests/models/hello_task.py b/tests/models/hello_task.py index 5c42bdd..f76a101 100644 --- a/tests/models/hello_task.py +++ b/tests/models/hello_task.py @@ -1,7 +1,5 @@ from pydantic import BaseModel -from workflowai.core.domain.task import Task - class HelloTaskInput(BaseModel): name: str @@ -11,16 +9,6 @@ class HelloTaskOutput(BaseModel): message: str -class HelloTask(Task[HelloTaskInput, HelloTaskOutput]): - input_class: type[HelloTaskInput] = HelloTaskInput - output_class: type[HelloTaskOutput] = HelloTaskOutput - - class HelloTaskOutputNotOptional(HelloTaskOutput): message: str another_field: str - - -class HelloTaskNotOptional(Task[HelloTaskInput, HelloTaskOutputNotOptional]): - input_class: type[HelloTaskInput] = HelloTaskInput - output_class: type[HelloTaskOutputNotOptional] = HelloTaskOutputNotOptional diff --git a/workflowai/__init__.py b/workflowai/__init__.py index 6ea1660..adf66ec 100644 --- a/workflowai/__init__.py +++ b/workflowai/__init__.py @@ -1,11 +1,14 @@ import os from typing import Optional -from workflowai.core.client import Client as Client +from typing_extensions import deprecated + +from workflowai.core.client._types import TaskDecorator +from workflowai.core.client.client import WorkflowAI as WorkflowAI from workflowai.core.domain.cache_usage import CacheUsage as CacheUsage from workflowai.core.domain.errors import WorkflowAIError as WorkflowAIError +from workflowai.core.domain.model import Model as Model from workflowai.core.domain.run import Run as Run -from workflowai.core.domain.task import Task as Task from workflowai.core.domain.task_version import TaskVersion as TaskVersion from workflowai.core.domain.version_reference import ( VersionReference as VersionReference, @@ -16,10 +19,8 @@ def _build_client( endpoint: Optional[str] = None, api_key: Optional[str] = None, default_version: Optional[VersionReference] = None, -) -> Client: - from workflowai.core.client.client import WorkflowAIClient - - return WorkflowAIClient( +): + return WorkflowAI( endpoint=endpoint or os.getenv("WORKFLOWAI_API_URL"), api_key=api_key or os.getenv("WORKFLOWAI_API_KEY", ""), default_version=default_version, @@ -27,7 +28,10 @@ def _build_client( # By default the shared client is created using the default environment variables -shared_client: Client = _build_client() +shared_client: WorkflowAI = _build_client() + +# The default model to use when running agents without a deployment +DEFAULT_MODEL: Model = os.getenv("WORKFLOWAI_DEFAULT_MODEL", "gemini-1.5-pro-latest") def init(api_key: Optional[str] = None, url: Optional[str] = None, default_version: Optional[VersionReference] = None): @@ -47,11 +51,29 @@ def init(api_key: Optional[str] = None, url: Optional[str] = None, default_versi shared_client = _build_client(url, api_key, default_version) +@deprecated("Use `workflowai.agent` instead") def task( - schema_id: int, + schema_id: Optional[int] = None, task_id: Optional[str] = None, version: Optional[VersionReference] = None, ): - from workflowai.core.client._fn_utils import task_wrapper + from workflowai.core.client._fn_utils import agent_wrapper + + return agent_wrapper(lambda: shared_client.api, schema_id, task_id, version) - return task_wrapper(lambda: shared_client, schema_id, task_id, version) + +def agent( + id: Optional[str] = None, # noqa: A002 + schema_id: Optional[int] = None, + version: Optional[VersionReference] = None, + model: Optional[Model] = None, +) -> TaskDecorator: + from workflowai.core.client._fn_utils import agent_wrapper + + return agent_wrapper( + lambda: shared_client.api, + schema_id=schema_id, + agent_id=id, + version=version, + model=model, + ) diff --git a/workflowai/core/client/__init__.py b/workflowai/core/client/__init__.py index 50758d4..e69de29 100644 --- a/workflowai/core/client/__init__.py +++ b/workflowai/core/client/__init__.py @@ -1 +0,0 @@ -from ._types import Client as Client diff --git a/workflowai/core/client/_api_test.py b/workflowai/core/client/_api_test.py index d5aeaa5..78622c8 100644 --- a/workflowai/core/client/_api_test.py +++ b/workflowai/core/client/_api_test.py @@ -67,20 +67,20 @@ def client() -> APIClient: return APIClient(endpoint="https://blabla.com", api_key="test_api_key") -class TestInputModel(BaseModel): +class _TestInputModel(BaseModel): bla: str = "bla" -class TestOutputModel(BaseModel): +class _TestOutputModel(BaseModel): a: str class TestAPIClientStream: async def test_stream_404(self, httpx_mock: HTTPXMock, client: APIClient): - class TestInputModel(BaseModel): + class _TestInputModel(BaseModel): test_input: str - class TestOutputModel(BaseModel): + class _TestOutputModel(BaseModel): test_output: str httpx_mock.add_response(status_code=404) @@ -89,8 +89,8 @@ class TestOutputModel(BaseModel): async for _ in client.stream( method="GET", path="test_path", - data=TestInputModel(test_input="test"), - returns=TestOutputModel, + data=_TestInputModel(test_input="test"), + returns=_TestOutputModel, ): pass @@ -106,8 +106,8 @@ async def _stm(): async for chunk in client.stream( method="GET", path="test_path", - data=TestInputModel(), - returns=TestOutputModel, + data=_TestInputModel(), + returns=_TestOutputModel, ) ] @@ -115,7 +115,7 @@ async def _stm(): async def test_stream_with_single_chunk( self, - stream_fn: Callable[[], Awaitable[list[TestOutputModel]]], + stream_fn: Callable[[], Awaitable[list[_TestOutputModel]]], httpx_mock: HTTPXMock, ): httpx_mock.add_response( @@ -127,7 +127,7 @@ async def test_stream_with_single_chunk( ) chunks = await stream_fn() - assert chunks == [TestOutputModel(a="test")] + assert chunks == [_TestOutputModel(a="test")] @pytest.mark.parametrize( "streamed_chunks", @@ -144,7 +144,7 @@ async def test_stream_with_single_chunk( ) async def test_stream_with_multiple_chunks( self, - stream_fn: Callable[[], Awaitable[list[TestOutputModel]]], + stream_fn: Callable[[], Awaitable[list[_TestOutputModel]]], httpx_mock: HTTPXMock, streamed_chunks: list[bytes], ): @@ -153,7 +153,7 @@ async def test_stream_with_multiple_chunks( httpx_mock.add_response(stream=IteratorStream(streamed_chunks)) chunks = await stream_fn() - assert chunks == [TestOutputModel(a="test"), TestOutputModel(a="test2")] + assert chunks == [_TestOutputModel(a="test"), _TestOutputModel(a="test2")] class TestReadAndConnectError: @@ -162,7 +162,7 @@ async def test_get(self, httpx_mock: HTTPXMock, client: APIClient, exception: Ex httpx_mock.add_exception(exception) with pytest.raises(WorkflowAIError) as e: - await client.get(path="test_path", returns=TestOutputModel) + await client.get(path="test_path", returns=_TestOutputModel) assert e.value.error.code == "connection_error" @@ -171,7 +171,7 @@ async def test_post(self, httpx_mock: HTTPXMock, client: APIClient, exception: E httpx_mock.add_exception(exception) with pytest.raises(WorkflowAIError) as e: - await client.post(path="test_path", data=TestInputModel(), returns=TestOutputModel) + await client.post(path="test_path", data=_TestInputModel(), returns=_TestOutputModel) assert e.value.error.code == "connection_error" @@ -183,8 +183,8 @@ async def test_stream(self, httpx_mock: HTTPXMock, client: APIClient, exception: async for _ in client.stream( method="GET", path="test_path", - data=TestInputModel(), - returns=TestOutputModel, + data=_TestInputModel(), + returns=_TestOutputModel, ): pass diff --git a/workflowai/core/client/_fn_utils.py b/workflowai/core/client/_fn_utils.py index d8b754c..13accda 100644 --- a/workflowai/core/client/_fn_utils.py +++ b/workflowai/core/client/_fn_utils.py @@ -3,6 +3,7 @@ from typing import ( Any, AsyncIterator, + Generic, NamedTuple, Optional, Sequence, @@ -16,19 +17,18 @@ from pydantic import BaseModel from typing_extensions import Unpack +from workflowai.core.client._api import APIClient from workflowai.core.client._types import ( - Client, - FinalRunFn, - FinalRunFnOutputOnly, FinalRunTemplate, - FinalStreamRunFn, - FinalStreamRunFnOutputOnly, RunParams, RunTemplate, TaskDecorator, ) +from workflowai.core.client.agent import Agent +from workflowai.core.domain.model import Model from workflowai.core.domain.run import Run -from workflowai.core.domain.task import Task, TaskInput, TaskOutput +from workflowai.core.domain.task import TaskInput, TaskOutput +from workflowai.core.domain.version_properties import VersionProperties from workflowai.core.domain.version_reference import VersionReference # TODO: add sync support @@ -52,7 +52,7 @@ def check_return_type(return_type_hint: Type[Any]) -> tuple[bool, Type[BaseModel raise ValueError("Function must have a return type hint that is a subclass of Pydantic's 'BaseModel' or 'Run'") -class ExtractFnData(NamedTuple): +class RunFunctionSpec(NamedTuple): stream: bool output_only: bool input_cls: Type[BaseModel] @@ -66,7 +66,7 @@ def is_async_iterator(t: type[Any]) -> bool: return issubclass(ori, AsyncIterator) -def extract_fn_data(fn: RunTemplate[TaskInput, TaskOutput]) -> ExtractFnData: +def extract_fn_spec(fn: RunTemplate[TaskInput, TaskOutput]) -> RunFunctionSpec: hints = get_type_hints(fn) if "return" not in hints: raise ValueError("Function must have a return type hint") @@ -87,91 +87,79 @@ def extract_fn_data(fn: RunTemplate[TaskInput, TaskOutput]) -> ExtractFnData: stream = False output_only, output_cls = check_return_type(return_type_hint) - return ExtractFnData(stream, output_only, input_cls, output_cls) + return RunFunctionSpec(stream, output_only, input_cls, output_cls) -def _wrap_run(client: Callable[[], Client], task: Task[TaskInput, TaskOutput]) -> FinalRunFn[TaskInput, TaskOutput]: - async def wrap(task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]) -> Run[TaskOutput]: - return await client().run(task, task_input, stream=False, **kwargs) +class _RunnableAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]): + async def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]): + return await self.run(task_input, **kwargs) - return wrap +class _RunnableOutputOnlyAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]): + async def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]): + return (await self.run(task_input, **kwargs)).task_output -def _wrap_run_output_only( - client: Callable[[], Client], - task: Task[TaskInput, TaskOutput], -) -> FinalRunFnOutputOnly[TaskInput, TaskOutput]: - async def wrap(task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]) -> TaskOutput: - run = await client().run(task, task_input, stream=False, **kwargs) - return run.task_output - return wrap +class _RunnableStreamAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]): + def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]): + return self.stream(task_input, **kwargs) -def _wrap_stream_run( - client: Callable[[], Client], - task: Task[TaskInput, TaskOutput], -) -> FinalStreamRunFn[TaskInput, TaskOutput]: - async def wrap(task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]) -> AsyncIterator[Run[TaskOutput]]: - s = await client().run(task, task_input, stream=True, **kwargs) - async for chunk in s: - yield chunk - - return wrap - - -def _wrap_stream_run_output_only( - client: Callable[[], Client], - task: Task[TaskInput, TaskOutput], -) -> FinalStreamRunFnOutputOnly[TaskInput, TaskOutput]: - async def wrap(task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]) -> AsyncIterator[TaskOutput]: - s = await client().run(task, task_input, stream=True, **kwargs) - async for chunk in s: +class _RunnableStreamOutputOnlyAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]): + async def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]): + async for chunk in self.stream(task_input, **kwargs): yield chunk.task_output - # TODO: not sure what's going on here... - return wrap # pyright: ignore [reportReturnType] - def wrap_run_template( - client: Callable[[], Client], - task_id: str, - task_schema_id: int, + client: Callable[[], APIClient], + agent_id: str, + schema_id: Optional[int], version: Optional[VersionReference], + model: Optional[Model], fn: RunTemplate[TaskInput, TaskOutput], -): - stream, output_only, input_cls, output_cls = extract_fn_data(fn) - # There is some co / contravariant issue here... - task: Task[TaskInput, TaskOutput] = Task( # pyright: ignore [reportAssignmentType] - id=task_id, - schema_id=task_schema_id, - input_class=input_cls, - output_class=output_cls, - version=version, - ) +) -> Union[ + _RunnableAgent[TaskInput, TaskOutput], + _RunnableOutputOnlyAgent[TaskInput, TaskOutput], + _RunnableStreamAgent[TaskInput, TaskOutput], + _RunnableStreamOutputOnlyAgent[TaskInput, TaskOutput], +]: + stream, output_only, input_cls, output_cls = extract_fn_spec(fn) + + if not version and (fn.__doc__ or model): + version = VersionProperties( + instructions=fn.__doc__, + model=model, + ) if stream: - if output_only: - return _wrap_stream_run_output_only(client, task) - return _wrap_stream_run(client, task) - if output_only: - return _wrap_run_output_only(client, task) - return _wrap_run(client, task) + task_cls = _RunnableStreamOutputOnlyAgent if output_only else _RunnableStreamAgent + else: + task_cls = _RunnableOutputOnlyAgent if output_only else _RunnableAgent + return task_cls( # pyright: ignore [reportUnknownVariableType] + agent_id=agent_id, + input_cls=input_cls, + output_cls=output_cls, + api=client, + schema_id=schema_id, + version=version, + ) -def task_id_from_fn_name(fn: Any) -> str: +def agent_id_from_fn_name(fn: Any) -> str: return fn.__name__.replace("_", "-").lower() -def task_wrapper( - client: Callable[[], Client], - schema_id: int, - task_id: Optional[str] = None, +def agent_wrapper( + client: Callable[[], APIClient], + schema_id: Optional[int] = None, + agent_id: Optional[str] = None, version: Optional[VersionReference] = None, + model: Optional[Model] = None, ) -> TaskDecorator: def wrap(fn: RunTemplate[TaskInput, TaskOutput]) -> FinalRunTemplate[TaskInput, TaskOutput]: - tid = task_id or task_id_from_fn_name(fn) - return functools.wraps(fn)(wrap_run_template(client, tid, schema_id, version, fn)) # pyright: ignore [reportReturnType] + tid = agent_id or agent_id_from_fn_name(fn) + return functools.wraps(fn)(wrap_run_template(client, tid, schema_id, version, model, fn)) # pyright: ignore [reportReturnType] - # TODO: pyright is unhappy with generics + # pyright is unhappy with generics return wrap # pyright: ignore [reportReturnType] diff --git a/workflowai/core/client/_fn_utils_test.py b/workflowai/core/client/_fn_utils_test.py index 0f0c437..a5a7534 100644 --- a/workflowai/core/client/_fn_utils_test.py +++ b/workflowai/core/client/_fn_utils_test.py @@ -1,9 +1,23 @@ from typing import AsyncIterator +from unittest.mock import Mock +import pytest from pydantic import BaseModel from tests.models.hello_task import HelloTaskInput, HelloTaskOutput -from workflowai.core.client._fn_utils import extract_fn_data, get_generic_args, is_async_iterator +from tests.utils import mock_aiter +from workflowai.core.client._api import APIClient +from workflowai.core.client._fn_utils import ( + _RunnableAgent, # pyright: ignore [reportPrivateUsage] + _RunnableOutputOnlyAgent, # pyright: ignore [reportPrivateUsage] + _RunnableStreamAgent, # pyright: ignore [reportPrivateUsage] + _RunnableStreamOutputOnlyAgent, # pyright: ignore [reportPrivateUsage] + agent_wrapper, + extract_fn_spec, + get_generic_args, + is_async_iterator, +) +from workflowai.core.client._models import RunResponse from workflowai.core.domain.run import Run @@ -33,13 +47,69 @@ def test_is_async_iterator(self): class TestExtractFnData: def test_run_output_only(self): - assert extract_fn_data(say_hello) == (False, True, HelloTaskInput, HelloTaskOutput) + assert extract_fn_spec(say_hello) == (False, True, HelloTaskInput, HelloTaskOutput) def test_run(self): - assert extract_fn_data(say_hello_run) == (False, False, HelloTaskInput, HelloTaskOutput) + assert extract_fn_spec(say_hello_run) == (False, False, HelloTaskInput, HelloTaskOutput) def test_stream_output_only(self): - assert extract_fn_data(stream_hello) == (True, True, HelloTaskInput, HelloTaskOutput) + assert extract_fn_spec(stream_hello) == (True, True, HelloTaskInput, HelloTaskOutput) def test_stream(self): - assert extract_fn_data(stream_hello_run) == (True, False, HelloTaskInput, HelloTaskOutput) + assert extract_fn_spec(stream_hello_run) == (True, False, HelloTaskInput, HelloTaskOutput) + + +class TestAgentWrapper: + """Check that the agent wrapper returns the correct types, and checks the implementation of the __call__ fn""" + + @pytest.fixture + def mock_api_client(self): + return Mock(spec=APIClient) + + async def fn_run(self, task_input: HelloTaskInput) -> Run[HelloTaskOutput]: ... + + async def test_fn_run(self, mock_api_client: Mock): + wrapped = agent_wrapper(lambda: mock_api_client, schema_id=1, agent_id="hello")(self.fn_run) + assert isinstance(wrapped, _RunnableAgent) + + mock_api_client.post.return_value = RunResponse(id="1", task_output={"message": "Hello, World!"}) + run = await wrapped(HelloTaskInput(name="World")) + assert isinstance(run, Run) + assert run.id == "1" + assert run.task_output == HelloTaskOutput(message="Hello, World!") + + def fn_stream(self, task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ... + + async def test_fn_stream(self, mock_api_client: Mock): + wrapped = agent_wrapper(lambda: mock_api_client, schema_id=1, agent_id="hello")(self.fn_stream) + assert isinstance(wrapped, _RunnableStreamAgent) + + mock_api_client.stream.return_value = mock_aiter(RunResponse(id="1", task_output={"message": "Hello, World!"})) + chunks = [c async for c in wrapped(HelloTaskInput(name="World"))] + assert len(chunks) == 1 + assert isinstance(chunks[0], Run) + assert chunks[0].id == "1" + assert chunks[0].task_output == HelloTaskOutput(message="Hello, World!") + + async def fn_run_output_only(self, task_input: HelloTaskInput) -> HelloTaskOutput: ... + + async def test_fn_run_output_only(self, mock_api_client: Mock): + wrapped = agent_wrapper(lambda: mock_api_client, schema_id=1, agent_id="hello")(self.fn_run_output_only) + assert isinstance(wrapped, _RunnableOutputOnlyAgent) + + mock_api_client.post.return_value = RunResponse(id="1", task_output={"message": "Hello, World!"}) + run = await wrapped(HelloTaskInput(name="World")) + assert isinstance(run, HelloTaskOutput) + assert run == HelloTaskOutput(message="Hello, World!") + + def fn_stream_output_only(self, task_input: HelloTaskInput) -> AsyncIterator[HelloTaskOutput]: ... + + async def test_fn_stream_output_only(self, mock_api_client: Mock): + wrapped = agent_wrapper(lambda: mock_api_client, schema_id=1, agent_id="hello")(self.fn_stream_output_only) + assert isinstance(wrapped, _RunnableStreamOutputOnlyAgent) + + mock_api_client.stream.return_value = mock_aiter(RunResponse(id="1", task_output={"message": "Hello, World!"})) + chunks = [c async for c in wrapped(HelloTaskInput(name="World"))] + assert len(chunks) == 1 + assert isinstance(chunks[0], HelloTaskOutput) + assert chunks[0] == HelloTaskOutput(message="Hello, World!") diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 0bf4a43..d179163 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -1,6 +1,6 @@ from typing import Any, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing_extensions import NotRequired, TypedDict from workflowai.core.client._types import OutputValidator @@ -8,13 +8,13 @@ from workflowai.core.domain.run import Run from workflowai.core.domain.task import TaskOutput from workflowai.core.domain.task_version import TaskVersion -from workflowai.core.domain.task_version_properties import TaskVersionProperties +from workflowai.core.domain.version_properties import VersionProperties as DVersionProperties class RunRequest(BaseModel): task_input: dict[str, Any] - version: Union[str, int] + version: Union[str, int, dict[str, Any]] use_cache: Optional[CacheUsage] = None @@ -55,7 +55,7 @@ def to_domain(self, task_id: str, task_schema_id: int, validator: OutputValidato task_output=validator(self.task_output), version=self.version and TaskVersion( - properties=TaskVersionProperties.model_construct( + properties=DVersionProperties.model_construct( None, **self.version.properties, ), @@ -63,3 +63,14 @@ def to_domain(self, task_id: str, task_schema_id: int, validator: OutputValidato duration_seconds=self.duration_seconds, cost_usd=self.cost_usd, ) + + +class CreateAgentRequest(BaseModel): + id: str = Field(description="The agent id, must be unique per tenant and URL safe") + input_schema: dict[str, Any] = Field(description="The input schema for the agent") + output_schema: dict[str, Any] = Field(description="The output schema for the agent") + + +class CreateAgentResponse(BaseModel): + id: str + schema_id: int diff --git a/workflowai/core/client/_types.py b/workflowai/core/client/_types.py index 43db6d7..0bb07a8 100644 --- a/workflowai/core/client/_types.py +++ b/workflowai/core/client/_types.py @@ -3,7 +3,6 @@ Any, AsyncIterator, Generic, - Literal, Optional, Protocol, TypeVar, @@ -16,7 +15,7 @@ from workflowai.core.domain.cache_usage import CacheUsage from workflowai.core.domain.run import Run -from workflowai.core.domain.task import Task, TaskInput, TaskOutput +from workflowai.core.domain.task import TaskInput, TaskOutput from workflowai.core.domain.version_reference import VersionReference TaskInputContra = TypeVar("TaskInputContra", bound=BaseModel, contravariant=True) @@ -133,72 +132,3 @@ def __call__( ) -> FinalStreamRunFnOutputOnly[TaskInput, TaskOutput]: ... def __call__(self, fn: RunTemplate[TaskInput, TaskOutput]) -> FinalRunTemplate[TaskInput, TaskOutput]: ... - - -class Client(Protocol): - """A client to interact with the WorkflowAI API""" - - @overload - async def run( - self, - task: Task[TaskInput, TaskOutput], - task_input: TaskInput, - stream: Literal[False] = False, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> Run[TaskOutput]: ... - - @overload - async def run( - self, - task: Task[TaskInput, TaskOutput], - task_input: TaskInput, - stream: Literal[True] = True, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> AsyncIterator[Run[TaskOutput]]: ... - - async def run( - self, - task: Task[TaskInput, TaskOutput], - task_input: TaskInput, - stream: bool = False, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> Union[Run[TaskOutput], AsyncIterator[Run[TaskOutput]]]: - """Run a task - - Args: - task (Task[TaskInput, TaskOutput]): the task to run - task_input (TaskInput): the input to the task - version (Optional[TaskVersionReference], optional): the version of the task to run. If not provided, - the version defined in the task is used. Defaults to None. - environment (Optional[str], optional): the environment to run the task in. If not provided, the environment - defined in the task is used. Defaults to None. - iteration (Optional[int], optional): the iteration of the task to run. If not provided, the iteration - defined in the task is used. Defaults to None. - stream (bool, optional): whether to stream the output. If True, the function returns an async iterator of - partial output objects. Defaults to False. - use_cache (CacheUsage, optional): how to use the cache. Defaults to "auto". - "auto" (default): if a previous run exists with the same version and input, and if - the temperature is 0, the cached output is returned - "always": the cached output is returned when available, regardless - of the temperature value - "never": the cache is never used - labels (Optional[set[str]], optional): a set of labels to attach to the run. - Labels are indexed and searchable. Defaults to None. - metadata (Optional[dict[str, Any]], optional): a dictionary of metadata to attach to the run. - Defaults to None. - retry_delay (int, optional): The initial delay between retries in milliseconds. Defaults to 5000. - max_retry_delay (int, optional): The maximum delay between retries in milliseconds. Defaults to 60000. - max_retry_count (int, optional): The maximum number of retry attempts. Defaults to 1. - - Returns: - Union[TaskRun[TaskInput, TaskOutput], AsyncIterator[TaskOutput]]: the task run object - or an async iterator of output objects - """ - ... - - def task( - self, - schema_id: int, - task_id: Optional[str] = None, - version: Optional[VersionReference] = None, - ) -> TaskDecorator: ... diff --git a/workflowai/core/client/_utils.py b/workflowai/core/client/_utils.py index 4abfb94..44dc6d2 100644 --- a/workflowai/core/client/_utils.py +++ b/workflowai/core/client/_utils.py @@ -2,6 +2,7 @@ # By adding the " at the end we more or less guarantee that # the delimiter is not withing a quoted string import asyncio +import os import re from json import JSONDecodeError from time import time @@ -9,6 +10,8 @@ from workflowai.core.client._types import OutputValidator from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.task import TaskOutput +from workflowai.core.domain.version_reference import VersionReference +from workflowai.core.logger import logger delimiter = re.compile(r'\}\n\ndata: \{"') @@ -84,3 +87,21 @@ async def _wait_for_exception(e: WorkflowAIError): def tolerant_validator(m: type[TaskOutput]) -> OutputValidator[TaskOutput]: return lambda payload: m.model_construct(None, **payload) + + +def global_default_version_reference() -> VersionReference: + version = os.getenv("WORKFLOWAI_DEFAULT_VERSION") + if not version: + return "production" + + if version in {"dev", "staging", "production"}: + return version # pyright: ignore [reportReturnType] + + try: + return int(version) + except ValueError: + pass + + logger.warning("Invalid default version: %s", version) + + return "production" diff --git a/workflowai/core/client/_utils_test.py b/workflowai/core/client/_utils_test.py index 91d565d..350584c 100644 --- a/workflowai/core/client/_utils_test.py +++ b/workflowai/core/client/_utils_test.py @@ -1,8 +1,9 @@ -from unittest.mock import Mock +from typing import Any +from unittest.mock import Mock, patch import pytest -from workflowai.core.client._utils import build_retryable_wait, split_chunks +from workflowai.core.client._utils import build_retryable_wait, global_default_version_reference, split_chunks from workflowai.core.domain.errors import BaseError, WorkflowAIError @@ -32,3 +33,12 @@ async def test_should_retry_count(self, request_error: WorkflowAIError): assert should_retry() await wait_for_exception(request_error) assert not should_retry() + + +@pytest.mark.parametrize( + ("env_var", "expected"), + [("p", "production"), ("production", "production"), ("dev", "dev"), ("staging", "staging"), ("1", 1)], +) +def test_global_default_version_reference(env_var: str, expected: Any): + with patch.dict("os.environ", {"WORKFLOWAI_DEFAULT_VERSION": env_var}): + assert global_default_version_reference() == expected diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py new file mode 100644 index 0000000..d1502e2 --- /dev/null +++ b/workflowai/core/client/agent.py @@ -0,0 +1,181 @@ +from collections.abc import Awaitable, Callable +from typing import Any, Generic, NamedTuple, Optional, Union + +from typing_extensions import Unpack + +from workflowai.core.client._api import APIClient +from workflowai.core.client._models import CreateAgentRequest, CreateAgentResponse, RunRequest, RunResponse +from workflowai.core.client._types import RunParams +from workflowai.core.client._utils import build_retryable_wait, global_default_version_reference, tolerant_validator +from workflowai.core.domain.errors import BaseError, WorkflowAIError +from workflowai.core.domain.run import Run +from workflowai.core.domain.task import TaskInput, TaskOutput +from workflowai.core.domain.version_properties import VersionProperties +from workflowai.core.domain.version_reference import VersionReference + + +class Agent(Generic[TaskInput, TaskOutput]): + def __init__( + self, + agent_id: str, + input_cls: type[TaskInput], + output_cls: type[TaskOutput], + api: Union[APIClient, Callable[[], APIClient]], + schema_id: Optional[int] = None, + version: Optional[VersionReference] = None, + ): + self.agent_id = agent_id + self.schema_id = schema_id + self.input_cls = input_cls + self.output_cls = output_cls + self.version: VersionReference = version or global_default_version_reference() + self._api = (lambda: api) if isinstance(api, APIClient) else api + + @property + def api(self) -> APIClient: + return self._api() + + class _PreparedRun(NamedTuple): + request: RunRequest + route: str + should_retry: Callable[[], bool] + wait_for_exception: Callable[[WorkflowAIError], Awaitable[None]] + schema_id: int + + def _sanitize_version(self, version: Optional[VersionReference]) -> Union[str, int, dict[str, Any]]: + if not version: + version = self.version + if not isinstance(version, VersionProperties): + return version + + dumped = version.model_dump(by_alias=True) + if not dumped.get("model"): + import workflowai + + dumped["model"] = workflowai.DEFAULT_MODEL + return dumped + + async def _prepare_run(self, task_input: TaskInput, stream: bool, **kwargs: Unpack[RunParams[TaskOutput]]): + schema_id = self.schema_id + if not schema_id: + schema_id = await self.register() + + version = self._sanitize_version(kwargs.get("version")) + + request = RunRequest( + task_input=task_input.model_dump(by_alias=True), + version=version, + stream=stream, + use_cache=kwargs.get("use_cache"), + metadata=kwargs.get("metadata"), + labels=kwargs.get("labels"), + ) + + route = f"/v1/_/agents/{self.agent_id}/schemas/{self.schema_id}/run" + should_retry, wait_for_exception = build_retryable_wait( + kwargs.get("max_retry_delay", 60), + kwargs.get("max_retry_count", 1), + ) + return self._PreparedRun(request, route, should_retry, wait_for_exception, schema_id) + + async def register(self): + """Registers the agent and returns the schema id""" + res = await self.api.post( + "/v1/_/agents", + CreateAgentRequest( + id=self.agent_id, + input_schema=self.input_cls.model_json_schema(), + output_schema=self.output_cls.model_json_schema(), + ), + returns=CreateAgentResponse, + ) + self.schema_id = res.schema_id + return res.schema_id + + async def run( + self, + task_input: TaskInput, + **kwargs: Unpack[RunParams[TaskOutput]], + ) -> Run[TaskOutput]: + """Run the agent + + Args: + task_input (TaskInput): the input to the task + version (Optional[TaskVersionReference], optional): the version of the task to run. If not provided, + the version defined in the task is used. Defaults to None. + use_cache (CacheUsage, optional): how to use the cache. Defaults to "auto". + "auto" (default): if a previous run exists with the same version and input, and if + the temperature is 0, the cached output is returned + "always": the cached output is returned when available, regardless + of the temperature value + "never": the cache is never used + labels (Optional[set[str]], optional): a set of labels to attach to the run. + Labels are indexed and searchable. Defaults to None. + metadata (Optional[dict[str, Any]], optional): a dictionary of metadata to attach to the run. + Defaults to None. + retry_delay (int, optional): The initial delay between retries in milliseconds. Defaults to 5000. + max_retry_delay (int, optional): The maximum delay between retries in milliseconds. Defaults to 60000. + max_retry_count (int, optional): The maximum number of retry attempts. Defaults to 1. + + Returns: + Union[TaskRun[TaskInput, TaskOutput], AsyncIterator[TaskOutput]]: the task run object + or an async iterator of output objects + """ + prepared_run = await self._prepare_run(task_input, stream=False, **kwargs) + validator = kwargs.get("validator") or self.output_cls.model_validate + + last_error = None + while prepared_run.should_retry(): + try: + res = await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse) + return res.to_domain(self.agent_id, prepared_run.schema_id, validator) + except WorkflowAIError as e: # noqa: PERF203 + last_error = e + await prepared_run.wait_for_exception(e) + + raise last_error or WorkflowAIError(error=BaseError(message="max retries reached"), response=None) + + async def stream( + self, + task_input: TaskInput, + **kwargs: Unpack[RunParams[TaskOutput]], + ): + """Stream the output of the agent + + Args: + task_input (TaskInput): the input to the task + version (Optional[TaskVersionReference], optional): the version of the task to run. If not provided, + the version defined in the task is used. Defaults to None. + use_cache (CacheUsage, optional): how to use the cache. Defaults to "auto". + "auto" (default): if a previous run exists with the same version and input, and if + the temperature is 0, the cached output is returned + "always": the cached output is returned when available, regardless + of the temperature value + "never": the cache is never used + labels (Optional[set[str]], optional): a set of labels to attach to the run. + Labels are indexed and searchable. Defaults to None. + metadata (Optional[dict[str, Any]], optional): a dictionary of metadata to attach to the run. + Defaults to None. + retry_delay (int, optional): The initial delay between retries in milliseconds. Defaults to 5000. + max_retry_delay (int, optional): The maximum delay between retries in milliseconds. Defaults to 60000. + max_retry_count (int, optional): The maximum number of retry attempts. Defaults to 1. + + Returns: + Union[TaskRun[TaskInput, TaskOutput], AsyncIterator[TaskOutput]]: the task run object + or an async iterator of output objects + """ + prepared_run = await self._prepare_run(task_input, stream=True, **kwargs) + validator = kwargs.get("validator") or tolerant_validator(self.output_cls) + + while prepared_run.should_retry(): + try: + async for chunk in self.api.stream( + method="POST", + path=prepared_run.route, + data=prepared_run.request, + returns=RunResponse, + ): + yield chunk.to_domain(self.agent_id, prepared_run.schema_id, validator) + return + except WorkflowAIError as e: # noqa: PERF203 + await prepared_run.wait_for_exception(e) diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py new file mode 100644 index 0000000..228ba3a --- /dev/null +++ b/workflowai/core/client/agent_test.py @@ -0,0 +1,258 @@ +import importlib.metadata +import json + +import httpx +import pytest +from pytest_httpx import HTTPXMock, IteratorStream + +from tests.models.hello_task import ( + HelloTaskInput, + HelloTaskOutput, + HelloTaskOutputNotOptional, +) +from tests.utils import fixtures_json +from workflowai.core.client._api import APIClient +from workflowai.core.client.agent import Agent +from workflowai.core.client.client import ( + WorkflowAI, +) +from workflowai.core.domain.errors import WorkflowAIError +from workflowai.core.domain.run import Run + + +@pytest.fixture +def api_client(): + return WorkflowAI(endpoint="http://localhost:8000", api_key="test").api + + +@pytest.fixture +def agent(api_client: APIClient): + return Agent(agent_id="123", schema_id=1, input_cls=HelloTaskInput, output_cls=HelloTaskOutput, api=api_client) + + +@pytest.fixture +def agent_not_optional(api_client: APIClient): + return Agent( + agent_id="123", + schema_id=1, + input_cls=HelloTaskInput, + output_cls=HelloTaskOutputNotOptional, + api=api_client, + ) + + +@pytest.fixture +def agent_no_schema(api_client: APIClient): + return Agent( + agent_id="123", + input_cls=HelloTaskInput, + output_cls=HelloTaskOutput, + api=api_client, + ) + + +class TestRun: + async def test_success(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): + httpx_mock.add_response(json=fixtures_json("task_run.json")) + + task_run = await agent.run(task_input=HelloTaskInput(name="Alice")) + + assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" + assert task_run.task_id == "123" + assert task_run.task_schema_id == 1 + + reqs = httpx_mock.get_requests() + assert len(reqs) == 1 + assert reqs[0].url == "http://localhost:8000/v1/_/agents/123/schemas/1/run" + + body = json.loads(reqs[0].content) + assert body == { + "task_input": {"name": "Alice"}, + "version": "production", + "stream": False, + } + + async def test_stream(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): + httpx_mock.add_response( + stream=IteratorStream( + [ + b'data: {"id":"1","task_output":{"message":""}}\n\n', + b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 + b'data: {"id":"1","task_output":{"message":"hello"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 + ], + ), + ) + + chunks = [chunk async for chunk in agent.stream(task_input=HelloTaskInput(name="Alice"))] + + outputs = [chunk.task_output for chunk in chunks] + assert outputs == [ + HelloTaskOutput(message=""), + HelloTaskOutput(message="hel"), + HelloTaskOutput(message="hello"), + HelloTaskOutput(message="hello"), + ] + last_message = chunks[-1] + assert isinstance(last_message, Run) + assert last_message.version + assert last_message.version.properties.model == "gpt-4o" + assert last_message.version.properties.temperature == 0.5 + assert last_message.cost_usd == 0.01 + assert last_message.duration_seconds == 10.1 + + async def test_stream_not_optional( + self, + httpx_mock: HTTPXMock, + agent_not_optional: Agent[HelloTaskInput, HelloTaskOutputNotOptional], + ): + # Checking that streaming works even with non optional fields + # The first two chunks are missing a required key but the last one has it + httpx_mock.add_response( + stream=IteratorStream( + [ + b'data: {"id":"1","task_output":{"message":""}}\n\n', + b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 + b'data: {"id":"1","task_output":{"message":"hello", "another_field": "test"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 + ], + ), + ) + + chunks = [chunk async for chunk in agent_not_optional.stream(task_input=HelloTaskInput(name="Alice"))] + + messages = [chunk.task_output.message for chunk in chunks] + assert messages == ["", "hel", "hello", "hello"] + + for chunk in chunks[:-1]: + with pytest.raises(AttributeError): + # Since the field is not optional, it will raise an attribute error + assert chunk.task_output.another_field + assert chunks[-1].task_output.another_field == "test" + + last_message = chunks[-1] + assert isinstance(last_message, Run) + assert last_message.version + assert last_message.version.properties.model == "gpt-4o" + assert last_message.version.properties.temperature == 0.5 + assert last_message.cost_usd == 0.01 + assert last_message.duration_seconds == 10.1 + + async def test_run_with_env(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): + httpx_mock.add_response(json=fixtures_json("task_run.json")) + + await agent.run( + task_input=HelloTaskInput(name="Alice"), + version="dev", + ) + + reqs = httpx_mock.get_requests() + assert len(reqs) == 1 + assert reqs[0].url == "http://localhost:8000/v1/_/agents/123/schemas/1/run" + + body = json.loads(reqs[0].content) + assert body == { + "task_input": {"name": "Alice"}, + "version": "dev", + "stream": False, + } + + async def test_success_with_headers(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): + httpx_mock.add_response(json=fixtures_json("task_run.json")) + + task_run = await agent.run(task_input=HelloTaskInput(name="Alice")) + + assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" + + reqs = httpx_mock.get_requests() + assert len(reqs) == 1 + assert reqs[0].url == "http://localhost:8000/v1/_/agents/123/schemas/1/run" + headers = { + "x-workflowai-source": "sdk", + "x-workflowai-language": "python", + "x-workflowai-version": importlib.metadata.version("workflowai"), + } + + body = json.loads(reqs[0].content) + assert body == { + "task_input": {"name": "Alice"}, + "version": "production", + "stream": False, + } + # Check for additional headers + for key, value in headers.items(): + assert reqs[0].headers[key] == value + + async def test_run_retries_on_too_many_requests( + self, + httpx_mock: HTTPXMock, + agent: Agent[HelloTaskInput, HelloTaskOutput], + ): + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/run", + headers={"Retry-After": "0.01"}, + status_code=429, + ) + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/run", + json=fixtures_json("task_run.json"), + ) + + task_run = await agent.run(task_input=HelloTaskInput(name="Alice"), max_retry_count=5) + + assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" + + reqs = httpx_mock.get_requests() + assert len(reqs) == 2 + + async def test_run_retries_on_connection_error( + self, + httpx_mock: HTTPXMock, + agent: Agent[HelloTaskInput, HelloTaskOutput], + ): + httpx_mock.add_exception(httpx.ConnectError("arg")) + httpx_mock.add_exception(httpx.ConnectError("arg")) + httpx_mock.add_response(json=fixtures_json("task_run.json")) + + task_run = await agent.run(task_input=HelloTaskInput(name="Alice"), max_retry_count=5) + assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" + + async def test_max_retries(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): + httpx_mock.add_exception(httpx.ConnectError("arg"), is_reusable=True) + httpx_mock.add_exception(httpx.ConnectError("arg"), is_reusable=True) + + with pytest.raises(WorkflowAIError): + await agent.run(task_input=HelloTaskInput(name="Alice"), max_retry_count=5) + + reqs = httpx_mock.get_requests() + assert len(reqs) == 5 + + async def test_auto_register(self, httpx_mock: HTTPXMock, agent_no_schema: Agent[HelloTaskInput, HelloTaskOutput]): + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents", + json={ + "id": "123", + "schema_id": 2, + }, + ) + run_response = fixtures_json("task_run.json") + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/2/run", + json=run_response, + ) + + out = await agent_no_schema.run(task_input=HelloTaskInput(name="Alice")) + assert out.id == "8f635b73-f403-47ee-bff9-18320616c6cc" + + run_response["id"] = "8f635b73-f403-47ee-bff9-18320616c6cc" + # Try and run again + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/2/run", + json=run_response, + ) + out2 = await agent_no_schema.run(task_input=HelloTaskInput(name="Alice")) + assert out2.id == "8f635b73-f403-47ee-bff9-18320616c6cc" + + reqs = httpx_mock.get_requests() + assert len(reqs) == 3 + assert reqs[0].url == "http://localhost:8000/v1/_/agents" + assert reqs[1].url == "http://localhost:8000/v1/_/agents/123/schemas/2/run" + assert reqs[2].url == "http://localhost:8000/v1/_/agents/123/schemas/2/run" diff --git a/workflowai/core/client/client.py b/workflowai/core/client/client.py index 589f839..66e397f 100644 --- a/workflowai/core/client/client.py +++ b/workflowai/core/client/client.py @@ -1,55 +1,15 @@ import importlib.metadata -import logging -import os -from collections.abc import Awaitable, Callable from typing import ( - AsyncIterator, - Literal, Optional, - Union, - overload, ) -from typing_extensions import Unpack - from workflowai.core.client._api import APIClient -from workflowai.core.client._fn_utils import task_wrapper -from workflowai.core.client._models import ( - RunRequest, - RunResponse, -) -from workflowai.core.client._types import ( - OutputValidator, - RunParams, -) -from workflowai.core.client._utils import build_retryable_wait, tolerant_validator -from workflowai.core.domain.errors import BaseError, WorkflowAIError -from workflowai.core.domain.run import Run -from workflowai.core.domain.task import Task, TaskInput, TaskOutput +from workflowai.core.client._fn_utils import agent_wrapper +from workflowai.core.client._utils import global_default_version_reference from workflowai.core.domain.version_reference import VersionReference -_logger = logging.getLogger("WorkflowAI") - - -def _compute_default_version_reference() -> VersionReference: - version = os.getenv("WORKFLOWAI_DEFAULT_VERSION") - if not version: - return "production" - - if version in {"dev", "staging", "production"}: - return version # pyright: ignore [reportReturnType] - - try: - return int(version) - except ValueError: - pass - - _logger.warning("Invalid default version: %s", version) - return "production" - - -class WorkflowAIClient: +class WorkflowAI: def __init__( self, api_key: str, @@ -62,112 +22,7 @@ def __init__( "x-workflowai-version": importlib.metadata.version("workflowai"), } self.api = APIClient(endpoint or "https://run.workflowai.com", api_key, self.additional_headers) - self.default_version: VersionReference = default_version or _compute_default_version_reference() - - @overload - async def run( - self, - task: Task[TaskInput, TaskOutput], - task_input: TaskInput, - stream: Literal[False] = False, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> Run[TaskOutput]: ... - - @overload - async def run( - self, - task: Task[TaskInput, TaskOutput], - task_input: TaskInput, - stream: Literal[True] = True, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> AsyncIterator[Run[TaskOutput]]: ... - - async def run( - self, - task: Task[TaskInput, TaskOutput], - task_input: TaskInput, - stream: bool = False, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> Union[Run[TaskOutput], AsyncIterator[Run[TaskOutput]]]: - request = RunRequest( - task_input=task_input.model_dump(by_alias=True), - version=kwargs.get("version") or task.version or self.default_version, - stream=stream, - use_cache=kwargs.get("use_cache"), - metadata=kwargs.get("metadata"), - labels=kwargs.get("labels"), - ) - - route = f"/v1/_/tasks/{task.id}/schemas/{task.schema_id}/run" - should_retry, wait_for_exception = build_retryable_wait( - kwargs.get("max_retry_delay", 60), - kwargs.get("max_retry_count", 1), - ) - - if not stream: - return await self._retriable_run( - route, - request, - should_retry=should_retry, - wait_for_exception=wait_for_exception, - task_id=task.id, - task_schema_id=task.schema_id, - validator=kwargs.get("validator") or task.output_class.model_validate, - ) - - return self._retriable_stream( - route, - request, - should_retry=should_retry, - wait_for_exception=wait_for_exception, - task_id=task.id, - task_schema_id=task.schema_id, - validator=kwargs.get("validator") or tolerant_validator(task.output_class), - ) - - async def _retriable_run( - self, - route: str, - request: RunRequest, - should_retry: Callable[[], bool], - wait_for_exception: Callable[[WorkflowAIError], Awaitable[None]], - task_id: str, - task_schema_id: int, - validator: OutputValidator[TaskOutput], - ): - last_error = None - while should_retry(): - try: - res = await self.api.post(route, request, returns=RunResponse) - return res.to_domain(task_id, task_schema_id, validator) - except WorkflowAIError as e: # noqa: PERF203 - last_error = e - await wait_for_exception(e) - - raise last_error or WorkflowAIError(error=BaseError(message="max retries reached"), response=None) - - async def _retriable_stream( - self, - route: str, - request: RunRequest, - should_retry: Callable[[], bool], - wait_for_exception: Callable[[WorkflowAIError], Awaitable[None]], - task_id: str, - task_schema_id: int, - validator: OutputValidator[TaskOutput], - ): - while should_retry(): - try: - async for chunk in self.api.stream( - method="POST", - path=route, - data=request, - returns=RunResponse, - ): - yield chunk.to_domain(task_id, task_schema_id, validator) - return - except WorkflowAIError as e: # noqa: PERF203 - await wait_for_exception(e) + self.default_version: VersionReference = default_version or global_default_version_reference() def task( self, @@ -175,4 +30,12 @@ def task( task_id: Optional[str] = None, version: Optional[VersionReference] = None, ): - return task_wrapper(lambda: self, schema_id, task_id=task_id, version=version) + return agent_wrapper(lambda: self.api, schema_id, agent_id=task_id, version=version) + + def agent( + self, + id: Optional[str] = None, # noqa: A002 + schema_id: Optional[int] = None, + version: Optional[VersionReference] = None, + ): + return agent_wrapper(lambda: self.api, schema_id=schema_id, agent_id=id, version=version) diff --git a/workflowai/core/client/client_test.py b/workflowai/core/client/client_test.py index 023aa36..f268340 100644 --- a/workflowai/core/client/client_test.py +++ b/workflowai/core/client/client_test.py @@ -1,225 +1,44 @@ -import importlib.metadata -import json -from typing import Any, AsyncIterator -from unittest.mock import AsyncMock, patch +from collections.abc import AsyncIterator +from typing import Any +from unittest.mock import Mock, patch -import httpx import pytest from pytest_httpx import HTTPXMock, IteratorStream -from tests.models.hello_task import HelloTask, HelloTaskInput, HelloTaskNotOptional, HelloTaskOutput -from tests.utils import fixtures_json -from workflowai.core.client import Client -from workflowai.core.client.client import ( - WorkflowAIClient, - _compute_default_version_reference, # pyright: ignore [reportPrivateUsage] -) -from workflowai.core.domain.errors import WorkflowAIError +from tests.models.hello_task import HelloTaskInput, HelloTaskOutput +from workflowai.core.client.client import WorkflowAI from workflowai.core.domain.run import Run -@pytest.fixture -def client(): - return WorkflowAIClient(endpoint="http://localhost:8000", api_key="test") - - -class TestRun: - async def test_success(self, httpx_mock: HTTPXMock, client: Client): - httpx_mock.add_response(json=fixtures_json("task_run.json")) - task = HelloTask(id="123", schema_id=1) - - task_run = await client.run(task, task_input=HelloTaskInput(name="Alice")) - - assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" - assert task_run.task_id == "123" - assert task_run.task_schema_id == 1 - - reqs = httpx_mock.get_requests() - assert len(reqs) == 1 - assert reqs[0].url == "http://localhost:8000/v1/_/tasks/123/schemas/1/run" - - body = json.loads(reqs[0].content) - assert body == { - "task_input": {"name": "Alice"}, - "version": "production", - "stream": False, - } - - async def test_stream(self, httpx_mock: HTTPXMock, client: Client): - httpx_mock.add_response( - stream=IteratorStream( - [ - b'data: {"id":"1","task_output":{"message":""}}\n\n', - b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 - b'data: {"id":"1","task_output":{"message":"hello"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 - ], - ), - ) - task = HelloTask(id="123", schema_id=1) - - streamed = await client.run( - task, - task_input=HelloTaskInput(name="Alice"), - stream=True, - ) - chunks = [chunk async for chunk in streamed] - - outputs = [chunk.task_output for chunk in chunks] - assert outputs == [ - HelloTaskOutput(message=""), - HelloTaskOutput(message="hel"), - HelloTaskOutput(message="hello"), - HelloTaskOutput(message="hello"), - ] - last_message = chunks[-1] - assert isinstance(last_message, Run) - assert last_message.version - assert last_message.version.properties.model == "gpt-4o" - assert last_message.version.properties.temperature == 0.5 - assert last_message.cost_usd == 0.01 - assert last_message.duration_seconds == 10.1 - - async def test_stream_not_optional(self, httpx_mock: HTTPXMock, client: Client): - # Checking that streaming works even with non optional fields - # The first two chunks are missing a required key but the last one has it - httpx_mock.add_response( - stream=IteratorStream( - [ - b'data: {"id":"1","task_output":{"message":""}}\n\n', - b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 - b'data: {"id":"1","task_output":{"message":"hello", "another_field": "test"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 - ], - ), - ) - task = HelloTaskNotOptional(id="123", schema_id=1) - - streamed = await client.run( - task, - task_input=HelloTaskInput(name="Alice"), - stream=True, - ) - chunks = [chunk async for chunk in streamed] - - messages = [chunk.task_output.message for chunk in chunks] - assert messages == ["", "hel", "hello", "hello"] - - for chunk in chunks[:-1]: - with pytest.raises(AttributeError): - # Since the field is not optional, it will raise an attribute error - assert chunk.task_output.another_field - assert chunks[-1].task_output.another_field == "test" - - last_message = chunks[-1] - assert isinstance(last_message, Run) - assert last_message.version - assert last_message.version.properties.model == "gpt-4o" - assert last_message.version.properties.temperature == 0.5 - assert last_message.cost_usd == 0.01 - assert last_message.duration_seconds == 10.1 - - async def test_run_with_env(self, httpx_mock: HTTPXMock, client: Client): - httpx_mock.add_response(json=fixtures_json("task_run.json")) - task = HelloTask(id="123", schema_id=1) - - await client.run( - task, - task_input=HelloTaskInput(name="Alice"), - version="dev", - ) - - reqs = httpx_mock.get_requests() - assert len(reqs) == 1 - assert reqs[0].url == "http://localhost:8000/v1/_/tasks/123/schemas/1/run" - - body = json.loads(reqs[0].content) - assert body == { - "task_input": {"name": "Alice"}, - "version": "dev", - "stream": False, - } - - async def test_success_with_headers(self, httpx_mock: HTTPXMock, client: Client): - httpx_mock.add_response(json=fixtures_json("task_run.json")) - task = HelloTask(id="123", schema_id=1) - - task_run = await client.run(task, task_input=HelloTaskInput(name="Alice")) - - assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" - - reqs = httpx_mock.get_requests() - assert len(reqs) == 1 - assert reqs[0].url == "http://localhost:8000/v1/_/tasks/123/schemas/1/run" - headers = { - "x-workflowai-source": "sdk", - "x-workflowai-language": "python", - "x-workflowai-version": importlib.metadata.version("workflowai"), - } - - body = json.loads(reqs[0].content) - assert body == { - "task_input": {"name": "Alice"}, - "version": "production", - "stream": False, - } - # Check for additional headers - for key, value in headers.items(): - assert reqs[0].headers[key] == value - - async def test_run_retries_on_too_many_requests(self, httpx_mock: HTTPXMock, client: Client): - task = HelloTask(id="123", schema_id=1) - - httpx_mock.add_response(headers={"Retry-After": "0.01"}, status_code=429) - httpx_mock.add_response(json=fixtures_json("task_run.json")) - - task_run = await client.run(task, task_input=HelloTaskInput(name="Alice"), max_retry_count=5) - - assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" - - reqs = httpx_mock.get_requests() - assert len(reqs) == 2 - assert reqs[0].url == "http://localhost:8000/v1/_/tasks/123/schemas/1/run" - assert reqs[1].url == "http://localhost:8000/v1/_/tasks/123/schemas/1/run" - - async def test_run_retries_on_connection_error(self, httpx_mock: HTTPXMock, client: Client): - task = HelloTask(id="123", schema_id=1) - - httpx_mock.add_exception(httpx.ConnectError("arg")) - httpx_mock.add_response(json=fixtures_json("task_run.json")) - - task_run = await client.run(task, task_input=HelloTaskInput(name="Alice"), max_retry_count=5) - assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" - - async def test_max_retries(self, httpx_mock: HTTPXMock, client: Client): - task = HelloTask(id="123", schema_id=1) - - httpx_mock.add_exception(httpx.ConnectError("arg"), is_reusable=True) - - with pytest.raises(WorkflowAIError): - await client.run(task, task_input=HelloTaskInput(name="Alice"), max_retry_count=5) - - reqs = httpx_mock.get_requests() - assert len(reqs) == 5 +class TestAgentDecorator: + @pytest.fixture + def workflowai(self): + # using httpx_mock to make sure we don't actually call the api + return WorkflowAI(api_key="test") + @pytest.fixture + def mock_run_fn(self): + with patch("workflowai.core.client.agent.Agent.run", autospec=True) as run_mock: + yield run_mock -class TestTask: @pytest.fixture - def patched_run_fn(self, client: Client): - with patch.object(client, "run", spec=client.run) as run_mock: + def mock_stream_fn(self): + with patch("workflowai.core.client.agent.Agent.stream", autospec=True) as run_mock: yield run_mock - def test_fn_name(self, client: Client): - @client.task(schema_id=1, task_id="123") + def test_fn_name(self, workflowai: WorkflowAI): + @workflowai.task(schema_id=1, task_id="123") async def fn(task_input: HelloTaskInput) -> HelloTaskOutput: ... assert fn.__name__ == "fn" assert fn.__doc__ is None assert callable(fn) - async def test_run_output_only(self, client: Client, patched_run_fn: AsyncMock): - @client.task(schema_id=1, task_id="123") + async def test_run_output_only(self, workflowai: WorkflowAI, mock_run_fn: Mock): + @workflowai.task(schema_id=1, task_id="123") async def fn(task_input: HelloTaskInput) -> HelloTaskOutput: ... - patched_run_fn.return_value = Run( + mock_run_fn.return_value = Run( task_output=HelloTaskOutput(message="hello"), task_id="123", task_schema_id=1, @@ -229,11 +48,11 @@ async def fn(task_input: HelloTaskInput) -> HelloTaskOutput: ... assert output == HelloTaskOutput(message="hello") - async def test_run_with_version(self, client: Client, patched_run_fn: AsyncMock): - @client.task(schema_id=1, task_id="123") + async def test_run_with_version(self, workflowai: WorkflowAI, mock_run_fn: Mock): + @workflowai.task(schema_id=1, task_id="123") async def fn(task_input: HelloTaskInput) -> Run[HelloTaskOutput]: ... - patched_run_fn.return_value = Run( + mock_run_fn.return_value = Run( id="1", task_output=HelloTaskOutput(message="hello"), task_id="123", @@ -246,12 +65,12 @@ async def fn(task_input: HelloTaskInput) -> Run[HelloTaskOutput]: ... assert output.task_output == HelloTaskOutput(message="hello") assert isinstance(output, Run) - async def test_stream(self, client: Client, httpx_mock: HTTPXMock): + async def test_stream(self, workflowai: WorkflowAI, httpx_mock: HTTPXMock): # We avoid mocking the run fn directly here, python does weird things with # having to await async iterators depending on how they are defined so instead we mock # the underlying api call to check that we don't need the extra await - @client.task(schema_id=1, task_id="123") + @workflowai.task(schema_id=1, task_id="123") def fn(task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ... httpx_mock.add_response( @@ -276,8 +95,8 @@ def _run(output: HelloTaskOutput, **kwargs: Any) -> Run[HelloTaskOutput]: _run(HelloTaskOutput(message="hello"), duration_seconds=10.1, cost_usd=0.01), ] - async def test_stream_output_only(self, client: Client, httpx_mock: HTTPXMock): - @client.task(schema_id=1) + async def test_stream_output_only(self, workflowai: WorkflowAI, httpx_mock: HTTPXMock): + @workflowai.task(schema_id=1) def fn(task_input: HelloTaskInput) -> AsyncIterator[HelloTaskOutput]: ... httpx_mock.add_response( @@ -300,12 +119,3 @@ def fn(task_input: HelloTaskInput) -> AsyncIterator[HelloTaskOutput]: ... HelloTaskOutput(message="hello"), HelloTaskOutput(message="hello"), ] - - -@pytest.mark.parametrize( - ("env_var", "expected"), - [("p", "production"), ("production", "production"), ("dev", "dev"), ("staging", "staging"), ("1", 1)], -) -def test_compute_default_version_reference(env_var: str, expected: Any): - with patch.dict("os.environ", {"WORKFLOWAI_DEFAULT_VERSION": env_var}): - assert _compute_default_version_reference() == expected diff --git a/workflowai/core/domain/model.py b/workflowai/core/domain/model.py new file mode 100644 index 0000000..56d8e7a --- /dev/null +++ b/workflowai/core/domain/model.py @@ -0,0 +1,105 @@ +from typing import Literal, Union + +Model = Union[ + Literal[ + # -------------------------------------------------------------------------- + # OpenAI Models + # -------------------------------------------------------------------------- + "gpt-4o-latest", + "gpt-4o-2024-11-20", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-mini-latest", + "gpt-4o-mini-2024-07-18", + "o1-2024-12-17-high", + "o1-2024-12-17", + "o1-2024-12-17-low", + "o1-preview-2024-09-12", + "o1-mini-latest", + "o1-mini-2024-09-12", + "gpt-4o-audio-preview-2024-12-17", + "gpt-4o-audio-preview-2024-10-01", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-1106-preview", + "gpt-4-1106-vision-preview", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-1106", + # -------------------------------------------------------------------------- + # Gemini Models + # -------------------------------------------------------------------------- + "gemini-2.0-flash-exp", + "gemini-2.0-flash-thinking-exp-1219", + "gemini-1.5-pro-latest", + "gemini-1.5-pro-002", + "gemini-1.5-pro-001", + "gemini-1.5-pro-preview-0514", + "gemini-1.5-pro-preview-0409", + "gemini-1.5-flash-latest", + "gemini-1.5-flash-002", + "gemini-1.5-flash-001", + "gemini-1.5-flash-8b", + "gemini-1.5-flash-preview-0514", + "gemini-exp-1206", + "gemini-exp-1121", + "gemini-1.0-pro-002", + "gemini-1.0-pro-001", + "gemini-1.0-pro-vision-001", + # -------------------------------------------------------------------------- + # Claude Models + # -------------------------------------------------------------------------- + "claude-3-5-sonnet-latest", + "claude-3-5-sonnet-20241022", + "claude-3-5-sonnet-20240620", + "claude-3-5-haiku-latest", + "claude-3-5-haiku-20241022", + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", + # -------------------------------------------------------------------------- + # Llama Models + # -------------------------------------------------------------------------- + "llama-3.3-70b", + "llama-3.2-90b", + "llama-3.2-11b", + "llama-3.2-11b-vision", + "llama-3.2-3b", + "llama-3.2-1b", + "llama-3.2-90b-vision-preview", + "llama-3.2-90b-text-preview", + "llama-3.2-11b-text-preview", + "llama-3.2-3b-preview", + "llama-3.2-1b-preview", + "llama-3.1-405b", + "llama-3.1-70b", + "llama-3.1-8b", + "llama3-70b-8192", + "llama3-8b-8192", + # -------------------------------------------------------------------------- + # Mistral AI Models + # -------------------------------------------------------------------------- + "mixtral-8x7b-32768", + "mistral-large-2-latest", + "mistral-large-2-2407", + "mistral-large-latest", + "mistral-large-2411", + "pixtral-large-latest", + "pixtral-large-2411", + "pixtral-12b-2409", + "ministral-3b-2410", + "ministral-8b-2410", + "mistral-small-2409", + "codestral-mamba-2407", + # -------------------------------------------------------------------------- + # Qwen Models + # -------------------------------------------------------------------------- + "qwen-v3p2-32b-instruct", + # -------------------------------------------------------------------------- + # DeepSeek Models + # -------------------------------------------------------------------------- + "deepseek-v3-2412", + "deepseek-r1-2501", + ], + # Adding string to allow for any model not currently in the SDK but supported by the API + str, +] diff --git a/workflowai/core/domain/task.py b/workflowai/core/domain/task.py index 300574b..285cba9 100644 --- a/workflowai/core/domain/task.py +++ b/workflowai/core/domain/task.py @@ -1,25 +1,6 @@ -from typing import Generic, Optional, TypeVar +from typing import TypeVar from pydantic import BaseModel -from workflowai.core.domain.version_reference import VersionReference - TaskInput = TypeVar("TaskInput", bound=BaseModel) TaskOutput = TypeVar("TaskOutput", bound=BaseModel) - - -class Task(BaseModel, Generic[TaskInput, TaskOutput]): - """ - A blueprint for a task. Used to instantiate task runs. - - It should not be used as is but subclassed to provide the necessary information for the task. - Default values are provided so that they can be overriden in subclasses - """ - - id: str = "" - schema_id: int = 0 - - version: Optional[VersionReference] = None - - input_class: type[TaskInput] = BaseModel # pyright: ignore [reportAssignmentType] - output_class: type[TaskOutput] = BaseModel # pyright: ignore [reportAssignmentType] diff --git a/workflowai/core/domain/task_version.py b/workflowai/core/domain/task_version.py index 3be4b01..a930923 100644 --- a/workflowai/core/domain/task_version.py +++ b/workflowai/core/domain/task_version.py @@ -1,10 +1,10 @@ from pydantic import BaseModel, Field -from workflowai.core.domain.task_version_properties import TaskVersionProperties +from workflowai.core.domain.version_properties import VersionProperties class TaskVersion(BaseModel): - properties: TaskVersionProperties = Field( - default_factory=TaskVersionProperties, + properties: VersionProperties = Field( + default_factory=VersionProperties, description="The properties used for executing the run.", ) diff --git a/workflowai/core/domain/task_version_properties.py b/workflowai/core/domain/version_properties.py similarity index 90% rename from workflowai/core/domain/task_version_properties.py rename to workflowai/core/domain/version_properties.py index 85a37ef..6cb6570 100644 --- a/workflowai/core/domain/task_version_properties.py +++ b/workflowai/core/domain/version_properties.py @@ -2,15 +2,17 @@ from pydantic import BaseModel, ConfigDict, Field +from workflowai.core.domain.model import Model -class TaskVersionProperties(BaseModel): + +class VersionProperties(BaseModel): """Properties that described a way a task run was executed. Although some keys are provided as an example, any key:value are accepted""" # Allow extra fields to support custom options model_config = ConfigDict(extra="allow") - model: Optional[str] = Field( + model: Optional[Model] = Field( default=None, description="The LLM model used for the run", ) diff --git a/workflowai/core/domain/version_properties_test.py b/workflowai/core/domain/version_properties_test.py new file mode 100644 index 0000000..b49c04f --- /dev/null +++ b/workflowai/core/domain/version_properties_test.py @@ -0,0 +1,18 @@ +from typing import Any + +import pytest + +from workflowai.core.domain.version_properties import VersionProperties + + +@pytest.mark.parametrize( + "payload", + [ + {"model": "gpt-4o-latest"}, + {"model": "gpt-4o-latest", "provider": "openai"}, + {"model": "whatever"}, + ], +) +def test_version_properties_validate(payload: dict[str, Any]): + # Check that we don't raise an error + assert VersionProperties.model_validate(payload) diff --git a/workflowai/core/domain/version_reference.py b/workflowai/core/domain/version_reference.py index 8faf26b..34a70fe 100644 --- a/workflowai/core/domain/version_reference.py +++ b/workflowai/core/domain/version_reference.py @@ -1,5 +1,7 @@ from typing import Literal, Union +from workflowai.core.domain.version_properties import VersionProperties + VersionEnvironment = Literal["dev", "staging", "production"] -VersionReference = Union[int, VersionEnvironment] +VersionReference = Union[int, VersionEnvironment, VersionProperties] diff --git a/workflowai/core/logger.py b/workflowai/core/logger.py new file mode 100644 index 0000000..9b4512e --- /dev/null +++ b/workflowai/core/logger.py @@ -0,0 +1,3 @@ +import logging + +logger = logging.getLogger("WorkflowAI")