Skip to content

Commit 0cb49fc

Browse files
authored
Merge pull request #248 from lightspeed-core/TamiTakamiya/support-output-shields-in-agents
Add support output_shields in agents
2 parents 1c873fb + 52034bf commit 0cb49fc

File tree

5 files changed

+249
-37
lines changed

5 files changed

+249
-37
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,16 @@ customization:
154154
disable_query_system_prompt: true
155155
```
156156

157+
## Safety Shields
158+
159+
A single Llama Stack configuration file can include multiple safety shields, which are utilized in agent
160+
configurations to monitor input and/or output streams. LCS uses the following naming convention to specify how each safety shield is
161+
utilized:
162+
163+
1. If the `shield_id` starts with `input_`, it will be used for input only.
164+
1. If the `shield_id` starts with `output_`, it will be used for output only.
165+
1. If the `shield_id` starts with `inout_`, it will be used both for input and output.
166+
1. Otherwise, it will be used for input only.
157167

158168
# Usage
159169

src/app/endpoints/query.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from llama_stack_client.lib.agents.agent import Agent
1313
from llama_stack_client import APIConnectionError
1414
from llama_stack_client import LlamaStackClient # type: ignore
15-
from llama_stack_client.types import UserMessage # type: ignore
15+
from llama_stack_client.types import UserMessage, Shield # type: ignore
1616
from llama_stack_client.types.agents.turn_create_params import (
1717
ToolgroupAgentToolGroupWithArgs,
1818
Toolgroup,
@@ -72,11 +72,12 @@ def is_transcripts_enabled() -> bool:
7272
return not configuration.user_data_collection_configuration.transcripts_disabled
7373

7474

75-
def get_agent(
75+
def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments
7676
client: LlamaStackClient,
7777
model_id: str,
7878
system_prompt: str,
79-
available_shields: list[str],
79+
available_input_shields: list[str],
80+
available_output_shields: list[str],
8081
conversation_id: str | None,
8182
) -> tuple[Agent, str]:
8283
"""Get existing agent or create a new one with session persistence."""
@@ -92,7 +93,8 @@ def get_agent(
9293
client,
9394
model=model_id,
9495
instructions=system_prompt,
95-
input_shields=available_shields if available_shields else [],
96+
input_shields=available_input_shields if available_input_shields else [],
97+
output_shields=available_output_shields if available_output_shields else [],
9698
tool_parser=GraniteToolParser.get_parser(model_id),
9799
enable_session_persistence=True,
98100
)
@@ -202,6 +204,20 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s
202204
return model_id
203205

204206

207+
def _is_inout_shield(shield: Shield) -> bool:
208+
return shield.identifier.startswith("inout_")
209+
210+
211+
def is_output_shield(shield: Shield) -> bool:
212+
"""Determine if the shield is for monitoring output."""
213+
return _is_inout_shield(shield) or shield.identifier.startswith("output_")
214+
215+
216+
def is_input_shield(shield: Shield) -> bool:
217+
"""Determine if the shield is for monitoring input."""
218+
return _is_inout_shield(shield) or not is_output_shield(shield)
219+
220+
205221
def retrieve_response(
206222
client: LlamaStackClient,
207223
model_id: str,
@@ -210,12 +226,20 @@ def retrieve_response(
210226
mcp_headers: dict[str, dict[str, str]] | None = None,
211227
) -> tuple[str, str]:
212228
"""Retrieve response from LLMs and agents."""
213-
available_shields = [shield.identifier for shield in client.shields.list()]
214-
if not available_shields:
229+
available_input_shields = [
230+
shield.identifier for shield in filter(is_input_shield, client.shields.list())
231+
]
232+
available_output_shields = [
233+
shield.identifier for shield in filter(is_output_shield, client.shields.list())
234+
]
235+
if not available_input_shields and not available_output_shields:
215236
logger.info("No available shields. Disabling safety")
216237
else:
217-
logger.info("Available shields found: %s", available_shields)
218-
238+
logger.info(
239+
"Available input shields: %s, output shields: %s",
240+
available_input_shields,
241+
available_output_shields,
242+
)
219243
# use system prompt from request or default one
220244
system_prompt = get_system_prompt(query_request, configuration)
221245
logger.debug("Using system prompt: %s", system_prompt)
@@ -229,7 +253,8 @@ def retrieve_response(
229253
client,
230254
model_id,
231255
system_prompt,
232-
available_shields,
256+
available_input_shields,
257+
available_output_shields,
233258
query_request.conversation_id,
234259
)
235260

src/app/endpoints/streaming_query.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import logging
55
import re
6+
from json import JSONDecodeError
67
from typing import Any, AsyncIterator
78

89
from cachetools import TTLCache # type: ignore
@@ -29,6 +30,8 @@
2930
from app.endpoints.conversations import conversation_id_to_agent_id
3031
from app.endpoints.query import (
3132
get_rag_toolgroups,
33+
is_input_shield,
34+
is_output_shield,
3235
is_transcripts_enabled,
3336
store_transcript,
3437
select_model_id,
@@ -43,11 +46,12 @@
4346
_agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600)
4447

4548

46-
async def get_agent(
49+
async def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments
4750
client: AsyncLlamaStackClient,
4851
model_id: str,
4952
system_prompt: str,
50-
available_shields: list[str],
53+
available_input_shields: list[str],
54+
available_output_shields: list[str],
5155
conversation_id: str | None,
5256
) -> tuple[AsyncAgent, str]:
5357
"""Get existing agent or create a new one with session persistence."""
@@ -62,7 +66,8 @@ async def get_agent(
6266
client, # type: ignore[arg-type]
6367
model=model_id,
6468
instructions=system_prompt,
65-
input_shields=available_shields if available_shields else [],
69+
input_shields=available_input_shields if available_input_shields else [],
70+
output_shields=available_output_shields if available_output_shields else [],
6671
tool_parser=GraniteToolParser.get_parser(model_id),
6772
enable_session_persistence=True,
6873
)
@@ -166,8 +171,14 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | N
166171
for match in METADATA_PATTERN.findall(
167172
text_content_item.text
168173
):
169-
meta = json.loads(match.replace("'", '"'))
170-
metadata_map[meta["document_id"]] = meta
174+
try:
175+
meta = json.loads(match.replace("'", '"'))
176+
metadata_map[meta["document_id"]] = meta
177+
except JSONDecodeError:
178+
logger.debug(
179+
"JSONDecodeError was thrown in processing %s",
180+
match,
181+
)
171182
if chunk.event.payload.step_details.tool_calls:
172183
tool_name = str(
173184
chunk.event.payload.step_details.tool_calls[0].tool_name
@@ -268,12 +279,22 @@ async def retrieve_response(
268279
mcp_headers: dict[str, dict[str, str]] | None = None,
269280
) -> tuple[Any, str]:
270281
"""Retrieve response from LLMs and agents."""
271-
available_shields = [shield.identifier for shield in await client.shields.list()]
272-
if not available_shields:
282+
available_input_shields = [
283+
shield.identifier
284+
for shield in filter(is_input_shield, await client.shields.list())
285+
]
286+
available_output_shields = [
287+
shield.identifier
288+
for shield in filter(is_output_shield, await client.shields.list())
289+
]
290+
if not available_input_shields and not available_output_shields:
273291
logger.info("No available shields. Disabling safety")
274292
else:
275-
logger.info("Available shields found: %s", available_shields)
276-
293+
logger.info(
294+
"Available input shields: %s, output shields: %s",
295+
available_input_shields,
296+
available_output_shields,
297+
)
277298
# use system prompt from request or default one
278299
system_prompt = get_system_prompt(query_request, configuration)
279300
logger.debug("Using system prompt: %s", system_prompt)
@@ -287,7 +308,8 @@ async def retrieve_response(
287308
client,
288309
model_id,
289310
system_prompt,
290-
available_shields,
311+
available_input_shields,
312+
available_output_shields,
291313
query_request.conversation_id,
292314
)
293315

0 commit comments

Comments
 (0)