1212from llama_stack_client .lib .agents .agent import Agent
1313from llama_stack_client import APIConnectionError
1414from llama_stack_client import LlamaStackClient # type: ignore
15- from llama_stack_client .types import UserMessage # type: ignore
15+ from llama_stack_client .types import UserMessage , Shield # type: ignore
1616from llama_stack_client .types .agents .turn_create_params import (
1717 ToolgroupAgentToolGroupWithArgs ,
1818 Toolgroup ,
@@ -72,11 +72,12 @@ def is_transcripts_enabled() -> bool:
7272 return not configuration .user_data_collection_configuration .transcripts_disabled
7373
7474
75- def get_agent (
75+ def get_agent ( # pylint: disable=too-many-arguments,too-many-positional-arguments
7676 client : LlamaStackClient ,
7777 model_id : str ,
7878 system_prompt : str ,
79- available_shields : list [str ],
79+ available_input_shields : list [str ],
80+ available_output_shields : list [str ],
8081 conversation_id : str | None ,
8182) -> tuple [Agent , str ]:
8283 """Get existing agent or create a new one with session persistence."""
@@ -92,7 +93,8 @@ def get_agent(
9293 client ,
9394 model = model_id ,
9495 instructions = system_prompt ,
95- input_shields = available_shields if available_shields else [],
96+ input_shields = available_input_shields if available_input_shields else [],
97+ output_shields = available_output_shields if available_output_shields else [],
9698 tool_parser = GraniteToolParser .get_parser (model_id ),
9799 enable_session_persistence = True ,
98100 )
@@ -202,6 +204,20 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s
202204 return model_id
203205
204206
207+ def _is_inout_shield (shield : Shield ) -> bool :
208+ return shield .identifier .startswith ("inout_" )
209+
210+
211+ def is_output_shield (shield : Shield ) -> bool :
212+ """Determine if the shield is for monitoring output."""
213+ return _is_inout_shield (shield ) or shield .identifier .startswith ("output_" )
214+
215+
216+ def is_input_shield (shield : Shield ) -> bool :
217+ """Determine if the shield is for monitoring input."""
218+ return _is_inout_shield (shield ) or not is_output_shield (shield )
219+
220+
205221def retrieve_response (
206222 client : LlamaStackClient ,
207223 model_id : str ,
@@ -210,12 +226,20 @@ def retrieve_response(
210226 mcp_headers : dict [str , dict [str , str ]] | None = None ,
211227) -> tuple [str , str ]:
212228 """Retrieve response from LLMs and agents."""
213- available_shields = [shield .identifier for shield in client .shields .list ()]
214- if not available_shields :
229+ available_input_shields = [
230+ shield .identifier for shield in filter (is_input_shield , client .shields .list ())
231+ ]
232+ available_output_shields = [
233+ shield .identifier for shield in filter (is_output_shield , client .shields .list ())
234+ ]
235+ if not available_input_shields and not available_output_shields :
215236 logger .info ("No available shields. Disabling safety" )
216237 else :
217- logger .info ("Available shields found: %s" , available_shields )
218-
238+ logger .info (
239+ "Available input shields: %s, output shields: %s" ,
240+ available_input_shields ,
241+ available_output_shields ,
242+ )
219243 # use system prompt from request or default one
220244 system_prompt = get_system_prompt (query_request , configuration )
221245 logger .debug ("Using system prompt: %s" , system_prompt )
@@ -229,7 +253,8 @@ def retrieve_response(
229253 client ,
230254 model_id ,
231255 system_prompt ,
232- available_shields ,
256+ available_input_shields ,
257+ available_output_shields ,
233258 query_request .conversation_id ,
234259 )
235260
0 commit comments