From fe1d7b69cbc1faa76e3f61563dd47b5cc83e988b Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Wed, 5 Nov 2025 15:00:24 +0100 Subject: [PATCH 1/2] Add shields support to the responses API implementation It includes both streaming and not streaming support, by leveraging the refusal field on the response --- src/app/endpoints/query_v2.py | 23 +- src/app/endpoints/streaming_query_v2.py | 34 ++- tests/unit/app/endpoints/test_query_v2.py | 194 ++++++++++++++ .../app/endpoints/test_streaming_query_v2.py | 248 ++++++++++++++++++ 4 files changed, 490 insertions(+), 9 deletions(-) diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py index 0aed8c2a1..a9c475ef6 100644 --- a/src/app/endpoints/query_v2.py +++ b/src/app/endpoints/query_v2.py @@ -322,7 +322,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche given query, handling shield configuration, tool usage, and attachment validation. - This function configures system prompts and toolgroups + This function configures system prompts, shields, and toolgroups (including RAG and MCP integration) as needed based on the query request and system configuration. It validates attachments, manages conversation and session @@ -342,8 +342,12 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche and the conversation ID, the list of parsed referenced documents, and token usage information. """ - # TODO(ltomasbo): implement shields support once available in Responses API - logger.info("Shields are not yet supported in Responses API. Disabling safety") + # List available shields for Responses API + available_shields = [shield.identifier for shield in await client.shields.list()] + if not available_shields: + logger.info("No available shields. Disabling safety") + else: + logger.info("Available shields: %s", available_shields) # use system prompt from request or default one system_prompt = get_system_prompt(query_request, configuration) @@ -381,6 +385,10 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche if query_request.conversation_id: create_kwargs["previous_response_id"] = query_request.conversation_id + # Add shields to extra_body if available + if available_shields: + create_kwargs["extra_body"] = {"guardrails": available_shields} + response = await client.responses.create(**create_kwargs) response = cast(OpenAIResponseObject, response) @@ -406,6 +414,15 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche if tool_summary: tool_calls.append(tool_summary) + # Check for shield violations + item_type = getattr(output_item, "type", None) + if item_type == "message": + refusal = getattr(output_item, "refusal", None) + if refusal: + # Metric for LLM validation errors (shield violations) + metrics.llm_calls_validation_errors_total.inc() + logger.warning("Shield violation detected: %s", refusal) + logger.info( "Response processing complete - Tool calls: %d, Response length: %d chars", len(tool_calls), diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py index 908a72373..6a77e4eaf 100644 --- a/src/app/endpoints/streaming_query_v2.py +++ b/src/app/endpoints/streaming_query_v2.py @@ -32,6 +32,7 @@ from authorization.middleware import authorize from configuration import configuration from constants import MEDIA_TYPE_JSON +import metrics from models.config import Action from models.context import ResponseGeneratorContext from models.requests import QueryRequest @@ -243,6 +244,18 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat elif event_type == "response.completed": # Capture the response object for token usage extraction latest_response_object = getattr(chunk, "response", None) + + # Check for shield violations in the completed response + if latest_response_object: + for output_item in getattr(latest_response_object, "output", []): + item_type = getattr(output_item, "type", None) + if item_type == "message": + refusal = getattr(output_item, "refusal", None) + if refusal: + # Metric for LLM validation errors (shield violations) + metrics.llm_calls_validation_errors_total.inc() + logger.warning("Shield violation detected: %s", refusal) + if not emitted_turn_complete: final_message = summary.llm_response or "".join(text_parts) if not final_message: @@ -348,11 +361,11 @@ async def retrieve_response( Asynchronously retrieves a streaming response and conversation ID from the Llama Stack agent for a given user query. - This function configures input/output shields, system prompt, - and tool usage based on the request and environment. It - prepares the agent with appropriate headers and toolgroups, - validates attachments if present, and initiates a streaming - turn with the user's query and any provided documents. + This function configures shields, system prompt, and tool usage + based on the request and environment. It prepares the agent with + appropriate headers and toolgroups, validates attachments if + present, and initiates a streaming turn with the user's query + and any provided documents. Parameters: model_id (str): Identifier of the model to use for the query. @@ -365,7 +378,12 @@ async def retrieve_response( tuple: A tuple containing the streaming response object and the conversation ID. """ - logger.info("Shields are not yet supported in Responses API.") + # List available shields for Responses API + available_shields = [shield.identifier for shield in await client.shields.list()] + if not available_shields: + logger.info("No available shields. Disabling safety") + else: + logger.info("Available shields: %s", available_shields) # use system prompt from request or default one system_prompt = get_system_prompt(query_request, configuration) @@ -402,6 +420,10 @@ async def retrieve_response( if query_request.conversation_id: create_params["previous_response_id"] = query_request.conversation_id + # Add shields to extra_body if available + if available_shields: + create_params["extra_body"] = {"guardrails": available_shields} + response = await client.responses.create(**create_params) response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response) diff --git a/tests/unit/app/endpoints/test_query_v2.py b/tests/unit/app/endpoints/test_query_v2.py index 6fefd4f13..10d46a684 100644 --- a/tests/unit/app/endpoints/test_query_v2.py +++ b/tests/unit/app/endpoints/test_query_v2.py @@ -120,6 +120,8 @@ async def test_retrieve_response_no_tools_bypasses_tools(mocker: MockerFixture) mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + # Mock shields.list + mock_client.shields.list = mocker.AsyncMock(return_value=[]) # Ensure system prompt resolution does not require real config mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") @@ -156,6 +158,8 @@ async def test_retrieve_response_builds_rag_and_mcp_tools( mock_vector_stores = mocker.Mock() mock_vector_stores.data = [mocker.Mock(id="dbA")] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + # Mock shields.list + mock_client.shields.list = mocker.AsyncMock(return_value=[]) mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") mock_cfg = mocker.Mock() @@ -222,6 +226,8 @@ async def test_retrieve_response_parses_output_and_tool_calls( mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + # Mock shields.list + mock_client.shields.list = mocker.AsyncMock(return_value=[]) mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) @@ -267,6 +273,8 @@ async def test_retrieve_response_with_usage_info(mocker: MockerFixture) -> None: mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + # Mock shields.list + mock_client.shields.list = mocker.AsyncMock(return_value=[]) mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) @@ -304,6 +312,8 @@ async def test_retrieve_response_with_usage_dict(mocker: MockerFixture) -> None: mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + # Mock shields.list + mock_client.shields.list = mocker.AsyncMock(return_value=[]) mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) @@ -341,6 +351,8 @@ async def test_retrieve_response_with_empty_usage_dict(mocker: MockerFixture) -> mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + # Mock shields.list + mock_client.shields.list = mocker.AsyncMock(return_value=[]) mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) @@ -369,6 +381,8 @@ async def test_retrieve_response_validates_attachments(mocker: MockerFixture) -> mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + # Mock shields.list + mock_client.shields.list = mocker.AsyncMock(return_value=[]) mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) @@ -515,3 +529,183 @@ async def test_query_endpoint_quota_exceeded( assert isinstance(detail, dict) assert detail["response"] == "Model quota exceeded" # type: ignore assert "gpt-4-turbo" in detail["cause"] # type: ignore + + +@pytest.mark.asyncio +async def test_retrieve_response_with_shields_available(mocker: MockerFixture) -> None: + """Test that shields are listed and passed to responses API when available.""" + mock_client = mocker.Mock() + + # Mock shields.list to return available shields + shield1 = mocker.Mock() + shield1.identifier = "shield-1" + shield2 = mocker.Mock() + shield2.identifier = "shield-2" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield1, shield2]) + + output_item = mocker.Mock() + output_item.type = "message" + output_item.role = "assistant" + output_item.content = "Safe response" + + response_obj = mocker.Mock() + response_obj.id = "resp-shields" + response_obj.output = [output_item] + response_obj.usage = None + + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + qr = QueryRequest(query="hello") + summary, conv_id, _referenced_docs, _token_usage = await retrieve_response( + mock_client, "model-shields", qr, token="tkn", provider_id="test-provider" + ) + + assert conv_id == "resp-shields" + assert summary.llm_response == "Safe response" + + # Verify that shields were passed in extra_body + kwargs = mock_client.responses.create.call_args.kwargs + assert "extra_body" in kwargs + assert "guardrails" in kwargs["extra_body"] + assert kwargs["extra_body"]["guardrails"] == ["shield-1", "shield-2"] + + +@pytest.mark.asyncio +async def test_retrieve_response_with_no_shields_available( + mocker: MockerFixture, +) -> None: + """Test that no extra_body is added when no shields are available.""" + mock_client = mocker.Mock() + + # Mock shields.list to return no shields + mock_client.shields.list = mocker.AsyncMock(return_value=[]) + + output_item = mocker.Mock() + output_item.type = "message" + output_item.role = "assistant" + output_item.content = "Response without shields" + + response_obj = mocker.Mock() + response_obj.id = "resp-no-shields" + response_obj.output = [output_item] + response_obj.usage = None + + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + qr = QueryRequest(query="hello") + summary, conv_id, _referenced_docs, _token_usage = await retrieve_response( + mock_client, "model-no-shields", qr, token="tkn", provider_id="test-provider" + ) + + assert conv_id == "resp-no-shields" + assert summary.llm_response == "Response without shields" + + # Verify that no extra_body was added + kwargs = mock_client.responses.create.call_args.kwargs + assert "extra_body" not in kwargs + + +@pytest.mark.asyncio +async def test_retrieve_response_detects_shield_violation( + mocker: MockerFixture, +) -> None: + """Test that shield violations are detected and metrics are incremented.""" + mock_client = mocker.Mock() + + # Mock shields.list to return available shields + shield1 = mocker.Mock() + shield1.identifier = "safety-shield" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield1]) + + # Create output with shield violation (refusal) + output_item = mocker.Mock() + output_item.type = "message" + output_item.role = "assistant" + output_item.content = "I cannot help with that request" + output_item.refusal = "Content violates safety policy" + + response_obj = mocker.Mock() + response_obj.id = "resp-violation" + response_obj.output = [output_item] + response_obj.usage = None + + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + # Mock the validation error metric + validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total") + + qr = QueryRequest(query="dangerous query") + summary, conv_id, _referenced_docs, _token_usage = await retrieve_response( + mock_client, "model-violation", qr, token="tkn", provider_id="test-provider" + ) + + assert conv_id == "resp-violation" + assert summary.llm_response == "I cannot help with that request" + + # Verify that the validation error metric was incremented + validation_metric.inc.assert_called_once() + + +@pytest.mark.asyncio +async def test_retrieve_response_no_violation_with_shields( + mocker: MockerFixture, +) -> None: + """Test that no metric is incremented when there's no shield violation.""" + mock_client = mocker.Mock() + + # Mock shields.list to return available shields + shield1 = mocker.Mock() + shield1.identifier = "safety-shield" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield1]) + + # Create output without shield violation + output_item = mocker.Mock() + output_item.type = "message" + output_item.role = "assistant" + output_item.content = "Safe response" + output_item.refusal = None # No violation + + response_obj = mocker.Mock() + response_obj.id = "resp-safe" + response_obj.output = [output_item] + response_obj.usage = None + + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + # Mock the validation error metric + validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total") + + qr = QueryRequest(query="safe query") + summary, conv_id, _referenced_docs, _token_usage = await retrieve_response( + mock_client, "model-safe", qr, token="tkn", provider_id="test-provider" + ) + + assert conv_id == "resp-safe" + assert summary.llm_response == "Safe response" + + # Verify that the validation error metric was NOT incremented + validation_metric.inc.assert_not_called() diff --git a/tests/unit/app/endpoints/test_streaming_query_v2.py b/tests/unit/app/endpoints/test_streaming_query_v2.py index 8d81c9e48..450b4cece 100644 --- a/tests/unit/app/endpoints/test_streaming_query_v2.py +++ b/tests/unit/app/endpoints/test_streaming_query_v2.py @@ -38,6 +38,8 @@ async def test_retrieve_response_builds_rag_and_mcp_tools( mock_vector_stores.data = [mocker.Mock(id="db1")] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) mock_client.responses.create = mocker.AsyncMock(return_value=mocker.Mock()) + # Mock shields.list + mock_client.shields.list = mocker.AsyncMock(return_value=[]) mocker.patch( "app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT" @@ -68,6 +70,8 @@ async def test_retrieve_response_no_tools_passes_none(mocker: MockerFixture) -> mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) mock_client.responses.create = mocker.AsyncMock(return_value=mocker.Mock()) + # Mock shields.list + mock_client.shields.list = mocker.AsyncMock(return_value=[]) mocker.patch( "app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT" @@ -222,3 +226,247 @@ def _raise(*_a: Any, **_k: Any) -> None: assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Unable to connect to Llama Stack" in str(exc.value.detail) fail_metric.inc.assert_called_once() + + +@pytest.mark.asyncio +async def test_retrieve_response_with_shields_available(mocker: MockerFixture) -> None: + """Test that shields are listed and passed to streaming responses API.""" + mock_client = mocker.Mock() + + # Mock shields.list to return available shields + shield1 = mocker.Mock() + shield1.identifier = "shield-1" + shield2 = mocker.Mock() + shield2.identifier = "shield-2" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield1, shield2]) + + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + mock_client.responses.create = mocker.AsyncMock(return_value=mocker.Mock()) + + mocker.patch( + "app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT" + ) + mocker.patch( + "app.endpoints.streaming_query_v2.configuration", mocker.Mock(mcp_servers=[]) + ) + + qr = QueryRequest(query="hello") + await retrieve_response(mock_client, "model-shields", qr, token="tok") + + # Verify that shields were passed in extra_body + kwargs = mock_client.responses.create.call_args.kwargs + assert "extra_body" in kwargs + assert "guardrails" in kwargs["extra_body"] + assert kwargs["extra_body"]["guardrails"] == ["shield-1", "shield-2"] + + +@pytest.mark.asyncio +async def test_retrieve_response_with_no_shields_available( + mocker: MockerFixture, +) -> None: + """Test that no extra_body is added when no shields are available.""" + mock_client = mocker.Mock() + + # Mock shields.list to return no shields + mock_client.shields.list = mocker.AsyncMock(return_value=[]) + + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + mock_client.responses.create = mocker.AsyncMock(return_value=mocker.Mock()) + + mocker.patch( + "app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT" + ) + mocker.patch( + "app.endpoints.streaming_query_v2.configuration", mocker.Mock(mcp_servers=[]) + ) + + qr = QueryRequest(query="hello") + await retrieve_response(mock_client, "model-no-shields", qr, token="tok") + + # Verify that no extra_body was added + kwargs = mock_client.responses.create.call_args.kwargs + assert "extra_body" not in kwargs + + +@pytest.mark.asyncio +async def test_streaming_response_detects_shield_violation( + mocker: MockerFixture, dummy_request: Request +) -> None: + """Test that shield violations in streaming responses are detected and metrics incremented.""" + # Skip real config checks + mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") + + # Model selection plumbing + mock_client = mocker.Mock() + mock_client.models.list = mocker.AsyncMock(return_value=[mocker.Mock()]) + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + mocker.patch( + "app.endpoints.streaming_query.evaluate_model_hints", + return_value=(None, None), + ) + mocker.patch( + "app.endpoints.streaming_query.select_model_and_provider_id", + return_value=("llama/m", "m", "p"), + ) + + # SSE helpers + mocker.patch( + "app.endpoints.streaming_query_v2.stream_start_event", + lambda conv_id: f"START:{conv_id}\n", + ) + mocker.patch( + "app.endpoints.streaming_query_v2.format_stream_data", + lambda obj: f"EV:{obj['event']}:{obj['data'].get('token','')}\n", + ) + mocker.patch( + "app.endpoints.streaming_query_v2.stream_end_event", + lambda _m, _s, _t, _media: "END\n", + ) + + # Mock the cleanup function that handles all post-streaming database/cache work + mocker.patch( + "app.endpoints.streaming_query_v2.cleanup_after_streaming", + mocker.AsyncMock(return_value=None), + ) + + # Mock the validation error metric + validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total") + + # Build a fake async stream with shield violation + async def fake_stream_with_violation() -> AsyncIterator[SimpleNamespace]: + yield SimpleNamespace( + type="response.created", response=SimpleNamespace(id="conv-violation") + ) + yield SimpleNamespace(type="response.output_text.delta", delta="I cannot ") + yield SimpleNamespace(type="response.output_text.done", text="I cannot help") + # Response completed with refusal in output + violation_item = SimpleNamespace( + type="message", + role="assistant", + refusal="Content violates safety policy", + ) + response_with_violation = SimpleNamespace( + id="conv-violation", output=[violation_item] + ) + yield SimpleNamespace( + type="response.completed", response=response_with_violation + ) + + mocker.patch( + "app.endpoints.streaming_query_v2.retrieve_response", + return_value=(fake_stream_with_violation(), ""), + ) + + mocker.patch("metrics.llm_calls_total") + + resp = await streaming_query_endpoint_handler_v2( + request=dummy_request, + query_request=QueryRequest(query="dangerous query"), + auth=("user123", "", True, "token-abc"), + mcp_headers={}, + ) + + assert isinstance(resp, StreamingResponse) + + # Collect emitted events to trigger the generator + events: list[str] = [] + async for chunk in resp.body_iterator: + s = chunk.decode() if isinstance(chunk, (bytes, bytearray)) else str(chunk) + events.append(s) + + # Verify that the validation error metric was incremented + validation_metric.inc.assert_called_once() + + +@pytest.mark.asyncio +async def test_streaming_response_no_shield_violation( + mocker: MockerFixture, dummy_request: Request +) -> None: + """Test that no metric is incremented when there's no shield violation in streaming.""" + # Skip real config checks + mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") + + # Model selection plumbing + mock_client = mocker.Mock() + mock_client.models.list = mocker.AsyncMock(return_value=[mocker.Mock()]) + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + mocker.patch( + "app.endpoints.streaming_query.evaluate_model_hints", + return_value=(None, None), + ) + mocker.patch( + "app.endpoints.streaming_query.select_model_and_provider_id", + return_value=("llama/m", "m", "p"), + ) + + # SSE helpers + mocker.patch( + "app.endpoints.streaming_query_v2.stream_start_event", + lambda conv_id: f"START:{conv_id}\n", + ) + mocker.patch( + "app.endpoints.streaming_query_v2.format_stream_data", + lambda obj: f"EV:{obj['event']}:{obj['data'].get('token','')}\n", + ) + mocker.patch( + "app.endpoints.streaming_query_v2.stream_end_event", + lambda _m, _s, _t, _media: "END\n", + ) + + # Mock the cleanup function that handles all post-streaming database/cache work + mocker.patch( + "app.endpoints.streaming_query_v2.cleanup_after_streaming", + mocker.AsyncMock(return_value=None), + ) + + # Mock the validation error metric + validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total") + + # Build a fake async stream without violation + async def fake_stream_without_violation() -> AsyncIterator[SimpleNamespace]: + yield SimpleNamespace( + type="response.created", response=SimpleNamespace(id="conv-safe") + ) + yield SimpleNamespace(type="response.output_text.delta", delta="Safe ") + yield SimpleNamespace(type="response.output_text.done", text="Safe response") + # Response completed without refusal + safe_item = SimpleNamespace( + type="message", + role="assistant", + refusal=None, # No violation + ) + response_safe = SimpleNamespace(id="conv-safe", output=[safe_item]) + yield SimpleNamespace(type="response.completed", response=response_safe) + + mocker.patch( + "app.endpoints.streaming_query_v2.retrieve_response", + return_value=(fake_stream_without_violation(), ""), + ) + + mocker.patch("metrics.llm_calls_total") + + resp = await streaming_query_endpoint_handler_v2( + request=dummy_request, + query_request=QueryRequest(query="safe query"), + auth=("user123", "", True, "token-abc"), + mcp_headers={}, + ) + + assert isinstance(resp, StreamingResponse) + + # Collect emitted events to trigger the generator + events: list[str] = [] + async for chunk in resp.body_iterator: + s = chunk.decode() if isinstance(chunk, (bytes, bytearray)) else str(chunk) + events.append(s) + + # Verify that the validation error metric was NOT incremented + validation_metric.inc.assert_not_called() From dccb61b528c6e70e90cc8f3db407391aa27a7f4a Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Fri, 14 Nov 2025 11:13:00 +0100 Subject: [PATCH 2/2] Responses API: avoid shields code duplication --- src/app/endpoints/query_v2.py | 17 ++------ src/app/endpoints/streaming_query_v2.py | 19 +++------ src/utils/shields.py | 54 +++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 27 deletions(-) create mode 100644 src/utils/shields.py diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py index a9c475ef6..3f0a01d1f 100644 --- a/src/app/endpoints/query_v2.py +++ b/src/app/endpoints/query_v2.py @@ -34,6 +34,7 @@ get_topic_summary_system_prompt, ) from utils.mcp_headers import mcp_headers_dependency +from utils.shields import detect_shield_violations, get_available_shields from utils.token_counter import TokenCounter from utils.types import TurnSummary, ToolCallSummary @@ -343,11 +344,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche and token usage information. """ # List available shields for Responses API - available_shields = [shield.identifier for shield in await client.shields.list()] - if not available_shields: - logger.info("No available shields. Disabling safety") - else: - logger.info("Available shields: %s", available_shields) + available_shields = await get_available_shields(client) # use system prompt from request or default one system_prompt = get_system_prompt(query_request, configuration) @@ -414,14 +411,8 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche if tool_summary: tool_calls.append(tool_summary) - # Check for shield violations - item_type = getattr(output_item, "type", None) - if item_type == "message": - refusal = getattr(output_item, "refusal", None) - if refusal: - # Metric for LLM validation errors (shield violations) - metrics.llm_calls_validation_errors_total.inc() - logger.warning("Shield violation detected: %s", refusal) + # Check for shield violations across all output items + detect_shield_violations(response.output) logger.info( "Response processing complete - Tool calls: %d, Response length: %d chars", diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py index 6a77e4eaf..bf4080c4d 100644 --- a/src/app/endpoints/streaming_query_v2.py +++ b/src/app/endpoints/streaming_query_v2.py @@ -32,7 +32,6 @@ from authorization.middleware import authorize from configuration import configuration from constants import MEDIA_TYPE_JSON -import metrics from models.config import Action from models.context import ResponseGeneratorContext from models.requests import QueryRequest @@ -42,6 +41,7 @@ get_system_prompt, ) from utils.mcp_headers import mcp_headers_dependency +from utils.shields import detect_shield_violations, get_available_shields from utils.token_counter import TokenCounter from utils.transcripts import store_transcript from utils.types import TurnSummary, ToolCallSummary @@ -247,14 +247,9 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat # Check for shield violations in the completed response if latest_response_object: - for output_item in getattr(latest_response_object, "output", []): - item_type = getattr(output_item, "type", None) - if item_type == "message": - refusal = getattr(output_item, "refusal", None) - if refusal: - # Metric for LLM validation errors (shield violations) - metrics.llm_calls_validation_errors_total.inc() - logger.warning("Shield violation detected: %s", refusal) + detect_shield_violations( + getattr(latest_response_object, "output", []) + ) if not emitted_turn_complete: final_message = summary.llm_response or "".join(text_parts) @@ -379,11 +374,7 @@ async def retrieve_response( and the conversation ID. """ # List available shields for Responses API - available_shields = [shield.identifier for shield in await client.shields.list()] - if not available_shields: - logger.info("No available shields. Disabling safety") - else: - logger.info("Available shields: %s", available_shields) + available_shields = await get_available_shields(client) # use system prompt from request or default one system_prompt = get_system_prompt(query_request, configuration) diff --git a/src/utils/shields.py b/src/utils/shields.py new file mode 100644 index 000000000..f9c96831e --- /dev/null +++ b/src/utils/shields.py @@ -0,0 +1,54 @@ +"""Utility functions for working with Llama Stack shields.""" + +import logging +from typing import Any + +from llama_stack_client import AsyncLlamaStackClient + +import metrics + +logger = logging.getLogger(__name__) + + +async def get_available_shields(client: AsyncLlamaStackClient) -> list[str]: + """ + Discover and return available shield identifiers. + + Args: + client: The Llama Stack client to query for available shields. + + Returns: + List of shield identifiers that are available. + """ + available_shields = [shield.identifier for shield in await client.shields.list()] + if not available_shields: + logger.info("No available shields. Disabling safety") + else: + logger.info("Available shields: %s", available_shields) + return available_shields + + +def detect_shield_violations(output_items: list[Any]) -> bool: + """ + Check output items for shield violations and update metrics. + + Iterates through output items looking for message items with refusal + attributes. If a refusal is found, increments the validation error + metric and logs a warning. + + Args: + output_items: List of output items from the LLM response to check. + + Returns: + True if a shield violation was detected, False otherwise. + """ + for output_item in output_items: + item_type = getattr(output_item, "type", None) + if item_type == "message": + refusal = getattr(output_item, "refusal", None) + if refusal: + # Metric for LLM validation errors (shield violations) + metrics.llm_calls_validation_errors_total.inc() + logger.warning("Shield violation detected: %s", refusal) + return True + return False