Skip to content

Commit fdbe09a

Browse files
committed
add override params to list_models
1 parent 30232d8 commit fdbe09a

File tree

3 files changed

+185
-12
lines changed

3 files changed

+185
-12
lines changed

workflowai/core/client/agent.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,11 @@ def _sanitize_validator(cls, kwargs: RunParams[AgentOutput], default: OutputVali
478478
validator = kwargs.pop("validator", default)
479479
return validator, cast(BaseRunParams, kwargs)
480480

481-
async def list_models(self) -> list[ModelInfo]:
481+
async def list_models(
482+
self,
483+
instructions: Optional[str] = None,
484+
requires_tools: Optional[bool] = None,
485+
) -> list[ModelInfo]:
482486
"""Fetch the list of available models from the API for this agent.
483487
484488
Returns:
@@ -491,18 +495,18 @@ async def list_models(self) -> list[ModelInfo]:
491495
if not self.schema_id:
492496
self.schema_id = await self.register()
493497

494-
data = ListModelsRequest()
495-
if self.version and isinstance(self.version, VersionProperties):
496-
data.instructions = self.version.instructions
498+
request_data = ListModelsRequest(instructions=instructions, requires_tools=requires_tools)
497499

498-
if self._tools:
499-
for _ in self._tools.values():
500-
data.requires_tools = True
500+
if instructions is None and self.version and isinstance(self.version, VersionProperties):
501+
request_data.instructions = self.version.instructions
502+
503+
if requires_tools is None and self._tools:
504+
request_data.requires_tools = True
501505

502506
response = await self.api.post(
503507
# The "_" refers to the currently authenticated tenant's namespace
504508
f"/v1/_/agents/{self.agent_id}/schemas/{self.schema_id}/models",
505-
data=data,
509+
data=request_data,
506510
returns=ListModelsResponse,
507511
)
508512
return response.items

workflowai/core/client/agent_test.py

Lines changed: 168 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -503,12 +503,178 @@ async def test_list_models(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_
503503
# Call the method
504504
models = await agent.list_models()
505505

506+
# Verify the HTTP request was made correctly
507+
request = httpx_mock.get_request()
508+
assert request is not None, "Expected an HTTP request to be made"
509+
assert request.method == "POST"
510+
assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models"
511+
assert json.loads(request.content) == {}
512+
513+
# Verify we get back the full ModelInfo objects
514+
assert len(models) == 2
515+
assert isinstance(models[0], ModelInfo)
516+
assert models[0].id == "gpt-4"
517+
assert models[0].name == "GPT-4"
518+
assert models[0].modes == ["chat"]
519+
assert models[0].metadata is not None
520+
assert models[0].metadata.provider_name == "OpenAI"
521+
522+
assert isinstance(models[1], ModelInfo)
523+
assert models[1].id == "claude-3"
524+
assert models[1].name == "Claude 3"
525+
assert models[1].modes == ["chat"]
526+
assert models[1].metadata is not None
527+
assert models[1].metadata.provider_name == "Anthropic"
528+
529+
530+
@pytest.mark.asyncio
531+
async def test_list_models_with_params_override(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock):
532+
"""Test that list_models correctly fetches and returns available models."""
533+
# Mock the HTTP response instead of the API client method
534+
httpx_mock.add_response(
535+
url="http://localhost:8000/v1/_/agents/123/schemas/1/models",
536+
json={
537+
"items": [
538+
{
539+
"id": "gpt-4",
540+
"name": "GPT-4",
541+
"icon_url": "https://example.com/gpt4.png",
542+
"modes": ["chat"],
543+
"is_not_supported_reason": None,
544+
"average_cost_per_run_usd": 0.01,
545+
"is_latest": True,
546+
"metadata": {
547+
"provider_name": "OpenAI",
548+
"price_per_input_token_usd": 0.0001,
549+
"price_per_output_token_usd": 0.0002,
550+
"release_date": "2024-01-01",
551+
"context_window_tokens": 128000,
552+
"quality_index": 0.95,
553+
},
554+
"is_default": True,
555+
"providers": ["openai"],
556+
},
557+
{
558+
"id": "claude-3",
559+
"name": "Claude 3",
560+
"icon_url": "https://example.com/claude3.png",
561+
"modes": ["chat"],
562+
"is_not_supported_reason": None,
563+
"average_cost_per_run_usd": 0.015,
564+
"is_latest": True,
565+
"metadata": {
566+
"provider_name": "Anthropic",
567+
"price_per_input_token_usd": 0.00015,
568+
"price_per_output_token_usd": 0.00025,
569+
"release_date": "2024-03-01",
570+
"context_window_tokens": 200000,
571+
"quality_index": 0.98,
572+
},
573+
"is_default": False,
574+
"providers": ["anthropic"],
575+
},
576+
],
577+
"count": 2,
578+
},
579+
)
580+
581+
# Call the method
582+
models = await agent.list_models(instructions="Some override instructions", requires_tools=True)
583+
584+
# Verify the HTTP request was made correctly
585+
request = httpx_mock.get_request()
586+
assert request is not None, "Expected an HTTP request to be made"
587+
assert request.method == "POST"
588+
assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models"
589+
assert json.loads(request.content) == {
590+
"instructions": "Some override instructions",
591+
"requires_tools": True,
592+
}
593+
594+
# Verify we get back the full ModelInfo objects
595+
assert len(models) == 2
596+
assert isinstance(models[0], ModelInfo)
597+
assert models[0].id == "gpt-4"
598+
assert models[0].name == "GPT-4"
599+
assert models[0].modes == ["chat"]
600+
assert models[0].metadata is not None
601+
assert models[0].metadata.provider_name == "OpenAI"
602+
603+
assert isinstance(models[1], ModelInfo)
604+
assert models[1].id == "claude-3"
605+
assert models[1].name == "Claude 3"
606+
assert models[1].modes == ["chat"]
607+
assert models[1].metadata is not None
608+
assert models[1].metadata.provider_name == "Anthropic"
609+
610+
611+
@pytest.mark.asyncio
612+
async def test_list_models_with_params_override_and_agent_with_tools_and_instructions(
613+
agent_with_tools_and_instructions: Agent[HelloTaskInput, HelloTaskOutput],
614+
httpx_mock: HTTPXMock,
615+
):
616+
"""Test that list_models correctly fetches and returns available models."""
617+
# Mock the HTTP response instead of the API client method
618+
httpx_mock.add_response(
619+
url="http://localhost:8000/v1/_/agents/123/schemas/1/models",
620+
json={
621+
"items": [
622+
{
623+
"id": "gpt-4",
624+
"name": "GPT-4",
625+
"icon_url": "https://example.com/gpt4.png",
626+
"modes": ["chat"],
627+
"is_not_supported_reason": None,
628+
"average_cost_per_run_usd": 0.01,
629+
"is_latest": True,
630+
"metadata": {
631+
"provider_name": "OpenAI",
632+
"price_per_input_token_usd": 0.0001,
633+
"price_per_output_token_usd": 0.0002,
634+
"release_date": "2024-01-01",
635+
"context_window_tokens": 128000,
636+
"quality_index": 0.95,
637+
},
638+
"is_default": True,
639+
"providers": ["openai"],
640+
},
641+
{
642+
"id": "claude-3",
643+
"name": "Claude 3",
644+
"icon_url": "https://example.com/claude3.png",
645+
"modes": ["chat"],
646+
"is_not_supported_reason": None,
647+
"average_cost_per_run_usd": 0.015,
648+
"is_latest": True,
649+
"metadata": {
650+
"provider_name": "Anthropic",
651+
"price_per_input_token_usd": 0.00015,
652+
"price_per_output_token_usd": 0.00025,
653+
"release_date": "2024-03-01",
654+
"context_window_tokens": 200000,
655+
"quality_index": 0.98,
656+
},
657+
"is_default": False,
658+
"providers": ["anthropic"],
659+
},
660+
],
661+
"count": 2,
662+
},
663+
)
664+
665+
# Call the method
666+
models = await agent_with_tools_and_instructions.list_models(
667+
instructions="Some override instructions",
668+
requires_tools=False,
669+
)
670+
506671
# Verify the HTTP request was made correctly
507672
request = httpx_mock.get_request()
508673
assert request is not None, "Expected an HTTP request to be made"
509674
assert request.method == "POST"
510675
assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models"
511676
assert json.loads(request.content) == {
677+
"instructions": "Some override instructions",
512678
"requires_tools": False,
513679
}
514680

@@ -579,9 +745,7 @@ async def test_list_models_registers_if_needed(
579745
assert reqs[0].url == "http://localhost:8000/v1/_/agents"
580746
assert reqs[1].method == "POST"
581747
assert reqs[1].url == "http://localhost:8000/v1/_/agents/123/schemas/2/models"
582-
assert json.loads(reqs[1].content) == {
583-
"requires_tools": False,
584-
}
748+
assert json.loads(reqs[1].content) == {}
585749

586750
# Verify we get back the full ModelInfo object
587751
assert len(models) == 1
@@ -656,7 +820,7 @@ async def test_list_models_with_instructions(
656820
assert request is not None, "Expected an HTTP request to be made"
657821
assert request.method == "POST"
658822
assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models"
659-
assert json.loads(request.content) == {"instructions": "Some instructions", "requires_tools": False}
823+
assert json.loads(request.content) == {"instructions": "Some instructions"}
660824

661825
# Verify we get back the full ModelInfo objects
662826
assert len(models) == 2

workflowai/core/domain/run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ async def fetch_completions(self) -> list[Completion]:
150150

151151

152152
class _AgentBase(Protocol, Generic[AgentOutput]):
153+
# TODO: fix circular dep
154+
from workflowai.core.client._models import ModelInfo
155+
153156
async def reply(
154157
self,
155158
run_id: str,
@@ -161,3 +164,5 @@ async def reply(
161164
...
162165

163166
async def fetch_completions(self, run_id: str) -> list[Completion]: ...
167+
168+
async def list_models(self) -> list[ModelInfo]: ...

0 commit comments

Comments
 (0)