Skip to content

Commit 606f66b

Browse files
committed
pass instructions and requires_tools to /models
1 parent 2073c37 commit 606f66b

File tree

4 files changed

+315
-2
lines changed

4 files changed

+315
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "workflowai"
3-
version = "0.6.0.dev17"
3+
version = "0.6.0.dev18"
44
description = ""
55
authors = ["Guillaume Aquilina <[email protected]>"]
66
readme = "README.md"

workflowai/core/client/_models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ class ListModelsResponse(Page[ModelInfo]):
202202
"""Response from the list models API endpoint."""
203203

204204

205+
class ListModelsRequest(BaseModel):
206+
instructions: Optional[str] = Field(default=None, description="Used to detect internal tools")
207+
requires_tools: Optional[bool] = Field(default=False, description="Whether the agent uses external tools")
208+
209+
205210
class CompletionsResponse(BaseModel):
206211
"""Response from the completions API endpoint."""
207212

workflowai/core/client/agent.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
CompletionsResponse,
1313
CreateAgentRequest,
1414
CreateAgentResponse,
15+
ListModelsRequest,
1516
ListModelsResponse,
1617
ModelInfo,
1718
ReplyRequest,
@@ -486,12 +487,22 @@ async def list_models(self) -> list[ModelInfo]:
486487
Raises:
487488
ValueError: If the agent has not been registered (schema_id is None).
488489
"""
490+
489491
if not self.schema_id:
490492
self.schema_id = await self.register()
491493

492-
response = await self.api.get(
494+
data = ListModelsRequest()
495+
if self.version and isinstance(self.version, VersionProperties):
496+
data.instructions = self.version.instructions
497+
498+
if self._tools:
499+
for _ in self._tools.values():
500+
data.requires_tools = True
501+
502+
response = await self.api.post(
493503
# The "_" refers to the currently authenticated tenant's namespace
494504
f"/v1/_/agents/{self.agent_id}/schemas/{self.schema_id}/models",
505+
data=data,
495506
returns=ListModelsResponse,
496507
)
497508
return response.items

workflowai/core/client/agent_test.py

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,49 @@ def agent(api_client: APIClient):
3535
return Agent(agent_id="123", schema_id=1, input_cls=HelloTaskInput, output_cls=HelloTaskOutput, api=api_client)
3636

3737

38+
@pytest.fixture
39+
def agent_with_instructions(api_client: APIClient):
40+
return Agent(
41+
agent_id="123",
42+
schema_id=1,
43+
input_cls=HelloTaskInput,
44+
output_cls=HelloTaskOutput,
45+
api=api_client,
46+
version=VersionProperties(instructions="Some instructions"),
47+
)
48+
49+
50+
@pytest.fixture
51+
def agent_with_tools(api_client: APIClient):
52+
def some_tool() -> str:
53+
return "Hello, world!"
54+
55+
return Agent(
56+
agent_id="123",
57+
schema_id=1,
58+
input_cls=HelloTaskInput,
59+
output_cls=HelloTaskOutput,
60+
api=api_client,
61+
tools=[some_tool],
62+
)
63+
64+
65+
@pytest.fixture
66+
def agent_with_tools_and_instructions(api_client: APIClient):
67+
def some_tool() -> str:
68+
return "Hello, world!"
69+
70+
return Agent(
71+
agent_id="123",
72+
schema_id=1,
73+
input_cls=HelloTaskInput,
74+
output_cls=HelloTaskOutput,
75+
api=api_client,
76+
version=VersionProperties(instructions="Some instructions"),
77+
tools=[some_tool],
78+
)
79+
80+
3881
@pytest.fixture
3982
def agent_not_optional(api_client: APIClient):
4083
return Agent(
@@ -463,7 +506,11 @@ async def test_list_models(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_
463506
# Verify the HTTP request was made correctly
464507
request = httpx_mock.get_request()
465508
assert request is not None, "Expected an HTTP request to be made"
509+
assert request.method == "POST"
466510
assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models"
511+
assert json.loads(request.content) == {
512+
"requires_tools": False,
513+
}
467514

468515
# Verify we get back the full ModelInfo objects
469516
assert len(models) == 2
@@ -530,7 +577,11 @@ async def test_list_models_registers_if_needed(
530577
reqs = httpx_mock.get_requests()
531578
assert len(reqs) == 2
532579
assert reqs[0].url == "http://localhost:8000/v1/_/agents"
580+
assert reqs[1].method == "POST"
533581
assert reqs[1].url == "http://localhost:8000/v1/_/agents/123/schemas/2/models"
582+
assert json.loads(reqs[1].content) == {
583+
"requires_tools": False,
584+
}
534585

535586
# Verify we get back the full ModelInfo object
536587
assert len(models) == 1
@@ -542,6 +593,252 @@ async def test_list_models_registers_if_needed(
542593
assert models[0].metadata.provider_name == "OpenAI"
543594

544595

596+
@pytest.mark.asyncio
597+
async def test_list_models_with_instructions(
598+
agent_with_instructions: Agent[HelloTaskInput, HelloTaskOutput],
599+
httpx_mock: HTTPXMock,
600+
):
601+
"""Test that list_models correctly fetches and returns available models."""
602+
# Mock the HTTP response instead of the API client method
603+
httpx_mock.add_response(
604+
method="POST",
605+
url="http://localhost:8000/v1/_/agents/123/schemas/1/models",
606+
json={
607+
"items": [
608+
{
609+
"id": "gpt-4",
610+
"name": "GPT-4",
611+
"icon_url": "https://example.com/gpt4.png",
612+
"modes": ["chat"],
613+
"is_not_supported_reason": None,
614+
"average_cost_per_run_usd": 0.01,
615+
"is_latest": True,
616+
"metadata": {
617+
"provider_name": "OpenAI",
618+
"price_per_input_token_usd": 0.0001,
619+
"price_per_output_token_usd": 0.0002,
620+
"release_date": "2024-01-01",
621+
"context_window_tokens": 128000,
622+
"quality_index": 0.95,
623+
},
624+
"is_default": True,
625+
"providers": ["openai"],
626+
},
627+
{
628+
"id": "claude-3",
629+
"name": "Claude 3",
630+
"icon_url": "https://example.com/claude3.png",
631+
"modes": ["chat"],
632+
"is_not_supported_reason": None,
633+
"average_cost_per_run_usd": 0.015,
634+
"is_latest": True,
635+
"metadata": {
636+
"provider_name": "Anthropic",
637+
"price_per_input_token_usd": 0.00015,
638+
"price_per_output_token_usd": 0.00025,
639+
"release_date": "2024-03-01",
640+
"context_window_tokens": 200000,
641+
"quality_index": 0.98,
642+
},
643+
"is_default": False,
644+
"providers": ["anthropic"],
645+
},
646+
],
647+
"count": 2,
648+
},
649+
)
650+
651+
# Call the method
652+
models = await agent_with_instructions.list_models()
653+
654+
# Verify the HTTP request was made correctly
655+
request = httpx_mock.get_request()
656+
assert request is not None, "Expected an HTTP request to be made"
657+
assert request.method == "POST"
658+
assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models"
659+
assert json.loads(request.content) == {"instructions": "Some instructions", "requires_tools": False}
660+
661+
# Verify we get back the full ModelInfo objects
662+
assert len(models) == 2
663+
assert isinstance(models[0], ModelInfo)
664+
assert models[0].id == "gpt-4"
665+
assert models[0].name == "GPT-4"
666+
assert models[0].modes == ["chat"]
667+
assert models[0].metadata is not None
668+
assert models[0].metadata.provider_name == "OpenAI"
669+
670+
assert isinstance(models[1], ModelInfo)
671+
assert models[1].id == "claude-3"
672+
assert models[1].name == "Claude 3"
673+
assert models[1].modes == ["chat"]
674+
assert models[1].metadata is not None
675+
assert models[1].metadata.provider_name == "Anthropic"
676+
677+
678+
@pytest.mark.asyncio
679+
async def test_list_models_with_tools(
680+
agent_with_tools: Agent[HelloTaskInput, HelloTaskOutput],
681+
httpx_mock: HTTPXMock,
682+
):
683+
"""Test that list_models correctly fetches and returns available models."""
684+
# Mock the HTTP response instead of the API client method
685+
httpx_mock.add_response(
686+
method="POST",
687+
url="http://localhost:8000/v1/_/agents/123/schemas/1/models",
688+
json={
689+
"items": [
690+
{
691+
"id": "gpt-4",
692+
"name": "GPT-4",
693+
"icon_url": "https://example.com/gpt4.png",
694+
"modes": ["chat"],
695+
"is_not_supported_reason": None,
696+
"average_cost_per_run_usd": 0.01,
697+
"is_latest": True,
698+
"metadata": {
699+
"provider_name": "OpenAI",
700+
"price_per_input_token_usd": 0.0001,
701+
"price_per_output_token_usd": 0.0002,
702+
"release_date": "2024-01-01",
703+
"context_window_tokens": 128000,
704+
"quality_index": 0.95,
705+
},
706+
"is_default": True,
707+
"providers": ["openai"],
708+
},
709+
{
710+
"id": "claude-3",
711+
"name": "Claude 3",
712+
"icon_url": "https://example.com/claude3.png",
713+
"modes": ["chat"],
714+
"is_not_supported_reason": None,
715+
"average_cost_per_run_usd": 0.015,
716+
"is_latest": True,
717+
"metadata": {
718+
"provider_name": "Anthropic",
719+
"price_per_input_token_usd": 0.00015,
720+
"price_per_output_token_usd": 0.00025,
721+
"release_date": "2024-03-01",
722+
"context_window_tokens": 200000,
723+
"quality_index": 0.98,
724+
},
725+
"is_default": False,
726+
"providers": ["anthropic"],
727+
},
728+
],
729+
"count": 2,
730+
},
731+
)
732+
733+
# Call the method
734+
models = await agent_with_tools.list_models()
735+
736+
# Verify the HTTP request was made correctly
737+
request = httpx_mock.get_request()
738+
assert request is not None, "Expected an HTTP request to be made"
739+
assert request.method == "POST"
740+
assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models"
741+
assert json.loads(request.content) == {"requires_tools": True}
742+
743+
# Verify we get back the full ModelInfo objects
744+
assert len(models) == 2
745+
assert isinstance(models[0], ModelInfo)
746+
assert models[0].id == "gpt-4"
747+
assert models[0].name == "GPT-4"
748+
assert models[0].modes == ["chat"]
749+
assert models[0].metadata is not None
750+
assert models[0].metadata.provider_name == "OpenAI"
751+
752+
assert isinstance(models[1], ModelInfo)
753+
assert models[1].id == "claude-3"
754+
assert models[1].name == "Claude 3"
755+
assert models[1].modes == ["chat"]
756+
assert models[1].metadata is not None
757+
assert models[1].metadata.provider_name == "Anthropic"
758+
759+
760+
@pytest.mark.asyncio
761+
async def test_list_models_with_instructions_and_tools(
762+
agent_with_tools_and_instructions: Agent[HelloTaskInput, HelloTaskOutput],
763+
httpx_mock: HTTPXMock,
764+
):
765+
"""Test that list_models correctly fetches and returns available models."""
766+
# Mock the HTTP response instead of the API client method
767+
httpx_mock.add_response(
768+
method="POST",
769+
url="http://localhost:8000/v1/_/agents/123/schemas/1/models",
770+
json={
771+
"items": [
772+
{
773+
"id": "gpt-4",
774+
"name": "GPT-4",
775+
"icon_url": "https://example.com/gpt4.png",
776+
"modes": ["chat"],
777+
"is_not_supported_reason": None,
778+
"average_cost_per_run_usd": 0.01,
779+
"is_latest": True,
780+
"metadata": {
781+
"provider_name": "OpenAI",
782+
"price_per_input_token_usd": 0.0001,
783+
"price_per_output_token_usd": 0.0002,
784+
"release_date": "2024-01-01",
785+
"context_window_tokens": 128000,
786+
"quality_index": 0.95,
787+
},
788+
"is_default": True,
789+
"providers": ["openai"],
790+
},
791+
{
792+
"id": "claude-3",
793+
"name": "Claude 3",
794+
"icon_url": "https://example.com/claude3.png",
795+
"modes": ["chat"],
796+
"is_not_supported_reason": None,
797+
"average_cost_per_run_usd": 0.015,
798+
"is_latest": True,
799+
"metadata": {
800+
"provider_name": "Anthropic",
801+
"price_per_input_token_usd": 0.00015,
802+
"price_per_output_token_usd": 0.00025,
803+
"release_date": "2024-03-01",
804+
"context_window_tokens": 200000,
805+
"quality_index": 0.98,
806+
},
807+
"is_default": False,
808+
"providers": ["anthropic"],
809+
},
810+
],
811+
"count": 2,
812+
},
813+
)
814+
815+
# Call the method
816+
models = await agent_with_tools_and_instructions.list_models()
817+
818+
# Verify the HTTP request was made correctly
819+
request = httpx_mock.get_request()
820+
assert request is not None, "Expected an HTTP request to be made"
821+
assert request.method == "POST"
822+
assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models"
823+
assert json.loads(request.content) == {"instructions": "Some instructions", "requires_tools": True}
824+
825+
# Verify we get back the full ModelInfo objects
826+
assert len(models) == 2
827+
assert isinstance(models[0], ModelInfo)
828+
assert models[0].id == "gpt-4"
829+
assert models[0].name == "GPT-4"
830+
assert models[0].modes == ["chat"]
831+
assert models[0].metadata is not None
832+
assert models[0].metadata.provider_name == "OpenAI"
833+
834+
assert isinstance(models[1], ModelInfo)
835+
assert models[1].id == "claude-3"
836+
assert models[1].name == "Claude 3"
837+
assert models[1].modes == ["chat"]
838+
assert models[1].metadata is not None
839+
assert models[1].metadata.provider_name == "Anthropic"
840+
841+
545842
class TestFetchCompletions:
546843
async def test_fetch_completions(self, agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock):
547844
"""Test that fetch_completions correctly fetches and returns completions."""

0 commit comments

Comments
 (0)