Skip to content

Conversation

@luis5tb
Copy link
Contributor

@luis5tb luis5tb commented Nov 5, 2025

Description

Extend the Responses API support (v2 endpoints) by also adding the option to use shields.

Note there is a limitation in LlamaStack where the same shields must be used for input and output.

Type of change

  • Refactor
  • New feature
  • Bug fix
  • CVE fix
  • Optimization
  • Documentation Update
  • Configuration Update
  • Bump-up service version
  • Bump-up dependent library
  • Bump-up library or tool used for development (does not change the final image)
  • CI configuration change
  • Konflux configuration change
  • Unit tests improvement
  • Integration tests improvement
  • End to end tests improvement

Related Tickets & Documents

Checklist before requesting a review

  • I have performed a self-review of my code.
  • PR has passed all pre-merge test jobs.
  • If it is a core feature, I have added thorough tests.

Testing

  • Please provide detailed steps to perform tests related to this code change.
  • How were the fix/results from this change verified? Please provide relevant screenshots or results.

Summary by CodeRabbit

  • New Features

    • Runtime discovery and integration of safety shields into query and streaming flows
    • Shield-based guardrails included in API request handling when shields are present
    • Shield violation detection that increments validation error metrics and emits warnings
  • Tests

    • Expanded coverage for shield availability, guardrail propagation, and violation detection across query and streaming endpoints

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 5, 2025

Walkthrough

Adds runtime shield discovery to Responses API v2 endpoints: queries available shields at request time, attaches shield IDs as extra_body.guardrails when present, and inspects response output items for refusal messages to increment a validation-error metric; tests expanded to cover shield presence/absence and violation detection.

Changes

Cohort / File(s) Summary
Endpoints (sync + streaming)
src/app/endpoints/query_v2.py, src/app/endpoints/streaming_query_v2.py
Call get_available_shields(client) (wraps client.shields.list()), include extra_body.guardrails with shield IDs when creating Responses API requests, and run detect_shield_violations against response output items to increment llm_calls_validation_errors_total on refusal detections; minor docstring/log updates.
Shield utilities
src/utils/shields.py
New module exposing get_available_shields(client: AsyncLlamaStackClient) -> list[str] and detect_shield_violations(output_items: list[Any]) -> bool with logging and metric increment on detected refusals.
Unit tests (query v2)
tests/unit/app/endpoints/test_query_v2.py
Add/mocks for shields.list() across tests; new tests asserting guardrail propagation when shields present, absence when none, and metric increment on detected shield refusal; adjust existing mocks to avoid false negatives.
Unit tests (streaming v2)
tests/unit/app/endpoints/test_streaming_query_v2.py
Mock shields.list() in existing tests; new tests for guardrail propagation (present/absent) and streaming-path shield violation detection (metric increment) and non-violation cases; adjust streaming mocks and SSE sequencing in tests.

Sequence Diagram(s)

sequenceDiagram
    participant C as Client
    participant H as Endpoint Handler
    participant S as Shields util
    participant R as Responses API
    participant P as Processor

    C->>H: QueryRequest v2

    rect rgb(240, 250, 240)
    Note over H,S: Discover shields
    H->>S: get_available_shields(client)
    S-->>H: [shield_id,...] or []
    end

    rect rgb(240, 240, 255)
    Note over H,R: Create Responses request
    alt Shields present
        H->>R: create(..., extra_body: { guardrails: [ids] })
    else No shields
        H->>R: create(..., no extra_body.guardrails)
    end
    end

    R-->>H: response (output items...)

    rect rgb(255, 250, 230)
    Note over H,P: Inspect outputs for refusals
    loop per output item
        alt message with refusal
            P->>P: detect_shield_violations -> increment llm_calls_validation_errors_total
            P->>H: log shield violation
        else
            P->>H: normal processing
        end
    end
    end

    H-->>C: Final response / stream
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Pay attention to error handling from client.shields.list() and how failures are logged/propagated.
  • Verify extra_body.guardrails shape matches Responses API expectations.
  • Confirm detect_shield_violations reliably identifies refusal messages without false positives.
  • Review tests for correct mocking of client.shields.list() and metrics assertions.

Possibly related PRs

Suggested reviewers

  • tisnik
  • manstis

Pre-merge checks and finishing touches

❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Title check ❓ Inconclusive The title is vague and uses work-in-progress indicator without providing clear specifics about the main change. Consider a more descriptive title that clearly explains the feature, such as 'Add shield support to Responses v2 endpoints' or 'Enable shields integration in Responses API v2'.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 92.59% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/app/endpoints/query_v2.py (1)

468-558: Prevent double-counting llm_calls_total.

extract_token_usage_from_responses_api always bumps llm_calls_total, but the shared streaming base already increments that metric before yielding the SSE stream. For the Response API path this results in every request being counted twice. Please make the metric increment optional so we can disable it for streaming callers while keeping the existing behavior for synchronous flows.

-def extract_token_usage_from_responses_api(
-    response: OpenAIResponseObject,
-    model: str,
-    provider: str,
-    system_prompt: str = "",  # pylint: disable=unused-argument
-) -> TokenCounter:
+def extract_token_usage_from_responses_api(
+    response: OpenAIResponseObject,
+    model: str,
+    provider: str,
+    system_prompt: str = "",  # pylint: disable=unused-argument
+    *,
+    increment_llm_call_metric: bool = True,
+) -> TokenCounter:
@@
-                # Update Prometheus metrics only when we have actual usage data
-                try:
-                    metrics.llm_token_sent_total.labels(provider, model).inc(
-                        token_counter.input_tokens
-                    )
-                    metrics.llm_token_received_total.labels(provider, model).inc(
-                        token_counter.output_tokens
-                    )
-                except (AttributeError, TypeError, ValueError) as e:
-                    logger.warning("Failed to update token metrics: %s", e)
-                _increment_llm_call_metric(provider, model)
+                # Update Prometheus metrics only when we have actual usage data
+                try:
+                    metrics.llm_token_sent_total.labels(provider, model).inc(
+                        token_counter.input_tokens
+                    )
+                    metrics.llm_token_received_total.labels(provider, model).inc(
+                        token_counter.output_tokens
+                    )
+                except (AttributeError, TypeError, ValueError) as e:
+                    logger.warning("Failed to update token metrics: %s", e)
+                if increment_llm_call_metric:
+                    _increment_llm_call_metric(provider, model)
@@
-                _increment_llm_call_metric(provider, model)
+                if increment_llm_call_metric:
+                    _increment_llm_call_metric(provider, model)
@@
-        _increment_llm_call_metric(provider, model)
+        if increment_llm_call_metric:
+            _increment_llm_call_metric(provider, model)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4f39c33 and fb25b8c.

📒 Files selected for processing (8)
  • src/app/endpoints/query_v2.py (7 hunks)
  • src/app/endpoints/streaming_query.py (6 hunks)
  • src/app/endpoints/streaming_query_v2.py (1 hunks)
  • src/app/routers.py (2 hunks)
  • src/models/context.py (1 hunks)
  • tests/unit/app/endpoints/test_query_v2.py (9 hunks)
  • tests/unit/app/endpoints/test_streaming_query_v2.py (1 hunks)
  • tests/unit/app/test_routers.py (5 hunks)
🧰 Additional context used
📓 Path-based instructions (9)
src/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

Use absolute imports for internal modules (e.g., from auth import get_auth_dependency)

Files:

  • src/app/routers.py
  • src/models/context.py
  • src/app/endpoints/streaming_query_v2.py
  • src/app/endpoints/streaming_query.py
  • src/app/endpoints/query_v2.py
src/app/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

Use standard FastAPI imports (from fastapi import APIRouter, HTTPException, Request, status, Depends) in FastAPI app code

Files:

  • src/app/routers.py
  • src/app/endpoints/streaming_query_v2.py
  • src/app/endpoints/streaming_query.py
  • src/app/endpoints/query_v2.py
**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

**/*.py: All modules start with descriptive module-level docstrings explaining purpose
Use logger = logging.getLogger(name) for module logging after import logging
Define type aliases at module level for clarity
All functions require docstrings with brief descriptions
Provide complete type annotations for all function parameters and return types
Use typing_extensions.Self in model validators where appropriate
Use modern union syntax (str | int) and Optional[T] or T | None consistently
Function names use snake_case with descriptive, action-oriented prefixes (get_, validate_, check_)
Avoid in-place parameter modification; return new data structures instead of mutating arguments
Use appropriate logging levels: debug, info, warning, error with clear messages
All classes require descriptive docstrings explaining purpose
Class names use PascalCase with conventional suffixes (Configuration, Error/Exception, Resolver, Interface)
Abstract base classes should use abc.ABC and @AbstractMethod for interfaces
Provide complete type annotations for all class attributes
Follow Google Python docstring style for modules, classes, and functions, including Args, Returns, Raises, Attributes sections as needed

Files:

  • src/app/routers.py
  • src/models/context.py
  • src/app/endpoints/streaming_query_v2.py
  • tests/unit/app/endpoints/test_query_v2.py
  • tests/unit/app/test_routers.py
  • src/app/endpoints/streaming_query.py
  • src/app/endpoints/query_v2.py
  • tests/unit/app/endpoints/test_streaming_query_v2.py
src/{app/**/*.py,client.py}

📄 CodeRabbit inference engine (CLAUDE.md)

Use async def for I/O-bound operations and external API calls

Files:

  • src/app/routers.py
  • src/app/endpoints/streaming_query_v2.py
  • src/app/endpoints/streaming_query.py
  • src/app/endpoints/query_v2.py
src/{models/**/*.py,configuration.py}

📄 CodeRabbit inference engine (CLAUDE.md)

src/{models/**/*.py,configuration.py}: Use @field_validator and @model_validator for custom validation in Pydantic models
Use precise type hints in configuration (e.g., Optional[FilePath], PositiveInt, SecretStr)

Files:

  • src/models/context.py
src/models/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

src/models/**/*.py: Pydantic models: use BaseModel for data models and extend ConfigurationBase for configuration
Use @model_validator and @field_validator for Pydantic model validation

Files:

  • src/models/context.py
src/app/endpoints/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

In API endpoints, raise FastAPI HTTPException with appropriate status codes for error handling

Files:

  • src/app/endpoints/streaming_query_v2.py
  • src/app/endpoints/streaming_query.py
  • src/app/endpoints/query_v2.py
tests/{unit,integration}/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/{unit,integration}/**/*.py: Use pytest for all unit and integration tests
Do not use unittest in tests; pytest is the standard

Files:

  • tests/unit/app/endpoints/test_query_v2.py
  • tests/unit/app/test_routers.py
  • tests/unit/app/endpoints/test_streaming_query_v2.py
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Use pytest-mock to create AsyncMock objects for async interactions in tests
Use the shared auth mock constant: MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") in tests

Files:

  • tests/unit/app/endpoints/test_query_v2.py
  • tests/unit/app/test_routers.py
  • tests/unit/app/endpoints/test_streaming_query_v2.py
🧬 Code graph analysis (7)
src/app/routers.py (1)
tests/unit/app/test_routers.py (1)
  • include_router (37-52)
src/models/context.py (1)
src/models/requests.py (1)
  • QueryRequest (73-225)
src/app/endpoints/streaming_query_v2.py (12)
src/app/database.py (1)
  • get_session (34-40)
src/app/endpoints/query.py (3)
  • is_transcripts_enabled (98-104)
  • persist_user_conversation_details (107-139)
  • validate_attachments_metadata (801-830)
src/app/endpoints/query_v2.py (4)
  • extract_token_usage_from_responses_api (468-558)
  • get_topic_summary (235-274)
  • prepare_tools_for_responses_api (634-685)
  • retrieve_response (304-444)
src/app/endpoints/streaming_query.py (6)
  • format_stream_data (126-137)
  • stream_end_event (164-220)
  • stream_start_event (140-161)
  • streaming_query_endpoint_handler_base (851-983)
  • response_generator (716-846)
  • retrieve_response (1018-1139)
src/models/cache_entry.py (1)
  • CacheEntry (7-24)
src/models/context.py (1)
  • ResponseGeneratorContext (12-48)
src/models/responses.py (2)
  • ForbiddenResponse (1120-1142)
  • UnauthorizedResponse (1094-1117)
src/utils/endpoints.py (2)
  • create_referenced_documents_with_metadata (563-577)
  • store_conversation_into_cache (231-251)
src/utils/mcp_headers.py (1)
  • mcp_headers_dependency (15-26)
src/utils/token_counter.py (1)
  • TokenCounter (18-41)
src/utils/transcripts.py (1)
  • store_transcript (40-99)
src/utils/types.py (2)
  • TurnSummary (89-163)
  • ToolCallSummary (73-86)
tests/unit/app/endpoints/test_query_v2.py (2)
src/models/config.py (1)
  • ModelContextProtocolServer (169-174)
src/app/endpoints/query_v2.py (2)
  • get_mcp_tools (592-631)
  • retrieve_response (304-444)
src/app/endpoints/streaming_query.py (8)
src/models/context.py (1)
  • ResponseGeneratorContext (12-48)
src/utils/endpoints.py (4)
  • get_system_prompt (126-190)
  • create_rag_chunks_dict (383-396)
  • create_referenced_documents_with_metadata (563-577)
  • store_conversation_into_cache (231-251)
src/metrics/utils.py (1)
  • update_llm_token_count_from_turn (60-77)
src/utils/token_counter.py (2)
  • extract_token_usage_from_turn (44-94)
  • TokenCounter (18-41)
src/app/endpoints/query.py (1)
  • persist_user_conversation_details (107-139)
src/utils/transcripts.py (1)
  • store_transcript (40-99)
src/app/database.py (1)
  • get_session (34-40)
src/models/database/conversations.py (1)
  • UserConversation (11-38)
src/app/endpoints/query_v2.py (2)
src/configuration.py (3)
  • configuration (73-77)
  • AppConfig (39-181)
  • mcp_servers (101-105)
src/models/requests.py (1)
  • QueryRequest (73-225)
tests/unit/app/endpoints/test_streaming_query_v2.py (4)
src/models/requests.py (1)
  • QueryRequest (73-225)
src/models/config.py (3)
  • config (140-146)
  • Action (329-375)
  • ModelContextProtocolServer (169-174)
src/app/endpoints/streaming_query_v2.py (2)
  • retrieve_response (397-478)
  • streaming_query_endpoint_handler_v2 (367-394)
src/configuration.py (1)
  • mcp_servers (101-105)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: e2e_tests (azure)
  • GitHub Check: e2e_tests (ci)
  • GitHub Check: build-pr
  • GitHub Check: Konflux kflux-prd-rh02 / lightspeed-stack-on-pull-request

Comment on lines 134 to 275
chunk_id = 0
summary = TurnSummary(llm_response="No response from the model", tool_calls=[])

# Determine media type for response formatting
media_type = context.query_request.media_type or MEDIA_TYPE_JSON

# Accumulators for Responses API
text_parts: list[str] = []
tool_item_registry: dict[str, dict[str, str]] = {}
emitted_turn_complete = False

# Handle conversation id and start event in-band on response.created
conv_id = context.conversation_id

# Track the latest response object from response.completed event
latest_response_object: Any | None = None

logger.debug("Starting streaming response (Responses API) processing")

async for chunk in turn_response:
event_type = getattr(chunk, "type", None)
logger.debug("Processing chunk %d, type: %s", chunk_id, event_type)

# Emit start on response.created
if event_type == "response.created":
try:
conv_id = getattr(chunk, "response").id
except Exception: # pylint: disable=broad-except
conv_id = ""
yield stream_start_event(conv_id)
continue

# Text streaming
if event_type == "response.output_text.delta":
delta = getattr(chunk, "delta", "")
if delta:
text_parts.append(delta)
yield format_stream_data(
{
"event": "token",
"data": {
"id": chunk_id,
"token": delta,
},
}
)
chunk_id += 1

# Final text of the output (capture, but emit at response.completed)
elif event_type == "response.output_text.done":
final_text = getattr(chunk, "text", "")
if final_text:
summary.llm_response = final_text

# Content part started - emit an empty token to kick off UI streaming if desired
elif event_type == "response.content_part.added":
yield format_stream_data(
{
"event": "token",
"data": {
"id": chunk_id,
"token": "",
},
}
)
chunk_id += 1

# Track tool call items as they are added so we can build a summary later
elif event_type == "response.output_item.added":
item = getattr(chunk, "item", None)
item_type = getattr(item, "type", None)
if item and item_type == "function_call":
item_id = getattr(item, "id", "")
name = getattr(item, "name", "function_call")
call_id = getattr(item, "call_id", item_id)
if item_id:
tool_item_registry[item_id] = {
"name": name,
"call_id": call_id,
}

# Stream tool call arguments as tool_call events
elif event_type == "response.function_call_arguments.delta":
delta = getattr(chunk, "delta", "")
yield format_stream_data(
{
"event": "tool_call",
"data": {
"id": chunk_id,
"role": "tool_execution",
"token": delta,
},
}
)
chunk_id += 1

# Finalize tool call arguments and append to summary
elif event_type in (
"response.function_call_arguments.done",
"response.mcp_call.arguments.done",
):
item_id = getattr(chunk, "item_id", "")
arguments = getattr(chunk, "arguments", "")
meta = tool_item_registry.get(item_id, {})
summary.tool_calls.append(
ToolCallSummary(
id=meta.get("call_id", item_id or "unknown"),
name=meta.get("name", "tool_call"),
args=arguments,
response=None,
)
)

# Completed response - capture final text and response object
elif event_type == "response.completed":
# Capture the response object for token usage extraction
latest_response_object = getattr(chunk, "response", None)

# Check for shield violations in the completed response
if latest_response_object:
for output_item in getattr(latest_response_object, "output", []):
item_type = getattr(output_item, "type", None)
if item_type == "message":
refusal = getattr(output_item, "refusal", None)
if refusal:
# Metric for LLM validation errors (shield violations)
metrics.llm_calls_validation_errors_total.inc()
logger.warning("Shield violation detected: %s", refusal)

if not emitted_turn_complete:
final_message = summary.llm_response or "".join(text_parts)
yield format_stream_data(
{
"event": "turn_complete",
"data": {
"id": chunk_id,
"token": final_message,
},
}
)
chunk_id += 1
emitted_turn_complete = True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Persist the streamed text into the summary.

When the Responses API only emits response.output_text.delta chunks (and never provides a non-empty response.output_text.done), summary.llm_response stays at the placeholder value. We then emit "No response from the model" in the final turn_complete event and we store the wrong text into transcripts/cache. Capture the accumulated deltas in the summary before finishing the turn.

-        summary = TurnSummary(llm_response="No response from the model", tool_calls=[])
+        summary = TurnSummary(llm_response="", tool_calls=[])
@@
-                if not emitted_turn_complete:
-                    final_message = summary.llm_response or "".join(text_parts)
+                if not emitted_turn_complete:
+                    final_message = summary.llm_response or "".join(text_parts)
+                    summary.llm_response = final_message
🤖 Prompt for AI Agents
In src/app/endpoints/streaming_query_v2.py around lines 134-276, the
summary.llm_response is left as the placeholder when only
response.output_text.delta chunks are received; before emitting the final
turn_complete (inside the response.completed handling) set summary.llm_response
= summary.llm_response if it already contains a non-placeholder value else
"".join(text_parts) (or simply assign "".join(text_parts) if
summary.llm_response is the placeholder) so the accumulated token deltas are
persisted into the summary and used for the final emitted token and stored
transcripts/cache.

Comment on lines 284 to 296

# Extract token usage from the response object
token_usage = (
extract_token_usage_from_responses_api(
latest_response_object, context.model_id, context.provider_id
)
if latest_response_object is not None
else TokenCounter()
)

yield stream_end_event(context.metadata_map, summary, token_usage, media_type)

if not is_transcripts_enabled():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid incrementing llm_calls_total twice.

After the streaming base already increments metrics.llm_calls_total, calling extract_token_usage_from_responses_api here triggers a second increment. Once the helper accepts an increment_llm_call_metric flag, pass False from the streaming path so the metric stays accurate while we still update token counters.

-            extract_token_usage_from_responses_api(
-                latest_response_object, context.model_id, context.provider_id
-            )
+            extract_token_usage_from_responses_api(
+                latest_response_object,
+                context.model_id,
+                context.provider_id,
+                increment_llm_call_metric=False,
+            )

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/app/endpoints/streaming_query_v2.py around lines 284 to 296, calling
extract_token_usage_from_responses_api here causes metrics.llm_calls_total to be
incremented a second time because the streaming base already incremented it;
update the call to pass increment_llm_call_metric=False (once the helper accepts
that flag) so the helper extracts token counters without bumping the LLM call
metric, ensuring only token counters are updated and metrics.llm_calls_total
remains accurate.

@luis5tb luis5tb force-pushed the responses_v2_shields branch from fb25b8c to b669b32 Compare November 14, 2025 08:09
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
src/app/endpoints/streaming_query_v2.py (1)

285-288: Be aware of double‑counting llm_calls_total in streaming (already flagged earlier).

extract_token_usage_from_responses_api internally increments metrics.llm_calls_total, and the streaming base handler also increments that metric; calling it here effectively bumps llm_calls_total twice per streaming call. A previous review already suggested adding an increment_llm_call_metric flag to the helper and passing False from this path to avoid double-counting.

Not a new regression in this diff, but still worth addressing in a follow-up.

🧹 Nitpick comments (3)
tests/unit/app/endpoints/test_query_v2.py (1)

618-661: Shield violation metric test aligns with implementation.

This test accurately models a shield violation (assistant message with a non-empty refusal), patches metrics.llm_calls_validation_errors_total, and asserts inc() is called exactly once. That matches the new loop in query_v2.retrieve_response that scans output items for refusals and bumps the validation-error metric.

One minor thought: if a response ever contained multiple violating message items, the metric would be incremented multiple times per call. If you intend this metric to be “per call” rather than “per offending message”, you could break after the first refusal, but that’s a behavioral choice, not a blocker.

src/app/endpoints/query_v2.py (1)

417-424: Shield violation detection is implemented as expected; consider counting once per call.

The loop over response.output checks each message-type item for a non-empty refusal and increments metrics.llm_calls_validation_errors_total for each. That matches the new tests and will correctly mark shield-triggered refusals.

If you intend llm_calls_validation_errors_total to be “per call” rather than “per offending message”, you could short-circuit after the first violation:

-    for output_item in response.output:
+    for output_item in response.output:
         ...
-        if item_type == "message":
-            refusal = getattr(output_item, "refusal", None)
-            if refusal:
-                metrics.llm_calls_validation_errors_total.inc()
-                logger.warning("Shield violation detected: %s", refusal)
+        if item_type == "message":
+            refusal = getattr(output_item, "refusal", None)
+            if refusal:
+                metrics.llm_calls_validation_errors_total.inc()
+                logger.warning("Shield violation detected: %s", refusal)
+                break

Not urgent, but it may make the metric easier to reason about.

src/app/endpoints/streaming_query_v2.py (1)

248-258: Streaming shield violation handling is consistent with the non‑streaming path.

On response.completed, you:

  • Grab the final response object.
  • Scan output for message items with a non-empty refusal.
  • Increment metrics.llm_calls_validation_errors_total and log a warning once per offending message.

This matches the semantics in query_v2.retrieve_response, and the new tests exercise both violation and no-violation cases. Same optional note as in the REST path: if you want a “per call” metric, consider breaking after the first refusal, but behavior is otherwise correct.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fb25b8c and b669b32.

📒 Files selected for processing (4)
  • src/app/endpoints/query_v2.py (4 hunks)
  • src/app/endpoints/streaming_query_v2.py (5 hunks)
  • tests/unit/app/endpoints/test_query_v2.py (8 hunks)
  • tests/unit/app/endpoints/test_streaming_query_v2.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/unit/app/endpoints/test_query_v2.py (3)
tests/unit/app/endpoints/test_streaming_query_v2.py (2)
  • test_retrieve_response_with_shields_available (232-262)
  • test_retrieve_response_with_no_shields_available (266-290)
src/models/requests.py (1)
  • QueryRequest (73-225)
src/app/endpoints/query_v2.py (1)
  • retrieve_response (309-449)
tests/unit/app/endpoints/test_streaming_query_v2.py (3)
tests/unit/app/endpoints/test_query_v2.py (3)
  • test_retrieve_response_with_shields_available (535-576)
  • test_retrieve_response_with_no_shields_available (580-615)
  • dummy_request (32-35)
src/app/endpoints/query_v2.py (1)
  • retrieve_response (309-449)
src/app/endpoints/streaming_query_v2.py (2)
  • retrieve_response (351-432)
  • streaming_query_endpoint_handler_v2 (321-348)
🪛 GitHub Actions: Unit tests
src/app/endpoints/streaming_query_v2.py

[error] 1-1: AttributeError: module 'app.endpoints.streaming_query_v2' does not have the attribute 'get_session'.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: build-pr
  • GitHub Check: Konflux kflux-prd-rh02 / lightspeed-stack-on-pull-request
  • GitHub Check: e2e_tests (azure)
  • GitHub Check: e2e_tests (ci)
🔇 Additional comments (10)
tests/unit/app/endpoints/test_streaming_query_v2.py (2)

41-42: Good: shields.list is mocked in existing streaming retrieve_response tests.

Adding mock_client.shields.list = mocker.AsyncMock(return_value=[]) here keeps the updated retrieve_response implementation from blowing up on missing shields while keeping the original test focus (tools & streaming flags) intact. No issues from this change.

Also applies to: 73-74


231-263: LGTM: streaming retrieve_response shields plumbing is covered.

These tests correctly:

  • Mock client.shields.list to return identifiers.
  • Assert that extra_body.guardrails is present when shields exist and absent when not.
  • Keep the rest of the behavior (vector stores, tools, system prompt) mocked but unconstrained, which is appropriate for this focused check.

No changes needed here.

tests/unit/app/endpoints/test_query_v2.py (4)

123-124: Good: existing retrieve_response tests are adapted to shields.

Adding mock_client.shields.list = mocker.AsyncMock(return_value=[]) in all the existing retrieve_response tests makes them compatible with the new shield discovery logic without changing what they assert (tools, usage, attachments, etc.). This is the right, minimal adjustment.

Also applies to: 161-162, 229-230, 276-277, 315-316, 354-355, 384-385


535-577: Non‑streaming shields plumbing test looks solid.

This test:

  • Mocks shields.list to return identifiers.
  • Verifies summary.llm_response and conv_id.
  • Asserts extra_body.guardrails matches the shield identifiers.

This directly exercises the new available_shieldsextra_body["guardrails"] flow in query_v2.retrieve_response. No changes needed.


579-616: Correctly covers “no shields → no extra_body” behavior.

The test for the empty-shields case ensures:

  • client.shields.list returns [].
  • retrieve_response still returns the expected summary and conversation id.
  • extra_body is not present in the responses.create kwargs.

This matches the intended behavior and protects against regressions where we might start sending an empty guardrails list.


663-705: “No violation” test correctly guards against false positives.

Here you explicitly set output_item.refusal = None and assert that validation_metric.inc is never called. That’s important because with bare Mock instances, accessing an unset attribute would otherwise yield a truthy Mock, which would incorrectly look like a violation. This test ensures the “no refusal” path stays clean.

src/app/endpoints/query_v2.py (2)

325-331: Docstring update matches new shields behavior.

The extended description (“configures system prompts, shields, and toolgroups…”) accurately reflects the new logic in retrieve_response. No issues here.


345-350: Shields discovery and guardrails wiring look correct.

  • available_shields = [shield.identifier for shield in await client.shields.list()] keeps the interface generic and only passes identifiers downstream.
  • Logging both “no shields” and the list of shields should help debugging.
  • Adding extra_body = {"guardrails": available_shields} only when the list is non-empty aligns with the tests and avoids sending a meaningless empty guardrails list.

All of this is consistent and matches the non-streaming tests.

Also applies to: 388-391

src/app/endpoints/streaming_query_v2.py (2)

35-35: metrics import is appropriate and scoped to new shield-violation tracking.

metrics is only used for llm_calls_validation_errors_total in this file, so this import is justified and doesn’t introduce unused-symbol noise.


364-368: Streaming retrieve_response shields integration looks correct.

  • Docstring now mentions shields, system prompt, and tools, reflecting what the function actually configures.
  • available_shields is derived from await client.shields.list(), with useful logging for both empty and non-empty cases.
  • extra_body = {"guardrails": available_shields} is only attached when shields exist, matching the tests and the non-streaming implementation.

This keeps the streaming path aligned with the REST v2 query behavior.

Also applies to: 381-386, 423-426

Comment on lines 293 to 385
@pytest.mark.asyncio
async def test_streaming_response_detects_shield_violation(mocker, dummy_request):
"""Test that shield violations in streaming responses are detected and metrics incremented."""
# Skip real config checks
mocker.patch("app.endpoints.streaming_query.check_configuration_loaded")

# Model selection plumbing
mock_client = mocker.Mock()
mock_client.models.list = mocker.AsyncMock(return_value=[mocker.Mock()])
mocker.patch(
"client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client
)
mocker.patch(
"app.endpoints.streaming_query.evaluate_model_hints",
return_value=(None, None),
)
mocker.patch(
"app.endpoints.streaming_query.select_model_and_provider_id",
return_value=("llama/m", "m", "p"),
)

# SSE helpers
mocker.patch(
"app.endpoints.streaming_query_v2.stream_start_event",
lambda conv_id: f"START:{conv_id}\n",
)
mocker.patch(
"app.endpoints.streaming_query_v2.format_stream_data",
lambda obj: f"EV:{obj['event']}:{obj['data'].get('token','')}\n",
)
mocker.patch(
"app.endpoints.streaming_query_v2.stream_end_event",
lambda _m, _s, _t, _media: "END\n",
)

# Conversation persistence and transcripts disabled
mocker.patch(
"app.endpoints.streaming_query_v2.persist_user_conversation_details",
return_value=None,
)
mocker.patch(
"app.endpoints.streaming_query_v2.is_transcripts_enabled", return_value=False
)

# Mock database and topic summary
mock_session = mocker.Mock()
mock_session.query.return_value.filter_by.return_value.first.return_value = (
mocker.Mock()
)
mock_context_manager = mocker.Mock()
mock_context_manager.__enter__ = mocker.Mock(return_value=mock_session)
mock_context_manager.__exit__ = mocker.Mock(return_value=None)
mocker.patch(
"app.endpoints.streaming_query_v2.get_session",
return_value=mock_context_manager,
)
mocker.patch(
"app.endpoints.streaming_query_v2.get_topic_summary",
mocker.AsyncMock(return_value=""),
)
mocker.patch(
"app.endpoints.streaming_query_v2.store_conversation_into_cache",
return_value=None,
)
mocker.patch(
"app.endpoints.streaming_query_v2.create_referenced_documents_with_metadata",
return_value=[],
)

# Mock the validation error metric
validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total")

# Build a fake async stream with shield violation
async def fake_stream_with_violation():
yield SimpleNamespace(
type="response.created", response=SimpleNamespace(id="conv-violation")
)
yield SimpleNamespace(type="response.output_text.delta", delta="I cannot ")
yield SimpleNamespace(type="response.output_text.done", text="I cannot help")
# Response completed with refusal in output
violation_item = SimpleNamespace(
type="message",
role="assistant",
refusal="Content violates safety policy",
)
response_with_violation = SimpleNamespace(
id="conv-violation", output=[violation_item]
)
yield SimpleNamespace(
type="response.completed", response=response_with_violation
)

mocker.patch(
"app.endpoints.streaming_query_v2.retrieve_response",
return_value=(fake_stream_with_violation(), ""),
)

mocker.patch("metrics.llm_calls_total")

resp = await streaming_query_endpoint_handler_v2(
request=dummy_request,
query_request=QueryRequest(query="dangerous query"),
auth=("user123", "", True, "token-abc"),
mcp_headers={},
)

assert isinstance(resp, StreamingResponse)

# Collect emitted events to trigger the generator
events: list[str] = []
async for chunk in resp.body_iterator:
s = chunk.decode() if isinstance(chunk, (bytes, bytearray)) else str(chunk)
events.append(s)

# Verify that the validation error metric was incremented
validation_metric.inc.assert_called_once()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix invalid patch target and stub cleanup dependencies at the right module.

mocker.patch("app.endpoints.streaming_query_v2.get_session", ...) fails because get_session is not defined in streaming_query_v2, matching the pipeline error, and even if it existed, it wouldn’t affect cleanup_after_streaming (which lives in utils.endpoints and uses its own module globals).

For this test to both pass and avoid touching real DB/cache, you should patch the dependencies where cleanup_after_streaming actually resolves them (likely utils.endpoints), and drop the invalid patch on streaming_query_v2.get_session.

For example:

-    mocker.patch(
-        "app.endpoints.streaming_query_v2.get_session",
-        return_value=mock_context_manager,
-    )
-    mocker.patch(
-        "app.endpoints.streaming_query_v2.store_conversation_into_cache",
-        return_value=None,
-    )
-    mocker.patch(
-        "app.endpoints.streaming_query_v2.create_referenced_documents_with_metadata",
-        return_value=[],
-    )
+    mocker.patch(
+        "utils.endpoints.get_session",
+        return_value=mock_context_manager,
+    )
+    mocker.patch(
+        "utils.endpoints.store_conversation_into_cache",
+        return_value=None,
+    )
+    mocker.patch(
+        "utils.endpoints.create_referenced_documents_with_metadata",
+        return_value=[],
+    )

This resolves the AttributeError and ensures the test stubs the actual call sites used by cleanup_after_streaming.

🤖 Prompt for AI Agents
In tests/unit/app/endpoints/test_streaming_query_v2.py around lines 293 to 409,
the test patches get_session on app.endpoints.streaming_query_v2 which does not
exist and cleanup_after_streaming resolves its dependencies from
utils.endpoints; remove the mocker.patch targeting
"app.endpoints.streaming_query_v2.get_session" and instead patch the real
call-sites in utils.endpoints (e.g., "utils.endpoints.get_session",
"utils.endpoints.get_topic_summary",
"utils.endpoints.store_conversation_into_cache",
"utils.endpoints.create_referenced_documents_with_metadata" or whichever
functions cleanup_after_streaming imports) so the cleanup logic uses the test
stubs and no real DB/cache is touched.

Comment on lines 411 to 472
@pytest.mark.asyncio
async def test_streaming_response_no_shield_violation(mocker, dummy_request):
"""Test that no metric is incremented when there's no shield violation in streaming."""
# Skip real config checks
mocker.patch("app.endpoints.streaming_query.check_configuration_loaded")

# Model selection plumbing
mock_client = mocker.Mock()
mock_client.models.list = mocker.AsyncMock(return_value=[mocker.Mock()])
mocker.patch(
"client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client
)
mocker.patch(
"app.endpoints.streaming_query.evaluate_model_hints",
return_value=(None, None),
)
mocker.patch(
"app.endpoints.streaming_query.select_model_and_provider_id",
return_value=("llama/m", "m", "p"),
)

# SSE helpers
mocker.patch(
"app.endpoints.streaming_query_v2.stream_start_event",
lambda conv_id: f"START:{conv_id}\n",
)
mocker.patch(
"app.endpoints.streaming_query_v2.format_stream_data",
lambda obj: f"EV:{obj['event']}:{obj['data'].get('token','')}\n",
)
mocker.patch(
"app.endpoints.streaming_query_v2.stream_end_event",
lambda _m, _s, _t, _media: "END\n",
)

# Conversation persistence and transcripts disabled
mocker.patch(
"app.endpoints.streaming_query_v2.persist_user_conversation_details",
return_value=None,
)
mocker.patch(
"app.endpoints.streaming_query_v2.is_transcripts_enabled", return_value=False
)

# Mock database and topic summary
mock_session = mocker.Mock()
mock_session.query.return_value.filter_by.return_value.first.return_value = (
mocker.Mock()
)
mock_context_manager = mocker.Mock()
mock_context_manager.__enter__ = mocker.Mock(return_value=mock_session)
mock_context_manager.__exit__ = mocker.Mock(return_value=None)
mocker.patch(
"app.endpoints.streaming_query_v2.get_session",
return_value=mock_context_manager,
)
mocker.patch(
"app.endpoints.streaming_query_v2.get_topic_summary",
mocker.AsyncMock(return_value=""),
)
mocker.patch(
"app.endpoints.streaming_query_v2.store_conversation_into_cache",
return_value=None,
)
mocker.patch(
"app.endpoints.streaming_query_v2.create_referenced_documents_with_metadata",
return_value=[],
)

# Mock the validation error metric
validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total")

# Build a fake async stream without violation
async def fake_stream_without_violation():
yield SimpleNamespace(
type="response.created", response=SimpleNamespace(id="conv-safe")
)
yield SimpleNamespace(type="response.output_text.delta", delta="Safe ")
yield SimpleNamespace(type="response.output_text.done", text="Safe response")
# Response completed without refusal
safe_item = SimpleNamespace(
type="message",
role="assistant",
refusal=None, # No violation
)
response_safe = SimpleNamespace(id="conv-safe", output=[safe_item])
yield SimpleNamespace(type="response.completed", response=response_safe)

mocker.patch(
"app.endpoints.streaming_query_v2.retrieve_response",
return_value=(fake_stream_without_violation(), ""),
)

mocker.patch("metrics.llm_calls_total")

resp = await streaming_query_endpoint_handler_v2(
request=dummy_request,
query_request=QueryRequest(query="safe query"),
auth=("user123", "", True, "token-abc"),
mcp_headers={},
)

assert isinstance(resp, StreamingResponse)

# Collect emitted events to trigger the generator
events: list[str] = []
async for chunk in resp.body_iterator:
s = chunk.decode() if isinstance(chunk, (bytes, bytearray)) else str(chunk)
events.append(s)

# Verify that the validation error metric was NOT incremented
validation_metric.inc.assert_not_called()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Apply the same patch-target fix in the “no shield violation” streaming test.

The second streaming shield test mirrors the first one and repeats the same incorrect patch targets (app.endpoints.streaming_query_v2.get_session, store_conversation_into_cache, create_referenced_documents_with_metadata). Update them to patch utils.endpoints.* instead, just like suggested above, so both tests stub the real dependencies and avoid the same AttributeError.

🤖 Prompt for AI Agents
In tests/unit/app/endpoints/test_streaming_query_v2.py around lines 411 to 522,
the test currently patches app.endpoints.streaming_query_v2.get_session,
app.endpoints.streaming_query_v2.store_conversation_into_cache and
app.endpoints.streaming_query_v2.create_referenced_documents_with_metadata which
is the wrong target and causes AttributeError; change those patch targets to the
utils.endpoints module counterparts (utils.endpoints.get_session,
utils.endpoints.store_conversation_into_cache,
utils.endpoints.create_referenced_documents_with_metadata) keeping the same
return values/AsyncMock usage so the test stubs the real dependencies correctly.

It includes both streaming and not streaming support, by
leveraging the refusal field on the response
@luis5tb luis5tb force-pushed the responses_v2_shields branch from b669b32 to fe1d7b6 Compare November 14, 2025 08:25
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
tests/unit/app/endpoints/test_query_v2.py (1)

123-124: Consider a pytest fixture for the shields mock to reduce duplication.

Multiple tests now mock shields.list identically. A shared fixture could eliminate this repetition and make future updates easier.

Example:

@pytest.fixture
def mock_empty_shields(mocker: MockerFixture) -> None:
    """Mock shields.list to return empty list."""
    return mocker.AsyncMock(return_value=[])

Then update tests to use it:

-    # Mock shields.list
-    mock_client.shields.list = mocker.AsyncMock(return_value=[])
+    mock_client.shields.list = mock_empty_shields

Also applies to: 161-162, 229-230, 276-277, 315-316, 354-355, 384-385

tests/unit/app/endpoints/test_streaming_query_v2.py (1)

41-42: Consider a pytest fixture for the shields mock (applies here too).

Same as in test_query_v2.py, a shared fixture could reduce duplication across streaming tests.

Also applies to: 73-74

src/app/endpoints/query_v2.py (2)

345-350: Shield discovery logic is duplicated in streaming_query_v2.py.

The shield discovery (lines 345-350) and propagation (lines 388-390) logic is identical in streaming_query_v2.py (lines 381-386, 423-425). Consider extracting this into a shared utility function to maintain consistency and reduce duplication.

Example:

# In utils/shields.py or similar
async def get_available_shields(
    client: AsyncLlamaStackClient
) -> list[str]:
    """Discover and return available shield identifiers."""
    shields = [shield.identifier for shield in await client.shields.list()]
    if not shields:
        logger.info("No available shields. Disabling safety")
    else:
        logger.info("Available shields: %s", shields)
    return shields

Then use it in both files:

available_shields = await get_available_shields(client)
# ...
if available_shields:
    create_kwargs["extra_body"] = {"guardrails": available_shields}

Also applies to: 388-390


417-424: Shield violation detection is duplicated in streaming_query_v2.py.

The violation detection logic (lines 417-424) is nearly identical in streaming_query_v2.py (lines 248-257). Consider extracting this into a shared utility function alongside the shield discovery refactor.

Example:

# In utils/shields.py
def detect_shield_violations(output_items: list[Any]) -> bool:
    """Check output items for shield violations and update metrics."""
    for output_item in output_items:
        item_type = getattr(output_item, "type", None)
        if item_type == "message":
            refusal = getattr(output_item, "refusal", None)
            if refusal:
                metrics.llm_calls_validation_errors_total.inc()
                logger.warning("Shield violation detected: %s", refusal)
                return True
    return False
src/app/endpoints/streaming_query_v2.py (1)

381-386: Shield discovery and propagation duplicated from query_v2.py.

This code is identical to query_v2.py lines 345-350 and 388-390. Please see the refactoring suggestion in the query_v2.py review to extract this into a shared utility.

Also applies to: 423-425

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b669b32 and fe1d7b6.

📒 Files selected for processing (4)
  • src/app/endpoints/query_v2.py (4 hunks)
  • src/app/endpoints/streaming_query_v2.py (5 hunks)
  • tests/unit/app/endpoints/test_query_v2.py (8 hunks)
  • tests/unit/app/endpoints/test_streaming_query_v2.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/unit/app/endpoints/test_streaming_query_v2.py (3)
tests/unit/app/endpoints/test_query_v2.py (2)
  • test_retrieve_response_with_shields_available (535-576)
  • test_retrieve_response_with_no_shields_available (580-617)
src/models/requests.py (1)
  • QueryRequest (73-225)
src/app/endpoints/streaming_query_v2.py (2)
  • retrieve_response (351-432)
  • streaming_query_endpoint_handler_v2 (321-348)
tests/unit/app/endpoints/test_query_v2.py (2)
tests/unit/app/endpoints/test_streaming_query_v2.py (2)
  • test_retrieve_response_with_shields_available (232-262)
  • test_retrieve_response_with_no_shields_available (266-292)
src/app/endpoints/query_v2.py (1)
  • retrieve_response (309-449)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: Konflux kflux-prd-rh02 / lightspeed-stack-on-pull-request
  • GitHub Check: build-pr
  • GitHub Check: e2e_tests (azure)
  • GitHub Check: e2e_tests (ci)
🔇 Additional comments (6)
tests/unit/app/endpoints/test_query_v2.py (1)

534-711: Well-structured test coverage for shields integration.

The four new tests comprehensively verify shield behavior: availability detection, guardrail propagation, violation detection, and metric updates. The test structure is clear and follows established patterns.

tests/unit/app/endpoints/test_streaming_query_v2.py (1)

231-472: Excellent streaming shield test coverage.

The new tests verify shields integration at both the retrieve_response level (guardrail propagation) and the full streaming endpoint level (violation detection in SSE flow). The mock streaming chunks with refusal data are properly structured.

src/app/endpoints/query_v2.py (1)

325-330: Good documentation update.

The docstring now accurately reflects that shields are configured alongside system prompts and toolgroups.

src/app/endpoints/streaming_query_v2.py (3)

35-35: Metrics import is correct.

The import is necessary for incrementing llm_calls_validation_errors_total when shield violations are detected.


248-257: Shield violation detection is correctly placed in streaming flow.

The violation check happens after response.completed, using the captured latest_response_object. This ensures all output items are available for inspection. However, note the duplication concern flagged in query_v2.py.


364-368: Docstring accurately reflects shield configuration.

The updated docstring correctly describes that shields are configured alongside system prompts and tools.

@luis5tb luis5tb changed the title [WIP] Responses v2 shields Responses v2 shields Nov 14, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
src/utils/shields.py (2)

13-28: Make shield discovery more robust and avoid failing the request on list() errors

Right now any failure in await client.shields.list() will bubble up and fail the whole request, even though shields are optional. It also assumes that every returned shield has a non‑empty identifier attribute.

To make this path safer and less brittle, consider:

  • Catching exceptions from client.shields.list() and falling back to [] while logging a warning.
  • Normalizing identifiers via getattr and filtering out None/empty ones.
  • Tweaking the log message to reflect that we’re “running without guardrails” rather than “disabling safety,” since this function doesn’t actually toggle anything.

For example:

-    available_shields = [shield.identifier for shield in await client.shields.list()]
-    if not available_shields:
-        logger.info("No available shields. Disabling safety")
-    else:
-        logger.info("Available shields: %s", available_shields)
-    return available_shields
+    try:
+        shields = await client.shields.list()
+    except Exception as exc:  # pylint: disable=broad-exception-caught
+        logger.warning(
+            "Failed to list shields from Llama Stack, proceeding without guardrails: %s",
+            exc,
+        )
+        return []
+
+    available_shields = [
+        getattr(shield, "identifier", None) for shield in shields
+    ]
+    available_shields = [s for s in available_shields if s]
+
+    if not available_shields:
+        logger.info("No available shields discovered; proceeding without guardrails")
+    else:
+        logger.info("Available shields: %s", available_shields)
+    return available_shields

This keeps shields best‑effort and avoids a hard dependency on the exact return type shape.


31-53: Clarify what counts as a “shield violation” and consider expanding detection coverage

The implementation is simple and effective for the specific case of a message output item with a top‑level, non‑empty refusal attribute, and the metric increment + warning log make sense.

A couple of follow‑ups to consider:

  • If the Responses API ever encodes refusals inside content parts (e.g., a part with a refusal field, similar to your _extract_text_from_response_output_item handling), this helper will miss them. You may want to either:
    • Explicitly document that you only consider top‑level output_item.refusal, or
    • Extend detection to mirror the traversal you already do in _extract_text_from_response_output_item.
  • The function returns bool but the current call site ignores the return value; if you don’t plan to branch on this, you could either drop the return value (and treat this as a pure side‑effect helper) or have the caller log/annotate based on it for observability.

No blocking issues here, but tightening the definition and coverage now will avoid surprises when the response schema evolves.

src/app/endpoints/query_v2.py (1)

346-348: Shield integration is correct at a high level; consider resilience, caching, and return‑value use

The overall flow makes sense:

  • Discover shields once via available_shields = await get_available_shields(client).
  • If any are present, inject them into create_kwargs["extra_body"] = {"guardrails": available_shields} so they apply to the Responses call.
  • After processing all output items, run detect_shield_violations(response.output) to bump the validation‑error metric when a refusal is present.

A few refinements worth considering:

  1. Don’t let shield listing failures break the request

    As get_available_shields is called on every request, any transient failure in client.shields.list() will currently abort the whole /query call. If shields are meant to be best‑effort, it’s safer to let get_available_shields swallow errors and return [] (see suggested refactor in src/utils/shields.py) so this endpoint remains functional even when shield discovery is flaky.

  2. Avoid repeated remote shield discovery per request

    You’re now doing a round‑trip to Llama Stack shields on every call to retrieve_response. If the set of shields is relatively static per process/config, consider caching them at process level (or with a short TTL) and reusing them instead of calling .list() every time. This will reduce latency and load on Llama Stack.

  3. Validate extra_body / guardrails shape against the client API

    The extra_body={"guardrails": available_shields} convention relies on the current AsyncLlamaStackClient.responses.create semantics. If the guardrails field ever changes name/shape upstream, this will fail at runtime. It’s worth double‑checking this against the version of the Llama Stack client you’re targeting.

  4. Use or drop the boolean return from detect_shield_violations

    Since the return value is currently ignored, either:

    • Capture it and, for example, log a high‑level “shield violation occurred” log with context (conversation_id, model), or
    • Change detect_shield_violations to return None if you intend it to be purely side‑effectful. This would make the API intent clearer.

None of these are blockers, but addressing (1) and (2) in particular will make this path more robust in production.

Also applies to: 385-388, 414-416

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fe1d7b6 and dccb61b.

📒 Files selected for processing (3)
  • src/app/endpoints/query_v2.py (5 hunks)
  • src/app/endpoints/streaming_query_v2.py (5 hunks)
  • src/utils/shields.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/app/endpoints/streaming_query_v2.py
🧰 Additional context used
🧬 Code graph analysis (1)
src/app/endpoints/query_v2.py (1)
src/utils/shields.py (2)
  • detect_shield_violations (31-54)
  • get_available_shields (13-28)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: build-pr
  • GitHub Check: e2e_tests (ci)
  • GitHub Check: e2e_tests (azure)
  • GitHub Check: Konflux kflux-prd-rh02 / lightspeed-stack-on-pull-request
🔇 Additional comments (3)
src/utils/shields.py (1)

1-11: Module setup looks straightforward

Docstring, imports, and logger initialization are consistent with the rest of the codebase style; no issues from a structure standpoint.

src/app/endpoints/query_v2.py (2)

37-37: Shields utilities import is appropriately scoped

Pulling detect_shield_violations and get_available_shields into this module keeps the response logic focused and avoids leaking shield details into unrelated layers.


326-332: Docstring accurately reflects the new shields behavior

The expanded description calling out shields alongside system prompts and toolgroups matches the actual implementation below (guardrails in extra_body and violation metrics), and keeps the function contract clear.

@eranco74
Copy link
Contributor

/lgtm

Copy link
Contributor

@tisnik tisnik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tisnik tisnik merged commit c1e4aee into lightspeed-core:main Nov 19, 2025
21 of 23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants