@@ -120,6 +120,8 @@ async def test_retrieve_response_no_tools_bypasses_tools(mocker: MockerFixture)
120120 mock_vector_stores = mocker .Mock ()
121121 mock_vector_stores .data = []
122122 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
123+ # Mock shields.list
124+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
123125
124126 # Ensure system prompt resolution does not require real config
125127 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
@@ -156,6 +158,8 @@ async def test_retrieve_response_builds_rag_and_mcp_tools(
156158 mock_vector_stores = mocker .Mock ()
157159 mock_vector_stores .data = [mocker .Mock (id = "dbA" )]
158160 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
161+ # Mock shields.list
162+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
159163
160164 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
161165 mock_cfg = mocker .Mock ()
@@ -222,6 +226,8 @@ async def test_retrieve_response_parses_output_and_tool_calls(
222226 mock_vector_stores = mocker .Mock ()
223227 mock_vector_stores .data = []
224228 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
229+ # Mock shields.list
230+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
225231
226232 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
227233 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -267,6 +273,8 @@ async def test_retrieve_response_with_usage_info(mocker: MockerFixture) -> None:
267273 mock_vector_stores = mocker .Mock ()
268274 mock_vector_stores .data = []
269275 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
276+ # Mock shields.list
277+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
270278
271279 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
272280 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -304,6 +312,8 @@ async def test_retrieve_response_with_usage_dict(mocker: MockerFixture) -> None:
304312 mock_vector_stores = mocker .Mock ()
305313 mock_vector_stores .data = []
306314 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
315+ # Mock shields.list
316+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
307317
308318 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
309319 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -341,6 +351,8 @@ async def test_retrieve_response_with_empty_usage_dict(mocker: MockerFixture) ->
341351 mock_vector_stores = mocker .Mock ()
342352 mock_vector_stores .data = []
343353 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
354+ # Mock shields.list
355+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
344356
345357 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
346358 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -369,6 +381,8 @@ async def test_retrieve_response_validates_attachments(mocker: MockerFixture) ->
369381 mock_vector_stores = mocker .Mock ()
370382 mock_vector_stores .data = []
371383 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
384+ # Mock shields.list
385+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
372386
373387 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
374388 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -515,3 +529,183 @@ async def test_query_endpoint_quota_exceeded(
515529 assert isinstance (detail , dict )
516530 assert detail ["response" ] == "Model quota exceeded" # type: ignore
517531 assert "gpt-4-turbo" in detail ["cause" ] # type: ignore
532+
533+
534+ @pytest .mark .asyncio
535+ async def test_retrieve_response_with_shields_available (mocker : MockerFixture ) -> None :
536+ """Test that shields are listed and passed to responses API when available."""
537+ mock_client = mocker .Mock ()
538+
539+ # Mock shields.list to return available shields
540+ shield1 = mocker .Mock ()
541+ shield1 .identifier = "shield-1"
542+ shield2 = mocker .Mock ()
543+ shield2 .identifier = "shield-2"
544+ mock_client .shields .list = mocker .AsyncMock (return_value = [shield1 , shield2 ])
545+
546+ output_item = mocker .Mock ()
547+ output_item .type = "message"
548+ output_item .role = "assistant"
549+ output_item .content = "Safe response"
550+
551+ response_obj = mocker .Mock ()
552+ response_obj .id = "resp-shields"
553+ response_obj .output = [output_item ]
554+ response_obj .usage = None
555+
556+ mock_client .responses .create = mocker .AsyncMock (return_value = response_obj )
557+ mock_vector_stores = mocker .Mock ()
558+ mock_vector_stores .data = []
559+ mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
560+
561+ mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
562+ mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
563+
564+ qr = QueryRequest (query = "hello" )
565+ summary , conv_id , _referenced_docs , _token_usage = await retrieve_response (
566+ mock_client , "model-shields" , qr , token = "tkn" , provider_id = "test-provider"
567+ )
568+
569+ assert conv_id == "resp-shields"
570+ assert summary .llm_response == "Safe response"
571+
572+ # Verify that shields were passed in extra_body
573+ kwargs = mock_client .responses .create .call_args .kwargs
574+ assert "extra_body" in kwargs
575+ assert "guardrails" in kwargs ["extra_body" ]
576+ assert kwargs ["extra_body" ]["guardrails" ] == ["shield-1" , "shield-2" ]
577+
578+
579+ @pytest .mark .asyncio
580+ async def test_retrieve_response_with_no_shields_available (
581+ mocker : MockerFixture ,
582+ ) -> None :
583+ """Test that no extra_body is added when no shields are available."""
584+ mock_client = mocker .Mock ()
585+
586+ # Mock shields.list to return no shields
587+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
588+
589+ output_item = mocker .Mock ()
590+ output_item .type = "message"
591+ output_item .role = "assistant"
592+ output_item .content = "Response without shields"
593+
594+ response_obj = mocker .Mock ()
595+ response_obj .id = "resp-no-shields"
596+ response_obj .output = [output_item ]
597+ response_obj .usage = None
598+
599+ mock_client .responses .create = mocker .AsyncMock (return_value = response_obj )
600+ mock_vector_stores = mocker .Mock ()
601+ mock_vector_stores .data = []
602+ mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
603+
604+ mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
605+ mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
606+
607+ qr = QueryRequest (query = "hello" )
608+ summary , conv_id , _referenced_docs , _token_usage = await retrieve_response (
609+ mock_client , "model-no-shields" , qr , token = "tkn" , provider_id = "test-provider"
610+ )
611+
612+ assert conv_id == "resp-no-shields"
613+ assert summary .llm_response == "Response without shields"
614+
615+ # Verify that no extra_body was added
616+ kwargs = mock_client .responses .create .call_args .kwargs
617+ assert "extra_body" not in kwargs
618+
619+
620+ @pytest .mark .asyncio
621+ async def test_retrieve_response_detects_shield_violation (
622+ mocker : MockerFixture ,
623+ ) -> None :
624+ """Test that shield violations are detected and metrics are incremented."""
625+ mock_client = mocker .Mock ()
626+
627+ # Mock shields.list to return available shields
628+ shield1 = mocker .Mock ()
629+ shield1 .identifier = "safety-shield"
630+ mock_client .shields .list = mocker .AsyncMock (return_value = [shield1 ])
631+
632+ # Create output with shield violation (refusal)
633+ output_item = mocker .Mock ()
634+ output_item .type = "message"
635+ output_item .role = "assistant"
636+ output_item .content = "I cannot help with that request"
637+ output_item .refusal = "Content violates safety policy"
638+
639+ response_obj = mocker .Mock ()
640+ response_obj .id = "resp-violation"
641+ response_obj .output = [output_item ]
642+ response_obj .usage = None
643+
644+ mock_client .responses .create = mocker .AsyncMock (return_value = response_obj )
645+ mock_vector_stores = mocker .Mock ()
646+ mock_vector_stores .data = []
647+ mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
648+
649+ mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
650+ mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
651+
652+ # Mock the validation error metric
653+ validation_metric = mocker .patch ("metrics.llm_calls_validation_errors_total" )
654+
655+ qr = QueryRequest (query = "dangerous query" )
656+ summary , conv_id , _referenced_docs , _token_usage = await retrieve_response (
657+ mock_client , "model-violation" , qr , token = "tkn" , provider_id = "test-provider"
658+ )
659+
660+ assert conv_id == "resp-violation"
661+ assert summary .llm_response == "I cannot help with that request"
662+
663+ # Verify that the validation error metric was incremented
664+ validation_metric .inc .assert_called_once ()
665+
666+
667+ @pytest .mark .asyncio
668+ async def test_retrieve_response_no_violation_with_shields (
669+ mocker : MockerFixture ,
670+ ) -> None :
671+ """Test that no metric is incremented when there's no shield violation."""
672+ mock_client = mocker .Mock ()
673+
674+ # Mock shields.list to return available shields
675+ shield1 = mocker .Mock ()
676+ shield1 .identifier = "safety-shield"
677+ mock_client .shields .list = mocker .AsyncMock (return_value = [shield1 ])
678+
679+ # Create output without shield violation
680+ output_item = mocker .Mock ()
681+ output_item .type = "message"
682+ output_item .role = "assistant"
683+ output_item .content = "Safe response"
684+ output_item .refusal = None # No violation
685+
686+ response_obj = mocker .Mock ()
687+ response_obj .id = "resp-safe"
688+ response_obj .output = [output_item ]
689+ response_obj .usage = None
690+
691+ mock_client .responses .create = mocker .AsyncMock (return_value = response_obj )
692+ mock_vector_stores = mocker .Mock ()
693+ mock_vector_stores .data = []
694+ mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
695+
696+ mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
697+ mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
698+
699+ # Mock the validation error metric
700+ validation_metric = mocker .patch ("metrics.llm_calls_validation_errors_total" )
701+
702+ qr = QueryRequest (query = "safe query" )
703+ summary , conv_id , _referenced_docs , _token_usage = await retrieve_response (
704+ mock_client , "model-safe" , qr , token = "tkn" , provider_id = "test-provider"
705+ )
706+
707+ assert conv_id == "resp-safe"
708+ assert summary .llm_response == "Safe response"
709+
710+ # Verify that the validation error metric was NOT incremented
711+ validation_metric .inc .assert_not_called ()
0 commit comments