Skip to content

Commit 811e573

Browse files
committed
add unit tests
Signed-off-by: Jaideep Rao <[email protected]>
1 parent 06481a6 commit 811e573

File tree

2 files changed

+343
-6
lines changed

2 files changed

+343
-6
lines changed

src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from collections.abc import AsyncIterator
99
from typing import Any
1010

11-
from opentelemetry import trace
1211
from openai.types.chat import ChatCompletionToolParam
12+
from opentelemetry import trace
1313

1414
from llama_stack.log import get_logger
1515
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
@@ -260,7 +260,7 @@ async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
260260
self.server_label_to_tools,
261261
)
262262
# chat_tool_choice can be str, dict-like object, or None
263-
if isinstance(chat_tool_choice, str):
263+
if isinstance(chat_tool_choice, str | type(None)):
264264
self.ctx.chat_tool_choice = chat_tool_choice
265265
else:
266266
self.ctx.chat_tool_choice = chat_tool_choice.model_dump()
@@ -1364,12 +1364,11 @@ async def _process_tool_choice(
13641364
# ensure that specified tool choices are available in the chat tools, if not, remove them from the list
13651365
final_tools = []
13661366
for tool in responses_tool_choice.tools:
1367-
tool_name = tool.get("name")
13681367
match tool.get("type"):
13691368
case "function":
1370-
final_tools.append({"type": "function", "function": {"name": tool_name}})
1369+
final_tools.append({"type": "function", "function": {"name": tool.get("name")}})
13711370
case "custom":
1372-
final_tools.append({"type": "custom", "custom": {"name": tool_name}})
1371+
final_tools.append({"type": "custom", "custom": {"name": tool.get("name")}})
13731372
case "mcp":
13741373
mcp_tools = convert_mcp_tool_choice(
13751374
chat_tool_names, tool.get("server_label"), server_label_to_tools, None
@@ -1396,7 +1395,7 @@ async def _process_tool_choice(
13961395
else:
13971396
# Handle specific tool choice by type
13981397
# Each case validates the tool exists in chat_tools before returning
1399-
tool_name = responses_tool_choice.name if responses_tool_choice.name else None
1398+
tool_name = getattr(responses_tool_choice, "name", None)
14001399
match responses_tool_choice:
14011400
case OpenAIResponseInputToolChoiceCustomTool():
14021401
if tool_name and tool_name not in chat_tool_names:

tests/unit/providers/agents/meta_reference/test_response_tool_context.py

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,24 @@
55
# the root directory of this source tree.
66

77

8+
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
9+
_process_tool_choice,
10+
)
811
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
12+
from llama_stack_api.inference import (
13+
OpenAIChatCompletionToolChoiceAllowedTools,
14+
OpenAIChatCompletionToolChoiceCustomTool,
15+
OpenAIChatCompletionToolChoiceFunctionTool,
16+
)
917
from llama_stack_api.openai_responses import (
1018
MCPListToolsTool,
19+
OpenAIResponseInputToolChoiceAllowedTools,
20+
OpenAIResponseInputToolChoiceCustomTool,
21+
OpenAIResponseInputToolChoiceFileSearch,
22+
OpenAIResponseInputToolChoiceFunctionTool,
23+
OpenAIResponseInputToolChoiceMCPTool,
24+
OpenAIResponseInputToolChoiceMode,
25+
OpenAIResponseInputToolChoiceWebSearch,
1126
OpenAIResponseInputToolFileSearch,
1227
OpenAIResponseInputToolFunction,
1328
OpenAIResponseInputToolMCP,
@@ -181,3 +196,326 @@ def test_mismatched_allowed_tools(self):
181196
assert len(context.previous_tool_listings) == 1
182197
assert len(context.previous_tool_listings[0].tools) == 1
183198
assert context.previous_tool_listings[0].server_label == "anotherlabel"
199+
200+
201+
class TestProcessToolChoice:
202+
"""Comprehensive test suite for _process_tool_choice function."""
203+
204+
def setup_method(self):
205+
"""Set up common test fixtures."""
206+
self.chat_tools = [
207+
{"type": "function", "function": {"name": "get_weather"}},
208+
{"type": "function", "function": {"name": "calculate"}},
209+
{"type": "function", "function": {"name": "file_search"}},
210+
{"type": "function", "function": {"name": "web_search"}},
211+
]
212+
self.server_label_to_tools = {
213+
"mcp_server_1": ["mcp_tool_1", "mcp_tool_2"],
214+
"mcp_server_2": ["mcp_tool_3"],
215+
}
216+
217+
async def test_mode_auto(self):
218+
"""Test auto mode - should return 'auto' string."""
219+
tool_choice = OpenAIResponseInputToolChoiceMode.auto
220+
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
221+
assert result == "auto"
222+
223+
async def test_mode_none(self):
224+
"""Test none mode - should return 'none' string."""
225+
tool_choice = OpenAIResponseInputToolChoiceMode.none
226+
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
227+
assert result == "none"
228+
229+
async def test_mode_required_with_tools(self):
230+
"""Test required mode with available tools - should return AllowedTools with all function tools."""
231+
tool_choice = OpenAIResponseInputToolChoiceMode.required
232+
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
233+
234+
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
235+
assert result.allowed_tools.mode == "required"
236+
assert len(result.allowed_tools.tools) == 4
237+
tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools]
238+
assert "get_weather" in tool_names
239+
assert "calculate" in tool_names
240+
assert "file_search" in tool_names
241+
assert "web_search" in tool_names
242+
243+
async def test_mode_required_without_tools(self):
244+
"""Test required mode without available tools - should return None."""
245+
tool_choice = OpenAIResponseInputToolChoiceMode.required
246+
result = await _process_tool_choice([], tool_choice, self.server_label_to_tools)
247+
assert result is None
248+
249+
async def test_allowed_tools_function(self):
250+
"""Test allowed_tools with function tool types."""
251+
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
252+
mode="required",
253+
tools=[
254+
{"type": "function", "name": "get_weather"},
255+
{"type": "function", "name": "calculate"},
256+
],
257+
)
258+
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
259+
260+
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
261+
assert result.allowed_tools.mode == "required"
262+
assert len(result.allowed_tools.tools) == 2
263+
assert result.allowed_tools.tools[0]["function"]["name"] == "get_weather"
264+
assert result.allowed_tools.tools[1]["function"]["name"] == "calculate"
265+
266+
async def test_allowed_tools_custom(self):
267+
"""Test allowed_tools with custom tool types."""
268+
chat_tools = [{"type": "function", "function": {"name": "custom_tool_1"}}]
269+
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
270+
mode="auto",
271+
tools=[{"type": "custom", "name": "custom_tool_1"}],
272+
)
273+
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
274+
275+
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
276+
assert result.allowed_tools.mode == "auto"
277+
assert len(result.allowed_tools.tools) == 1
278+
assert result.allowed_tools.tools[0]["type"] == "custom"
279+
assert result.allowed_tools.tools[0]["custom"]["name"] == "custom_tool_1"
280+
281+
async def test_allowed_tools_file_search(self):
282+
"""Test allowed_tools with file_search."""
283+
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
284+
mode="required",
285+
tools=[{"type": "file_search"}],
286+
)
287+
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
288+
289+
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
290+
assert len(result.allowed_tools.tools) == 1
291+
assert result.allowed_tools.tools[0]["function"]["name"] == "file_search"
292+
293+
async def test_allowed_tools_web_search(self):
294+
"""Test allowed_tools with web_search."""
295+
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
296+
mode="required",
297+
tools=[
298+
{"type": "web_search_preview_2025_03_11"},
299+
{"type": "web_search_2025_08_26"},
300+
{"type": "web_search_preview"},
301+
],
302+
)
303+
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
304+
305+
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
306+
assert len(result.allowed_tools.tools) == 3
307+
assert result.allowed_tools.tools[0]["function"]["name"] == "web_search"
308+
assert result.allowed_tools.tools[0]["type"] == "function"
309+
assert result.allowed_tools.tools[1]["function"]["name"] == "web_search"
310+
assert result.allowed_tools.tools[1]["type"] == "function"
311+
assert result.allowed_tools.tools[2]["function"]["name"] == "web_search"
312+
assert result.allowed_tools.tools[2]["type"] == "function"
313+
314+
async def test_allowed_tools_mcp_server_label(self):
315+
"""Test allowed_tools with MCP server label (no specific tool name)."""
316+
chat_tools = [
317+
{"type": "function", "function": {"name": "mcp_tool_1"}},
318+
{"type": "function", "function": {"name": "mcp_tool_2"}},
319+
]
320+
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
321+
mode="required",
322+
tools=[{"type": "mcp", "server_label": "mcp_server_1"}],
323+
)
324+
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
325+
326+
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
327+
assert len(result.allowed_tools.tools) == 2
328+
tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools]
329+
assert "mcp_tool_1" in tool_names
330+
assert "mcp_tool_2" in tool_names
331+
332+
async def test_allowed_tools_mixed_types(self):
333+
"""Test allowed_tools with mixed tool types."""
334+
chat_tools = [
335+
{"type": "function", "function": {"name": "get_weather"}},
336+
{"type": "function", "function": {"name": "file_search"}},
337+
{"type": "function", "function": {"name": "web_search"}},
338+
{"type": "function", "function": {"name": "mcp_tool_1"}},
339+
]
340+
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
341+
mode="auto",
342+
tools=[
343+
{"type": "function", "name": "get_weather"},
344+
{"type": "file_search"},
345+
{"type": "web_search"},
346+
{"type": "mcp", "server_label": "mcp_server_1"},
347+
],
348+
)
349+
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
350+
351+
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
352+
# Should have: get_weather, file_search, web_search, mcp_tool_1, mcp_tool_2
353+
assert len(result.allowed_tools.tools) >= 3
354+
355+
async def test_allowed_tools_invalid_type(self):
356+
"""Test allowed_tools with an unsupported tool type - should skip it."""
357+
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
358+
mode="required",
359+
tools=[
360+
{"type": "function", "name": "get_weather"},
361+
{"type": "unsupported_type", "name": "bad_tool"},
362+
],
363+
)
364+
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
365+
366+
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
367+
# Should only include the valid function tool
368+
assert len(result.allowed_tools.tools) == 1
369+
assert result.allowed_tools.tools[0]["function"]["name"] == "get_weather"
370+
371+
async def test_specific_custom_tool_valid(self):
372+
"""Test specific custom tool choice when tool exists."""
373+
chat_tools = [{"type": "function", "function": {"name": "custom_tool"}}]
374+
tool_choice = OpenAIResponseInputToolChoiceCustomTool(name="custom_tool")
375+
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
376+
377+
assert isinstance(result, OpenAIChatCompletionToolChoiceCustomTool)
378+
assert result.custom.name == "custom_tool"
379+
380+
async def test_specific_custom_tool_invalid(self):
381+
"""Test specific custom tool choice when tool doesn't exist - should return None."""
382+
tool_choice = OpenAIResponseInputToolChoiceCustomTool(name="nonexistent_tool")
383+
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
384+
assert result is None
385+
386+
async def test_specific_function_tool_valid(self):
387+
"""Test specific function tool choice when tool exists."""
388+
tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="get_weather")
389+
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
390+
391+
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
392+
assert result.function.name == "get_weather"
393+
394+
async def test_specific_function_tool_invalid(self):
395+
"""Test specific function tool choice when tool doesn't exist - should return None."""
396+
tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="nonexistent_function")
397+
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
398+
assert result is None
399+
400+
async def test_specific_file_search_valid(self):
401+
"""Test file_search tool choice when available."""
402+
tool_choice = OpenAIResponseInputToolChoiceFileSearch()
403+
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
404+
405+
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
406+
assert result.function.name == "file_search"
407+
408+
async def test_specific_file_search_invalid(self):
409+
"""Test file_search tool choice when not available - should return None."""
410+
chat_tools = [{"type": "function", "function": {"name": "get_weather"}}]
411+
tool_choice = OpenAIResponseInputToolChoiceFileSearch()
412+
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
413+
assert result is None
414+
415+
async def test_specific_web_search_valid(self):
416+
"""Test web_search tool choice when available."""
417+
tool_choice = OpenAIResponseInputToolChoiceWebSearch()
418+
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
419+
420+
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
421+
assert result.function.name == "web_search"
422+
423+
async def test_specific_web_search_invalid(self):
424+
"""Test web_search tool choice when not available - should return None."""
425+
chat_tools = [{"type": "function", "function": {"name": "get_weather"}}]
426+
tool_choice = OpenAIResponseInputToolChoiceWebSearch()
427+
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
428+
assert result is None
429+
430+
async def test_specific_mcp_tool_with_name(self):
431+
"""Test MCP tool choice with specific tool name."""
432+
chat_tools = [{"type": "function", "function": {"name": "mcp_tool_1"}}]
433+
tool_choice = OpenAIResponseInputToolChoiceMCPTool(
434+
server_label="mcp_server_1",
435+
name="mcp_tool_1",
436+
)
437+
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
438+
439+
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
440+
assert result.function.name == "mcp_tool_1"
441+
442+
async def test_specific_mcp_tool_with_name_not_in_chat_tools(self):
443+
"""Test MCP tool choice with specific tool name that doesn't exist in chat_tools."""
444+
chat_tools = [{"type": "function", "function": {"name": "other_tool"}}]
445+
tool_choice = OpenAIResponseInputToolChoiceMCPTool(
446+
server_label="mcp_server_1",
447+
name="mcp_tool_1",
448+
)
449+
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
450+
assert result is None
451+
452+
async def test_specific_mcp_tool_server_label_only(self):
453+
"""Test MCP tool choice with only server label (no specific tool name)."""
454+
chat_tools = [
455+
{"type": "function", "function": {"name": "mcp_tool_1"}},
456+
{"type": "function", "function": {"name": "mcp_tool_2"}},
457+
]
458+
tool_choice = OpenAIResponseInputToolChoiceMCPTool(server_label="mcp_server_1")
459+
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
460+
461+
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
462+
assert result.allowed_tools.mode == "required"
463+
assert len(result.allowed_tools.tools) == 2
464+
tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools]
465+
assert "mcp_tool_1" in tool_names
466+
assert "mcp_tool_2" in tool_names
467+
468+
async def test_specific_mcp_tool_unknown_server(self):
469+
"""Test MCP tool choice with unknown server label."""
470+
tool_choice = OpenAIResponseInputToolChoiceMCPTool(
471+
server_label="unknown_server",
472+
name="some_tool",
473+
)
474+
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
475+
# Should return None because server not found
476+
assert result is None
477+
478+
async def test_empty_chat_tools(self):
479+
"""Test with empty chat_tools list."""
480+
tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="get_weather")
481+
result = await _process_tool_choice([], tool_choice, self.server_label_to_tools)
482+
assert result is None
483+
484+
async def test_empty_server_label_to_tools(self):
485+
"""Test with empty server_label_to_tools mapping."""
486+
tool_choice = OpenAIResponseInputToolChoiceMCPTool(server_label="mcp_server_1")
487+
result = await _process_tool_choice(self.chat_tools, tool_choice, {})
488+
# Should handle gracefully
489+
assert result is None or isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
490+
491+
async def test_allowed_tools_empty_list(self):
492+
"""Test allowed_tools with empty tools list."""
493+
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(mode="auto", tools=[])
494+
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
495+
496+
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
497+
assert len(result.allowed_tools.tools) == 0
498+
499+
async def test_mcp_tool_multiple_servers(self):
500+
"""Test MCP tool choice with multiple server labels."""
501+
chat_tools = [
502+
{"type": "function", "function": {"name": "mcp_tool_1"}},
503+
{"type": "function", "function": {"name": "mcp_tool_2"}},
504+
{"type": "function", "function": {"name": "mcp_tool_3"}},
505+
]
506+
server_label_to_tools = {
507+
"server_a": ["mcp_tool_1"],
508+
"server_b": ["mcp_tool_2", "mcp_tool_3"],
509+
}
510+
511+
# Test server_a
512+
tool_choice_a = OpenAIResponseInputToolChoiceMCPTool(server_label="server_a")
513+
result_a = await _process_tool_choice(chat_tools, tool_choice_a, server_label_to_tools)
514+
assert isinstance(result_a, OpenAIChatCompletionToolChoiceAllowedTools)
515+
assert len(result_a.allowed_tools.tools) == 1
516+
517+
# Test server_b
518+
tool_choice_b = OpenAIResponseInputToolChoiceMCPTool(server_label="server_b")
519+
result_b = await _process_tool_choice(chat_tools, tool_choice_b, server_label_to_tools)
520+
assert isinstance(result_b, OpenAIChatCompletionToolChoiceAllowedTools)
521+
assert len(result_b.allowed_tools.tools) == 2

0 commit comments

Comments
 (0)