|
5 | 5 | # the root directory of this source tree. |
6 | 6 |
|
7 | 7 |
|
| 8 | +from llama_stack.providers.inline.agents.meta_reference.responses.streaming import ( |
| 9 | + _process_tool_choice, |
| 10 | +) |
8 | 11 | from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext |
| 12 | +from llama_stack_api.inference import ( |
| 13 | + OpenAIChatCompletionToolChoiceAllowedTools, |
| 14 | + OpenAIChatCompletionToolChoiceCustomTool, |
| 15 | + OpenAIChatCompletionToolChoiceFunctionTool, |
| 16 | +) |
9 | 17 | from llama_stack_api.openai_responses import ( |
10 | 18 | MCPListToolsTool, |
| 19 | + OpenAIResponseInputToolChoiceAllowedTools, |
| 20 | + OpenAIResponseInputToolChoiceCustomTool, |
| 21 | + OpenAIResponseInputToolChoiceFileSearch, |
| 22 | + OpenAIResponseInputToolChoiceFunctionTool, |
| 23 | + OpenAIResponseInputToolChoiceMCPTool, |
| 24 | + OpenAIResponseInputToolChoiceMode, |
| 25 | + OpenAIResponseInputToolChoiceWebSearch, |
11 | 26 | OpenAIResponseInputToolFileSearch, |
12 | 27 | OpenAIResponseInputToolFunction, |
13 | 28 | OpenAIResponseInputToolMCP, |
@@ -181,3 +196,326 @@ def test_mismatched_allowed_tools(self): |
181 | 196 | assert len(context.previous_tool_listings) == 1 |
182 | 197 | assert len(context.previous_tool_listings[0].tools) == 1 |
183 | 198 | assert context.previous_tool_listings[0].server_label == "anotherlabel" |
| 199 | + |
| 200 | + |
| 201 | +class TestProcessToolChoice: |
| 202 | + """Comprehensive test suite for _process_tool_choice function.""" |
| 203 | + |
| 204 | + def setup_method(self): |
| 205 | + """Set up common test fixtures.""" |
| 206 | + self.chat_tools = [ |
| 207 | + {"type": "function", "function": {"name": "get_weather"}}, |
| 208 | + {"type": "function", "function": {"name": "calculate"}}, |
| 209 | + {"type": "function", "function": {"name": "file_search"}}, |
| 210 | + {"type": "function", "function": {"name": "web_search"}}, |
| 211 | + ] |
| 212 | + self.server_label_to_tools = { |
| 213 | + "mcp_server_1": ["mcp_tool_1", "mcp_tool_2"], |
| 214 | + "mcp_server_2": ["mcp_tool_3"], |
| 215 | + } |
| 216 | + |
| 217 | + async def test_mode_auto(self): |
| 218 | + """Test auto mode - should return 'auto' string.""" |
| 219 | + tool_choice = OpenAIResponseInputToolChoiceMode.auto |
| 220 | + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) |
| 221 | + assert result == "auto" |
| 222 | + |
| 223 | + async def test_mode_none(self): |
| 224 | + """Test none mode - should return 'none' string.""" |
| 225 | + tool_choice = OpenAIResponseInputToolChoiceMode.none |
| 226 | + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) |
| 227 | + assert result == "none" |
| 228 | + |
| 229 | + async def test_mode_required_with_tools(self): |
| 230 | + """Test required mode with available tools - should return AllowedTools with all function tools.""" |
| 231 | + tool_choice = OpenAIResponseInputToolChoiceMode.required |
| 232 | + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) |
| 233 | + |
| 234 | + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) |
| 235 | + assert result.allowed_tools.mode == "required" |
| 236 | + assert len(result.allowed_tools.tools) == 4 |
| 237 | + tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools] |
| 238 | + assert "get_weather" in tool_names |
| 239 | + assert "calculate" in tool_names |
| 240 | + assert "file_search" in tool_names |
| 241 | + assert "web_search" in tool_names |
| 242 | + |
| 243 | + async def test_mode_required_without_tools(self): |
| 244 | + """Test required mode without available tools - should return None.""" |
| 245 | + tool_choice = OpenAIResponseInputToolChoiceMode.required |
| 246 | + result = await _process_tool_choice([], tool_choice, self.server_label_to_tools) |
| 247 | + assert result is None |
| 248 | + |
| 249 | + async def test_allowed_tools_function(self): |
| 250 | + """Test allowed_tools with function tool types.""" |
| 251 | + tool_choice = OpenAIResponseInputToolChoiceAllowedTools( |
| 252 | + mode="required", |
| 253 | + tools=[ |
| 254 | + {"type": "function", "name": "get_weather"}, |
| 255 | + {"type": "function", "name": "calculate"}, |
| 256 | + ], |
| 257 | + ) |
| 258 | + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) |
| 259 | + |
| 260 | + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) |
| 261 | + assert result.allowed_tools.mode == "required" |
| 262 | + assert len(result.allowed_tools.tools) == 2 |
| 263 | + assert result.allowed_tools.tools[0]["function"]["name"] == "get_weather" |
| 264 | + assert result.allowed_tools.tools[1]["function"]["name"] == "calculate" |
| 265 | + |
| 266 | + async def test_allowed_tools_custom(self): |
| 267 | + """Test allowed_tools with custom tool types.""" |
| 268 | + chat_tools = [{"type": "function", "function": {"name": "custom_tool_1"}}] |
| 269 | + tool_choice = OpenAIResponseInputToolChoiceAllowedTools( |
| 270 | + mode="auto", |
| 271 | + tools=[{"type": "custom", "name": "custom_tool_1"}], |
| 272 | + ) |
| 273 | + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) |
| 274 | + |
| 275 | + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) |
| 276 | + assert result.allowed_tools.mode == "auto" |
| 277 | + assert len(result.allowed_tools.tools) == 1 |
| 278 | + assert result.allowed_tools.tools[0]["type"] == "custom" |
| 279 | + assert result.allowed_tools.tools[0]["custom"]["name"] == "custom_tool_1" |
| 280 | + |
| 281 | + async def test_allowed_tools_file_search(self): |
| 282 | + """Test allowed_tools with file_search.""" |
| 283 | + tool_choice = OpenAIResponseInputToolChoiceAllowedTools( |
| 284 | + mode="required", |
| 285 | + tools=[{"type": "file_search"}], |
| 286 | + ) |
| 287 | + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) |
| 288 | + |
| 289 | + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) |
| 290 | + assert len(result.allowed_tools.tools) == 1 |
| 291 | + assert result.allowed_tools.tools[0]["function"]["name"] == "file_search" |
| 292 | + |
| 293 | + async def test_allowed_tools_web_search(self): |
| 294 | + """Test allowed_tools with web_search.""" |
| 295 | + tool_choice = OpenAIResponseInputToolChoiceAllowedTools( |
| 296 | + mode="required", |
| 297 | + tools=[ |
| 298 | + {"type": "web_search_preview_2025_03_11"}, |
| 299 | + {"type": "web_search_2025_08_26"}, |
| 300 | + {"type": "web_search_preview"}, |
| 301 | + ], |
| 302 | + ) |
| 303 | + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) |
| 304 | + |
| 305 | + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) |
| 306 | + assert len(result.allowed_tools.tools) == 3 |
| 307 | + assert result.allowed_tools.tools[0]["function"]["name"] == "web_search" |
| 308 | + assert result.allowed_tools.tools[0]["type"] == "function" |
| 309 | + assert result.allowed_tools.tools[1]["function"]["name"] == "web_search" |
| 310 | + assert result.allowed_tools.tools[1]["type"] == "function" |
| 311 | + assert result.allowed_tools.tools[2]["function"]["name"] == "web_search" |
| 312 | + assert result.allowed_tools.tools[2]["type"] == "function" |
| 313 | + |
| 314 | + async def test_allowed_tools_mcp_server_label(self): |
| 315 | + """Test allowed_tools with MCP server label (no specific tool name).""" |
| 316 | + chat_tools = [ |
| 317 | + {"type": "function", "function": {"name": "mcp_tool_1"}}, |
| 318 | + {"type": "function", "function": {"name": "mcp_tool_2"}}, |
| 319 | + ] |
| 320 | + tool_choice = OpenAIResponseInputToolChoiceAllowedTools( |
| 321 | + mode="required", |
| 322 | + tools=[{"type": "mcp", "server_label": "mcp_server_1"}], |
| 323 | + ) |
| 324 | + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) |
| 325 | + |
| 326 | + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) |
| 327 | + assert len(result.allowed_tools.tools) == 2 |
| 328 | + tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools] |
| 329 | + assert "mcp_tool_1" in tool_names |
| 330 | + assert "mcp_tool_2" in tool_names |
| 331 | + |
| 332 | + async def test_allowed_tools_mixed_types(self): |
| 333 | + """Test allowed_tools with mixed tool types.""" |
| 334 | + chat_tools = [ |
| 335 | + {"type": "function", "function": {"name": "get_weather"}}, |
| 336 | + {"type": "function", "function": {"name": "file_search"}}, |
| 337 | + {"type": "function", "function": {"name": "web_search"}}, |
| 338 | + {"type": "function", "function": {"name": "mcp_tool_1"}}, |
| 339 | + ] |
| 340 | + tool_choice = OpenAIResponseInputToolChoiceAllowedTools( |
| 341 | + mode="auto", |
| 342 | + tools=[ |
| 343 | + {"type": "function", "name": "get_weather"}, |
| 344 | + {"type": "file_search"}, |
| 345 | + {"type": "web_search"}, |
| 346 | + {"type": "mcp", "server_label": "mcp_server_1"}, |
| 347 | + ], |
| 348 | + ) |
| 349 | + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) |
| 350 | + |
| 351 | + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) |
| 352 | + # Should have: get_weather, file_search, web_search, mcp_tool_1, mcp_tool_2 |
| 353 | + assert len(result.allowed_tools.tools) >= 3 |
| 354 | + |
| 355 | + async def test_allowed_tools_invalid_type(self): |
| 356 | + """Test allowed_tools with an unsupported tool type - should skip it.""" |
| 357 | + tool_choice = OpenAIResponseInputToolChoiceAllowedTools( |
| 358 | + mode="required", |
| 359 | + tools=[ |
| 360 | + {"type": "function", "name": "get_weather"}, |
| 361 | + {"type": "unsupported_type", "name": "bad_tool"}, |
| 362 | + ], |
| 363 | + ) |
| 364 | + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) |
| 365 | + |
| 366 | + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) |
| 367 | + # Should only include the valid function tool |
| 368 | + assert len(result.allowed_tools.tools) == 1 |
| 369 | + assert result.allowed_tools.tools[0]["function"]["name"] == "get_weather" |
| 370 | + |
| 371 | + async def test_specific_custom_tool_valid(self): |
| 372 | + """Test specific custom tool choice when tool exists.""" |
| 373 | + chat_tools = [{"type": "function", "function": {"name": "custom_tool"}}] |
| 374 | + tool_choice = OpenAIResponseInputToolChoiceCustomTool(name="custom_tool") |
| 375 | + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) |
| 376 | + |
| 377 | + assert isinstance(result, OpenAIChatCompletionToolChoiceCustomTool) |
| 378 | + assert result.custom.name == "custom_tool" |
| 379 | + |
| 380 | + async def test_specific_custom_tool_invalid(self): |
| 381 | + """Test specific custom tool choice when tool doesn't exist - should return None.""" |
| 382 | + tool_choice = OpenAIResponseInputToolChoiceCustomTool(name="nonexistent_tool") |
| 383 | + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) |
| 384 | + assert result is None |
| 385 | + |
| 386 | + async def test_specific_function_tool_valid(self): |
| 387 | + """Test specific function tool choice when tool exists.""" |
| 388 | + tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="get_weather") |
| 389 | + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) |
| 390 | + |
| 391 | + assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool) |
| 392 | + assert result.function.name == "get_weather" |
| 393 | + |
| 394 | + async def test_specific_function_tool_invalid(self): |
| 395 | + """Test specific function tool choice when tool doesn't exist - should return None.""" |
| 396 | + tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="nonexistent_function") |
| 397 | + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) |
| 398 | + assert result is None |
| 399 | + |
| 400 | + async def test_specific_file_search_valid(self): |
| 401 | + """Test file_search tool choice when available.""" |
| 402 | + tool_choice = OpenAIResponseInputToolChoiceFileSearch() |
| 403 | + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) |
| 404 | + |
| 405 | + assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool) |
| 406 | + assert result.function.name == "file_search" |
| 407 | + |
| 408 | + async def test_specific_file_search_invalid(self): |
| 409 | + """Test file_search tool choice when not available - should return None.""" |
| 410 | + chat_tools = [{"type": "function", "function": {"name": "get_weather"}}] |
| 411 | + tool_choice = OpenAIResponseInputToolChoiceFileSearch() |
| 412 | + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) |
| 413 | + assert result is None |
| 414 | + |
| 415 | + async def test_specific_web_search_valid(self): |
| 416 | + """Test web_search tool choice when available.""" |
| 417 | + tool_choice = OpenAIResponseInputToolChoiceWebSearch() |
| 418 | + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) |
| 419 | + |
| 420 | + assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool) |
| 421 | + assert result.function.name == "web_search" |
| 422 | + |
| 423 | + async def test_specific_web_search_invalid(self): |
| 424 | + """Test web_search tool choice when not available - should return None.""" |
| 425 | + chat_tools = [{"type": "function", "function": {"name": "get_weather"}}] |
| 426 | + tool_choice = OpenAIResponseInputToolChoiceWebSearch() |
| 427 | + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) |
| 428 | + assert result is None |
| 429 | + |
| 430 | + async def test_specific_mcp_tool_with_name(self): |
| 431 | + """Test MCP tool choice with specific tool name.""" |
| 432 | + chat_tools = [{"type": "function", "function": {"name": "mcp_tool_1"}}] |
| 433 | + tool_choice = OpenAIResponseInputToolChoiceMCPTool( |
| 434 | + server_label="mcp_server_1", |
| 435 | + name="mcp_tool_1", |
| 436 | + ) |
| 437 | + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) |
| 438 | + |
| 439 | + assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool) |
| 440 | + assert result.function.name == "mcp_tool_1" |
| 441 | + |
| 442 | + async def test_specific_mcp_tool_with_name_not_in_chat_tools(self): |
| 443 | + """Test MCP tool choice with specific tool name that doesn't exist in chat_tools.""" |
| 444 | + chat_tools = [{"type": "function", "function": {"name": "other_tool"}}] |
| 445 | + tool_choice = OpenAIResponseInputToolChoiceMCPTool( |
| 446 | + server_label="mcp_server_1", |
| 447 | + name="mcp_tool_1", |
| 448 | + ) |
| 449 | + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) |
| 450 | + assert result is None |
| 451 | + |
| 452 | + async def test_specific_mcp_tool_server_label_only(self): |
| 453 | + """Test MCP tool choice with only server label (no specific tool name).""" |
| 454 | + chat_tools = [ |
| 455 | + {"type": "function", "function": {"name": "mcp_tool_1"}}, |
| 456 | + {"type": "function", "function": {"name": "mcp_tool_2"}}, |
| 457 | + ] |
| 458 | + tool_choice = OpenAIResponseInputToolChoiceMCPTool(server_label="mcp_server_1") |
| 459 | + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) |
| 460 | + |
| 461 | + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) |
| 462 | + assert result.allowed_tools.mode == "required" |
| 463 | + assert len(result.allowed_tools.tools) == 2 |
| 464 | + tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools] |
| 465 | + assert "mcp_tool_1" in tool_names |
| 466 | + assert "mcp_tool_2" in tool_names |
| 467 | + |
| 468 | + async def test_specific_mcp_tool_unknown_server(self): |
| 469 | + """Test MCP tool choice with unknown server label.""" |
| 470 | + tool_choice = OpenAIResponseInputToolChoiceMCPTool( |
| 471 | + server_label="unknown_server", |
| 472 | + name="some_tool", |
| 473 | + ) |
| 474 | + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) |
| 475 | + # Should return None because server not found |
| 476 | + assert result is None |
| 477 | + |
| 478 | + async def test_empty_chat_tools(self): |
| 479 | + """Test with empty chat_tools list.""" |
| 480 | + tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="get_weather") |
| 481 | + result = await _process_tool_choice([], tool_choice, self.server_label_to_tools) |
| 482 | + assert result is None |
| 483 | + |
| 484 | + async def test_empty_server_label_to_tools(self): |
| 485 | + """Test with empty server_label_to_tools mapping.""" |
| 486 | + tool_choice = OpenAIResponseInputToolChoiceMCPTool(server_label="mcp_server_1") |
| 487 | + result = await _process_tool_choice(self.chat_tools, tool_choice, {}) |
| 488 | + # Should handle gracefully |
| 489 | + assert result is None or isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) |
| 490 | + |
| 491 | + async def test_allowed_tools_empty_list(self): |
| 492 | + """Test allowed_tools with empty tools list.""" |
| 493 | + tool_choice = OpenAIResponseInputToolChoiceAllowedTools(mode="auto", tools=[]) |
| 494 | + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) |
| 495 | + |
| 496 | + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) |
| 497 | + assert len(result.allowed_tools.tools) == 0 |
| 498 | + |
| 499 | + async def test_mcp_tool_multiple_servers(self): |
| 500 | + """Test MCP tool choice with multiple server labels.""" |
| 501 | + chat_tools = [ |
| 502 | + {"type": "function", "function": {"name": "mcp_tool_1"}}, |
| 503 | + {"type": "function", "function": {"name": "mcp_tool_2"}}, |
| 504 | + {"type": "function", "function": {"name": "mcp_tool_3"}}, |
| 505 | + ] |
| 506 | + server_label_to_tools = { |
| 507 | + "server_a": ["mcp_tool_1"], |
| 508 | + "server_b": ["mcp_tool_2", "mcp_tool_3"], |
| 509 | + } |
| 510 | + |
| 511 | + # Test server_a |
| 512 | + tool_choice_a = OpenAIResponseInputToolChoiceMCPTool(server_label="server_a") |
| 513 | + result_a = await _process_tool_choice(chat_tools, tool_choice_a, server_label_to_tools) |
| 514 | + assert isinstance(result_a, OpenAIChatCompletionToolChoiceAllowedTools) |
| 515 | + assert len(result_a.allowed_tools.tools) == 1 |
| 516 | + |
| 517 | + # Test server_b |
| 518 | + tool_choice_b = OpenAIResponseInputToolChoiceMCPTool(server_label="server_b") |
| 519 | + result_b = await _process_tool_choice(chat_tools, tool_choice_b, server_label_to_tools) |
| 520 | + assert isinstance(result_b, OpenAIChatCompletionToolChoiceAllowedTools) |
| 521 | + assert len(result_b.allowed_tools.tools) == 2 |
0 commit comments