From 02bd32ce66cdc036beadd2a5af20d80e32ef69e0 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 11 Jun 2025 11:52:22 -0400 Subject: [PATCH 1/2] feat: add use fallback parameter --- workflowai/core/_common_types.py | 3 +++ workflowai/core/client/_models.py | 2 ++ workflowai/core/client/agent.py | 1 + workflowai/core/domain/model.py | 37 ++++++++++++++++++++-------- workflowai/core/domain/model_test.py | 9 +++++-- 5 files changed, 40 insertions(+), 12 deletions(-) diff --git a/workflowai/core/_common_types.py b/workflowai/core/_common_types.py index cde0fa2..6ac9e70 100644 --- a/workflowai/core/_common_types.py +++ b/workflowai/core/_common_types.py @@ -2,9 +2,11 @@ Annotated, Any, Generic, + Literal, Optional, Protocol, TypeVar, + Union, ) from pydantic import BaseModel @@ -42,6 +44,7 @@ class VersionRunParams(TypedDict): class OtherRunParams(TypedDict): use_cache: NotRequired["CacheUsage"] + use_fallback: NotRequired[Union[Literal["auto", "never"], list[str]]] max_retry_delay: NotRequired[float] max_retry_count: NotRequired[float] diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 6da5a3e..92a466c 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -28,6 +28,8 @@ class RunRequest(BaseModel): use_cache: Optional[CacheUsage] = None + use_fallback: Optional[Union[Literal["auto", "never"], list[str]]] = None + metadata: Optional[dict[str, Any]] = None labels: Optional[set[str]] = None # deprecated, to be included in metadata diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index 07b321d..71ed0d1 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -209,6 +209,7 @@ async def _prepare_run(self, agent_input: AgentInput, stream: bool, **kwargs: Un version=version, stream=stream, use_cache=self._get_run_param("use_cache", kwargs), + use_fallback=self._get_run_param("use_fallback", kwargs), metadata=kwargs.get("metadata"), ) diff --git a/workflowai/core/domain/model.py b/workflowai/core/domain/model.py index 48cd18b..dc92db2 100644 --- a/workflowai/core/domain/model.py +++ b/workflowai/core/domain/model.py @@ -12,22 +12,13 @@ # higher, comment out the line where it should be in "natural" order, and add another one wherever # needed for the order class Model(str, Enum): - # -------------------------------------------------------------------------- - # OpenAI Models - # -------------------------------------------------------------------------- - - GPT_41_LATEST = "gpt-4.1-latest" - GPT_41_2025_04_14 = "gpt-4.1-2025-04-14" - GPT_41_MINI_LATEST = "gpt-4.1-mini-latest" - GPT_41_MINI_2025_04_14 = "gpt-4.1-mini-2025-04-14" - GPT_41_NANO_LATEST = "gpt-4.1-nano-latest" - GPT_41_NANO_2025_04_14 = "gpt-4.1-nano-2025-04-14" GPT_4O_LATEST = "gpt-4o-latest" GPT_4O_2024_11_20 = "gpt-4o-2024-11-20" GPT_4O_2024_08_06 = "gpt-4o-2024-08-06" GPT_4O_2024_05_13 = "gpt-4o-2024-05-13" GPT_4O_MINI_LATEST = "gpt-4o-mini-latest" GPT_4O_MINI_2024_07_18 = "gpt-4o-mini-2024-07-18" + GPT_IMAGE_1 = "gpt-image-1" O3_LATEST_HIGH_REASONING_EFFORT = "o3-latest-high" O3_LATEST_MEDIUM_REASONING_EFFORT = "o3-latest-medium" O3_LATEST_LOW_REASONING_EFFORT = "o3-latest-low" @@ -52,7 +43,14 @@ class Model(str, Enum): O1_PREVIEW_2024_09_12 = "o1-preview-2024-09-12" O1_MINI_LATEST = "o1-mini-latest" O1_MINI_2024_09_12 = "o1-mini-2024-09-12" + GPT_41_LATEST = "gpt-4.1-latest" + GPT_41_2025_04_14 = "gpt-4.1-2025-04-14" + GPT_41_MINI_LATEST = "gpt-4.1-mini-latest" + GPT_41_MINI_2025_04_14 = "gpt-4.1-mini-2025-04-14" + GPT_41_NANO_LATEST = "gpt-4.1-nano-latest" + GPT_41_NANO_2025_04_14 = "gpt-4.1-nano-2025-04-14" GPT_45_PREVIEW_2025_02_27 = "gpt-4.5-preview-2025-02-27" + GPT_4O_AUDIO_PREVIEW_2025_06_03 = "gpt-4o-audio-preview-2025-06-03" GPT_4O_AUDIO_PREVIEW_2024_12_17 = "gpt-4o-audio-preview-2024-12-17" GPT_40_AUDIO_PREVIEW_2024_10_01 = "gpt-4o-audio-preview-2024-10-01" GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09" @@ -66,6 +64,10 @@ class Model(str, Enum): # Gemini Models # -------------------------------------------------------------------------- GEMINI_2_0_FLASH_LATEST = "gemini-2.0-flash-latest" + GEMINI_2_5_PRO_PREVIEW_0605 = "gemini-2.5-pro-preview-06-05" + GEMINI_2_5_FLASH_PREVIEW_0520 = "gemini-2.5-flash-preview-05-20" + GEMINI_2_5_FLASH_THINKING_PREVIEW_0520 = "gemini-2.5-flash-thinking-preview-05-20" + GEMINI_2_5_PRO_PREVIEW_0506 = "gemini-2.5-pro-preview-05-06" GEMINI_2_5_FLASH_PREVIEW_0417 = "gemini-2.5-flash-preview-04-17" GEMINI_2_5_FLASH_THINKING_PREVIEW_0417 = "gemini-2.5-flash-thinking-preview-04-17" GEMINI_2_5_PRO_PREVIEW_0325 = "gemini-2.5-pro-preview-03-25" @@ -93,9 +95,18 @@ class Model(str, Enum): GEMINI_1_0_PRO_001 = "gemini-1.0-pro-001" GEMINI_1_0_PRO_VISION_001 = "gemini-1.0-pro-vision-001" + IMAGEN_3_0_LATEST = "imagen-3.0-generate-latest" + IMAGEN_3_0_002 = "imagen-3.0-generate-002" + IMAGEN_3_0_001 = "imagen-3.0-generate-001" + IMAGEN_3_0_FAST_001 = "imagen-3.0-fast-generate-001" + # -------------------------------------------------------------------------- # Claude Models # -------------------------------------------------------------------------- + CLAUDE_4_SONNET_LATEST = "claude-sonnet-4-latest" + CLAUDE_4_SONNET_20250514 = "claude-sonnet-4-20250514" + CLAUDE_4_OPUS_LATEST = "claude-opus-4-latest" + CLAUDE_4_OPUS_20250514 = "claude-opus-4-20250514" CLAUDE_3_7_SONNET_LATEST = "claude-3-7-sonnet-latest" CLAUDE_3_7_SONNET_20250219 = "claude-3-7-sonnet-20250219" CLAUDE_3_5_SONNET_LATEST = "claude-3-5-sonnet-latest" @@ -139,6 +150,7 @@ class Model(str, Enum): MISTRAL_LARGE_2_2407 = "mistral-large-2-2407" MISTRAL_LARGE_LATEST = "mistral-large-latest" MISTRAL_LARGE_2411 = "mistral-large-2411" + MISTRAL_MEDIUM_2505 = "mistral-medium-2505" PIXTRAL_LARGE_LATEST = "pixtral-large-latest" PIXTRAL_LARGE_2411 = "pixtral-large-2411" PIXTRAL_12B_2409 = "pixtral-12b-2409" @@ -149,6 +161,8 @@ class Model(str, Enum): MISTRAL_SMALL_2501 = "mistral-small-2501" MISTRAL_SMALL_2409 = "mistral-small-2409" MISTRAL_SABA_2502 = "mistral-saba-2502" + MAGISTRAL_SMALL_2506 = "magistral-small-2506" + MAGISTRAL_MEDIUM_2506 = "magistral-medium-2506" CODESTRAL_2501 = "codestral-2501" CODESTRAL_MAMBA_2407 = "codestral-mamba-2407" @@ -157,6 +171,8 @@ class Model(str, Enum): # -------------------------------------------------------------------------- QWEN_QWQ_32B = "qwen-qwq-32b" QWEN_QWQ_32B_PREVIEW = "qwen-v3p2-32b-instruct" + QWEN3_235B_A22B = "qwen3-235b-a22b" + QWEN3_30B_A3B = "qwen3-30b-a3b" # -------------------------------------------------------------------------- # DeepSeek Models @@ -166,6 +182,7 @@ class Model(str, Enum): DEEPSEEK_V3_LATEST = "deepseek-v3-latest" DEEPSEEK_R1_2501 = "deepseek-r1-2501" DEEPSEEK_R1_2501_BASIC = "deepseek-r1-2501-basic" + DEEPSEEK_R1_0528 = "deepseek-r1-0528" # -------------------------------------------------------------------------- # XAI Models diff --git a/workflowai/core/domain/model_test.py b/workflowai/core/domain/model_test.py index d23bacb..6124cae 100644 --- a/workflowai/core/domain/model_test.py +++ b/workflowai/core/domain/model_test.py @@ -6,10 +6,15 @@ async def test_model_exhaustive(): """Make sure the list of models is synchronized with the prod API""" async with httpx.AsyncClient() as client: - response = await client.get("https://run.workflowai.com/v1/models") + response = await client.get("https://run.workflowai.com/v1/models?raw=true") response.raise_for_status() models: list[str] = response.json() # Converting to a set of strings should not be needed # but it makes pytest errors prettier - assert set(models) == {m.value for m in Model} + expected_models = set(models) + actual_models = {m.value for m in Model} + missing_models = expected_models - actual_models + extra_models = actual_models - expected_models + assert not extra_models, f"Extra models: {extra_models}" + assert not missing_models, f"Missing models: {missing_models}" From e97cd03c885751df5a8d3c97b564af0af8e9751f Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 11 Jun 2025 11:53:25 -0400 Subject: [PATCH 2/2] test: add test_run_task_run_with_fallback --- tests/integration/conftest.py | 2 ++ tests/integration/run_test.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 13ca5f7..877c4dc 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -121,6 +121,8 @@ def check_request( assert request.headers["x-workflowai-source"] == "sdk" assert request.headers["x-workflowai-language"] == "python" + return body + @pytest.fixture def test_client(httpx_mock: HTTPXMock) -> IntTestClient: diff --git a/tests/integration/run_test.py b/tests/integration/run_test.py index f0944c9..ba2ab48 100644 --- a/tests/integration/run_test.py +++ b/tests/integration/run_test.py @@ -38,6 +38,22 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapit test_client.check_request() +async def test_run_task_run_with_fallback(test_client: IntTestClient) -> None: + @workflowai.agent(schema_id=1) + async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapitalTaskOutput]: ... + + test_client.mock_response() + + task_input = CityToCapitalTaskInput(city="Hello") + with_run = await city_to_capital(task_input, use_fallback="never") + + assert with_run.id == "123" + assert with_run.output.capital == "Tokyo" + + body = test_client.check_request() + assert body["use_fallback"] == "never" + + async def test_run_task_run_version(test_client: IntTestClient) -> None: @workflowai.agent(schema_id=1, version="staging") async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapitalTaskOutput]: ...