Skip to content

Sanitize version #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "workflowai"
version = "0.6.0.dev21"
version = "0.6.0.dev22"
description = ""
authors = ["Guillaume Aquilina <[email protected]>"]
readme = "README.md"
Expand Down
1 change: 1 addition & 0 deletions workflowai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from workflowai.core.domain.model import Model as Model
from workflowai.core.domain.run import Run as Run
from workflowai.core.domain.version import Version as Version
from workflowai.core.domain.version_properties import VersionProperties as VersionProperties
from workflowai.core.domain.version_reference import (
VersionReference as VersionReference,
)
Expand Down
41 changes: 40 additions & 1 deletion workflowai/core/client/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
import asyncio
import os
import re
from collections.abc import Mapping
from json import JSONDecodeError
from time import time
from typing import Any
from typing import Any, NamedTuple, Optional, Union

from typing_extensions import Self

from workflowai.core._common_types import OutputValidator
from workflowai.core._logger import logger
from workflowai.core.domain.errors import BaseError, WorkflowAIError
from workflowai.core.domain.task import AgentOutput
from workflowai.core.domain.version_properties import VersionProperties
from workflowai.core.domain.version_reference import VersionReference
from workflowai.core.utils._pydantic import partial_model

Expand Down Expand Up @@ -113,3 +117,38 @@ def global_default_version_reference() -> VersionReference:
logger.warning("Invalid default version: %s", version)

return "production"


class ModelInstructionTemperature(NamedTuple):
"""A combination of run properties, with useful method
for combination"""

model: Optional[str] = None
instructions: Optional[str] = None
temperature: Optional[float] = None

@classmethod
def from_dict(cls, d: Mapping[str, Any]):
return cls(
model=d.get("model"),
instructions=d.get("instructions"),
temperature=d.get("temperature"),
)

@classmethod
def from_version(cls, version: Union[int, str, VersionProperties, None]):
if isinstance(version, VersionProperties):
return cls(
model=version.model,
instructions=version.instructions,
temperature=version.temperature,
)
return cls()

@classmethod
def combine(cls, *args: Self):
return cls(
model=next((a.model for a in args if a.model is not None), None),
instructions=next((a.instructions for a in args if a.instructions is not None), None),
temperature=next((a.temperature for a in args if a.temperature is not None), None),
)
37 changes: 24 additions & 13 deletions workflowai/core/client/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from workflowai.core.client._types import RunParams
from workflowai.core.client._utils import (
ModelInstructionTemperature,
build_retryable_wait,
default_validator,
global_default_version_reference,
Expand Down Expand Up @@ -123,32 +124,42 @@ class _PreparedRun(NamedTuple):

def _sanitize_version(self, params: VersionRunParams) -> Union[str, int, dict[str, Any]]:
"""Combine a version requested at runtime and the version requested at build time."""
# Version contains either the requested version or the default version
# this is important to combine the check below of whether the version is a remote version (e-g production)
# or a local version (VersionProperties)
version = params.get("version", self.version)
model = params.get("model")
instructions = params.get("instructions")
temperature = params.get("temperature")

has_property_overrides = bool(model or instructions or temperature or self._tools)
# Combine all overrides in a tuple
overrides = ModelInstructionTemperature.from_dict(params)
has_property_overrides = bool(self._tools or any(o is not None for o in overrides))

# Version exists and is a remote version
if version and not isinstance(version, VersionProperties):
# No property override so we return as is
if not has_property_overrides and not self._tools:
return version
# In the case where the version requested a build time was a remote version
# (either an ID or an environment), we use an empty template for the version
logger.warning("Overriding remove version with a local one")
logger.warning("Overriding remote version with a local one")
version = VersionProperties()

# Version does not exist and there are no overrides
# We return the default version
if not version and not has_property_overrides:
g = global_default_version_reference()
return g.model_dump(by_alias=True, exclude_unset=True) if isinstance(g, VersionProperties) else g

dumped = version.model_dump(by_alias=True, exclude_unset=True) if version else {}

if not dumped.get("model"):
requested = ModelInstructionTemperature.from_version(version)
defaults = ModelInstructionTemperature.from_version(self.version)
combined = ModelInstructionTemperature.combine(overrides, requested, defaults)

if not combined.model:
# We always provide a default model since it is required by the API
import workflowai

dumped["model"] = workflowai.DEFAULT_MODEL
combined = combined._replace(model=workflowai.DEFAULT_MODEL)

if self._tools:
dumped["enabled_tools"] = [
Expand All @@ -161,12 +172,12 @@ def _sanitize_version(self, params: VersionRunParams) -> Union[str, int, dict[st
for tool in self._tools.values()
]
# Finally we apply the property overrides
if model:
dumped["model"] = model
if instructions:
dumped["instructions"] = instructions
if temperature:
dumped["temperature"] = temperature
if combined.model is not None:
dumped["model"] = combined.model
if combined.instructions is not None:
dumped["instructions"] = combined.instructions
if combined.temperature is not None:
dumped["temperature"] = combined.temperature
return dumped

async def _prepare_run(self, agent_input: AgentInput, stream: bool, **kwargs: Unpack[RunParams[AgentOutput]]):
Expand Down
20 changes: 20 additions & 0 deletions workflowai/core/client/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,26 @@ def test_only_model_privider(self, agent: Agent[HelloTaskInput, HelloTaskOutput]
"instructions": "You are a helpful assistant.",
}

def test_with_explicit_version_without_instructions(self, agent: Agent[HelloTaskInput, HelloTaskOutput]):
"""In the case where the agent has instructions but we send a version without instructions,
we use the instructions from the agent"""

agent.version = VersionProperties(instructions="You are a helpful assistant.")
sanitized = agent._sanitize_version({"version": VersionProperties(model="gpt-4o-latest")}) # pyright: ignore [reportPrivateUsage]
assert sanitized == {
"model": "gpt-4o-latest",
"instructions": "You are a helpful assistant.",
}

def test_override_with_0_temperature(self, agent: Agent[HelloTaskInput, HelloTaskOutput]):
"""Test that a 0 temperature is not overridden by the default version"""
agent.version = VersionProperties(temperature=0.7)
sanitized = agent._sanitize_version({"version": VersionProperties(temperature=0)}) # pyright: ignore [reportPrivateUsage]
assert sanitized == {
"model": "gemini-1.5-pro-latest",
"temperature": 0.0,
}


class TestListModels:
async def test_list_models(self, agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock):
Expand Down