Skip to content

Commit fe1d7b6

Browse files
committed
Add shields support to the responses API implementation
It includes both streaming and not streaming support, by leveraging the refusal field on the response
1 parent 41e89f6 commit fe1d7b6

File tree

4 files changed

+490
-9
lines changed

4 files changed

+490
-9
lines changed

src/app/endpoints/query_v2.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
322322
given query, handling shield configuration, tool usage, and
323323
attachment validation.
324324
325-
This function configures system prompts and toolgroups
325+
This function configures system prompts, shields, and toolgroups
326326
(including RAG and MCP integration) as needed based on
327327
the query request and system configuration. It
328328
validates attachments, manages conversation and session
@@ -342,8 +342,12 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
342342
and the conversation ID, the list of parsed referenced documents,
343343
and token usage information.
344344
"""
345-
# TODO(ltomasbo): implement shields support once available in Responses API
346-
logger.info("Shields are not yet supported in Responses API. Disabling safety")
345+
# List available shields for Responses API
346+
available_shields = [shield.identifier for shield in await client.shields.list()]
347+
if not available_shields:
348+
logger.info("No available shields. Disabling safety")
349+
else:
350+
logger.info("Available shields: %s", available_shields)
347351

348352
# use system prompt from request or default one
349353
system_prompt = get_system_prompt(query_request, configuration)
@@ -381,6 +385,10 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
381385
if query_request.conversation_id:
382386
create_kwargs["previous_response_id"] = query_request.conversation_id
383387

388+
# Add shields to extra_body if available
389+
if available_shields:
390+
create_kwargs["extra_body"] = {"guardrails": available_shields}
391+
384392
response = await client.responses.create(**create_kwargs)
385393
response = cast(OpenAIResponseObject, response)
386394

@@ -406,6 +414,15 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
406414
if tool_summary:
407415
tool_calls.append(tool_summary)
408416

417+
# Check for shield violations
418+
item_type = getattr(output_item, "type", None)
419+
if item_type == "message":
420+
refusal = getattr(output_item, "refusal", None)
421+
if refusal:
422+
# Metric for LLM validation errors (shield violations)
423+
metrics.llm_calls_validation_errors_total.inc()
424+
logger.warning("Shield violation detected: %s", refusal)
425+
409426
logger.info(
410427
"Response processing complete - Tool calls: %d, Response length: %d chars",
411428
len(tool_calls),

src/app/endpoints/streaming_query_v2.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from authorization.middleware import authorize
3333
from configuration import configuration
3434
from constants import MEDIA_TYPE_JSON
35+
import metrics
3536
from models.config import Action
3637
from models.context import ResponseGeneratorContext
3738
from models.requests import QueryRequest
@@ -243,6 +244,18 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
243244
elif event_type == "response.completed":
244245
# Capture the response object for token usage extraction
245246
latest_response_object = getattr(chunk, "response", None)
247+
248+
# Check for shield violations in the completed response
249+
if latest_response_object:
250+
for output_item in getattr(latest_response_object, "output", []):
251+
item_type = getattr(output_item, "type", None)
252+
if item_type == "message":
253+
refusal = getattr(output_item, "refusal", None)
254+
if refusal:
255+
# Metric for LLM validation errors (shield violations)
256+
metrics.llm_calls_validation_errors_total.inc()
257+
logger.warning("Shield violation detected: %s", refusal)
258+
246259
if not emitted_turn_complete:
247260
final_message = summary.llm_response or "".join(text_parts)
248261
if not final_message:
@@ -348,11 +361,11 @@ async def retrieve_response(
348361
Asynchronously retrieves a streaming response and conversation
349362
ID from the Llama Stack agent for a given user query.
350363
351-
This function configures input/output shields, system prompt,
352-
and tool usage based on the request and environment. It
353-
prepares the agent with appropriate headers and toolgroups,
354-
validates attachments if present, and initiates a streaming
355-
turn with the user's query and any provided documents.
364+
This function configures shields, system prompt, and tool usage
365+
based on the request and environment. It prepares the agent with
366+
appropriate headers and toolgroups, validates attachments if
367+
present, and initiates a streaming turn with the user's query
368+
and any provided documents.
356369
357370
Parameters:
358371
model_id (str): Identifier of the model to use for the query.
@@ -365,7 +378,12 @@ async def retrieve_response(
365378
tuple: A tuple containing the streaming response object
366379
and the conversation ID.
367380
"""
368-
logger.info("Shields are not yet supported in Responses API.")
381+
# List available shields for Responses API
382+
available_shields = [shield.identifier for shield in await client.shields.list()]
383+
if not available_shields:
384+
logger.info("No available shields. Disabling safety")
385+
else:
386+
logger.info("Available shields: %s", available_shields)
369387

370388
# use system prompt from request or default one
371389
system_prompt = get_system_prompt(query_request, configuration)
@@ -402,6 +420,10 @@ async def retrieve_response(
402420
if query_request.conversation_id:
403421
create_params["previous_response_id"] = query_request.conversation_id
404422

423+
# Add shields to extra_body if available
424+
if available_shields:
425+
create_params["extra_body"] = {"guardrails": available_shields}
426+
405427
response = await client.responses.create(**create_params)
406428
response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response)
407429

tests/unit/app/endpoints/test_query_v2.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ async def test_retrieve_response_no_tools_bypasses_tools(mocker: MockerFixture)
120120
mock_vector_stores = mocker.Mock()
121121
mock_vector_stores.data = []
122122
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
123+
# Mock shields.list
124+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
123125

124126
# Ensure system prompt resolution does not require real config
125127
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
@@ -156,6 +158,8 @@ async def test_retrieve_response_builds_rag_and_mcp_tools(
156158
mock_vector_stores = mocker.Mock()
157159
mock_vector_stores.data = [mocker.Mock(id="dbA")]
158160
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
161+
# Mock shields.list
162+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
159163

160164
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
161165
mock_cfg = mocker.Mock()
@@ -222,6 +226,8 @@ async def test_retrieve_response_parses_output_and_tool_calls(
222226
mock_vector_stores = mocker.Mock()
223227
mock_vector_stores.data = []
224228
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
229+
# Mock shields.list
230+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
225231

226232
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
227233
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
@@ -267,6 +273,8 @@ async def test_retrieve_response_with_usage_info(mocker: MockerFixture) -> None:
267273
mock_vector_stores = mocker.Mock()
268274
mock_vector_stores.data = []
269275
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
276+
# Mock shields.list
277+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
270278

271279
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
272280
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
@@ -304,6 +312,8 @@ async def test_retrieve_response_with_usage_dict(mocker: MockerFixture) -> None:
304312
mock_vector_stores = mocker.Mock()
305313
mock_vector_stores.data = []
306314
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
315+
# Mock shields.list
316+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
307317

308318
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
309319
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
@@ -341,6 +351,8 @@ async def test_retrieve_response_with_empty_usage_dict(mocker: MockerFixture) ->
341351
mock_vector_stores = mocker.Mock()
342352
mock_vector_stores.data = []
343353
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
354+
# Mock shields.list
355+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
344356

345357
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
346358
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
@@ -369,6 +381,8 @@ async def test_retrieve_response_validates_attachments(mocker: MockerFixture) ->
369381
mock_vector_stores = mocker.Mock()
370382
mock_vector_stores.data = []
371383
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
384+
# Mock shields.list
385+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
372386

373387
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
374388
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
@@ -515,3 +529,183 @@ async def test_query_endpoint_quota_exceeded(
515529
assert isinstance(detail, dict)
516530
assert detail["response"] == "Model quota exceeded" # type: ignore
517531
assert "gpt-4-turbo" in detail["cause"] # type: ignore
532+
533+
534+
@pytest.mark.asyncio
535+
async def test_retrieve_response_with_shields_available(mocker: MockerFixture) -> None:
536+
"""Test that shields are listed and passed to responses API when available."""
537+
mock_client = mocker.Mock()
538+
539+
# Mock shields.list to return available shields
540+
shield1 = mocker.Mock()
541+
shield1.identifier = "shield-1"
542+
shield2 = mocker.Mock()
543+
shield2.identifier = "shield-2"
544+
mock_client.shields.list = mocker.AsyncMock(return_value=[shield1, shield2])
545+
546+
output_item = mocker.Mock()
547+
output_item.type = "message"
548+
output_item.role = "assistant"
549+
output_item.content = "Safe response"
550+
551+
response_obj = mocker.Mock()
552+
response_obj.id = "resp-shields"
553+
response_obj.output = [output_item]
554+
response_obj.usage = None
555+
556+
mock_client.responses.create = mocker.AsyncMock(return_value=response_obj)
557+
mock_vector_stores = mocker.Mock()
558+
mock_vector_stores.data = []
559+
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
560+
561+
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
562+
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
563+
564+
qr = QueryRequest(query="hello")
565+
summary, conv_id, _referenced_docs, _token_usage = await retrieve_response(
566+
mock_client, "model-shields", qr, token="tkn", provider_id="test-provider"
567+
)
568+
569+
assert conv_id == "resp-shields"
570+
assert summary.llm_response == "Safe response"
571+
572+
# Verify that shields were passed in extra_body
573+
kwargs = mock_client.responses.create.call_args.kwargs
574+
assert "extra_body" in kwargs
575+
assert "guardrails" in kwargs["extra_body"]
576+
assert kwargs["extra_body"]["guardrails"] == ["shield-1", "shield-2"]
577+
578+
579+
@pytest.mark.asyncio
580+
async def test_retrieve_response_with_no_shields_available(
581+
mocker: MockerFixture,
582+
) -> None:
583+
"""Test that no extra_body is added when no shields are available."""
584+
mock_client = mocker.Mock()
585+
586+
# Mock shields.list to return no shields
587+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
588+
589+
output_item = mocker.Mock()
590+
output_item.type = "message"
591+
output_item.role = "assistant"
592+
output_item.content = "Response without shields"
593+
594+
response_obj = mocker.Mock()
595+
response_obj.id = "resp-no-shields"
596+
response_obj.output = [output_item]
597+
response_obj.usage = None
598+
599+
mock_client.responses.create = mocker.AsyncMock(return_value=response_obj)
600+
mock_vector_stores = mocker.Mock()
601+
mock_vector_stores.data = []
602+
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
603+
604+
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
605+
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
606+
607+
qr = QueryRequest(query="hello")
608+
summary, conv_id, _referenced_docs, _token_usage = await retrieve_response(
609+
mock_client, "model-no-shields", qr, token="tkn", provider_id="test-provider"
610+
)
611+
612+
assert conv_id == "resp-no-shields"
613+
assert summary.llm_response == "Response without shields"
614+
615+
# Verify that no extra_body was added
616+
kwargs = mock_client.responses.create.call_args.kwargs
617+
assert "extra_body" not in kwargs
618+
619+
620+
@pytest.mark.asyncio
621+
async def test_retrieve_response_detects_shield_violation(
622+
mocker: MockerFixture,
623+
) -> None:
624+
"""Test that shield violations are detected and metrics are incremented."""
625+
mock_client = mocker.Mock()
626+
627+
# Mock shields.list to return available shields
628+
shield1 = mocker.Mock()
629+
shield1.identifier = "safety-shield"
630+
mock_client.shields.list = mocker.AsyncMock(return_value=[shield1])
631+
632+
# Create output with shield violation (refusal)
633+
output_item = mocker.Mock()
634+
output_item.type = "message"
635+
output_item.role = "assistant"
636+
output_item.content = "I cannot help with that request"
637+
output_item.refusal = "Content violates safety policy"
638+
639+
response_obj = mocker.Mock()
640+
response_obj.id = "resp-violation"
641+
response_obj.output = [output_item]
642+
response_obj.usage = None
643+
644+
mock_client.responses.create = mocker.AsyncMock(return_value=response_obj)
645+
mock_vector_stores = mocker.Mock()
646+
mock_vector_stores.data = []
647+
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
648+
649+
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
650+
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
651+
652+
# Mock the validation error metric
653+
validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total")
654+
655+
qr = QueryRequest(query="dangerous query")
656+
summary, conv_id, _referenced_docs, _token_usage = await retrieve_response(
657+
mock_client, "model-violation", qr, token="tkn", provider_id="test-provider"
658+
)
659+
660+
assert conv_id == "resp-violation"
661+
assert summary.llm_response == "I cannot help with that request"
662+
663+
# Verify that the validation error metric was incremented
664+
validation_metric.inc.assert_called_once()
665+
666+
667+
@pytest.mark.asyncio
668+
async def test_retrieve_response_no_violation_with_shields(
669+
mocker: MockerFixture,
670+
) -> None:
671+
"""Test that no metric is incremented when there's no shield violation."""
672+
mock_client = mocker.Mock()
673+
674+
# Mock shields.list to return available shields
675+
shield1 = mocker.Mock()
676+
shield1.identifier = "safety-shield"
677+
mock_client.shields.list = mocker.AsyncMock(return_value=[shield1])
678+
679+
# Create output without shield violation
680+
output_item = mocker.Mock()
681+
output_item.type = "message"
682+
output_item.role = "assistant"
683+
output_item.content = "Safe response"
684+
output_item.refusal = None # No violation
685+
686+
response_obj = mocker.Mock()
687+
response_obj.id = "resp-safe"
688+
response_obj.output = [output_item]
689+
response_obj.usage = None
690+
691+
mock_client.responses.create = mocker.AsyncMock(return_value=response_obj)
692+
mock_vector_stores = mocker.Mock()
693+
mock_vector_stores.data = []
694+
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
695+
696+
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
697+
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
698+
699+
# Mock the validation error metric
700+
validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total")
701+
702+
qr = QueryRequest(query="safe query")
703+
summary, conv_id, _referenced_docs, _token_usage = await retrieve_response(
704+
mock_client, "model-safe", qr, token="tkn", provider_id="test-provider"
705+
)
706+
707+
assert conv_id == "resp-safe"
708+
assert summary.llm_response == "Safe response"
709+
710+
# Verify that the validation error metric was NOT incremented
711+
validation_metric.inc.assert_not_called()

0 commit comments

Comments
 (0)