Skip to content

Commit c1e4aee

Browse files
authored
Merge pull request #760 from luis5tb/responses_v2_shields
Responses v2 shields
2 parents 6bc8850 + dccb61b commit c1e4aee

File tree

5 files changed

+526
-9
lines changed

5 files changed

+526
-9
lines changed

src/app/endpoints/query_v2.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
get_topic_summary_system_prompt,
3535
)
3636
from utils.mcp_headers import mcp_headers_dependency
37+
from utils.shields import detect_shield_violations, get_available_shields
3738
from utils.token_counter import TokenCounter
3839
from utils.types import TurnSummary, ToolCallSummary
3940

@@ -322,7 +323,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
322323
given query, handling shield configuration, tool usage, and
323324
attachment validation.
324325
325-
This function configures system prompts and toolgroups
326+
This function configures system prompts, shields, and toolgroups
326327
(including RAG and MCP integration) as needed based on
327328
the query request and system configuration. It
328329
validates attachments, manages conversation and session
@@ -342,8 +343,8 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
342343
and the conversation ID, the list of parsed referenced documents,
343344
and token usage information.
344345
"""
345-
# TODO(ltomasbo): implement shields support once available in Responses API
346-
logger.info("Shields are not yet supported in Responses API. Disabling safety")
346+
# List available shields for Responses API
347+
available_shields = await get_available_shields(client)
347348

348349
# use system prompt from request or default one
349350
system_prompt = get_system_prompt(query_request, configuration)
@@ -381,6 +382,10 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
381382
if query_request.conversation_id:
382383
create_kwargs["previous_response_id"] = query_request.conversation_id
383384

385+
# Add shields to extra_body if available
386+
if available_shields:
387+
create_kwargs["extra_body"] = {"guardrails": available_shields}
388+
384389
response = await client.responses.create(**create_kwargs)
385390
response = cast(OpenAIResponseObject, response)
386391

@@ -406,6 +411,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
406411
if tool_summary:
407412
tool_calls.append(tool_summary)
408413

414+
# Check for shield violations across all output items
415+
detect_shield_violations(response.output)
416+
409417
logger.info(
410418
"Response processing complete - Tool calls: %d, Response length: %d chars",
411419
len(tool_calls),

src/app/endpoints/streaming_query_v2.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
get_system_prompt,
4242
)
4343
from utils.mcp_headers import mcp_headers_dependency
44+
from utils.shields import detect_shield_violations, get_available_shields
4445
from utils.token_counter import TokenCounter
4546
from utils.transcripts import store_transcript
4647
from utils.types import TurnSummary, ToolCallSummary
@@ -243,6 +244,13 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
243244
elif event_type == "response.completed":
244245
# Capture the response object for token usage extraction
245246
latest_response_object = getattr(chunk, "response", None)
247+
248+
# Check for shield violations in the completed response
249+
if latest_response_object:
250+
detect_shield_violations(
251+
getattr(latest_response_object, "output", [])
252+
)
253+
246254
if not emitted_turn_complete:
247255
final_message = summary.llm_response or "".join(text_parts)
248256
if not final_message:
@@ -348,11 +356,11 @@ async def retrieve_response(
348356
Asynchronously retrieves a streaming response and conversation
349357
ID from the Llama Stack agent for a given user query.
350358
351-
This function configures input/output shields, system prompt,
352-
and tool usage based on the request and environment. It
353-
prepares the agent with appropriate headers and toolgroups,
354-
validates attachments if present, and initiates a streaming
355-
turn with the user's query and any provided documents.
359+
This function configures shields, system prompt, and tool usage
360+
based on the request and environment. It prepares the agent with
361+
appropriate headers and toolgroups, validates attachments if
362+
present, and initiates a streaming turn with the user's query
363+
and any provided documents.
356364
357365
Parameters:
358366
model_id (str): Identifier of the model to use for the query.
@@ -365,7 +373,8 @@ async def retrieve_response(
365373
tuple: A tuple containing the streaming response object
366374
and the conversation ID.
367375
"""
368-
logger.info("Shields are not yet supported in Responses API.")
376+
# List available shields for Responses API
377+
available_shields = await get_available_shields(client)
369378

370379
# use system prompt from request or default one
371380
system_prompt = get_system_prompt(query_request, configuration)
@@ -402,6 +411,10 @@ async def retrieve_response(
402411
if query_request.conversation_id:
403412
create_params["previous_response_id"] = query_request.conversation_id
404413

414+
# Add shields to extra_body if available
415+
if available_shields:
416+
create_params["extra_body"] = {"guardrails": available_shields}
417+
405418
response = await client.responses.create(**create_params)
406419
response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response)
407420

src/utils/shields.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Utility functions for working with Llama Stack shields."""
2+
3+
import logging
4+
from typing import Any
5+
6+
from llama_stack_client import AsyncLlamaStackClient
7+
8+
import metrics
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
async def get_available_shields(client: AsyncLlamaStackClient) -> list[str]:
14+
"""
15+
Discover and return available shield identifiers.
16+
17+
Args:
18+
client: The Llama Stack client to query for available shields.
19+
20+
Returns:
21+
List of shield identifiers that are available.
22+
"""
23+
available_shields = [shield.identifier for shield in await client.shields.list()]
24+
if not available_shields:
25+
logger.info("No available shields. Disabling safety")
26+
else:
27+
logger.info("Available shields: %s", available_shields)
28+
return available_shields
29+
30+
31+
def detect_shield_violations(output_items: list[Any]) -> bool:
32+
"""
33+
Check output items for shield violations and update metrics.
34+
35+
Iterates through output items looking for message items with refusal
36+
attributes. If a refusal is found, increments the validation error
37+
metric and logs a warning.
38+
39+
Args:
40+
output_items: List of output items from the LLM response to check.
41+
42+
Returns:
43+
True if a shield violation was detected, False otherwise.
44+
"""
45+
for output_item in output_items:
46+
item_type = getattr(output_item, "type", None)
47+
if item_type == "message":
48+
refusal = getattr(output_item, "refusal", None)
49+
if refusal:
50+
# Metric for LLM validation errors (shield violations)
51+
metrics.llm_calls_validation_errors_total.inc()
52+
logger.warning("Shield violation detected: %s", refusal)
53+
return True
54+
return False

0 commit comments

Comments
 (0)