@@ -120,6 +120,8 @@ async def test_retrieve_response_no_tools_bypasses_tools(mocker: MockerFixture)
120120 mock_vector_stores = mocker .Mock ()
121121 mock_vector_stores .data = []
122122 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
123+ # Mock shields.list
124+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
123125
124126 # Ensure system prompt resolution does not require real config
125127 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
@@ -156,6 +158,8 @@ async def test_retrieve_response_builds_rag_and_mcp_tools(
156158 mock_vector_stores = mocker .Mock ()
157159 mock_vector_stores .data = [mocker .Mock (id = "dbA" )]
158160 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
161+ # Mock shields.list
162+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
159163
160164 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
161165 mock_cfg = mocker .Mock ()
@@ -222,6 +226,8 @@ async def test_retrieve_response_parses_output_and_tool_calls(
222226 mock_vector_stores = mocker .Mock ()
223227 mock_vector_stores .data = []
224228 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
229+ # Mock shields.list
230+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
225231
226232 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
227233 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -267,6 +273,8 @@ async def test_retrieve_response_with_usage_info(mocker: MockerFixture) -> None:
267273 mock_vector_stores = mocker .Mock ()
268274 mock_vector_stores .data = []
269275 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
276+ # Mock shields.list
277+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
270278
271279 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
272280 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -304,6 +312,8 @@ async def test_retrieve_response_with_usage_dict(mocker: MockerFixture) -> None:
304312 mock_vector_stores = mocker .Mock ()
305313 mock_vector_stores .data = []
306314 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
315+ # Mock shields.list
316+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
307317
308318 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
309319 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -341,6 +351,8 @@ async def test_retrieve_response_with_empty_usage_dict(mocker: MockerFixture) ->
341351 mock_vector_stores = mocker .Mock ()
342352 mock_vector_stores .data = []
343353 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
354+ # Mock shields.list
355+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
344356
345357 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
346358 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -369,6 +381,8 @@ async def test_retrieve_response_validates_attachments(mocker: MockerFixture) ->
369381 mock_vector_stores = mocker .Mock ()
370382 mock_vector_stores .data = []
371383 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
384+ # Mock shields.list
385+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
372386
373387 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
374388 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -515,3 +529,177 @@ async def test_query_endpoint_quota_exceeded(
515529 assert isinstance (detail , dict )
516530 assert detail ["response" ] == "Model quota exceeded" # type: ignore
517531 assert "gpt-4-turbo" in detail ["cause" ] # type: ignore
532+
533+
534+ @pytest .mark .asyncio
535+ async def test_retrieve_response_with_shields_available (mocker ):
536+ """Test that shields are listed and passed to responses API when available."""
537+ mock_client = mocker .Mock ()
538+
539+ # Mock shields.list to return available shields
540+ shield1 = mocker .Mock ()
541+ shield1 .identifier = "shield-1"
542+ shield2 = mocker .Mock ()
543+ shield2 .identifier = "shield-2"
544+ mock_client .shields .list = mocker .AsyncMock (return_value = [shield1 , shield2 ])
545+
546+ output_item = mocker .Mock ()
547+ output_item .type = "message"
548+ output_item .role = "assistant"
549+ output_item .content = "Safe response"
550+
551+ response_obj = mocker .Mock ()
552+ response_obj .id = "resp-shields"
553+ response_obj .output = [output_item ]
554+ response_obj .usage = None
555+
556+ mock_client .responses .create = mocker .AsyncMock (return_value = response_obj )
557+ mock_vector_stores = mocker .Mock ()
558+ mock_vector_stores .data = []
559+ mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
560+
561+ mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
562+ mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
563+
564+ qr = QueryRequest (query = "hello" )
565+ summary , conv_id , _referenced_docs , _token_usage = await retrieve_response (
566+ mock_client , "model-shields" , qr , token = "tkn" , provider_id = "test-provider"
567+ )
568+
569+ assert conv_id == "resp-shields"
570+ assert summary .llm_response == "Safe response"
571+
572+ # Verify that shields were passed in extra_body
573+ kwargs = mock_client .responses .create .call_args .kwargs
574+ assert "extra_body" in kwargs
575+ assert "guardrails" in kwargs ["extra_body" ]
576+ assert kwargs ["extra_body" ]["guardrails" ] == ["shield-1" , "shield-2" ]
577+
578+
579+ @pytest .mark .asyncio
580+ async def test_retrieve_response_with_no_shields_available (mocker ):
581+ """Test that no extra_body is added when no shields are available."""
582+ mock_client = mocker .Mock ()
583+
584+ # Mock shields.list to return no shields
585+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
586+
587+ output_item = mocker .Mock ()
588+ output_item .type = "message"
589+ output_item .role = "assistant"
590+ output_item .content = "Response without shields"
591+
592+ response_obj = mocker .Mock ()
593+ response_obj .id = "resp-no-shields"
594+ response_obj .output = [output_item ]
595+ response_obj .usage = None
596+
597+ mock_client .responses .create = mocker .AsyncMock (return_value = response_obj )
598+ mock_vector_stores = mocker .Mock ()
599+ mock_vector_stores .data = []
600+ mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
601+
602+ mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
603+ mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
604+
605+ qr = QueryRequest (query = "hello" )
606+ summary , conv_id , _referenced_docs , _token_usage = await retrieve_response (
607+ mock_client , "model-no-shields" , qr , token = "tkn" , provider_id = "test-provider"
608+ )
609+
610+ assert conv_id == "resp-no-shields"
611+ assert summary .llm_response == "Response without shields"
612+
613+ # Verify that no extra_body was added
614+ kwargs = mock_client .responses .create .call_args .kwargs
615+ assert "extra_body" not in kwargs
616+
617+
618+ @pytest .mark .asyncio
619+ async def test_retrieve_response_detects_shield_violation (mocker ):
620+ """Test that shield violations are detected and metrics are incremented."""
621+ mock_client = mocker .Mock ()
622+
623+ # Mock shields.list to return available shields
624+ shield1 = mocker .Mock ()
625+ shield1 .identifier = "safety-shield"
626+ mock_client .shields .list = mocker .AsyncMock (return_value = [shield1 ])
627+
628+ # Create output with shield violation (refusal)
629+ output_item = mocker .Mock ()
630+ output_item .type = "message"
631+ output_item .role = "assistant"
632+ output_item .content = "I cannot help with that request"
633+ output_item .refusal = "Content violates safety policy"
634+
635+ response_obj = mocker .Mock ()
636+ response_obj .id = "resp-violation"
637+ response_obj .output = [output_item ]
638+ response_obj .usage = None
639+
640+ mock_client .responses .create = mocker .AsyncMock (return_value = response_obj )
641+ mock_vector_stores = mocker .Mock ()
642+ mock_vector_stores .data = []
643+ mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
644+
645+ mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
646+ mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
647+
648+ # Mock the validation error metric
649+ validation_metric = mocker .patch ("metrics.llm_calls_validation_errors_total" )
650+
651+ qr = QueryRequest (query = "dangerous query" )
652+ summary , conv_id , _referenced_docs , _token_usage = await retrieve_response (
653+ mock_client , "model-violation" , qr , token = "tkn" , provider_id = "test-provider"
654+ )
655+
656+ assert conv_id == "resp-violation"
657+ assert summary .llm_response == "I cannot help with that request"
658+
659+ # Verify that the validation error metric was incremented
660+ validation_metric .inc .assert_called_once ()
661+
662+
663+ @pytest .mark .asyncio
664+ async def test_retrieve_response_no_violation_with_shields (mocker ):
665+ """Test that no metric is incremented when there's no shield violation."""
666+ mock_client = mocker .Mock ()
667+
668+ # Mock shields.list to return available shields
669+ shield1 = mocker .Mock ()
670+ shield1 .identifier = "safety-shield"
671+ mock_client .shields .list = mocker .AsyncMock (return_value = [shield1 ])
672+
673+ # Create output without shield violation
674+ output_item = mocker .Mock ()
675+ output_item .type = "message"
676+ output_item .role = "assistant"
677+ output_item .content = "Safe response"
678+ output_item .refusal = None # No violation
679+
680+ response_obj = mocker .Mock ()
681+ response_obj .id = "resp-safe"
682+ response_obj .output = [output_item ]
683+ response_obj .usage = None
684+
685+ mock_client .responses .create = mocker .AsyncMock (return_value = response_obj )
686+ mock_vector_stores = mocker .Mock ()
687+ mock_vector_stores .data = []
688+ mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
689+
690+ mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
691+ mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
692+
693+ # Mock the validation error metric
694+ validation_metric = mocker .patch ("metrics.llm_calls_validation_errors_total" )
695+
696+ qr = QueryRequest (query = "safe query" )
697+ summary , conv_id , _referenced_docs , _token_usage = await retrieve_response (
698+ mock_client , "model-safe" , qr , token = "tkn" , provider_id = "test-provider"
699+ )
700+
701+ assert conv_id == "resp-safe"
702+ assert summary .llm_response == "Safe response"
703+
704+ # Verify that the validation error metric was NOT incremented
705+ validation_metric .inc .assert_not_called ()
0 commit comments