Skip to content

Commit 48ecb43

Browse files
authored
Merge pull request #57 from WorkflowAI/feature/add-audio-field-support
feat(models): add ModelInfo and list_models functionality
2 parents 386c244 + a08d5b4 commit 48ecb43

File tree

7 files changed

+322
-1
lines changed

7 files changed

+322
-1
lines changed

tests/e2e/assets/call.mp3

428 KB
Binary file not shown.

tests/e2e/audio_models_test.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""
2+
This test verifies model availability for audio processing tasks.
3+
It checks which models support audio processing and which don't,
4+
ensuring proper handling of unsupported models.
5+
"""
6+
7+
import base64
8+
import os
9+
10+
import pytest
11+
from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType]
12+
13+
import workflowai
14+
from workflowai import Model, Run
15+
from workflowai.fields import Audio
16+
17+
18+
class AudioInput(BaseModel):
19+
"""Input containing the audio file to analyze."""
20+
audio: Audio = Field(
21+
description="The audio recording to analyze for spam/robocall detection",
22+
)
23+
24+
25+
class SpamIndicator(BaseModel):
26+
"""A specific indicator that suggests the call might be spam."""
27+
description: str = Field(
28+
description="Description of the spam indicator found in the audio",
29+
examples=[
30+
"Uses urgency to pressure the listener",
31+
"Mentions winning a prize without entering a contest",
32+
"Automated/robotic voice detected",
33+
],
34+
)
35+
quote: str = Field(
36+
description="The exact quote or timestamp where this indicator appears",
37+
examples=[
38+
"'You must act now before it's too late'",
39+
"'You've been selected as our prize winner'",
40+
"0:05-0:15 - Synthetic voice pattern detected",
41+
],
42+
)
43+
44+
45+
class AudioClassification(BaseModel):
46+
"""Output containing the spam classification results."""
47+
is_spam: bool = Field(
48+
description="Whether the audio is classified as spam/robocall",
49+
)
50+
confidence_score: float = Field(
51+
description="Confidence score for the classification (0.0 to 1.0)",
52+
ge=0.0,
53+
le=1.0,
54+
)
55+
spam_indicators: list[SpamIndicator] = Field(
56+
default_factory=list,
57+
description="List of specific indicators that suggest this is spam",
58+
)
59+
reasoning: str = Field(
60+
description="Detailed explanation of why this was classified as spam or legitimate",
61+
)
62+
63+
64+
@workflowai.agent(
65+
id="audio-spam-detector",
66+
model=Model.GEMINI_1_5_FLASH_LATEST,
67+
)
68+
async def classify_audio(audio_input: AudioInput) -> Run[AudioClassification]:
69+
"""
70+
Analyze the audio recording to determine if it's a spam/robocall.
71+
72+
Guidelines:
73+
1. Listen for common spam/robocall indicators:
74+
- Use of urgency or pressure tactics
75+
- Unsolicited offers or prizes
76+
- Automated/synthetic voices
77+
- Requests for personal/financial information
78+
- Impersonation of legitimate organizations
79+
80+
2. Consider both content and delivery:
81+
- What is being said (transcribe key parts)
82+
- How it's being said (tone, pacing, naturalness)
83+
- Background noise and call quality
84+
85+
3. Provide clear reasoning:
86+
- Cite specific examples from the audio
87+
- Explain confidence level
88+
- Note any uncertainty
89+
"""
90+
...
91+
92+
93+
@pytest.fixture
94+
def audio_file() -> Audio:
95+
"""Load the test audio file."""
96+
current_dir = os.path.dirname(os.path.abspath(__file__))
97+
audio_path = os.path.join(current_dir, "assets", "call.mp3")
98+
99+
if not os.path.exists(audio_path):
100+
raise FileNotFoundError(
101+
f"Audio file not found at {audio_path}. "
102+
"Please make sure you have the example audio file in the correct location.",
103+
)
104+
105+
with open(audio_path, "rb") as f:
106+
audio_data = f.read()
107+
108+
return Audio(
109+
content_type="audio/mp3",
110+
data=base64.b64encode(audio_data).decode(),
111+
)

workflowai/core/client/_models.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Literal, Optional, Union
1+
from typing import Any, Generic, Literal, Optional, TypeVar, Union
22

33
from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType]
44
from typing_extensions import NotRequired, TypedDict
@@ -156,3 +156,42 @@ class CreateAgentRequest(BaseModel):
156156
class CreateAgentResponse(BaseModel):
157157
id: str
158158
schema_id: int
159+
160+
161+
class ModelMetadata(BaseModel):
162+
"""Metadata for a model."""
163+
provider_name: str = Field(description="Name of the model provider")
164+
price_per_input_token_usd: Optional[float] = Field(None, description="Cost per input token in USD")
165+
price_per_output_token_usd: Optional[float] = Field(None, description="Cost per output token in USD")
166+
release_date: Optional[str] = Field(None, description="Release date of the model")
167+
context_window_tokens: Optional[int] = Field(None, description="Size of the context window in tokens")
168+
quality_index: Optional[float] = Field(None, description="Quality index of the model")
169+
170+
171+
class ModelInfo(BaseModel):
172+
"""Information about a model."""
173+
id: str = Field(description="Unique identifier for the model")
174+
name: str = Field(description="Display name of the model")
175+
icon_url: Optional[str] = Field(None, description="URL for the model's icon")
176+
modes: list[str] = Field(default_factory=list, description="Supported modes for this model")
177+
is_not_supported_reason: Optional[str] = Field(
178+
None,
179+
description="Reason why the model is not supported, if applicable",
180+
)
181+
average_cost_per_run_usd: Optional[float] = Field(None, description="Average cost per run in USD")
182+
is_latest: bool = Field(default=False, description="Whether this is the latest version of the model")
183+
metadata: Optional[ModelMetadata] = Field(None, description="Additional metadata about the model")
184+
is_default: bool = Field(default=False, description="Whether this is the default model")
185+
providers: list[str] = Field(default_factory=list, description="List of providers that offer this model")
186+
187+
188+
T = TypeVar("T")
189+
190+
class Page(BaseModel, Generic[T]):
191+
"""A generic paginated response."""
192+
items: list[T] = Field(description="List of items in this page")
193+
count: Optional[int] = Field(None, description="Total number of items available")
194+
195+
196+
class ListModelsResponse(Page[ModelInfo]):
197+
"""Response from the list models API endpoint."""

workflowai/core/client/agent.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from workflowai.core.client._models import (
1212
CreateAgentRequest,
1313
CreateAgentResponse,
14+
ListModelsResponse,
15+
ModelInfo,
1416
ReplyRequest,
1517
RunRequest,
1618
RunResponse,
@@ -469,3 +471,22 @@ async def reply(
469471
def _sanitize_validator(cls, kwargs: RunParams[AgentOutput], default: OutputValidator[AgentOutput]):
470472
validator = kwargs.pop("validator", default)
471473
return validator, cast(BaseRunParams, kwargs)
474+
475+
async def list_models(self) -> list[ModelInfo]:
476+
"""Fetch the list of available models from the API for this agent.
477+
478+
Returns:
479+
list[ModelInfo]: List of available models with their full information.
480+
481+
Raises:
482+
ValueError: If the agent has not been registered (schema_id is None).
483+
"""
484+
if not self.schema_id:
485+
self.schema_id = await self.register()
486+
487+
response = await self.api.get(
488+
# The "_" refers to the currently authenticated tenant's namespace
489+
f"/v1/_/agents/{self.agent_id}/schemas/{self.schema_id}/models",
490+
returns=ListModelsResponse,
491+
)
492+
return response.items

workflowai/core/client/agent_test.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from tests.utils import fixtures_json
1515
from workflowai.core.client._api import APIClient
16+
from workflowai.core.client._models import ModelInfo
1617
from workflowai.core.client.agent import Agent
1718
from workflowai.core.client.client import (
1819
WorkflowAI,
@@ -367,3 +368,139 @@ def test_version_properties_with_model(self, agent: Agent[HelloTaskInput, HelloT
367368
def test_version_with_models_and_version(self, agent: Agent[HelloTaskInput, HelloTaskOutput]):
368369
# If version is explcitly provided then it takes priority and we log a warning
369370
assert agent._sanitize_version({"version": "staging", "model": "gemini-1.5-pro-latest"}) == "staging" # pyright: ignore [reportPrivateUsage]
371+
372+
373+
@pytest.mark.asyncio
374+
async def test_list_models(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock):
375+
"""Test that list_models correctly fetches and returns available models."""
376+
# Mock the HTTP response instead of the API client method
377+
httpx_mock.add_response(
378+
url="http://localhost:8000/v1/_/agents/123/schemas/1/models",
379+
json={
380+
"items": [
381+
{
382+
"id": "gpt-4",
383+
"name": "GPT-4",
384+
"icon_url": "https://example.com/gpt4.png",
385+
"modes": ["chat"],
386+
"is_not_supported_reason": None,
387+
"average_cost_per_run_usd": 0.01,
388+
"is_latest": True,
389+
"metadata": {
390+
"provider_name": "OpenAI",
391+
"price_per_input_token_usd": 0.0001,
392+
"price_per_output_token_usd": 0.0002,
393+
"release_date": "2024-01-01",
394+
"context_window_tokens": 128000,
395+
"quality_index": 0.95,
396+
},
397+
"is_default": True,
398+
"providers": ["openai"],
399+
},
400+
{
401+
"id": "claude-3",
402+
"name": "Claude 3",
403+
"icon_url": "https://example.com/claude3.png",
404+
"modes": ["chat"],
405+
"is_not_supported_reason": None,
406+
"average_cost_per_run_usd": 0.015,
407+
"is_latest": True,
408+
"metadata": {
409+
"provider_name": "Anthropic",
410+
"price_per_input_token_usd": 0.00015,
411+
"price_per_output_token_usd": 0.00025,
412+
"release_date": "2024-03-01",
413+
"context_window_tokens": 200000,
414+
"quality_index": 0.98,
415+
},
416+
"is_default": False,
417+
"providers": ["anthropic"],
418+
},
419+
],
420+
"count": 2,
421+
},
422+
)
423+
424+
# Call the method
425+
models = await agent.list_models()
426+
427+
# Verify the HTTP request was made correctly
428+
request = httpx_mock.get_request()
429+
assert request is not None, "Expected an HTTP request to be made"
430+
assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models"
431+
432+
# Verify we get back the full ModelInfo objects
433+
assert len(models) == 2
434+
assert isinstance(models[0], ModelInfo)
435+
assert models[0].id == "gpt-4"
436+
assert models[0].name == "GPT-4"
437+
assert models[0].modes == ["chat"]
438+
assert models[0].metadata is not None
439+
assert models[0].metadata.provider_name == "OpenAI"
440+
441+
assert isinstance(models[1], ModelInfo)
442+
assert models[1].id == "claude-3"
443+
assert models[1].name == "Claude 3"
444+
assert models[1].modes == ["chat"]
445+
assert models[1].metadata is not None
446+
assert models[1].metadata.provider_name == "Anthropic"
447+
448+
449+
@pytest.mark.asyncio
450+
async def test_list_models_registers_if_needed(
451+
agent_no_schema: Agent[HelloTaskInput, HelloTaskOutput],
452+
httpx_mock: HTTPXMock,
453+
):
454+
"""Test that list_models registers the agent if it hasn't been registered yet."""
455+
# Mock the registration response
456+
httpx_mock.add_response(
457+
url="http://localhost:8000/v1/_/agents",
458+
json={"id": "123", "schema_id": 2},
459+
)
460+
461+
# Mock the models response with the new structure
462+
httpx_mock.add_response(
463+
url="http://localhost:8000/v1/_/agents/123/schemas/2/models",
464+
json={
465+
"items": [
466+
{
467+
"id": "gpt-4",
468+
"name": "GPT-4",
469+
"icon_url": "https://example.com/gpt4.png",
470+
"modes": ["chat"],
471+
"is_not_supported_reason": None,
472+
"average_cost_per_run_usd": 0.01,
473+
"is_latest": True,
474+
"metadata": {
475+
"provider_name": "OpenAI",
476+
"price_per_input_token_usd": 0.0001,
477+
"price_per_output_token_usd": 0.0002,
478+
"release_date": "2024-01-01",
479+
"context_window_tokens": 128000,
480+
"quality_index": 0.95,
481+
},
482+
"is_default": True,
483+
"providers": ["openai"],
484+
},
485+
],
486+
"count": 1,
487+
},
488+
)
489+
490+
# Call the method
491+
models = await agent_no_schema.list_models()
492+
493+
# Verify both API calls were made
494+
reqs = httpx_mock.get_requests()
495+
assert len(reqs) == 2
496+
assert reqs[0].url == "http://localhost:8000/v1/_/agents"
497+
assert reqs[1].url == "http://localhost:8000/v1/_/agents/123/schemas/2/models"
498+
499+
# Verify we get back the full ModelInfo object
500+
assert len(models) == 1
501+
assert isinstance(models[0], ModelInfo)
502+
assert models[0].id == "gpt-4"
503+
assert models[0].name == "GPT-4"
504+
assert models[0].modes == ["chat"]
505+
assert models[0].metadata is not None
506+
assert models[0].metadata.provider_name == "OpenAI"

workflowai/core/fields/audio.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""Audio field for handling audio file inputs."""
2+
3+
from workflowai.core.fields.file import File
4+
5+
6+
class Audio(File):
7+
"""A field representing an audio file.
8+
9+
This field is used to handle audio inputs in various formats (MP3, WAV, etc.).
10+
The audio can be provided either as base64-encoded data or as a URL.
11+
"""
12+
pass

workflowai/fields.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from workflowai.core.fields.audio import Audio as Audio
12
from workflowai.core.fields.chat_message import ChatMessage as ChatMessage
23
from workflowai.core.fields.email_address import EmailAddressStr as EmailAddressStr
34
from workflowai.core.fields.file import File as File

0 commit comments

Comments
 (0)