-
Notifications
You must be signed in to change notification settings - Fork 55
Responses v2 shields #760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Responses v2 shields #760
Conversation
WalkthroughAdds runtime shield discovery to Responses API v2 endpoints: queries available shields at request time, attaches shield IDs as Changes
Sequence Diagram(s)sequenceDiagram
participant C as Client
participant H as Endpoint Handler
participant S as Shields util
participant R as Responses API
participant P as Processor
C->>H: QueryRequest v2
rect rgb(240, 250, 240)
Note over H,S: Discover shields
H->>S: get_available_shields(client)
S-->>H: [shield_id,...] or []
end
rect rgb(240, 240, 255)
Note over H,R: Create Responses request
alt Shields present
H->>R: create(..., extra_body: { guardrails: [ids] })
else No shields
H->>R: create(..., no extra_body.guardrails)
end
end
R-->>H: response (output items...)
rect rgb(255, 250, 230)
Note over H,P: Inspect outputs for refusals
loop per output item
alt message with refusal
P->>P: detect_shield_violations -> increment llm_calls_validation_errors_total
P->>H: log shield violation
else
P->>H: normal processing
end
end
end
H-->>C: Final response / stream
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/app/endpoints/query_v2.py (1)
468-558: Prevent double-countingllm_calls_total.
extract_token_usage_from_responses_apialways bumpsllm_calls_total, but the shared streaming base already increments that metric before yielding the SSE stream. For the Response API path this results in every request being counted twice. Please make the metric increment optional so we can disable it for streaming callers while keeping the existing behavior for synchronous flows.-def extract_token_usage_from_responses_api( - response: OpenAIResponseObject, - model: str, - provider: str, - system_prompt: str = "", # pylint: disable=unused-argument -) -> TokenCounter: +def extract_token_usage_from_responses_api( + response: OpenAIResponseObject, + model: str, + provider: str, + system_prompt: str = "", # pylint: disable=unused-argument + *, + increment_llm_call_metric: bool = True, +) -> TokenCounter: @@ - # Update Prometheus metrics only when we have actual usage data - try: - metrics.llm_token_sent_total.labels(provider, model).inc( - token_counter.input_tokens - ) - metrics.llm_token_received_total.labels(provider, model).inc( - token_counter.output_tokens - ) - except (AttributeError, TypeError, ValueError) as e: - logger.warning("Failed to update token metrics: %s", e) - _increment_llm_call_metric(provider, model) + # Update Prometheus metrics only when we have actual usage data + try: + metrics.llm_token_sent_total.labels(provider, model).inc( + token_counter.input_tokens + ) + metrics.llm_token_received_total.labels(provider, model).inc( + token_counter.output_tokens + ) + except (AttributeError, TypeError, ValueError) as e: + logger.warning("Failed to update token metrics: %s", e) + if increment_llm_call_metric: + _increment_llm_call_metric(provider, model) @@ - _increment_llm_call_metric(provider, model) + if increment_llm_call_metric: + _increment_llm_call_metric(provider, model) @@ - _increment_llm_call_metric(provider, model) + if increment_llm_call_metric: + _increment_llm_call_metric(provider, model)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
src/app/endpoints/query_v2.py(7 hunks)src/app/endpoints/streaming_query.py(6 hunks)src/app/endpoints/streaming_query_v2.py(1 hunks)src/app/routers.py(2 hunks)src/models/context.py(1 hunks)tests/unit/app/endpoints/test_query_v2.py(9 hunks)tests/unit/app/endpoints/test_streaming_query_v2.py(1 hunks)tests/unit/app/test_routers.py(5 hunks)
🧰 Additional context used
📓 Path-based instructions (9)
src/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
Use absolute imports for internal modules (e.g., from auth import get_auth_dependency)
Files:
src/app/routers.pysrc/models/context.pysrc/app/endpoints/streaming_query_v2.pysrc/app/endpoints/streaming_query.pysrc/app/endpoints/query_v2.py
src/app/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
Use standard FastAPI imports (from fastapi import APIRouter, HTTPException, Request, status, Depends) in FastAPI app code
Files:
src/app/routers.pysrc/app/endpoints/streaming_query_v2.pysrc/app/endpoints/streaming_query.pysrc/app/endpoints/query_v2.py
**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
**/*.py: All modules start with descriptive module-level docstrings explaining purpose
Use logger = logging.getLogger(name) for module logging after import logging
Define type aliases at module level for clarity
All functions require docstrings with brief descriptions
Provide complete type annotations for all function parameters and return types
Use typing_extensions.Self in model validators where appropriate
Use modern union syntax (str | int) and Optional[T] or T | None consistently
Function names use snake_case with descriptive, action-oriented prefixes (get_, validate_, check_)
Avoid in-place parameter modification; return new data structures instead of mutating arguments
Use appropriate logging levels: debug, info, warning, error with clear messages
All classes require descriptive docstrings explaining purpose
Class names use PascalCase with conventional suffixes (Configuration, Error/Exception, Resolver, Interface)
Abstract base classes should use abc.ABC and @AbstractMethod for interfaces
Provide complete type annotations for all class attributes
Follow Google Python docstring style for modules, classes, and functions, including Args, Returns, Raises, Attributes sections as needed
Files:
src/app/routers.pysrc/models/context.pysrc/app/endpoints/streaming_query_v2.pytests/unit/app/endpoints/test_query_v2.pytests/unit/app/test_routers.pysrc/app/endpoints/streaming_query.pysrc/app/endpoints/query_v2.pytests/unit/app/endpoints/test_streaming_query_v2.py
src/{app/**/*.py,client.py}
📄 CodeRabbit inference engine (CLAUDE.md)
Use async def for I/O-bound operations and external API calls
Files:
src/app/routers.pysrc/app/endpoints/streaming_query_v2.pysrc/app/endpoints/streaming_query.pysrc/app/endpoints/query_v2.py
src/{models/**/*.py,configuration.py}
📄 CodeRabbit inference engine (CLAUDE.md)
src/{models/**/*.py,configuration.py}: Use @field_validator and @model_validator for custom validation in Pydantic models
Use precise type hints in configuration (e.g., Optional[FilePath], PositiveInt, SecretStr)
Files:
src/models/context.py
src/models/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
src/models/**/*.py: Pydantic models: use BaseModel for data models and extend ConfigurationBase for configuration
Use @model_validator and @field_validator for Pydantic model validation
Files:
src/models/context.py
src/app/endpoints/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
In API endpoints, raise FastAPI HTTPException with appropriate status codes for error handling
Files:
src/app/endpoints/streaming_query_v2.pysrc/app/endpoints/streaming_query.pysrc/app/endpoints/query_v2.py
tests/{unit,integration}/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/{unit,integration}/**/*.py: Use pytest for all unit and integration tests
Do not use unittest in tests; pytest is the standard
Files:
tests/unit/app/endpoints/test_query_v2.pytests/unit/app/test_routers.pytests/unit/app/endpoints/test_streaming_query_v2.py
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Use pytest-mock to create AsyncMock objects for async interactions in tests
Use the shared auth mock constant: MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") in tests
Files:
tests/unit/app/endpoints/test_query_v2.pytests/unit/app/test_routers.pytests/unit/app/endpoints/test_streaming_query_v2.py
🧬 Code graph analysis (7)
src/app/routers.py (1)
tests/unit/app/test_routers.py (1)
include_router(37-52)
src/models/context.py (1)
src/models/requests.py (1)
QueryRequest(73-225)
src/app/endpoints/streaming_query_v2.py (12)
src/app/database.py (1)
get_session(34-40)src/app/endpoints/query.py (3)
is_transcripts_enabled(98-104)persist_user_conversation_details(107-139)validate_attachments_metadata(801-830)src/app/endpoints/query_v2.py (4)
extract_token_usage_from_responses_api(468-558)get_topic_summary(235-274)prepare_tools_for_responses_api(634-685)retrieve_response(304-444)src/app/endpoints/streaming_query.py (6)
format_stream_data(126-137)stream_end_event(164-220)stream_start_event(140-161)streaming_query_endpoint_handler_base(851-983)response_generator(716-846)retrieve_response(1018-1139)src/models/cache_entry.py (1)
CacheEntry(7-24)src/models/context.py (1)
ResponseGeneratorContext(12-48)src/models/responses.py (2)
ForbiddenResponse(1120-1142)UnauthorizedResponse(1094-1117)src/utils/endpoints.py (2)
create_referenced_documents_with_metadata(563-577)store_conversation_into_cache(231-251)src/utils/mcp_headers.py (1)
mcp_headers_dependency(15-26)src/utils/token_counter.py (1)
TokenCounter(18-41)src/utils/transcripts.py (1)
store_transcript(40-99)src/utils/types.py (2)
TurnSummary(89-163)ToolCallSummary(73-86)
tests/unit/app/endpoints/test_query_v2.py (2)
src/models/config.py (1)
ModelContextProtocolServer(169-174)src/app/endpoints/query_v2.py (2)
get_mcp_tools(592-631)retrieve_response(304-444)
src/app/endpoints/streaming_query.py (8)
src/models/context.py (1)
ResponseGeneratorContext(12-48)src/utils/endpoints.py (4)
get_system_prompt(126-190)create_rag_chunks_dict(383-396)create_referenced_documents_with_metadata(563-577)store_conversation_into_cache(231-251)src/metrics/utils.py (1)
update_llm_token_count_from_turn(60-77)src/utils/token_counter.py (2)
extract_token_usage_from_turn(44-94)TokenCounter(18-41)src/app/endpoints/query.py (1)
persist_user_conversation_details(107-139)src/utils/transcripts.py (1)
store_transcript(40-99)src/app/database.py (1)
get_session(34-40)src/models/database/conversations.py (1)
UserConversation(11-38)
src/app/endpoints/query_v2.py (2)
src/configuration.py (3)
configuration(73-77)AppConfig(39-181)mcp_servers(101-105)src/models/requests.py (1)
QueryRequest(73-225)
tests/unit/app/endpoints/test_streaming_query_v2.py (4)
src/models/requests.py (1)
QueryRequest(73-225)src/models/config.py (3)
config(140-146)Action(329-375)ModelContextProtocolServer(169-174)src/app/endpoints/streaming_query_v2.py (2)
retrieve_response(397-478)streaming_query_endpoint_handler_v2(367-394)src/configuration.py (1)
mcp_servers(101-105)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: e2e_tests (azure)
- GitHub Check: e2e_tests (ci)
- GitHub Check: build-pr
- GitHub Check: Konflux kflux-prd-rh02 / lightspeed-stack-on-pull-request
| chunk_id = 0 | ||
| summary = TurnSummary(llm_response="No response from the model", tool_calls=[]) | ||
|
|
||
| # Determine media type for response formatting | ||
| media_type = context.query_request.media_type or MEDIA_TYPE_JSON | ||
|
|
||
| # Accumulators for Responses API | ||
| text_parts: list[str] = [] | ||
| tool_item_registry: dict[str, dict[str, str]] = {} | ||
| emitted_turn_complete = False | ||
|
|
||
| # Handle conversation id and start event in-band on response.created | ||
| conv_id = context.conversation_id | ||
|
|
||
| # Track the latest response object from response.completed event | ||
| latest_response_object: Any | None = None | ||
|
|
||
| logger.debug("Starting streaming response (Responses API) processing") | ||
|
|
||
| async for chunk in turn_response: | ||
| event_type = getattr(chunk, "type", None) | ||
| logger.debug("Processing chunk %d, type: %s", chunk_id, event_type) | ||
|
|
||
| # Emit start on response.created | ||
| if event_type == "response.created": | ||
| try: | ||
| conv_id = getattr(chunk, "response").id | ||
| except Exception: # pylint: disable=broad-except | ||
| conv_id = "" | ||
| yield stream_start_event(conv_id) | ||
| continue | ||
|
|
||
| # Text streaming | ||
| if event_type == "response.output_text.delta": | ||
| delta = getattr(chunk, "delta", "") | ||
| if delta: | ||
| text_parts.append(delta) | ||
| yield format_stream_data( | ||
| { | ||
| "event": "token", | ||
| "data": { | ||
| "id": chunk_id, | ||
| "token": delta, | ||
| }, | ||
| } | ||
| ) | ||
| chunk_id += 1 | ||
|
|
||
| # Final text of the output (capture, but emit at response.completed) | ||
| elif event_type == "response.output_text.done": | ||
| final_text = getattr(chunk, "text", "") | ||
| if final_text: | ||
| summary.llm_response = final_text | ||
|
|
||
| # Content part started - emit an empty token to kick off UI streaming if desired | ||
| elif event_type == "response.content_part.added": | ||
| yield format_stream_data( | ||
| { | ||
| "event": "token", | ||
| "data": { | ||
| "id": chunk_id, | ||
| "token": "", | ||
| }, | ||
| } | ||
| ) | ||
| chunk_id += 1 | ||
|
|
||
| # Track tool call items as they are added so we can build a summary later | ||
| elif event_type == "response.output_item.added": | ||
| item = getattr(chunk, "item", None) | ||
| item_type = getattr(item, "type", None) | ||
| if item and item_type == "function_call": | ||
| item_id = getattr(item, "id", "") | ||
| name = getattr(item, "name", "function_call") | ||
| call_id = getattr(item, "call_id", item_id) | ||
| if item_id: | ||
| tool_item_registry[item_id] = { | ||
| "name": name, | ||
| "call_id": call_id, | ||
| } | ||
|
|
||
| # Stream tool call arguments as tool_call events | ||
| elif event_type == "response.function_call_arguments.delta": | ||
| delta = getattr(chunk, "delta", "") | ||
| yield format_stream_data( | ||
| { | ||
| "event": "tool_call", | ||
| "data": { | ||
| "id": chunk_id, | ||
| "role": "tool_execution", | ||
| "token": delta, | ||
| }, | ||
| } | ||
| ) | ||
| chunk_id += 1 | ||
|
|
||
| # Finalize tool call arguments and append to summary | ||
| elif event_type in ( | ||
| "response.function_call_arguments.done", | ||
| "response.mcp_call.arguments.done", | ||
| ): | ||
| item_id = getattr(chunk, "item_id", "") | ||
| arguments = getattr(chunk, "arguments", "") | ||
| meta = tool_item_registry.get(item_id, {}) | ||
| summary.tool_calls.append( | ||
| ToolCallSummary( | ||
| id=meta.get("call_id", item_id or "unknown"), | ||
| name=meta.get("name", "tool_call"), | ||
| args=arguments, | ||
| response=None, | ||
| ) | ||
| ) | ||
|
|
||
| # Completed response - capture final text and response object | ||
| elif event_type == "response.completed": | ||
| # Capture the response object for token usage extraction | ||
| latest_response_object = getattr(chunk, "response", None) | ||
|
|
||
| # Check for shield violations in the completed response | ||
| if latest_response_object: | ||
| for output_item in getattr(latest_response_object, "output", []): | ||
| item_type = getattr(output_item, "type", None) | ||
| if item_type == "message": | ||
| refusal = getattr(output_item, "refusal", None) | ||
| if refusal: | ||
| # Metric for LLM validation errors (shield violations) | ||
| metrics.llm_calls_validation_errors_total.inc() | ||
| logger.warning("Shield violation detected: %s", refusal) | ||
|
|
||
| if not emitted_turn_complete: | ||
| final_message = summary.llm_response or "".join(text_parts) | ||
| yield format_stream_data( | ||
| { | ||
| "event": "turn_complete", | ||
| "data": { | ||
| "id": chunk_id, | ||
| "token": final_message, | ||
| }, | ||
| } | ||
| ) | ||
| chunk_id += 1 | ||
| emitted_turn_complete = True | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Persist the streamed text into the summary.
When the Responses API only emits response.output_text.delta chunks (and never provides a non-empty response.output_text.done), summary.llm_response stays at the placeholder value. We then emit "No response from the model" in the final turn_complete event and we store the wrong text into transcripts/cache. Capture the accumulated deltas in the summary before finishing the turn.
- summary = TurnSummary(llm_response="No response from the model", tool_calls=[])
+ summary = TurnSummary(llm_response="", tool_calls=[])
@@
- if not emitted_turn_complete:
- final_message = summary.llm_response or "".join(text_parts)
+ if not emitted_turn_complete:
+ final_message = summary.llm_response or "".join(text_parts)
+ summary.llm_response = final_message🤖 Prompt for AI Agents
In src/app/endpoints/streaming_query_v2.py around lines 134-276, the
summary.llm_response is left as the placeholder when only
response.output_text.delta chunks are received; before emitting the final
turn_complete (inside the response.completed handling) set summary.llm_response
= summary.llm_response if it already contains a non-placeholder value else
"".join(text_parts) (or simply assign "".join(text_parts) if
summary.llm_response is the placeholder) so the accumulated token deltas are
persisted into the summary and used for the final emitted token and stored
transcripts/cache.
|
|
||
| # Extract token usage from the response object | ||
| token_usage = ( | ||
| extract_token_usage_from_responses_api( | ||
| latest_response_object, context.model_id, context.provider_id | ||
| ) | ||
| if latest_response_object is not None | ||
| else TokenCounter() | ||
| ) | ||
|
|
||
| yield stream_end_event(context.metadata_map, summary, token_usage, media_type) | ||
|
|
||
| if not is_transcripts_enabled(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid incrementing llm_calls_total twice.
After the streaming base already increments metrics.llm_calls_total, calling extract_token_usage_from_responses_api here triggers a second increment. Once the helper accepts an increment_llm_call_metric flag, pass False from the streaming path so the metric stays accurate while we still update token counters.
- extract_token_usage_from_responses_api(
- latest_response_object, context.model_id, context.provider_id
- )
+ extract_token_usage_from_responses_api(
+ latest_response_object,
+ context.model_id,
+ context.provider_id,
+ increment_llm_call_metric=False,
+ )Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/app/endpoints/streaming_query_v2.py around lines 284 to 296, calling
extract_token_usage_from_responses_api here causes metrics.llm_calls_total to be
incremented a second time because the streaming base already incremented it;
update the call to pass increment_llm_call_metric=False (once the helper accepts
that flag) so the helper extracts token counters without bumping the LLM call
metric, ensuring only token counters are updated and metrics.llm_calls_total
remains accurate.
fb25b8c to
b669b32
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
src/app/endpoints/streaming_query_v2.py (1)
285-288: Be aware of double‑countingllm_calls_totalin streaming (already flagged earlier).
extract_token_usage_from_responses_apiinternally incrementsmetrics.llm_calls_total, and the streaming base handler also increments that metric; calling it here effectively bumpsllm_calls_totaltwice per streaming call. A previous review already suggested adding anincrement_llm_call_metricflag to the helper and passingFalsefrom this path to avoid double-counting.Not a new regression in this diff, but still worth addressing in a follow-up.
🧹 Nitpick comments (3)
tests/unit/app/endpoints/test_query_v2.py (1)
618-661: Shield violation metric test aligns with implementation.This test accurately models a shield violation (assistant
messagewith a non-emptyrefusal), patchesmetrics.llm_calls_validation_errors_total, and assertsinc()is called exactly once. That matches the new loop inquery_v2.retrieve_responsethat scans output items for refusals and bumps the validation-error metric.One minor thought: if a response ever contained multiple violating
messageitems, the metric would be incremented multiple times per call. If you intend this metric to be “per call” rather than “per offending message”, you could break after the first refusal, but that’s a behavioral choice, not a blocker.src/app/endpoints/query_v2.py (1)
417-424: Shield violation detection is implemented as expected; consider counting once per call.The loop over
response.outputchecks eachmessage-type item for a non-emptyrefusaland incrementsmetrics.llm_calls_validation_errors_totalfor each. That matches the new tests and will correctly mark shield-triggered refusals.If you intend
llm_calls_validation_errors_totalto be “per call” rather than “per offending message”, you could short-circuit after the first violation:- for output_item in response.output: + for output_item in response.output: ... - if item_type == "message": - refusal = getattr(output_item, "refusal", None) - if refusal: - metrics.llm_calls_validation_errors_total.inc() - logger.warning("Shield violation detected: %s", refusal) + if item_type == "message": + refusal = getattr(output_item, "refusal", None) + if refusal: + metrics.llm_calls_validation_errors_total.inc() + logger.warning("Shield violation detected: %s", refusal) + breakNot urgent, but it may make the metric easier to reason about.
src/app/endpoints/streaming_query_v2.py (1)
248-258: Streaming shield violation handling is consistent with the non‑streaming path.On
response.completed, you:
- Grab the final response object.
- Scan
outputformessageitems with a non-emptyrefusal.- Increment
metrics.llm_calls_validation_errors_totaland log a warning once per offending message.This matches the semantics in
query_v2.retrieve_response, and the new tests exercise both violation and no-violation cases. Same optional note as in the REST path: if you want a “per call” metric, consider breaking after the first refusal, but behavior is otherwise correct.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/app/endpoints/query_v2.py(4 hunks)src/app/endpoints/streaming_query_v2.py(5 hunks)tests/unit/app/endpoints/test_query_v2.py(8 hunks)tests/unit/app/endpoints/test_streaming_query_v2.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/unit/app/endpoints/test_query_v2.py (3)
tests/unit/app/endpoints/test_streaming_query_v2.py (2)
test_retrieve_response_with_shields_available(232-262)test_retrieve_response_with_no_shields_available(266-290)src/models/requests.py (1)
QueryRequest(73-225)src/app/endpoints/query_v2.py (1)
retrieve_response(309-449)
tests/unit/app/endpoints/test_streaming_query_v2.py (3)
tests/unit/app/endpoints/test_query_v2.py (3)
test_retrieve_response_with_shields_available(535-576)test_retrieve_response_with_no_shields_available(580-615)dummy_request(32-35)src/app/endpoints/query_v2.py (1)
retrieve_response(309-449)src/app/endpoints/streaming_query_v2.py (2)
retrieve_response(351-432)streaming_query_endpoint_handler_v2(321-348)
🪛 GitHub Actions: Unit tests
src/app/endpoints/streaming_query_v2.py
[error] 1-1: AttributeError: module 'app.endpoints.streaming_query_v2' does not have the attribute 'get_session'.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: build-pr
- GitHub Check: Konflux kflux-prd-rh02 / lightspeed-stack-on-pull-request
- GitHub Check: e2e_tests (azure)
- GitHub Check: e2e_tests (ci)
🔇 Additional comments (10)
tests/unit/app/endpoints/test_streaming_query_v2.py (2)
41-42: Good: shields.list is mocked in existing streaming retrieve_response tests.Adding
mock_client.shields.list = mocker.AsyncMock(return_value=[])here keeps the updatedretrieve_responseimplementation from blowing up on missingshieldswhile keeping the original test focus (tools & streaming flags) intact. No issues from this change.Also applies to: 73-74
231-263: LGTM: streaming retrieve_response shields plumbing is covered.These tests correctly:
- Mock
client.shields.listto return identifiers.- Assert that
extra_body.guardrailsis present when shields exist and absent when not.- Keep the rest of the behavior (vector stores, tools, system prompt) mocked but unconstrained, which is appropriate for this focused check.
No changes needed here.
tests/unit/app/endpoints/test_query_v2.py (4)
123-124: Good: existing retrieve_response tests are adapted to shields.Adding
mock_client.shields.list = mocker.AsyncMock(return_value=[])in all the existingretrieve_responsetests makes them compatible with the new shield discovery logic without changing what they assert (tools, usage, attachments, etc.). This is the right, minimal adjustment.Also applies to: 161-162, 229-230, 276-277, 315-316, 354-355, 384-385
535-577: Non‑streaming shields plumbing test looks solid.This test:
- Mocks
shields.listto return identifiers.- Verifies
summary.llm_responseandconv_id.- Asserts
extra_body.guardrailsmatches the shield identifiers.This directly exercises the new
available_shields→extra_body["guardrails"]flow inquery_v2.retrieve_response. No changes needed.
579-616: Correctly covers “no shields → no extra_body” behavior.The test for the empty-shields case ensures:
client.shields.listreturns[].retrieve_responsestill returns the expected summary and conversation id.extra_bodyis not present in the responses.create kwargs.This matches the intended behavior and protects against regressions where we might start sending an empty guardrails list.
663-705: “No violation” test correctly guards against false positives.Here you explicitly set
output_item.refusal = Noneand assert thatvalidation_metric.incis never called. That’s important because with bareMockinstances, accessing an unset attribute would otherwise yield a truthyMock, which would incorrectly look like a violation. This test ensures the “no refusal” path stays clean.src/app/endpoints/query_v2.py (2)
325-331: Docstring update matches new shields behavior.The extended description (“configures system prompts, shields, and toolgroups…”) accurately reflects the new logic in
retrieve_response. No issues here.
345-350: Shields discovery and guardrails wiring look correct.
available_shields = [shield.identifier for shield in await client.shields.list()]keeps the interface generic and only passes identifiers downstream.- Logging both “no shields” and the list of shields should help debugging.
- Adding
extra_body = {"guardrails": available_shields}only when the list is non-empty aligns with the tests and avoids sending a meaningless empty guardrails list.All of this is consistent and matches the non-streaming tests.
Also applies to: 388-391
src/app/endpoints/streaming_query_v2.py (2)
35-35: metrics import is appropriate and scoped to new shield-violation tracking.
metricsis only used forllm_calls_validation_errors_totalin this file, so this import is justified and doesn’t introduce unused-symbol noise.
364-368: Streaming retrieve_response shields integration looks correct.
- Docstring now mentions shields, system prompt, and tools, reflecting what the function actually configures.
available_shieldsis derived fromawait client.shields.list(), with useful logging for both empty and non-empty cases.extra_body = {"guardrails": available_shields}is only attached when shields exist, matching the tests and the non-streaming implementation.This keeps the streaming path aligned with the REST v2 query behavior.
Also applies to: 381-386, 423-426
| @pytest.mark.asyncio | ||
| async def test_streaming_response_detects_shield_violation(mocker, dummy_request): | ||
| """Test that shield violations in streaming responses are detected and metrics incremented.""" | ||
| # Skip real config checks | ||
| mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") | ||
|
|
||
| # Model selection plumbing | ||
| mock_client = mocker.Mock() | ||
| mock_client.models.list = mocker.AsyncMock(return_value=[mocker.Mock()]) | ||
| mocker.patch( | ||
| "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query.evaluate_model_hints", | ||
| return_value=(None, None), | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query.select_model_and_provider_id", | ||
| return_value=("llama/m", "m", "p"), | ||
| ) | ||
|
|
||
| # SSE helpers | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.stream_start_event", | ||
| lambda conv_id: f"START:{conv_id}\n", | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.format_stream_data", | ||
| lambda obj: f"EV:{obj['event']}:{obj['data'].get('token','')}\n", | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.stream_end_event", | ||
| lambda _m, _s, _t, _media: "END\n", | ||
| ) | ||
|
|
||
| # Conversation persistence and transcripts disabled | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.persist_user_conversation_details", | ||
| return_value=None, | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.is_transcripts_enabled", return_value=False | ||
| ) | ||
|
|
||
| # Mock database and topic summary | ||
| mock_session = mocker.Mock() | ||
| mock_session.query.return_value.filter_by.return_value.first.return_value = ( | ||
| mocker.Mock() | ||
| ) | ||
| mock_context_manager = mocker.Mock() | ||
| mock_context_manager.__enter__ = mocker.Mock(return_value=mock_session) | ||
| mock_context_manager.__exit__ = mocker.Mock(return_value=None) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.get_session", | ||
| return_value=mock_context_manager, | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.get_topic_summary", | ||
| mocker.AsyncMock(return_value=""), | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.store_conversation_into_cache", | ||
| return_value=None, | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.create_referenced_documents_with_metadata", | ||
| return_value=[], | ||
| ) | ||
|
|
||
| # Mock the validation error metric | ||
| validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total") | ||
|
|
||
| # Build a fake async stream with shield violation | ||
| async def fake_stream_with_violation(): | ||
| yield SimpleNamespace( | ||
| type="response.created", response=SimpleNamespace(id="conv-violation") | ||
| ) | ||
| yield SimpleNamespace(type="response.output_text.delta", delta="I cannot ") | ||
| yield SimpleNamespace(type="response.output_text.done", text="I cannot help") | ||
| # Response completed with refusal in output | ||
| violation_item = SimpleNamespace( | ||
| type="message", | ||
| role="assistant", | ||
| refusal="Content violates safety policy", | ||
| ) | ||
| response_with_violation = SimpleNamespace( | ||
| id="conv-violation", output=[violation_item] | ||
| ) | ||
| yield SimpleNamespace( | ||
| type="response.completed", response=response_with_violation | ||
| ) | ||
|
|
||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.retrieve_response", | ||
| return_value=(fake_stream_with_violation(), ""), | ||
| ) | ||
|
|
||
| mocker.patch("metrics.llm_calls_total") | ||
|
|
||
| resp = await streaming_query_endpoint_handler_v2( | ||
| request=dummy_request, | ||
| query_request=QueryRequest(query="dangerous query"), | ||
| auth=("user123", "", True, "token-abc"), | ||
| mcp_headers={}, | ||
| ) | ||
|
|
||
| assert isinstance(resp, StreamingResponse) | ||
|
|
||
| # Collect emitted events to trigger the generator | ||
| events: list[str] = [] | ||
| async for chunk in resp.body_iterator: | ||
| s = chunk.decode() if isinstance(chunk, (bytes, bytearray)) else str(chunk) | ||
| events.append(s) | ||
|
|
||
| # Verify that the validation error metric was incremented | ||
| validation_metric.inc.assert_called_once() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix invalid patch target and stub cleanup dependencies at the right module.
mocker.patch("app.endpoints.streaming_query_v2.get_session", ...) fails because get_session is not defined in streaming_query_v2, matching the pipeline error, and even if it existed, it wouldn’t affect cleanup_after_streaming (which lives in utils.endpoints and uses its own module globals).
For this test to both pass and avoid touching real DB/cache, you should patch the dependencies where cleanup_after_streaming actually resolves them (likely utils.endpoints), and drop the invalid patch on streaming_query_v2.get_session.
For example:
- mocker.patch(
- "app.endpoints.streaming_query_v2.get_session",
- return_value=mock_context_manager,
- )
- mocker.patch(
- "app.endpoints.streaming_query_v2.store_conversation_into_cache",
- return_value=None,
- )
- mocker.patch(
- "app.endpoints.streaming_query_v2.create_referenced_documents_with_metadata",
- return_value=[],
- )
+ mocker.patch(
+ "utils.endpoints.get_session",
+ return_value=mock_context_manager,
+ )
+ mocker.patch(
+ "utils.endpoints.store_conversation_into_cache",
+ return_value=None,
+ )
+ mocker.patch(
+ "utils.endpoints.create_referenced_documents_with_metadata",
+ return_value=[],
+ )This resolves the AttributeError and ensures the test stubs the actual call sites used by cleanup_after_streaming.
🤖 Prompt for AI Agents
In tests/unit/app/endpoints/test_streaming_query_v2.py around lines 293 to 409,
the test patches get_session on app.endpoints.streaming_query_v2 which does not
exist and cleanup_after_streaming resolves its dependencies from
utils.endpoints; remove the mocker.patch targeting
"app.endpoints.streaming_query_v2.get_session" and instead patch the real
call-sites in utils.endpoints (e.g., "utils.endpoints.get_session",
"utils.endpoints.get_topic_summary",
"utils.endpoints.store_conversation_into_cache",
"utils.endpoints.create_referenced_documents_with_metadata" or whichever
functions cleanup_after_streaming imports) so the cleanup logic uses the test
stubs and no real DB/cache is touched.
| @pytest.mark.asyncio | ||
| async def test_streaming_response_no_shield_violation(mocker, dummy_request): | ||
| """Test that no metric is incremented when there's no shield violation in streaming.""" | ||
| # Skip real config checks | ||
| mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") | ||
|
|
||
| # Model selection plumbing | ||
| mock_client = mocker.Mock() | ||
| mock_client.models.list = mocker.AsyncMock(return_value=[mocker.Mock()]) | ||
| mocker.patch( | ||
| "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query.evaluate_model_hints", | ||
| return_value=(None, None), | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query.select_model_and_provider_id", | ||
| return_value=("llama/m", "m", "p"), | ||
| ) | ||
|
|
||
| # SSE helpers | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.stream_start_event", | ||
| lambda conv_id: f"START:{conv_id}\n", | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.format_stream_data", | ||
| lambda obj: f"EV:{obj['event']}:{obj['data'].get('token','')}\n", | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.stream_end_event", | ||
| lambda _m, _s, _t, _media: "END\n", | ||
| ) | ||
|
|
||
| # Conversation persistence and transcripts disabled | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.persist_user_conversation_details", | ||
| return_value=None, | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.is_transcripts_enabled", return_value=False | ||
| ) | ||
|
|
||
| # Mock database and topic summary | ||
| mock_session = mocker.Mock() | ||
| mock_session.query.return_value.filter_by.return_value.first.return_value = ( | ||
| mocker.Mock() | ||
| ) | ||
| mock_context_manager = mocker.Mock() | ||
| mock_context_manager.__enter__ = mocker.Mock(return_value=mock_session) | ||
| mock_context_manager.__exit__ = mocker.Mock(return_value=None) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.get_session", | ||
| return_value=mock_context_manager, | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.get_topic_summary", | ||
| mocker.AsyncMock(return_value=""), | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.store_conversation_into_cache", | ||
| return_value=None, | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.create_referenced_documents_with_metadata", | ||
| return_value=[], | ||
| ) | ||
|
|
||
| # Mock the validation error metric | ||
| validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total") | ||
|
|
||
| # Build a fake async stream without violation | ||
| async def fake_stream_without_violation(): | ||
| yield SimpleNamespace( | ||
| type="response.created", response=SimpleNamespace(id="conv-safe") | ||
| ) | ||
| yield SimpleNamespace(type="response.output_text.delta", delta="Safe ") | ||
| yield SimpleNamespace(type="response.output_text.done", text="Safe response") | ||
| # Response completed without refusal | ||
| safe_item = SimpleNamespace( | ||
| type="message", | ||
| role="assistant", | ||
| refusal=None, # No violation | ||
| ) | ||
| response_safe = SimpleNamespace(id="conv-safe", output=[safe_item]) | ||
| yield SimpleNamespace(type="response.completed", response=response_safe) | ||
|
|
||
| mocker.patch( | ||
| "app.endpoints.streaming_query_v2.retrieve_response", | ||
| return_value=(fake_stream_without_violation(), ""), | ||
| ) | ||
|
|
||
| mocker.patch("metrics.llm_calls_total") | ||
|
|
||
| resp = await streaming_query_endpoint_handler_v2( | ||
| request=dummy_request, | ||
| query_request=QueryRequest(query="safe query"), | ||
| auth=("user123", "", True, "token-abc"), | ||
| mcp_headers={}, | ||
| ) | ||
|
|
||
| assert isinstance(resp, StreamingResponse) | ||
|
|
||
| # Collect emitted events to trigger the generator | ||
| events: list[str] = [] | ||
| async for chunk in resp.body_iterator: | ||
| s = chunk.decode() if isinstance(chunk, (bytes, bytearray)) else str(chunk) | ||
| events.append(s) | ||
|
|
||
| # Verify that the validation error metric was NOT incremented | ||
| validation_metric.inc.assert_not_called() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apply the same patch-target fix in the “no shield violation” streaming test.
The second streaming shield test mirrors the first one and repeats the same incorrect patch targets (app.endpoints.streaming_query_v2.get_session, store_conversation_into_cache, create_referenced_documents_with_metadata). Update them to patch utils.endpoints.* instead, just like suggested above, so both tests stub the real dependencies and avoid the same AttributeError.
🤖 Prompt for AI Agents
In tests/unit/app/endpoints/test_streaming_query_v2.py around lines 411 to 522,
the test currently patches app.endpoints.streaming_query_v2.get_session,
app.endpoints.streaming_query_v2.store_conversation_into_cache and
app.endpoints.streaming_query_v2.create_referenced_documents_with_metadata which
is the wrong target and causes AttributeError; change those patch targets to the
utils.endpoints module counterparts (utils.endpoints.get_session,
utils.endpoints.store_conversation_into_cache,
utils.endpoints.create_referenced_documents_with_metadata) keeping the same
return values/AsyncMock usage so the test stubs the real dependencies correctly.
It includes both streaming and not streaming support, by leveraging the refusal field on the response
b669b32 to
fe1d7b6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (5)
tests/unit/app/endpoints/test_query_v2.py (1)
123-124: Consider a pytest fixture for the shields mock to reduce duplication.Multiple tests now mock
shields.listidentically. A shared fixture could eliminate this repetition and make future updates easier.Example:
@pytest.fixture def mock_empty_shields(mocker: MockerFixture) -> None: """Mock shields.list to return empty list.""" return mocker.AsyncMock(return_value=[])Then update tests to use it:
- # Mock shields.list - mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.shields.list = mock_empty_shieldsAlso applies to: 161-162, 229-230, 276-277, 315-316, 354-355, 384-385
tests/unit/app/endpoints/test_streaming_query_v2.py (1)
41-42: Consider a pytest fixture for the shields mock (applies here too).Same as in
test_query_v2.py, a shared fixture could reduce duplication across streaming tests.Also applies to: 73-74
src/app/endpoints/query_v2.py (2)
345-350: Shield discovery logic is duplicated in streaming_query_v2.py.The shield discovery (lines 345-350) and propagation (lines 388-390) logic is identical in
streaming_query_v2.py(lines 381-386, 423-425). Consider extracting this into a shared utility function to maintain consistency and reduce duplication.Example:
# In utils/shields.py or similar async def get_available_shields( client: AsyncLlamaStackClient ) -> list[str]: """Discover and return available shield identifiers.""" shields = [shield.identifier for shield in await client.shields.list()] if not shields: logger.info("No available shields. Disabling safety") else: logger.info("Available shields: %s", shields) return shieldsThen use it in both files:
available_shields = await get_available_shields(client) # ... if available_shields: create_kwargs["extra_body"] = {"guardrails": available_shields}Also applies to: 388-390
417-424: Shield violation detection is duplicated in streaming_query_v2.py.The violation detection logic (lines 417-424) is nearly identical in
streaming_query_v2.py(lines 248-257). Consider extracting this into a shared utility function alongside the shield discovery refactor.Example:
# In utils/shields.py def detect_shield_violations(output_items: list[Any]) -> bool: """Check output items for shield violations and update metrics.""" for output_item in output_items: item_type = getattr(output_item, "type", None) if item_type == "message": refusal = getattr(output_item, "refusal", None) if refusal: metrics.llm_calls_validation_errors_total.inc() logger.warning("Shield violation detected: %s", refusal) return True return Falsesrc/app/endpoints/streaming_query_v2.py (1)
381-386: Shield discovery and propagation duplicated from query_v2.py.This code is identical to
query_v2.pylines 345-350 and 388-390. Please see the refactoring suggestion in thequery_v2.pyreview to extract this into a shared utility.Also applies to: 423-425
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/app/endpoints/query_v2.py(4 hunks)src/app/endpoints/streaming_query_v2.py(5 hunks)tests/unit/app/endpoints/test_query_v2.py(8 hunks)tests/unit/app/endpoints/test_streaming_query_v2.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/unit/app/endpoints/test_streaming_query_v2.py (3)
tests/unit/app/endpoints/test_query_v2.py (2)
test_retrieve_response_with_shields_available(535-576)test_retrieve_response_with_no_shields_available(580-617)src/models/requests.py (1)
QueryRequest(73-225)src/app/endpoints/streaming_query_v2.py (2)
retrieve_response(351-432)streaming_query_endpoint_handler_v2(321-348)
tests/unit/app/endpoints/test_query_v2.py (2)
tests/unit/app/endpoints/test_streaming_query_v2.py (2)
test_retrieve_response_with_shields_available(232-262)test_retrieve_response_with_no_shields_available(266-292)src/app/endpoints/query_v2.py (1)
retrieve_response(309-449)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Konflux kflux-prd-rh02 / lightspeed-stack-on-pull-request
- GitHub Check: build-pr
- GitHub Check: e2e_tests (azure)
- GitHub Check: e2e_tests (ci)
🔇 Additional comments (6)
tests/unit/app/endpoints/test_query_v2.py (1)
534-711: Well-structured test coverage for shields integration.The four new tests comprehensively verify shield behavior: availability detection, guardrail propagation, violation detection, and metric updates. The test structure is clear and follows established patterns.
tests/unit/app/endpoints/test_streaming_query_v2.py (1)
231-472: Excellent streaming shield test coverage.The new tests verify shields integration at both the retrieve_response level (guardrail propagation) and the full streaming endpoint level (violation detection in SSE flow). The mock streaming chunks with refusal data are properly structured.
src/app/endpoints/query_v2.py (1)
325-330: Good documentation update.The docstring now accurately reflects that shields are configured alongside system prompts and toolgroups.
src/app/endpoints/streaming_query_v2.py (3)
35-35: Metrics import is correct.The import is necessary for incrementing
llm_calls_validation_errors_totalwhen shield violations are detected.
248-257: Shield violation detection is correctly placed in streaming flow.The violation check happens after
response.completed, using the capturedlatest_response_object. This ensures all output items are available for inspection. However, note the duplication concern flagged inquery_v2.py.
364-368: Docstring accurately reflects shield configuration.The updated docstring correctly describes that shields are configured alongside system prompts and tools.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
src/utils/shields.py (2)
13-28: Make shield discovery more robust and avoid failing the request on list() errorsRight now any failure in
await client.shields.list()will bubble up and fail the whole request, even though shields are optional. It also assumes that every returned shield has a non‑emptyidentifierattribute.To make this path safer and less brittle, consider:
- Catching exceptions from
client.shields.list()and falling back to[]while logging a warning.- Normalizing identifiers via
getattrand filtering outNone/empty ones.- Tweaking the log message to reflect that we’re “running without guardrails” rather than “disabling safety,” since this function doesn’t actually toggle anything.
For example:
- available_shields = [shield.identifier for shield in await client.shields.list()] - if not available_shields: - logger.info("No available shields. Disabling safety") - else: - logger.info("Available shields: %s", available_shields) - return available_shields + try: + shields = await client.shields.list() + except Exception as exc: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to list shields from Llama Stack, proceeding without guardrails: %s", + exc, + ) + return [] + + available_shields = [ + getattr(shield, "identifier", None) for shield in shields + ] + available_shields = [s for s in available_shields if s] + + if not available_shields: + logger.info("No available shields discovered; proceeding without guardrails") + else: + logger.info("Available shields: %s", available_shields) + return available_shieldsThis keeps shields best‑effort and avoids a hard dependency on the exact return type shape.
31-53: Clarify what counts as a “shield violation” and consider expanding detection coverageThe implementation is simple and effective for the specific case of a
messageoutput item with a top‑level, non‑emptyrefusalattribute, and the metric increment + warning log make sense.A couple of follow‑ups to consider:
- If the Responses API ever encodes refusals inside
contentparts (e.g., a part with arefusalfield, similar to your_extract_text_from_response_output_itemhandling), this helper will miss them. You may want to either:
- Explicitly document that you only consider top‑level
output_item.refusal, or- Extend detection to mirror the traversal you already do in
_extract_text_from_response_output_item.- The function returns
boolbut the current call site ignores the return value; if you don’t plan to branch on this, you could either drop the return value (and treat this as a pure side‑effect helper) or have the caller log/annotate based on it for observability.No blocking issues here, but tightening the definition and coverage now will avoid surprises when the response schema evolves.
src/app/endpoints/query_v2.py (1)
346-348: Shield integration is correct at a high level; consider resilience, caching, and return‑value useThe overall flow makes sense:
- Discover shields once via
available_shields = await get_available_shields(client).- If any are present, inject them into
create_kwargs["extra_body"] = {"guardrails": available_shields}so they apply to the Responses call.- After processing all output items, run
detect_shield_violations(response.output)to bump the validation‑error metric when a refusal is present.A few refinements worth considering:
Don’t let shield listing failures break the request
As
get_available_shieldsis called on every request, any transient failure inclient.shields.list()will currently abort the whole/querycall. If shields are meant to be best‑effort, it’s safer to letget_available_shieldsswallow errors and return[](see suggested refactor insrc/utils/shields.py) so this endpoint remains functional even when shield discovery is flaky.Avoid repeated remote shield discovery per request
You’re now doing a round‑trip to Llama Stack shields on every call to
retrieve_response. If the set of shields is relatively static per process/config, consider caching them at process level (or with a short TTL) and reusing them instead of calling.list()every time. This will reduce latency and load on Llama Stack.Validate extra_body / guardrails shape against the client API
The
extra_body={"guardrails": available_shields}convention relies on the currentAsyncLlamaStackClient.responses.createsemantics. If the guardrails field ever changes name/shape upstream, this will fail at runtime. It’s worth double‑checking this against the version of the Llama Stack client you’re targeting.Use or drop the boolean return from
detect_shield_violationsSince the return value is currently ignored, either:
- Capture it and, for example, log a high‑level “shield violation occurred” log with context (conversation_id, model), or
- Change
detect_shield_violationsto returnNoneif you intend it to be purely side‑effectful. This would make the API intent clearer.None of these are blockers, but addressing (1) and (2) in particular will make this path more robust in production.
Also applies to: 385-388, 414-416
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/app/endpoints/query_v2.py(5 hunks)src/app/endpoints/streaming_query_v2.py(5 hunks)src/utils/shields.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/app/endpoints/streaming_query_v2.py
🧰 Additional context used
🧬 Code graph analysis (1)
src/app/endpoints/query_v2.py (1)
src/utils/shields.py (2)
detect_shield_violations(31-54)get_available_shields(13-28)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: build-pr
- GitHub Check: e2e_tests (ci)
- GitHub Check: e2e_tests (azure)
- GitHub Check: Konflux kflux-prd-rh02 / lightspeed-stack-on-pull-request
🔇 Additional comments (3)
src/utils/shields.py (1)
1-11: Module setup looks straightforwardDocstring, imports, and logger initialization are consistent with the rest of the codebase style; no issues from a structure standpoint.
src/app/endpoints/query_v2.py (2)
37-37: Shields utilities import is appropriately scopedPulling
detect_shield_violationsandget_available_shieldsinto this module keeps the response logic focused and avoids leaking shield details into unrelated layers.
326-332: Docstring accurately reflects the new shields behaviorThe expanded description calling out shields alongside system prompts and toolgroups matches the actual implementation below (guardrails in
extra_bodyand violation metrics), and keeps the function contract clear.
|
/lgtm |
tisnik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
Extend the Responses API support (v2 endpoints) by also adding the option to use shields.
Note there is a limitation in LlamaStack where the same shields must be used for input and output.
Type of change
Related Tickets & Documents
Checklist before requesting a review
Testing
Summary by CodeRabbit
New Features
Tests