diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index f9d4b00b..ad349e56 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -3,7 +3,7 @@ import requests import structlog -from fastapi import APIRouter, Depends, HTTPException, Response +from fastapi import APIRouter, Depends, HTTPException, Query, Response from fastapi.responses import StreamingResponse from fastapi.routing import APIRoute from pydantic import BaseModel, ValidationError @@ -11,8 +11,9 @@ import codegate.muxing.models as mux_models from codegate import __version__ from codegate.api import v1_models, v1_processing +from codegate.config import API_DEFAULT_PAGE_SIZE, API_MAX_PAGE_SIZE from codegate.db.connection import AlreadyExistsError, DbReader -from codegate.db.models import AlertSeverity, WorkspaceWithModel +from codegate.db.models import AlertSeverity, AlertTriggerType, WorkspaceWithModel from codegate.providers import crud as provendcrud from codegate.workspaces import crud @@ -429,7 +430,13 @@ async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSu tags=["Workspaces"], generate_unique_id_function=uniq_name, ) -async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversation]: +async def get_workspace_messages( + workspace_name: str, + page: int = Query(1, ge=1), + page_size: int = Query(API_DEFAULT_PAGE_SIZE, ge=1, le=API_MAX_PAGE_SIZE), + filter_by_ids: Optional[List[str]] = Query(None), + filter_by_alert_trigger_types: Optional[List[AlertTriggerType]] = Query(None), +) -> v1_models.PaginatedMessagesResponse: """Get messages for a workspace.""" try: ws = await wscrud.get_workspace_by_name(workspace_name) @@ -439,16 +446,40 @@ async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversa logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") + offset = (page - 1) * page_size + fetched_messages: List[v1_models.Conversation] = [] + try: - prompts_with_output_alerts_usage = ( - await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id( - ws.id, AlertSeverity.CRITICAL.value + while len(fetched_messages) < page_size: + messages_batch = await dbreader.get_messages( + ws.id, + offset, + page_size, + filter_by_ids, + list([AlertSeverity.CRITICAL.value]), # TODO: Configurable severity + filter_by_alert_trigger_types, + ) + if not messages_batch: + break + parsed_conversations, _ = await v1_processing.parse_messages_in_conversations( + messages_batch ) + fetched_messages.extend(parsed_conversations) + + offset += len(messages_batch) + + final_messages = fetched_messages[:page_size] + + # Fetch total message count + total_count = await dbreader.get_total_messages_count_by_workspace_id( + ws.id, AlertSeverity.CRITICAL.value ) - conversations, _ = await v1_processing.parse_messages_in_conversations( - prompts_with_output_alerts_usage + return v1_models.PaginatedMessagesResponse( + data=final_messages, + limit=page_size, + offset=(page - 1) * page_size, + total=total_count, ) - return conversations except Exception: logger.exception("Error while getting messages") raise HTTPException(status_code=500, detail="Internal server error") diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index 51f65ea9..fd427cba 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -322,3 +322,10 @@ class ModelByProvider(pydantic.BaseModel): def __str__(self): return f"{self.provider_name} / {self.name}" + + +class PaginatedMessagesResponse(pydantic.BaseModel): + data: List[Conversation] + limit: int + offset: int + total: int diff --git a/src/codegate/api/v1_processing.py b/src/codegate/api/v1_processing.py index 10f42075..0f5f013d 100644 --- a/src/codegate/api/v1_processing.py +++ b/src/codegate/api/v1_processing.py @@ -21,7 +21,7 @@ TokenUsageByModel, ) from codegate.db.connection import alert_queue -from codegate.db.models import Alert, GetPromptWithOutputsRow, TokenUsage +from codegate.db.models import Alert, GetMessagesRow, TokenUsage logger = structlog.get_logger("codegate") @@ -152,7 +152,7 @@ def _parse_single_output(single_output: dict) -> str: async def _get_partial_question_answer( - row: GetPromptWithOutputsRow, + row: GetMessagesRow, ) -> Optional[PartialQuestionAnswer]: """ Parse a row from the get_prompt_with_outputs query and return a PartialConversation @@ -423,7 +423,7 @@ async def match_conversations( async def _process_prompt_output_to_partial_qa( - prompts_outputs: List[GetPromptWithOutputsRow], + prompts_outputs: List[GetMessagesRow], ) -> List[PartialQuestionAnswer]: """ Process the prompts and outputs to PartialQuestionAnswer objects. @@ -435,7 +435,7 @@ async def _process_prompt_output_to_partial_qa( async def parse_messages_in_conversations( - prompts_outputs: List[GetPromptWithOutputsRow], + prompts_outputs: List[GetMessagesRow], ) -> Tuple[List[Conversation], Dict[str, Conversation]]: """ Get all the messages from the database and return them as a list of conversations. @@ -477,7 +477,7 @@ async def parse_row_alert_conversation( async def parse_get_alert_conversation( alerts: List[Alert], - prompts_outputs: List[GetPromptWithOutputsRow], + prompts_outputs: List[GetMessagesRow], ) -> List[AlertConversation]: """ Parse a list of rows from the get_alerts_with_prompt_and_output query and return a list of @@ -496,7 +496,7 @@ async def parse_get_alert_conversation( async def parse_workspace_token_usage( - prompts_outputs: List[GetPromptWithOutputsRow], + prompts_outputs: List[GetMessagesRow], ) -> TokenUsageAggregate: """ Parse the token usage from the workspace. @@ -515,7 +515,6 @@ async def remove_duplicate_alerts(alerts: List[v1_models.Alert]) -> List[v1_mode for alert in sorted( alerts, key=lambda x: x.timestamp, reverse=True ): # Sort alerts by timestamp descending - # Handle trigger_string based on its type trigger_string_content = "" if isinstance(alert.trigger_string, dict): diff --git a/src/codegate/config.py b/src/codegate/config.py index 761ca09e..1167a9d8 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -25,6 +25,9 @@ "llamacpp": "./codegate_volume/models", # Default LlamaCpp model path } +API_DEFAULT_PAGE_SIZE = 50 +API_MAX_PAGE_SIZE = 100 + @dataclass class Config: diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 38bf6010..56ba16bb 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -11,7 +11,7 @@ from alembic import command as alembic_command from alembic.config import Config as AlembicConfig from pydantic import BaseModel -from sqlalchemy import CursorResult, TextClause, event, text +from sqlalchemy import CursorResult, TextClause, bindparam, event, text from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import create_async_engine @@ -20,9 +20,10 @@ from codegate.db.models import ( ActiveWorkspace, Alert, - GetPromptWithOutputsRow, + AlertTriggerType, + GetMessagesRow, GetWorkspaceByNameConditions, - IntermediatePromptWithOutputUsageAlerts, + IntermediateMessagesRow, MuxRule, Output, Persona, @@ -629,7 +630,7 @@ async def _exec_vec_db_query_to_pydantic( conn.close() return results - async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithOutputsRow]: + async def get_prompts_with_output(self, workpace_id: str) -> List[GetMessagesRow]: sql = text( """ SELECT @@ -649,44 +650,120 @@ async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithO ) conditions = {"workspace_id": workpace_id} prompts = await self._exec_select_conditions_to_pydantic( - GetPromptWithOutputsRow, sql, conditions, should_raise=True + GetMessagesRow, sql, conditions, should_raise=True ) return prompts - async def get_prompts_with_output_alerts_usage_by_workspace_id( - self, workspace_id: str, trigger_category: Optional[str] = None - ) -> List[GetPromptWithOutputsRow]: + async def get_messages( + self, + workspace_id: str, + offset: int = 0, + page_size: int = 20, + filter_by_ids: Optional[List[str]] = None, + filter_by_alert_trigger_categories: Optional[List[str]] = None, + filter_by_alert_trigger_types: Optional[List[str]] = None, + ) -> List[GetMessagesRow]: """ - Get all prompts with their outputs, alerts and token usage by workspace_id. + Retrieve prompts with their associated outputs and alerts, with filtering and pagination. + + Args: + workspace_id: The ID of the workspace to fetch prompts from + offset: Number of records to skip (for pagination) + page_size: Number of records per page + filter_by_ids: Optional list of prompt IDs to filter by + filter_by_alert_trigger_categories: Optional list of alert categories to filter by + filter_by_alert_trigger_types: Optional list of alert trigger types to filter by + + Returns: + List of GetPromptWithOutputsRow containing prompts, outputs, and alerts """ - - sql = text( - """ + # Build base query + base_query = """ + WITH filtered_prompts AS ( + SELECT distinct p.id, p.timestamp, p.provider, p.request, p.type + FROM prompts p + LEFT JOIN alerts a ON p.id = a.prompt_id + WHERE p.workspace_id = :workspace_id + {filter_conditions} + ORDER BY p.timestamp DESC + LIMIT :page_size OFFSET :offset + ) SELECT - p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type, - o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost, - a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp - FROM prompts p + p.id as prompt_id, + p.timestamp as prompt_timestamp, + p.provider, + p.request, + p.type, + o.id as output_id, + o.output, + o.timestamp as output_timestamp, + o.input_tokens, + o.output_tokens, + o.input_cost, + o.output_cost, + a.id as alert_id, + a.code_snippet, + a.trigger_string, + a.trigger_type, + a.trigger_category, + a.timestamp as alert_timestamp + FROM filtered_prompts p LEFT JOIN outputs o ON p.id = o.prompt_id LEFT JOIN alerts a ON p.id = a.prompt_id - WHERE p.workspace_id = :workspace_id - AND (a.trigger_category = :trigger_category OR a.trigger_category is NULL) - ORDER BY o.timestamp DESC, a.timestamp DESC - """ # noqa: E501 - ) - # If trigger category is None we want to get all alerts - trigger_category = trigger_category if trigger_category else "%" - conditions = {"workspace_id": workspace_id, "trigger_category": trigger_category} - rows: List[IntermediatePromptWithOutputUsageAlerts] = ( - await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + """ + + # Build conditions and filters + conditions = { + "workspace_id": workspace_id, + "page_size": page_size, + "offset": offset, + } + + # Conditionally add filter clauses and conditions + + filter_conditions = [] + + if filter_by_alert_trigger_categories: + filter_conditions.append( + "AND a.trigger_category IN :filter_by_alert_trigger_categories" ) + conditions["filter_by_alert_trigger_categories"] = filter_by_alert_trigger_categories + + if filter_by_alert_trigger_types: + filter_conditions.append( + "AND EXISTS (SELECT 1 FROM alerts a2 WHERE a2.prompt_id = p.id AND a2.trigger_type IN :filter_by_alert_trigger_types)" # noqa: E501 + ) + conditions["filter_by_alert_trigger_types"] = filter_by_alert_trigger_types + + if filter_by_ids: + filter_conditions.append("AND p.id IN :filter_by_ids") + conditions["filter_by_ids"] = filter_by_ids + + filter_clause = " ".join(filter_conditions) + query = base_query.format(filter_conditions=filter_clause) + + sql = text(query) + + # Bind optional params + + if filter_by_alert_trigger_categories: + sql = sql.bindparams(bindparam("filter_by_alert_trigger_categories", expanding=True)) + if filter_by_alert_trigger_types: + sql = sql.bindparams(bindparam("filter_by_alert_trigger_types", expanding=True)) + if filter_by_ids: + sql = sql.bindparams(bindparam("filter_by_ids", expanding=True)) + + # Execute query + rows = await self._exec_select_conditions_to_pydantic( + IntermediateMessagesRow, sql, conditions, should_raise=True ) - prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} + + # Process results + prompts_dict: Dict[str, GetMessagesRow] = {} + for row in rows: - prompt_id = row.prompt_id - if prompt_id not in prompts_dict: - prompts_dict[prompt_id] = GetPromptWithOutputsRow( + if row.prompt_id not in prompts_dict: + prompts_dict[row.prompt_id] = GetMessagesRow( id=row.prompt_id, timestamp=row.prompt_timestamp, provider=row.provider, @@ -701,6 +778,7 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( output_cost=row.output_cost, alerts=[], ) + if row.alert_id: alert = Alert( id=row.alert_id, @@ -711,10 +789,41 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( trigger_category=row.trigger_category, timestamp=row.alert_timestamp, ) - prompts_dict[prompt_id].alerts.append(alert) + if alert not in prompts_dict[row.prompt_id].alerts: + prompts_dict[row.prompt_id].alerts.append(alert) return list(prompts_dict.values()) + async def get_total_messages_count_by_workspace_id( + self, workspace_id: str, trigger_category: Optional[str] = None + ) -> int: + """ + Get total count of unique messages for a given workspace_id, + considering trigger_category. + """ + sql = text( + """ + SELECT COUNT(DISTINCT p.id) + FROM prompts p + LEFT JOIN alerts a ON p.id = a.prompt_id + WHERE p.workspace_id = :workspace_id + """ + ) + conditions = {"workspace_id": workspace_id} + + if trigger_category: + sql = text(sql.text + " AND a.trigger_category = :trigger_category") + conditions["trigger_category"] = trigger_category + + async with self._async_db_engine.begin() as conn: + try: + result = await conn.execute(sql, conditions) + count = result.scalar() # Fetches the integer result directly + return count or 0 # Ensure it returns an integer + except Exception as e: + logger.error(f"Failed to fetch message count. Error: {e}") + return 0 # Return 0 in case of failure + async def get_alerts_by_workspace( self, workspace_id: str, trigger_category: Optional[str] = None ) -> List[Alert]: @@ -749,19 +858,19 @@ async def get_alerts_by_workspace( async def get_alerts_summary_by_workspace(self, workspace_id: str) -> dict: """Get aggregated alert summary counts for a given workspace_id.""" sql = text( - """ + f""" SELECT - COUNT(*) AS total_alerts, - SUM(CASE WHEN a.trigger_type = 'codegate-secrets' THEN 1 ELSE 0 END) - AS codegate_secrets_count, - SUM(CASE WHEN a.trigger_type = 'codegate-context-retriever' THEN 1 ELSE 0 END) - AS codegate_context_retriever_count, - SUM(CASE WHEN a.trigger_type = 'codegate-pii' THEN 1 ELSE 0 END) - AS codegate_pii_count + COUNT(*) AS total_alerts, + SUM(CASE WHEN a.trigger_type = '{AlertTriggerType.CODEGATE_SECRETS.value}' THEN 1 ELSE 0 END) + AS codegate_secrets_count, + SUM(CASE WHEN a.trigger_type = '{AlertTriggerType.CODEGATE_CONTEXT_RETRIEVER.value}' THEN 1 ELSE 0 END) + AS codegate_context_retriever_count, + SUM(CASE WHEN a.trigger_type = '{AlertTriggerType.CODEGATE_PII.value}' THEN 1 ELSE 0 END) + AS codegate_pii_count FROM alerts a INNER JOIN prompts p ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id - """ + """ # noqa: E501 # nosec ) conditions = {"workspace_id": workspace_id} diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index a5941e96..63422d96 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -11,6 +11,12 @@ class AlertSeverity(str, Enum): CRITICAL = "critical" +class AlertTriggerType(str, Enum): + CODEGATE_PII = "codegate-pii" + CODEGATE_CONTEXT_RETRIEVER = "codegate-context-retriever" + CODEGATE_SECRETS = "codegate-secrets" + + class Alert(BaseModel): id: str prompt_id: str @@ -137,7 +143,7 @@ class ProviderType(str, Enum): openrouter = "openrouter" -class IntermediatePromptWithOutputUsageAlerts(BaseModel): +class IntermediateMessagesRow(BaseModel): """ An intermediate model to represent the result of a query for a prompt and related outputs, usage stats & alerts. @@ -163,7 +169,7 @@ class IntermediatePromptWithOutputUsageAlerts(BaseModel): alert_timestamp: Optional[Any] -class GetPromptWithOutputsRow(BaseModel): +class GetMessagesRow(BaseModel): id: Any timestamp: Any provider: Optional[Any] diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index e22874a6..96dce865 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -5,7 +5,7 @@ from litellm import ChatCompletionRequest from codegate.clients.clients import ClientType -from codegate.db.models import AlertSeverity +from codegate.db.models import AlertSeverity, AlertTriggerType from codegate.extract_snippets.factory import MessageCodeExtractorFactory from codegate.pipeline.base import ( PipelineContext, @@ -36,7 +36,7 @@ def name(self) -> str: """ Returns the name of this pipeline step. """ - return "codegate-context-retriever" + return AlertTriggerType.CODEGATE_CONTEXT_RETRIEVER.value def generate_context_str( self, objects: list[object], context: PipelineContext, snippet_map: dict diff --git a/src/codegate/pipeline/pii/analyzer.py b/src/codegate/pipeline/pii/analyzer.py index 706deb9b..bed67b17 100644 --- a/src/codegate/pipeline/pii/analyzer.py +++ b/src/codegate/pipeline/pii/analyzer.py @@ -4,6 +4,7 @@ from presidio_analyzer import AnalyzerEngine from presidio_anonymizer import AnonymizerEngine +from codegate.db.models import AlertTriggerType from codegate.pipeline.base import PipelineContext from codegate.pipeline.sensitive_data.session_store import SessionStore @@ -30,7 +31,7 @@ class PiiAnalyzer: """ _instance: Optional["PiiAnalyzer"] = None - _name = "codegate-pii" + _name = AlertTriggerType.CODEGATE_PII.value @classmethod def get_instance(cls) -> "PiiAnalyzer": diff --git a/src/codegate/pipeline/pii/pii.py b/src/codegate/pipeline/pii/pii.py index d7f33d67..c2d8b0ae 100644 --- a/src/codegate/pipeline/pii/pii.py +++ b/src/codegate/pipeline/pii/pii.py @@ -6,7 +6,7 @@ from litellm.types.utils import Delta, StreamingChoices from codegate.config import Config -from codegate.db.models import AlertSeverity +from codegate.db.models import AlertSeverity, AlertTriggerType from codegate.pipeline.base import ( PipelineContext, PipelineResult, @@ -51,7 +51,7 @@ def __init__(self, sensitive_data_manager: SensitiveDataManager): @property def name(self) -> str: - return "codegate-pii" + return AlertTriggerType.CODEGATE_PII.value def _get_redacted_snippet(self, message: str, pii_details: List[Dict[str, Any]]) -> str: # If no PII found, return empty string @@ -419,7 +419,7 @@ async def process_chunk( # TODO: Might want to check these with James! notification_text = ( f"🛡️ [CodeGate protected {redacted_count} instances of PII, including {pii_summary}]" - f"(http://localhost:9090/?search=codegate-pii) from being leaked " + f"(http://localhost:9090/?search={AlertTriggerType.CODEGATE_PII.value}) from being leaked " # noqa: E501 f"by redacting them.\n\n" ) diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index c299469e..99443340 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -7,7 +7,7 @@ from litellm.types.utils import Delta, StreamingChoices from codegate.config import Config -from codegate.db.models import AlertSeverity +from codegate.db.models import AlertSeverity, AlertTriggerType from codegate.extract_snippets.factory import MessageCodeExtractorFactory from codegate.pipeline.base import ( CodeSnippet, @@ -178,7 +178,7 @@ def __init__( self._sensitive_data_manager = sensitive_data_manager self._session_id = session_id self._context = context - self._name = "codegate-secrets" + self._name = AlertTriggerType.CODEGATE_SECRETS.value super().__init__() @@ -255,7 +255,7 @@ def name(self) -> str: Returns: str: The identifier 'codegate-secrets'. """ - return "codegate-secrets" + return AlertTriggerType.CODEGATE_SECRETS.value def _redact_text( self, @@ -556,7 +556,7 @@ async def process_chunk( notification_chunk = self._create_chunk( chunk, f"\n🛡️ [CodeGate prevented {redacted_count} {secret_text}]" - f"(http://localhost:9090/?search=codegate-secrets) from being leaked " + f"(http://localhost:9090/?search={AlertTriggerType.CODEGATE_SECRETS.value}) from being leaked " # noqa: E501 f"by redacting them.\n\n", ) notification_chunk.choices[0].delta.role = "assistant" @@ -564,7 +564,7 @@ async def process_chunk( notification_chunk = self._create_chunk( chunk, f"\n🛡️ [CodeGate prevented {redacted_count} {secret_text}]" - f"(http://localhost:9090/?search=codegate-secrets) from being leaked " + f"(http://localhost:9090/?search={AlertTriggerType.CODEGATE_SECRETS.value}) from being leaked " # noqa: E501 f"by redacting them.\n\n", ) diff --git a/tests/api/test_v1_processing.py b/tests/api/test_v1_processing.py index ad8ffcbd..fc3ee363 100644 --- a/tests/api/test_v1_processing.py +++ b/tests/api/test_v1_processing.py @@ -13,7 +13,7 @@ parse_request, remove_duplicate_alerts, ) -from codegate.db.models import GetPromptWithOutputsRow +from codegate.db.models import AlertTriggerType, GetMessagesRow @pytest.mark.asyncio @@ -147,7 +147,7 @@ async def test_parse_output(output_dict, expected_str): @pytest.mark.parametrize( "row", [ - GetPromptWithOutputsRow( + GetMessagesRow( id="1", timestamp=timestamp_now, provider="openai", @@ -446,7 +446,7 @@ def test_group_partial_messages(pq_list, expected_group_ids): prompt_id="p1", code_snippet=None, trigger_string="secret1 Context xyz", - trigger_type="codegate-secrets", + trigger_type=AlertTriggerType.CODEGATE_SECRETS.value, trigger_category="critical", timestamp=datetime.datetime(2023, 1, 1, 12, 0, 0), ), @@ -455,7 +455,7 @@ def test_group_partial_messages(pq_list, expected_group_ids): prompt_id="p2", code_snippet=None, trigger_string="secret1 Context abc", - trigger_type="codegate-secrets", + trigger_type=AlertTriggerType.CODEGATE_SECRETS.value, trigger_category="critical", timestamp=datetime.datetime(2023, 1, 1, 12, 0, 3), ), @@ -471,7 +471,7 @@ def test_group_partial_messages(pq_list, expected_group_ids): prompt_id="p1", code_snippet=None, trigger_string="secret1 Context xyz", - trigger_type="codegate-secrets", + trigger_type=AlertTriggerType.CODEGATE_SECRETS.value, trigger_category="critical", timestamp=datetime.datetime(2023, 1, 1, 12, 0, 0), ), @@ -480,7 +480,7 @@ def test_group_partial_messages(pq_list, expected_group_ids): prompt_id="p2", code_snippet=None, trigger_string="secret1 Context abc", - trigger_type="codegate-secrets", + trigger_type=AlertTriggerType.CODEGATE_SECRETS.value, trigger_category="critical", timestamp=datetime.datetime(2023, 1, 1, 12, 0, 6), ), @@ -496,7 +496,7 @@ def test_group_partial_messages(pq_list, expected_group_ids): prompt_id="p1", code_snippet=None, trigger_string="secret1 Context xyz", - trigger_type="codegate-secrets", + trigger_type=AlertTriggerType.CODEGATE_SECRETS.value, trigger_category="critical", timestamp=datetime.datetime(2023, 1, 1, 12, 0, 0), ), @@ -514,7 +514,7 @@ def test_group_partial_messages(pq_list, expected_group_ids): prompt_id="p3", code_snippet=None, trigger_string="secret1 Context abc", - trigger_type="codegate-secrets", + trigger_type=AlertTriggerType.CODEGATE_SECRETS.value, trigger_category="critical", timestamp=datetime.datetime(2023, 1, 1, 12, 0, 3), ), diff --git a/tests/pipeline/pii/test_pi.py b/tests/pipeline/pii/test_pi.py index 06d2881f..4f6fadf5 100644 --- a/tests/pipeline/pii/test_pi.py +++ b/tests/pipeline/pii/test_pi.py @@ -4,6 +4,7 @@ from litellm import ChatCompletionRequest, ModelResponse from litellm.types.utils import Delta, StreamingChoices +from codegate.db.models import AlertTriggerType from codegate.pipeline.base import PipelineContext, PipelineSensitiveData from codegate.pipeline.output import OutputPipelineContext from codegate.pipeline.pii.pii import CodegatePii, PiiRedactionNotifier, PiiUnRedactionStep @@ -25,7 +26,7 @@ def pii_step(self): return CodegatePii(mock_sensitive_data_manager) def test_name(self, pii_step): - assert pii_step.name == "codegate-pii" + assert pii_step.name == AlertTriggerType.CODEGATE_PII.value def test_get_redacted_snippet_no_pii(self, pii_step): message = "Hello world"