Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/app/endpoints/query_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
get_topic_summary_system_prompt,
)
from utils.mcp_headers import mcp_headers_dependency
from utils.shields import detect_shield_violations, get_available_shields
from utils.token_counter import TokenCounter
from utils.types import TurnSummary, ToolCallSummary

Expand Down Expand Up @@ -322,7 +323,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
given query, handling shield configuration, tool usage, and
attachment validation.

This function configures system prompts and toolgroups
This function configures system prompts, shields, and toolgroups
(including RAG and MCP integration) as needed based on
the query request and system configuration. It
validates attachments, manages conversation and session
Expand All @@ -342,8 +343,8 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
and the conversation ID, the list of parsed referenced documents,
and token usage information.
"""
# TODO(ltomasbo): implement shields support once available in Responses API
logger.info("Shields are not yet supported in Responses API. Disabling safety")
# List available shields for Responses API
available_shields = await get_available_shields(client)

# use system prompt from request or default one
system_prompt = get_system_prompt(query_request, configuration)
Expand Down Expand Up @@ -381,6 +382,10 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
if query_request.conversation_id:
create_kwargs["previous_response_id"] = query_request.conversation_id

# Add shields to extra_body if available
if available_shields:
create_kwargs["extra_body"] = {"guardrails": available_shields}

response = await client.responses.create(**create_kwargs)
response = cast(OpenAIResponseObject, response)

Expand All @@ -406,6 +411,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
if tool_summary:
tool_calls.append(tool_summary)

# Check for shield violations across all output items
detect_shield_violations(response.output)

logger.info(
"Response processing complete - Tool calls: %d, Response length: %d chars",
len(tool_calls),
Expand Down
25 changes: 19 additions & 6 deletions src/app/endpoints/streaming_query_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
get_system_prompt,
)
from utils.mcp_headers import mcp_headers_dependency
from utils.shields import detect_shield_violations, get_available_shields
from utils.token_counter import TokenCounter
from utils.transcripts import store_transcript
from utils.types import TurnSummary, ToolCallSummary
Expand Down Expand Up @@ -243,6 +244,13 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
elif event_type == "response.completed":
# Capture the response object for token usage extraction
latest_response_object = getattr(chunk, "response", None)

# Check for shield violations in the completed response
if latest_response_object:
detect_shield_violations(
getattr(latest_response_object, "output", [])
)

if not emitted_turn_complete:
final_message = summary.llm_response or "".join(text_parts)
if not final_message:
Expand Down Expand Up @@ -348,11 +356,11 @@ async def retrieve_response(
Asynchronously retrieves a streaming response and conversation
ID from the Llama Stack agent for a given user query.

This function configures input/output shields, system prompt,
and tool usage based on the request and environment. It
prepares the agent with appropriate headers and toolgroups,
validates attachments if present, and initiates a streaming
turn with the user's query and any provided documents.
This function configures shields, system prompt, and tool usage
based on the request and environment. It prepares the agent with
appropriate headers and toolgroups, validates attachments if
present, and initiates a streaming turn with the user's query
and any provided documents.

Parameters:
model_id (str): Identifier of the model to use for the query.
Expand All @@ -365,7 +373,8 @@ async def retrieve_response(
tuple: A tuple containing the streaming response object
and the conversation ID.
"""
logger.info("Shields are not yet supported in Responses API.")
# List available shields for Responses API
available_shields = await get_available_shields(client)

# use system prompt from request or default one
system_prompt = get_system_prompt(query_request, configuration)
Expand Down Expand Up @@ -402,6 +411,10 @@ async def retrieve_response(
if query_request.conversation_id:
create_params["previous_response_id"] = query_request.conversation_id

# Add shields to extra_body if available
if available_shields:
create_params["extra_body"] = {"guardrails": available_shields}

response = await client.responses.create(**create_params)
response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response)

Expand Down
54 changes: 54 additions & 0 deletions src/utils/shields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Utility functions for working with Llama Stack shields."""

import logging
from typing import Any

from llama_stack_client import AsyncLlamaStackClient

import metrics

logger = logging.getLogger(__name__)


async def get_available_shields(client: AsyncLlamaStackClient) -> list[str]:
"""
Discover and return available shield identifiers.

Args:
client: The Llama Stack client to query for available shields.

Returns:
List of shield identifiers that are available.
"""
available_shields = [shield.identifier for shield in await client.shields.list()]
if not available_shields:
logger.info("No available shields. Disabling safety")
else:
logger.info("Available shields: %s", available_shields)
return available_shields


def detect_shield_violations(output_items: list[Any]) -> bool:
"""
Check output items for shield violations and update metrics.

Iterates through output items looking for message items with refusal
attributes. If a refusal is found, increments the validation error
metric and logs a warning.

Args:
output_items: List of output items from the LLM response to check.

Returns:
True if a shield violation was detected, False otherwise.
"""
for output_item in output_items:
item_type = getattr(output_item, "type", None)
if item_type == "message":
refusal = getattr(output_item, "refusal", None)
if refusal:
# Metric for LLM validation errors (shield violations)
metrics.llm_calls_validation_errors_total.inc()
logger.warning("Shield violation detected: %s", refusal)
return True
return False
Loading
Loading