Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "workflowai"
version = "0.6.0.dev14"
version = "0.6.0.dev15"
description = ""
authors = ["Guillaume Aquilina <[email protected]>"]
readme = "README.md"
Expand Down
27 changes: 15 additions & 12 deletions workflowai/core/client/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
self.schema_id = schema_id
self.input_cls = input_cls
self.output_cls = output_cls
self.version: VersionReference = version or global_default_version_reference()
self.version = version
self._api = (lambda: api) if isinstance(api, APIClient) else api
self._tools = self.build_tools(tools) if tools else None

Expand All @@ -118,24 +118,27 @@ class _PreparedRun(NamedTuple):
schema_id: int

def _sanitize_version(self, params: VersionRunParams) -> Union[str, int, dict[str, Any]]:
version = params.get("version")
"""Combine a version requested at runtime and the version requested at build time."""
version = params.get("version", self.version)
model = params.get("model")
instructions = params.get("instructions")
temperature = params.get("temperature")

has_property_overrides = bool(model or instructions or temperature)
has_property_overrides = bool(model or instructions or temperature or self._tools)

if not version:
# If versions is not specified, we fill with the default agent version only if
# there are no additional properties
version = self.version if not has_property_overrides else VersionProperties()
if version and not isinstance(version, VersionProperties):
if not has_property_overrides and not self._tools:
return version
# In the case where the version requested a build time was a remote version
# (either an ID or an environment), we use an empty template for the version
logger.warning("Overriding remove version with a local one")
version = VersionProperties()

if not isinstance(version, VersionProperties):
if has_property_overrides or self._tools:
logger.warning("Property overrides are ignored when version is not a VersionProperties")
return version
if not version and not has_property_overrides:
g = global_default_version_reference()
return g.model_dump(by_alias=True, exclude_unset=True) if isinstance(g, VersionProperties) else g

dumped = version.model_dump(by_alias=True, exclude_unset=True)
dumped = version.model_dump(by_alias=True, exclude_unset=True) if version else {}

if not dumped.get("model"):
# We always provide a default model since it is required by the API
Expand Down
47 changes: 41 additions & 6 deletions workflowai/core/client/agent_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.metadata
import json
from unittest.mock import Mock, patch

import httpx
import pytest
Expand Down Expand Up @@ -346,28 +347,62 @@ class AliasOutput(BaseModel):


class TestSanitizeVersion:
def test_global_default(self, agent: Agent[HelloTaskInput, HelloTaskOutput]):
"""Test that the global default version is used when no version is provided"""
assert agent.version is None, "sanity"
assert agent._sanitize_version({}) == "production" # pyright: ignore [reportPrivateUsage]

@patch("workflowai.core.client.agent.global_default_version_reference")
def test_global_default_with_properties(
self,
mock_global_default: Mock,
agent: Agent[HelloTaskInput, HelloTaskOutput],
):
"""Check that a dict is returned when the global default version is a VersionProperties"""
mock_global_default.return_value = VersionProperties(model="gpt-4o")
assert agent._sanitize_version({}) == { # pyright: ignore [reportPrivateUsage]
"model": "gpt-4o",
}

def test_string_version(self, agent: Agent[HelloTaskInput, HelloTaskOutput]):
"""Check that a string is returned when the version is a string"""
assert agent._sanitize_version({"version": "production"}) == "production" # pyright: ignore [reportPrivateUsage]

def test_default_version(self, agent: Agent[HelloTaskInput, HelloTaskOutput]):
assert agent._sanitize_version({}) == "production" # pyright: ignore [reportPrivateUsage]
def test_override_remove_version(self, agent: Agent[HelloTaskInput, HelloTaskOutput]):
"""Check that a remote version is overridden by a local one"""
agent.version = "staging"
assert agent._sanitize_version({"model": "gpt-4o-latest"}) == { # pyright: ignore [reportPrivateUsage]
"model": "gpt-4o-latest",
}

def test_version_properties(self, agent: Agent[HelloTaskInput, HelloTaskOutput]):
"""Check that a dict is returned when the version is a VersionProperties"""
assert agent._sanitize_version({"version": VersionProperties(temperature=0.7)}) == { # pyright: ignore [reportPrivateUsage]
"temperature": 0.7,
"model": "gemini-1.5-pro-latest",
}

def test_version_properties_with_model(self, agent: Agent[HelloTaskInput, HelloTaskOutput]):
# When the default version is used and we pass the model, the model has priority
assert agent.version == "production", "sanity"
"""When the default version is used and we pass the model, the model has priority"""
assert agent.version is None, "sanity"
assert agent._sanitize_version({"model": "gemini-1.5-pro-latest"}) == { # pyright: ignore [reportPrivateUsage]
"model": "gemini-1.5-pro-latest",
}

def test_version_with_models_and_version(self, agent: Agent[HelloTaskInput, HelloTaskOutput]):
# If version is explcitly provided then it takes priority and we log a warning
assert agent._sanitize_version({"version": "staging", "model": "gemini-1.5-pro-latest"}) == "staging" # pyright: ignore [reportPrivateUsage]
"""If the runtime version is a remote version but a model is passed, we use an empty template for the version"""
assert agent._sanitize_version({"version": "staging", "model": "gemini-1.5-pro-latest"}) == { # pyright: ignore [reportPrivateUsage]
"model": "gemini-1.5-pro-latest",
}

def test_only_model_privider(self, agent: Agent[HelloTaskInput, HelloTaskOutput]):
"""Test that when an agent has instructions we use the instructions when overriding the model"""
agent.version = VersionProperties(model="gpt-4o", instructions="You are a helpful assistant.")
sanitized = agent._sanitize_version({"model": "gemini-1.5-pro-latest"}) # pyright: ignore [reportPrivateUsage]
assert sanitized == {
"model": "gemini-1.5-pro-latest",
"instructions": "You are a helpful assistant.",
}


@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion workflowai/core/domain/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __str__(self) -> str:

@property
def run_url(self):
return f"{env.WORKFLOWAI_APP_URL}/agents/{self.agent_id}/runs/{self.id}"
return f"{env.WORKFLOWAI_APP_URL}/_/agents/{self.agent_id}/runs/{self.id}"


class _AgentBase(Protocol, Generic[AgentOutput]):
Expand Down
2 changes: 1 addition & 1 deletion workflowai/core/domain/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,4 @@ def test_format_output_no_cost_latency() -> None:
class TestRunURL:
@patch("workflowai.env.WORKFLOWAI_APP_URL", "https://workflowai.hello")
def test_run_url(self, run1: Run[_TestOutput]):
assert run1.run_url == "https://workflowai.hello/agents/agent-1/runs/test-id"
assert run1.run_url == "https://workflowai.hello/_/agents/agent-1/runs/test-id"