diff --git a/prompts/default.yaml b/prompts/default.yaml index 1b20ca00..aa70cda0 100644 --- a/prompts/default.yaml +++ b/prompts/default.yaml @@ -50,3 +50,7 @@ red_team: "You are a red team member conducting a security assessment. Identify # BlueTeam prompts blue_team: "You are a blue team member conducting a security assessment. Identify security controls, misconfigurations, and potential vulnerabilities." + +# Per client prompts +client_prompts: + kodu: "If malicious packages or leaked secrets are found, please end the task, sending the problems found embedded in tags" diff --git a/src/codegate/pipeline/factory.py b/src/codegate/pipeline/factory.py index 0fdd66c4..3d0bb326 100644 --- a/src/codegate/pipeline/factory.py +++ b/src/codegate/pipeline/factory.py @@ -29,7 +29,9 @@ def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelinePr CodegateSecrets(), CodegateCli(), CodegateContextRetriever(), - SystemPrompt(Config.get_config().prompts.default_chat), + SystemPrompt( + Config.get_config().prompts.default_chat, Config.get_config().prompts.client_prompts + ), ] return SequentialPipelineProcessor( input_steps, diff --git a/src/codegate/pipeline/system_prompt/codegate.py b/src/codegate/pipeline/system_prompt/codegate.py index 76bcf9d1..0dbf39a8 100644 --- a/src/codegate/pipeline/system_prompt/codegate.py +++ b/src/codegate/pipeline/system_prompt/codegate.py @@ -1,5 +1,6 @@ from typing import Optional +from codegate.clients.clients import ClientType from litellm import ChatCompletionRequest, ChatCompletionSystemMessage from codegate.pipeline.base import ( @@ -16,8 +17,9 @@ class SystemPrompt(PipelineStep): the word "codegate" in the user message. """ - def __init__(self, system_prompt: str): + def __init__(self, system_prompt: str, client_prompts: dict[str]): self.codegate_system_prompt = system_prompt + self.client_prompts = client_prompts @property def name(self) -> str: @@ -36,6 +38,7 @@ async def _get_workspace_custom_instructions(self) -> str: async def _construct_system_prompt( self, + client: ClientType, wrksp_custom_instr: str, req_sys_prompt: Optional[str], should_add_codegate_sys_prompt: bool, @@ -59,6 +62,10 @@ def _start_or_append(existing_prompt: str, new_prompt: str) -> str: if req_sys_prompt and "codegate" not in req_sys_prompt.lower(): system_prompt = _start_or_append(system_prompt, req_sys_prompt) + # Add per client system prompt + if client and client.value in self.client_prompts: + system_prompt = _start_or_append(system_prompt, self.client_prompts[client.value]) + return system_prompt async def _should_add_codegate_system_prompt(self, context: PipelineContext) -> bool: @@ -92,7 +99,10 @@ async def process( req_sys_prompt = request_system_message.get("content") system_prompt = await self._construct_system_prompt( - wrksp_custom_instructions, req_sys_prompt, should_add_codegate_sys_prompt + context.client, + wrksp_custom_instructions, + req_sys_prompt, + should_add_codegate_sys_prompt, ) context.add_alert(self.name, trigger_string=system_prompt) if not request_system_message: diff --git a/src/codegate/prompts.py b/src/codegate/prompts.py index 63405a08..6629382c 100644 --- a/src/codegate/prompts.py +++ b/src/codegate/prompts.py @@ -41,11 +41,19 @@ def from_file(cls, prompt_path: Union[str, Path]) -> "PromptConfig": if not isinstance(prompt_data, dict): raise ConfigurationError("Prompts file must contain a YAML dictionary") - # Validate all values are strings - for key, value in prompt_data.items(): - if not isinstance(value, str): - raise ConfigurationError(f"Prompt '{key}' must be a string, got {type(value)}") - + def validate_prompts(data, parent_key=""): + """Recursively validate prompt values.""" + for key, value in data.items(): + full_key = f"{parent_key}.{key}" if parent_key else key + if isinstance(value, dict): + validate_prompts(value, full_key) # Recurse into nested dictionaries + elif not isinstance(value, str): + raise ConfigurationError( + f"Prompt '{full_key}' must be a string, got {type(value)}" + ) + + # Validate the entire structure + validate_prompts(prompt_data) return cls(prompts=prompt_data) except yaml.YAMLError as e: raise ConfigurationError(f"Failed to parse prompts file: {e}") diff --git a/tests/pipeline/system_prompt/test_system_prompt.py b/tests/pipeline/system_prompt/test_system_prompt.py index 6ea36a93..c9d1937d 100644 --- a/tests/pipeline/system_prompt/test_system_prompt.py +++ b/tests/pipeline/system_prompt/test_system_prompt.py @@ -13,7 +13,7 @@ def test_init_with_system_message(self): Test initialization with a system message """ test_message = "Test system prompt" - step = SystemPrompt(system_prompt=test_message) + step = SystemPrompt(system_prompt=test_message, client_prompts={}) assert step.codegate_system_prompt == test_message @pytest.mark.asyncio @@ -28,7 +28,7 @@ async def test_process_system_prompt_insertion(self): # Create system prompt step system_prompt = "Security analysis system prompt" - step = SystemPrompt(system_prompt=system_prompt) + step = SystemPrompt(system_prompt=system_prompt, client_prompts={}) step._get_workspace_custom_instructions = AsyncMock(return_value="") # Mock the get_last_user_message method @@ -62,7 +62,7 @@ async def test_process_system_prompt_update(self): # Create system prompt step system_prompt = "Security analysis system prompt" - step = SystemPrompt(system_prompt=system_prompt) + step = SystemPrompt(system_prompt=system_prompt, client_prompts={}) step._get_workspace_custom_instructions = AsyncMock(return_value="") # Mock the get_last_user_message method @@ -97,7 +97,7 @@ async def test_edge_cases(self, edge_case): mock_context = Mock(spec=PipelineContext) system_prompt = "Security edge case prompt" - step = SystemPrompt(system_prompt=system_prompt) + step = SystemPrompt(system_prompt=system_prompt, client_prompts={}) step._get_workspace_custom_instructions = AsyncMock(return_value="") # Mock get_last_user_message to return None