|
18 | 18 | from authentication import get_auth_dependency |
19 | 19 | from authentication.interface import AuthTuple |
20 | 20 | from authorization.middleware import authorize |
21 | | -from configuration import configuration |
| 21 | +from configuration import AppConfig, configuration |
22 | 22 | import metrics |
23 | 23 | from models.config import Action |
24 | 24 | from models.requests import QueryRequest |
@@ -355,31 +355,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche |
355 | 355 | validate_attachments_metadata(query_request.attachments) |
356 | 356 |
|
357 | 357 | # Prepare tools for responses API |
358 | | - toolgroups: list[dict[str, Any]] | None = None |
359 | | - if not query_request.no_tools: |
360 | | - toolgroups = [] |
361 | | - # Get vector stores for RAG tools |
362 | | - vector_store_ids = [ |
363 | | - vector_store.id for vector_store in (await client.vector_stores.list()).data |
364 | | - ] |
365 | | - |
366 | | - # Add RAG tools if vector stores are available |
367 | | - rag_tools = get_rag_tools(vector_store_ids) |
368 | | - if rag_tools: |
369 | | - toolgroups.extend(rag_tools) |
370 | | - |
371 | | - # Add MCP server tools |
372 | | - mcp_tools = get_mcp_tools(configuration.mcp_servers, token, mcp_headers) |
373 | | - if mcp_tools: |
374 | | - toolgroups.extend(mcp_tools) |
375 | | - logger.debug( |
376 | | - "Configured %d MCP tools: %s", |
377 | | - len(mcp_tools), |
378 | | - [tool.get("server_label", "unknown") for tool in mcp_tools], |
379 | | - ) |
380 | | - # Convert empty list to None for consistency with existing behavior |
381 | | - if not toolgroups: |
382 | | - toolgroups = None |
| 358 | + toolgroups = await prepare_tools_for_responses_api( |
| 359 | + client, query_request, token, configuration, mcp_headers |
| 360 | + ) |
383 | 361 |
|
384 | 362 | # Prepare input for Responses API |
385 | 363 | # Convert attachments to text and concatenate with query |
@@ -625,11 +603,71 @@ def get_mcp_tools( |
625 | 603 | "require_approval": "never", |
626 | 604 | } |
627 | 605 |
|
628 | | - # Add authentication if headers or token provided (Response API format) |
629 | | - headers = (mcp_headers or {}).get(mcp_server.url) |
630 | | - if headers: |
| 606 | + # Build headers: start with token auth, then merge in per-server headers |
| 607 | + if token or mcp_headers: |
| 608 | + headers = {} |
| 609 | + # Add token-based auth if available |
| 610 | + if token: |
| 611 | + headers["Authorization"] = f"Bearer {token}" |
| 612 | + # Merge in per-server headers (can override Authorization if needed) |
| 613 | + server_headers = (mcp_headers or {}).get(mcp_server.url) |
| 614 | + if server_headers: |
| 615 | + headers.update(server_headers) |
631 | 616 | tool_def["headers"] = headers |
632 | | - elif token: |
633 | | - tool_def["headers"] = {"Authorization": f"Bearer {token}"} |
| 617 | + |
634 | 618 | tools.append(tool_def) |
635 | 619 | return tools |
| 620 | + |
| 621 | + |
| 622 | +async def prepare_tools_for_responses_api( |
| 623 | + client: AsyncLlamaStackClient, |
| 624 | + query_request: QueryRequest, |
| 625 | + token: str, |
| 626 | + config: AppConfig, |
| 627 | + mcp_headers: dict[str, dict[str, str]] | None = None, |
| 628 | +) -> list[dict[str, Any]] | None: |
| 629 | + """ |
| 630 | + Prepare tools for Responses API including RAG and MCP tools. |
| 631 | +
|
| 632 | + This function retrieves vector stores and combines them with MCP |
| 633 | + server tools to create a unified toolgroups list for the Responses API. |
| 634 | +
|
| 635 | + Args: |
| 636 | + client: The Llama Stack client instance |
| 637 | + query_request: The user's query request |
| 638 | + token: Authentication token for MCP tools |
| 639 | + config: Configuration object containing MCP server settings |
| 640 | + mcp_headers: Per-request headers for MCP servers |
| 641 | +
|
| 642 | + Returns: |
| 643 | + list[dict[str, Any]] | None: List of tool configurations for the |
| 644 | + Responses API, or None if no_tools is True or no tools are available |
| 645 | + """ |
| 646 | + if query_request.no_tools: |
| 647 | + return None |
| 648 | + |
| 649 | + toolgroups = [] |
| 650 | + # Get vector stores for RAG tools |
| 651 | + vector_store_ids = [ |
| 652 | + vector_store.id for vector_store in (await client.vector_stores.list()).data |
| 653 | + ] |
| 654 | + |
| 655 | + # Add RAG tools if vector stores are available |
| 656 | + rag_tools = get_rag_tools(vector_store_ids) |
| 657 | + if rag_tools: |
| 658 | + toolgroups.extend(rag_tools) |
| 659 | + |
| 660 | + # Add MCP server tools |
| 661 | + mcp_tools = get_mcp_tools(config.mcp_servers, token, mcp_headers) |
| 662 | + if mcp_tools: |
| 663 | + toolgroups.extend(mcp_tools) |
| 664 | + logger.debug( |
| 665 | + "Configured %d MCP tools: %s", |
| 666 | + len(mcp_tools), |
| 667 | + [tool.get("server_label", "unknown") for tool in mcp_tools], |
| 668 | + ) |
| 669 | + # Convert empty list to None for consistency with existing behavior |
| 670 | + if not toolgroups: |
| 671 | + return None |
| 672 | + |
| 673 | + return toolgroups |
0 commit comments