Skip to content

Commit dccb61b

Browse files
committed
Responses API: avoid shields code duplication
1 parent fe1d7b6 commit dccb61b

File tree

3 files changed

+63
-27
lines changed

3 files changed

+63
-27
lines changed

src/app/endpoints/query_v2.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
get_topic_summary_system_prompt,
3535
)
3636
from utils.mcp_headers import mcp_headers_dependency
37+
from utils.shields import detect_shield_violations, get_available_shields
3738
from utils.token_counter import TokenCounter
3839
from utils.types import TurnSummary, ToolCallSummary
3940

@@ -343,11 +344,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
343344
and token usage information.
344345
"""
345346
# List available shields for Responses API
346-
available_shields = [shield.identifier for shield in await client.shields.list()]
347-
if not available_shields:
348-
logger.info("No available shields. Disabling safety")
349-
else:
350-
logger.info("Available shields: %s", available_shields)
347+
available_shields = await get_available_shields(client)
351348

352349
# use system prompt from request or default one
353350
system_prompt = get_system_prompt(query_request, configuration)
@@ -414,14 +411,8 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
414411
if tool_summary:
415412
tool_calls.append(tool_summary)
416413

417-
# Check for shield violations
418-
item_type = getattr(output_item, "type", None)
419-
if item_type == "message":
420-
refusal = getattr(output_item, "refusal", None)
421-
if refusal:
422-
# Metric for LLM validation errors (shield violations)
423-
metrics.llm_calls_validation_errors_total.inc()
424-
logger.warning("Shield violation detected: %s", refusal)
414+
# Check for shield violations across all output items
415+
detect_shield_violations(response.output)
425416

426417
logger.info(
427418
"Response processing complete - Tool calls: %d, Response length: %d chars",

src/app/endpoints/streaming_query_v2.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from authorization.middleware import authorize
3333
from configuration import configuration
3434
from constants import MEDIA_TYPE_JSON
35-
import metrics
3635
from models.config import Action
3736
from models.context import ResponseGeneratorContext
3837
from models.requests import QueryRequest
@@ -42,6 +41,7 @@
4241
get_system_prompt,
4342
)
4443
from utils.mcp_headers import mcp_headers_dependency
44+
from utils.shields import detect_shield_violations, get_available_shields
4545
from utils.token_counter import TokenCounter
4646
from utils.transcripts import store_transcript
4747
from utils.types import TurnSummary, ToolCallSummary
@@ -247,14 +247,9 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
247247

248248
# Check for shield violations in the completed response
249249
if latest_response_object:
250-
for output_item in getattr(latest_response_object, "output", []):
251-
item_type = getattr(output_item, "type", None)
252-
if item_type == "message":
253-
refusal = getattr(output_item, "refusal", None)
254-
if refusal:
255-
# Metric for LLM validation errors (shield violations)
256-
metrics.llm_calls_validation_errors_total.inc()
257-
logger.warning("Shield violation detected: %s", refusal)
250+
detect_shield_violations(
251+
getattr(latest_response_object, "output", [])
252+
)
258253

259254
if not emitted_turn_complete:
260255
final_message = summary.llm_response or "".join(text_parts)
@@ -379,11 +374,7 @@ async def retrieve_response(
379374
and the conversation ID.
380375
"""
381376
# List available shields for Responses API
382-
available_shields = [shield.identifier for shield in await client.shields.list()]
383-
if not available_shields:
384-
logger.info("No available shields. Disabling safety")
385-
else:
386-
logger.info("Available shields: %s", available_shields)
377+
available_shields = await get_available_shields(client)
387378

388379
# use system prompt from request or default one
389380
system_prompt = get_system_prompt(query_request, configuration)

src/utils/shields.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Utility functions for working with Llama Stack shields."""
2+
3+
import logging
4+
from typing import Any
5+
6+
from llama_stack_client import AsyncLlamaStackClient
7+
8+
import metrics
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
async def get_available_shields(client: AsyncLlamaStackClient) -> list[str]:
14+
"""
15+
Discover and return available shield identifiers.
16+
17+
Args:
18+
client: The Llama Stack client to query for available shields.
19+
20+
Returns:
21+
List of shield identifiers that are available.
22+
"""
23+
available_shields = [shield.identifier for shield in await client.shields.list()]
24+
if not available_shields:
25+
logger.info("No available shields. Disabling safety")
26+
else:
27+
logger.info("Available shields: %s", available_shields)
28+
return available_shields
29+
30+
31+
def detect_shield_violations(output_items: list[Any]) -> bool:
32+
"""
33+
Check output items for shield violations and update metrics.
34+
35+
Iterates through output items looking for message items with refusal
36+
attributes. If a refusal is found, increments the validation error
37+
metric and logs a warning.
38+
39+
Args:
40+
output_items: List of output items from the LLM response to check.
41+
42+
Returns:
43+
True if a shield violation was detected, False otherwise.
44+
"""
45+
for output_item in output_items:
46+
item_type = getattr(output_item, "type", None)
47+
if item_type == "message":
48+
refusal = getattr(output_item, "refusal", None)
49+
if refusal:
50+
# Metric for LLM validation errors (shield violations)
51+
metrics.llm_calls_validation_errors_total.inc()
52+
logger.warning("Shield violation detected: %s", refusal)
53+
return True
54+
return False

0 commit comments

Comments
 (0)