diff --git a/pyproject.toml b/pyproject.toml index 753c4b6..febbfcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.6.0.dev21" +version = "0.6.0.dev22" description = "" authors = ["Guillaume Aquilina "] readme = "README.md" diff --git a/workflowai/__init__.py b/workflowai/__init__.py index 1635a32..0c4f5a5 100644 --- a/workflowai/__init__.py +++ b/workflowai/__init__.py @@ -12,6 +12,7 @@ from workflowai.core.domain.model import Model as Model from workflowai.core.domain.run import Run as Run from workflowai.core.domain.version import Version as Version +from workflowai.core.domain.version_properties import VersionProperties as VersionProperties from workflowai.core.domain.version_reference import ( VersionReference as VersionReference, ) diff --git a/workflowai/core/client/_utils.py b/workflowai/core/client/_utils.py index 0894082..c87f2b8 100644 --- a/workflowai/core/client/_utils.py +++ b/workflowai/core/client/_utils.py @@ -4,14 +4,18 @@ import asyncio import os import re +from collections.abc import Mapping from json import JSONDecodeError from time import time -from typing import Any +from typing import Any, NamedTuple, Optional, Union + +from typing_extensions import Self from workflowai.core._common_types import OutputValidator from workflowai.core._logger import logger from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.task import AgentOutput +from workflowai.core.domain.version_properties import VersionProperties from workflowai.core.domain.version_reference import VersionReference from workflowai.core.utils._pydantic import partial_model @@ -113,3 +117,38 @@ def global_default_version_reference() -> VersionReference: logger.warning("Invalid default version: %s", version) return "production" + + +class ModelInstructionTemperature(NamedTuple): + """A combination of run properties, with useful method + for combination""" + + model: Optional[str] = None + instructions: Optional[str] = None + temperature: Optional[float] = None + + @classmethod + def from_dict(cls, d: Mapping[str, Any]): + return cls( + model=d.get("model"), + instructions=d.get("instructions"), + temperature=d.get("temperature"), + ) + + @classmethod + def from_version(cls, version: Union[int, str, VersionProperties, None]): + if isinstance(version, VersionProperties): + return cls( + model=version.model, + instructions=version.instructions, + temperature=version.temperature, + ) + return cls() + + @classmethod + def combine(cls, *args: Self): + return cls( + model=next((a.model for a in args if a.model is not None), None), + instructions=next((a.instructions for a in args if a.instructions is not None), None), + temperature=next((a.temperature for a in args if a.temperature is not None), None), + ) diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index c480d4d..20a0311 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -21,6 +21,7 @@ ) from workflowai.core.client._types import RunParams from workflowai.core.client._utils import ( + ModelInstructionTemperature, build_retryable_wait, default_validator, global_default_version_reference, @@ -123,32 +124,42 @@ class _PreparedRun(NamedTuple): def _sanitize_version(self, params: VersionRunParams) -> Union[str, int, dict[str, Any]]: """Combine a version requested at runtime and the version requested at build time.""" + # Version contains either the requested version or the default version + # this is important to combine the check below of whether the version is a remote version (e-g production) + # or a local version (VersionProperties) version = params.get("version", self.version) - model = params.get("model") - instructions = params.get("instructions") - temperature = params.get("temperature") - has_property_overrides = bool(model or instructions or temperature or self._tools) + # Combine all overrides in a tuple + overrides = ModelInstructionTemperature.from_dict(params) + has_property_overrides = bool(self._tools or any(o is not None for o in overrides)) + # Version exists and is a remote version if version and not isinstance(version, VersionProperties): + # No property override so we return as is if not has_property_overrides and not self._tools: return version # In the case where the version requested a build time was a remote version # (either an ID or an environment), we use an empty template for the version - logger.warning("Overriding remove version with a local one") + logger.warning("Overriding remote version with a local one") version = VersionProperties() + # Version does not exist and there are no overrides + # We return the default version if not version and not has_property_overrides: g = global_default_version_reference() return g.model_dump(by_alias=True, exclude_unset=True) if isinstance(g, VersionProperties) else g dumped = version.model_dump(by_alias=True, exclude_unset=True) if version else {} - if not dumped.get("model"): + requested = ModelInstructionTemperature.from_version(version) + defaults = ModelInstructionTemperature.from_version(self.version) + combined = ModelInstructionTemperature.combine(overrides, requested, defaults) + + if not combined.model: # We always provide a default model since it is required by the API import workflowai - dumped["model"] = workflowai.DEFAULT_MODEL + combined = combined._replace(model=workflowai.DEFAULT_MODEL) if self._tools: dumped["enabled_tools"] = [ @@ -161,12 +172,12 @@ def _sanitize_version(self, params: VersionRunParams) -> Union[str, int, dict[st for tool in self._tools.values() ] # Finally we apply the property overrides - if model: - dumped["model"] = model - if instructions: - dumped["instructions"] = instructions - if temperature: - dumped["temperature"] = temperature + if combined.model is not None: + dumped["model"] = combined.model + if combined.instructions is not None: + dumped["instructions"] = combined.instructions + if combined.temperature is not None: + dumped["temperature"] = combined.temperature return dumped async def _prepare_run(self, agent_input: AgentInput, stream: bool, **kwargs: Unpack[RunParams[AgentOutput]]): diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index ca18480..925f044 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -385,6 +385,26 @@ def test_only_model_privider(self, agent: Agent[HelloTaskInput, HelloTaskOutput] "instructions": "You are a helpful assistant.", } + def test_with_explicit_version_without_instructions(self, agent: Agent[HelloTaskInput, HelloTaskOutput]): + """In the case where the agent has instructions but we send a version without instructions, + we use the instructions from the agent""" + + agent.version = VersionProperties(instructions="You are a helpful assistant.") + sanitized = agent._sanitize_version({"version": VersionProperties(model="gpt-4o-latest")}) # pyright: ignore [reportPrivateUsage] + assert sanitized == { + "model": "gpt-4o-latest", + "instructions": "You are a helpful assistant.", + } + + def test_override_with_0_temperature(self, agent: Agent[HelloTaskInput, HelloTaskOutput]): + """Test that a 0 temperature is not overridden by the default version""" + agent.version = VersionProperties(temperature=0.7) + sanitized = agent._sanitize_version({"version": VersionProperties(temperature=0)}) # pyright: ignore [reportPrivateUsage] + assert sanitized == { + "model": "gemini-1.5-pro-latest", + "temperature": 0.0, + } + class TestListModels: async def test_list_models(self, agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock):