From e9f273a5292c1389462d2e54e919ac035dc78d6f Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Wed, 5 Mar 2025 13:38:00 +0100 Subject: [PATCH 1/9] feat: update messages endpoint to return a conversation summary Modify the messages endpoint to return just a conversationsummary, that will simplify the current queries. Create a different endpoint that will return a list of conversations for a given prompt id --- src/codegate/api/v1.py | 110 +++++++++++++++--- src/codegate/api/v1_models.py | 21 ++++ src/codegate/config.py | 3 + src/codegate/db/connection.py | 204 ++++++++++++++++++++-------------- src/codegate/db/models.py | 15 +++ 5 files changed, 254 insertions(+), 99 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 33efea33e..8e3f374ac 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -3,16 +3,17 @@ import requests import structlog -from fastapi import APIRouter, Depends, HTTPException, Response +from fastapi import APIRouter, Depends, HTTPException, Query, Response from fastapi.responses import StreamingResponse from fastapi.routing import APIRoute from pydantic import BaseModel, ValidationError +from codegate.config import API_DEFAULT_PAGE_SIZE, API_MAX_PAGE_SIZE import codegate.muxing.models as mux_models from codegate import __version__ from codegate.api import v1_models, v1_processing from codegate.db.connection import AlreadyExistsError, DbReader -from codegate.db.models import AlertSeverity, Persona, WorkspaceWithModel +from codegate.db.models import AlertSeverity, AlertTriggerType, Persona, WorkspaceWithModel from codegate.muxing.persona import ( PersonaDoesNotExistError, PersonaManager, @@ -443,11 +444,11 @@ async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSu raise HTTPException(status_code=500, detail="Internal server error") try: - summary = await dbreader.get_alerts_summary_by_workspace(ws.id) + summary = await dbreader.get_alerts_summary(workspace_id=ws.id) return v1_models.AlertSummary( - malicious_packages=summary["codegate_context_retriever_count"], - pii=summary["codegate_pii_count"], - secrets=summary["codegate_secrets_count"], + malicious_packages=summary.total_packages_count, + pii=summary.total_pii_count, + secrets=summary.total_secrets_count, ) except Exception: logger.exception("Error while getting alerts summary") @@ -459,7 +460,13 @@ async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSu tags=["Workspaces"], generate_unique_id_function=uniq_name, ) -async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversation]: +async def get_workspace_messages( + workspace_name: str, + page: int = Query(1, ge=1), + page_size: int = Query(API_DEFAULT_PAGE_SIZE, ge=1, le=API_MAX_PAGE_SIZE), + filter_by_ids: Optional[List[str]] = Query(None), + filter_by_alert_trigger_types: Optional[List[AlertTriggerType]] = Query(None), +) -> v1_models.PaginatedMessagesResponse: """Get messages for a workspace.""" try: ws = await wscrud.get_workspace_by_name(workspace_name) @@ -469,20 +476,89 @@ async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversa logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") - try: - prompts_with_output_alerts_usage = ( - await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id( - ws.id, AlertSeverity.CRITICAL.value - ) + offset = (page - 1) * page_size + + prompts = await dbreader.get_prompts( + ws.id, + offset, + page_size, + filter_by_ids, + list([AlertSeverity.CRITICAL.value]), # TODO: Configurable severity + filter_by_alert_trigger_types, + ) + # Fetch total message count + total_count = await dbreader.get_total_messages_count_by_workspace_id( + ws.id, AlertSeverity.CRITICAL.value + ) + + # iterate for all prompts to compose the conversation summary + conversation_summaries: List[v1_models.ConversationSummary] = [] + for prompt in prompts: + if not prompt.request: + logger.warning(f"Skipping prompt {prompt.id}. Empty request field") + continue + + messages, _ = await v1_processing.parse_request(prompt.request) + if not messages or len(messages) == 0: + logger.warning(f"Skipping prompt {prompt.id}. No messages found") + continue + + # message is just the first entry in the request + message_obj = v1_models.ChatMessage( + message=messages[0], timestamp=prompt.timestamp, message_id=prompt.id ) - conversations, _ = await v1_processing.parse_messages_in_conversations( - prompts_with_output_alerts_usage + + # count total alerts for the prompt + total_alerts_row = await dbreader.get_alerts_summary(prompt_id=prompt.id) + + # get token usage for the prompt + prompts_outputs = await dbreader.get_prompts_with_output(prompt_id=prompt.id) + ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs) + + conversation_summary = v1_models.ConversationSummary( + chat_id=prompt.id, + prompt=message_obj, + provider=prompt.provider, + type=prompt.type, + conversation_timestamp=prompt.timestamp, + total_alerts=total_alerts_row.total_alerts, + token_usage_agg=ws_token_usage, ) - return conversations + + conversation_summaries.append(conversation_summary) + + return v1_models.PaginatedMessagesResponse( + data=conversation_summaries, + limit=page_size, + offset=(page - 1) * page_size, + total=total_count, + ) + + +@v1.get( + "/workspaces/{workspace_name}/messages/{prompt_id}", + tags=["Workspaces"], + generate_unique_id_function=uniq_name, +) +async def get_messages_by_prompt_id( + workspace_name: str, + prompt_id: str, +) -> List[v1_models.Conversation]: + """Get messages for a workspace.""" + try: + ws = await wscrud.get_workspace_by_name(workspace_name) + except crud.WorkspaceDoesNotExistError: + raise HTTPException(status_code=404, detail="Workspace does not exist") except Exception: - logger.exception("Error while getting messages") + logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") + prompts_outputs = await dbreader.get_prompts_with_output( + workspace_id=ws.id, prompt_id=prompt_id + ) + conversations, _ = await v1_processing.parse_messages_in_conversations(prompts_outputs) + return conversations + @v1.get( "/workspaces/{workspace_name}/custom-instructions", @@ -665,7 +741,7 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage raise HTTPException(status_code=500, detail="Internal server error") try: - prompts_outputs = await dbreader.get_prompts_with_output(ws.id) + prompts_outputs = await dbreader.get_prompts_with_output(worskpace_id=ws.id) ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs) return ws_token_usage except Exception: diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index dff26489e..20091f092 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -218,6 +218,20 @@ class Conversation(pydantic.BaseModel): alerts: List[Alert] = [] +class ConversationSummary(pydantic.BaseModel): + """ + Represents a conversation summary. + """ + + chat_id: str + prompt: ChatMessage + total_alerts: int + token_usage_agg: Optional[TokenUsageAggregate] + provider: Optional[str] + type: QuestionType + conversation_timestamp: datetime.datetime + + class AlertConversation(pydantic.BaseModel): """ Represents an alert with it's respective conversation. @@ -333,3 +347,10 @@ class PersonaUpdateRequest(pydantic.BaseModel): new_name: str new_description: str + + +class PaginatedMessagesResponse(pydantic.BaseModel): + data: List[ConversationSummary] + limit: int + offset: int + total: int diff --git a/src/codegate/config.py b/src/codegate/config.py index 179ec4d34..ee5cb1689 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -25,6 +25,9 @@ "llamacpp": "./codegate_volume/models", # Default LlamaCpp model path } +API_DEFAULT_PAGE_SIZE = 50 +API_MAX_PAGE_SIZE = 100 + @dataclass class Config: diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 3f439aeaa..8d2d240a1 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -12,17 +12,20 @@ from alembic import command as alembic_command from alembic.config import Config as AlembicConfig from pydantic import BaseModel -from sqlalchemy import CursorResult, TextClause, event, text +from sqlalchemy import CursorResult, TextClause, bindparam, event, text from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker +from codegate.config import API_DEFAULT_PAGE_SIZE from codegate.db.fim_cache import FimCache from codegate.db.models import ( ActiveWorkspace, Alert, - GetPromptWithOutputsRow, + AlertSummaryRow, + AlertTriggerType, + GetMessagesRow, GetWorkspaceByNameConditions, Instance, IntermediatePromptWithOutputUsageAlerts, @@ -685,7 +688,11 @@ async def _exec_vec_db_query_to_pydantic( conn.close() return results - async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithOutputsRow]: + async def get_prompts_with_output( + self, workspace_id: Optional[str] = None, prompt_id: Optional[str] = None + ) -> List[GetMessagesRow]: + if not workspace_id and not prompt_id: + raise ValueError("Either workspace_id or prompt_id must be provided.") sql = text( """ SELECT @@ -699,77 +706,94 @@ async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithO o.output_cost FROM prompts p LEFT JOIN outputs o ON p.id = o.prompt_id - WHERE p.workspace_id = :workspace_id + WHERE (:workspace_id IS NULL OR p.workspace_id = :workspace_id) + AND (:prompt_id IS NULL OR p.id = :prompt_id) ORDER BY o.timestamp DESC """ ) - conditions = {"workspace_id": workpace_id} + conditions = {"workspace_id": workspace_id, "prompt_id": prompt_id} prompts = await self._exec_select_conditions_to_pydantic( - GetPromptWithOutputsRow, sql, conditions, should_raise=True + GetMessagesRow, sql, conditions, should_raise=True ) return prompts - async def get_prompts_with_output_alerts_usage_by_workspace_id( - self, workspace_id: str, trigger_category: Optional[str] = None - ) -> List[GetPromptWithOutputsRow]: + async def get_prompts( + self, + workspace_id: str, + offset: int = 0, + page_size: int = API_DEFAULT_PAGE_SIZE, + filter_by_ids: Optional[List[str]] = None, + filter_by_alert_trigger_categories: Optional[List[str]] = None, + filter_by_alert_trigger_types: Optional[List[str]] = None, + ) -> List[Prompt]: """ - Get all prompts with their outputs, alerts and token usage by workspace_id. + Retrieve prompts with filtering and pagination. + + Args: + workspace_id: The ID of the workspace to fetch prompts from + offset: Number of records to skip (for pagination) + page_size: Number of records per page + filter_by_ids: Optional list of prompt IDs to filter by + filter_by_alert_trigger_categories: Optional list of alert categories to filter by + filter_by_alert_trigger_types: Optional list of alert trigger types to filter by + + Returns: + List of Prompt containing prompt details """ - - sql = text( - """ - SELECT - p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type, - o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost, - a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp - FROM prompts p - LEFT JOIN outputs o ON p.id = o.prompt_id + # Build base query + base_query = """ + SELECT DISTINCT p.id, p.timestamp, p.provider, p.request, p.type, p.workspace_id FROM prompts p LEFT JOIN alerts a ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id - AND (a.trigger_category = :trigger_category OR a.trigger_category is NULL) - ORDER BY o.timestamp DESC, a.timestamp DESC - """ # noqa: E501 - ) - # If trigger category is None we want to get all alerts - trigger_category = trigger_category if trigger_category else "%" - conditions = {"workspace_id": workspace_id, "trigger_category": trigger_category} - rows: List[IntermediatePromptWithOutputUsageAlerts] = ( - await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + {filter_conditions} + ORDER BY p.timestamp DESC + LIMIT :page_size OFFSET :offset + """ + # Build conditions and filters + conditions = { + "workspace_id": workspace_id, + "page_size": page_size, + "offset": offset, + } + + # Conditionally add filter clauses and conditions + filter_conditions = [] + + if filter_by_alert_trigger_categories: + filter_conditions.append( + "AND a.trigger_category IN :filter_by_alert_trigger_categories" ) - ) - prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} - for row in rows: - prompt_id = row.prompt_id - if prompt_id not in prompts_dict: - prompts_dict[prompt_id] = GetPromptWithOutputsRow( - id=row.prompt_id, - timestamp=row.prompt_timestamp, - provider=row.provider, - request=row.request, - type=row.type, - output_id=row.output_id, - output=row.output, - output_timestamp=row.output_timestamp, - input_tokens=row.input_tokens, - output_tokens=row.output_tokens, - input_cost=row.input_cost, - output_cost=row.output_cost, - alerts=[], - ) - if row.alert_id: - alert = Alert( - id=row.alert_id, - prompt_id=row.prompt_id, - code_snippet=row.code_snippet, - trigger_string=row.trigger_string, - trigger_type=row.trigger_type, - trigger_category=row.trigger_category, - timestamp=row.alert_timestamp, - ) - prompts_dict[prompt_id].alerts.append(alert) + conditions["filter_by_alert_trigger_categories"] = filter_by_alert_trigger_categories + + if filter_by_alert_trigger_types: + filter_conditions.append( + "AND EXISTS (SELECT 1 FROM alerts a2 WHERE a2.prompt_id = p.id AND a2.trigger_type IN :filter_by_alert_trigger_types)" # noqa: E501 + ) + conditions["filter_by_alert_trigger_types"] = filter_by_alert_trigger_types + + if filter_by_ids: + filter_conditions.append("AND p.id IN :filter_by_ids") + conditions["filter_by_ids"] = filter_by_ids - return list(prompts_dict.values()) + filter_clause = " ".join(filter_conditions) + query = base_query.format(filter_conditions=filter_clause) + + sql = text(query) + + # Bind optional params + + if filter_by_alert_trigger_categories: + sql = sql.bindparams(bindparam("filter_by_alert_trigger_categories", expanding=True)) + if filter_by_alert_trigger_types: + sql = sql.bindparams(bindparam("filter_by_alert_trigger_types", expanding=True)) + if filter_by_ids: + sql = sql.bindparams(bindparam("filter_by_ids", expanding=True)) + + # Execute query + rows = await self._exec_select_conditions_to_pydantic( + Prompt, sql, conditions, should_raise=True + ) + return rows async def get_alerts_by_workspace( self, workspace_id: str, trigger_category: Optional[str] = None @@ -802,37 +826,53 @@ async def get_alerts_by_workspace( ) return prompts - async def get_alerts_summary_by_workspace(self, workspace_id: str) -> dict: - """Get aggregated alert summary counts for a given workspace_id.""" + async def get_alerts_summary( + self, workspace_id: str = None, prompt_id: str = None + ) -> AlertSummaryRow: + """Get aggregated alert summary counts for a given workspace_id or prompt id.""" + if not workspace_id and not prompt_id: + raise ValueError("Either workspace_id or prompt_id must be provided.") + + filters = [] + conditions = {} + + if workspace_id: + filters.append("p.workspace_id = :workspace_id") + conditions["workspace_id"] = workspace_id + + if prompt_id: + filters.append("a.prompt_id = :prompt_id") + conditions["prompt_id"] = prompt_id + + filter_clause = " AND ".join(filters) + sql = text( - """ + f""" SELECT - COUNT(*) AS total_alerts, - SUM(CASE WHEN a.trigger_type = 'codegate-secrets' THEN 1 ELSE 0 END) - AS codegate_secrets_count, - SUM(CASE WHEN a.trigger_type = 'codegate-context-retriever' THEN 1 ELSE 0 END) - AS codegate_context_retriever_count, - SUM(CASE WHEN a.trigger_type = 'codegate-pii' THEN 1 ELSE 0 END) - AS codegate_pii_count + COUNT(*) AS total_alerts, + SUM(CASE WHEN a.trigger_type = '{AlertTriggerType.CODEGATE_SECRETS.value}' THEN 1 ELSE 0 END) + AS codegate_secrets_count, + SUM(CASE WHEN a.trigger_type = '{AlertTriggerType.CODEGATE_CONTEXT_RETRIEVER.value}' THEN 1 ELSE 0 END) + AS codegate_context_retriever_count, + SUM(CASE WHEN a.trigger_type = '{AlertTriggerType.CODEGATE_PII.value}' THEN 1 ELSE 0 END) + AS codegate_pii_count FROM alerts a INNER JOIN prompts p ON p.id = a.prompt_id - WHERE p.workspace_id = :workspace_id - """ + WHERE {filter_clause} + """ # noqa: E501 # nosec ) - conditions = {"workspace_id": workspace_id} - async with self._async_db_engine.begin() as conn: result = await conn.execute(sql, conditions) row = result.fetchone() # Return a dictionary with counts (handling None values safely) - return { - "codegate_secrets_count": row.codegate_secrets_count or 0 if row else 0, - "codegate_context_retriever_count": ( - row.codegate_context_retriever_count or 0 if row else 0 - ), - "codegate_pii_count": row.codegate_pii_count or 0 if row else 0, - } + + return AlertSummaryRow( + total_alerts=row.total_alerts or 0 if row else 0, + total_secrets_count=row.codegate_secrets_count or 0 if row else 0, + total_packages_count=row.codegate_context_retriever_count or 0 if row else 0, + total_pii_count=row.codegate_pii_count or 0 if row else 0, + ) async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]: sql = text( diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 07c4c8edf..a76994faf 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -115,6 +115,21 @@ class WorkspaceRow(BaseModel): custom_instructions: Optional[str] +class AlertSummaryRow(BaseModel): + """An alert summary row entry""" + + total_alerts: int + total_secrets_count: int + total_packages_count: int + total_pii_count: int + + +class AlertTriggerType(str, Enum): + CODEGATE_PII = "codegate-pii" + CODEGATE_CONTEXT_RETRIEVER = "codegate-context-retriever" + CODEGATE_SECRETS = "codegate-secrets" + + class GetWorkspaceByNameConditions(BaseModel): name: WorkspaceNameStr From 3798440d2142bd70a49facbf8aea92374f76a661 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Fri, 7 Mar 2025 10:43:47 +0100 Subject: [PATCH 2/9] fixes from rebase --- src/codegate/db/connection.py | 33 +++++++++++++++++++++++++++++++-- src/codegate/db/models.py | 16 ++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 8d2d240a1..7c5739f68 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -4,7 +4,7 @@ import sqlite3 import uuid from pathlib import Path -from typing import Dict, List, Optional, Type +from typing import List, Optional, Type import numpy as np import sqlite_vec_sl_tmp @@ -28,7 +28,6 @@ GetMessagesRow, GetWorkspaceByNameConditions, Instance, - IntermediatePromptWithOutputUsageAlerts, MuxRule, Output, Persona, @@ -795,6 +794,36 @@ async def get_prompts( ) return rows + async def get_total_messages_count_by_workspace_id( + self, workspace_id: str, trigger_category: Optional[str] = None + ) -> int: + """ + Get total count of unique messages for a given workspace_id, + considering trigger_category. + """ + sql = text( + """ + SELECT COUNT(DISTINCT p.id) + FROM prompts p + LEFT JOIN alerts a ON p.id = a.prompt_id + WHERE p.workspace_id = :workspace_id + """ + ) + conditions = {"workspace_id": workspace_id} + + if trigger_category: + sql = text(sql.text + " AND a.trigger_category = :trigger_category") + conditions["trigger_category"] = trigger_category + + async with self._async_db_engine.begin() as conn: + try: + result = await conn.execute(sql, conditions) + count = result.scalar() # Fetches the integer result directly + return count or 0 # Ensure it returns an integer + except Exception as e: + logger.error(f"Failed to fetch message count. Error: {e}") + return 0 # Return 0 in case of failure + async def get_alerts_by_workspace( self, workspace_id: str, trigger_category: Optional[str] = None ) -> List[Alert]: diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index a76994faf..beeef8718 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -337,3 +337,19 @@ class PersonaDistance(Persona): """ distance: float + + +class GetMessagesRow(BaseModel): + id: Any + timestamp: Any + provider: Optional[Any] + request: Any + type: Any + output_id: Optional[Any] + output: Optional[Any] + output_timestamp: Optional[Any] + input_tokens: Optional[int] + output_tokens: Optional[int] + input_cost: Optional[float] + output_cost: Optional[float] + alerts: List[Alert] = [] From 28f37dc9397accb0ad6c05bf0e60e50324f0e16d Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Fri, 7 Mar 2025 10:49:05 +0100 Subject: [PATCH 3/9] fix lint --- src/codegate/db/connection.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 7c5739f68..1bce74511 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -741,12 +741,13 @@ async def get_prompts( """ # Build base query base_query = """ - SELECT DISTINCT p.id, p.timestamp, p.provider, p.request, p.type, p.workspace_id FROM prompts p + SELECT DISTINCT p.id, p.timestamp, p.provider, p.request, p.type, + p.workspace_id FROM prompts p LEFT JOIN alerts a ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id {filter_conditions} ORDER BY p.timestamp DESC - LIMIT :page_size OFFSET :offset + LIMIT :page_size OFFSET :offset """ # Build conditions and filters conditions = { From 1917726eacace9450c1c8a6cae05b9466d748e52 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Mon, 10 Mar 2025 09:36:30 +0100 Subject: [PATCH 4/9] changes from review --- src/codegate/api/v1.py | 114 +++++++++++++++++++--------------- src/codegate/api/v1_models.py | 3 +- 2 files changed, 67 insertions(+), 50 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 8e3f374ac..ddd6e5adf 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -449,6 +449,7 @@ async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSu malicious_packages=summary.total_packages_count, pii=summary.total_pii_count, secrets=summary.total_secrets_count, + total_alerts=summary.total_alerts, ) except Exception: logger.exception("Error while getting alerts summary") @@ -477,60 +478,74 @@ async def get_workspace_messages( raise HTTPException(status_code=500, detail="Internal server error") offset = (page - 1) * page_size + valid_conversations: List[v1_models.ConversationSummary] = [] + fetched_prompts = 0 + + while len(valid_conversations) < page_size: + batch_size = page_size * 2 # Fetch more prompts to compensate for potential skips + + prompts = await dbreader.get_prompts( + ws.id, + offset + fetched_prompts, + batch_size, + filter_by_ids, + list([AlertSeverity.CRITICAL.value]), # TODO: Configurable severity + filter_by_alert_trigger_types, + ) + + # iterate for all prompts to compose the conversation summary + for prompt in prompts: + fetched_prompts += 1 + if not prompt.request: + logger.warning(f"Skipping prompt {prompt.id}. Empty request field") + continue + + messages, _ = await v1_processing.parse_request(prompt.request) + if not messages or len(messages) == 0: + logger.warning(f"Skipping prompt {prompt.id}. No messages found") + continue + + # message is just the first entry in the request + message_obj = v1_models.ChatMessage( + message=messages[0], timestamp=prompt.timestamp, message_id=prompt.id + ) + + # count total alerts for the prompt + total_alerts_row = await dbreader.get_alerts_summary(prompt_id=prompt.id) + + # get token usage for the prompt + prompts_outputs = await dbreader.get_prompts_with_output(prompt_id=prompt.id) + ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs) + + conversation_summary = v1_models.ConversationSummary( + chat_id=prompt.id, + prompt=message_obj, + provider=prompt.provider, + type=prompt.type, + conversation_timestamp=prompt.timestamp, + alerts_summary=v1_models.AlertSummary( + malicious_packages=total_alerts_row.total_packages_count, + pii=total_alerts_row.total_pii_count, + secrets=total_alerts_row.total_secrets_count, + total_alerts=total_alerts_row.total_alerts, + ), + total_alerts=total_alerts_row.total_alerts, + token_usage_agg=ws_token_usage, + ) + + valid_conversations.append(conversation_summary) + if len(valid_conversations) >= page_size: + break - prompts = await dbreader.get_prompts( - ws.id, - offset, - page_size, - filter_by_ids, - list([AlertSeverity.CRITICAL.value]), # TODO: Configurable severity - filter_by_alert_trigger_types, - ) # Fetch total message count total_count = await dbreader.get_total_messages_count_by_workspace_id( ws.id, AlertSeverity.CRITICAL.value ) - # iterate for all prompts to compose the conversation summary - conversation_summaries: List[v1_models.ConversationSummary] = [] - for prompt in prompts: - if not prompt.request: - logger.warning(f"Skipping prompt {prompt.id}. Empty request field") - continue - - messages, _ = await v1_processing.parse_request(prompt.request) - if not messages or len(messages) == 0: - logger.warning(f"Skipping prompt {prompt.id}. No messages found") - continue - - # message is just the first entry in the request - message_obj = v1_models.ChatMessage( - message=messages[0], timestamp=prompt.timestamp, message_id=prompt.id - ) - - # count total alerts for the prompt - total_alerts_row = await dbreader.get_alerts_summary(prompt_id=prompt.id) - - # get token usage for the prompt - prompts_outputs = await dbreader.get_prompts_with_output(prompt_id=prompt.id) - ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs) - - conversation_summary = v1_models.ConversationSummary( - chat_id=prompt.id, - prompt=message_obj, - provider=prompt.provider, - type=prompt.type, - conversation_timestamp=prompt.timestamp, - total_alerts=total_alerts_row.total_alerts, - token_usage_agg=ws_token_usage, - ) - - conversation_summaries.append(conversation_summary) - return v1_models.PaginatedMessagesResponse( - data=conversation_summaries, + data=valid_conversations, limit=page_size, - offset=(page - 1) * page_size, + offset=offset, total=total_count, ) @@ -543,7 +558,7 @@ async def get_workspace_messages( async def get_messages_by_prompt_id( workspace_name: str, prompt_id: str, -) -> List[v1_models.Conversation]: +) -> v1_models.Conversation: """Get messages for a workspace.""" try: ws = await wscrud.get_workspace_by_name(workspace_name) @@ -552,12 +567,13 @@ async def get_messages_by_prompt_id( except Exception: logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") - prompts_outputs = await dbreader.get_prompts_with_output( workspace_id=ws.id, prompt_id=prompt_id ) conversations, _ = await v1_processing.parse_messages_in_conversations(prompts_outputs) - return conversations + if not conversations: + raise HTTPException(status_code=404, detail="Conversation not found") + return conversations[0] @v1.get( diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index 20091f092..7eb594c20 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -191,6 +191,7 @@ class AlertSummary(pydantic.BaseModel): malicious_packages: int pii: int secrets: int + total_alerts: int class PartialQuestionAnswer(pydantic.BaseModel): @@ -225,7 +226,7 @@ class ConversationSummary(pydantic.BaseModel): chat_id: str prompt: ChatMessage - total_alerts: int + alerts_summary: AlertSummary token_usage_agg: Optional[TokenUsageAggregate] provider: Optional[str] type: QuestionType From 9227ac8421e8d140a9f655ca8bae50e0e43037d6 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Mon, 10 Mar 2025 10:21:44 +0100 Subject: [PATCH 5/9] fixes from review --- src/codegate/api/v1.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index ddd6e5adf..0cbedd624 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -493,6 +493,9 @@ async def get_workspace_messages( filter_by_alert_trigger_types, ) + if not prompts or len(prompts) == 0: + break + # iterate for all prompts to compose the conversation summary for prompt in prompts: fetched_prompts += 1 @@ -757,7 +760,7 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage raise HTTPException(status_code=500, detail="Internal server error") try: - prompts_outputs = await dbreader.get_prompts_with_output(worskpace_id=ws.id) + prompts_outputs = await dbreader.get_prompts_with_output(workspace_id=ws.id) ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs) return ws_token_usage except Exception: From a1a3efe995188d9f67b256d47e237223c07437ec Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Mon, 10 Mar 2025 12:22:04 +0100 Subject: [PATCH 6/9] fix pagination --- src/codegate/api/v1.py | 7 ++++-- src/codegate/db/connection.py | 43 +++++++++++++++++++++++++++++------ 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 0cbedd624..a23f53839 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -489,7 +489,7 @@ async def get_workspace_messages( offset + fetched_prompts, batch_size, filter_by_ids, - list([AlertSeverity.CRITICAL.value]), # TODO: Configurable severity + list([AlertSeverity.CRITICAL.value]), filter_by_alert_trigger_types, ) @@ -542,7 +542,10 @@ async def get_workspace_messages( # Fetch total message count total_count = await dbreader.get_total_messages_count_by_workspace_id( - ws.id, AlertSeverity.CRITICAL.value + ws.id, + filter_by_ids, + list([AlertSeverity.CRITICAL.value]), + filter_by_alert_trigger_types, ) return v1_models.PaginatedMessagesResponse( diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 1bce74511..a24866683 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -796,25 +796,54 @@ async def get_prompts( return rows async def get_total_messages_count_by_workspace_id( - self, workspace_id: str, trigger_category: Optional[str] = None + self, + workspace_id: str, + filter_by_ids: Optional[List[str]] = None, + filter_by_alert_trigger_categories: Optional[List[str]] = None, + filter_by_alert_trigger_types: Optional[List[str]] = None, ) -> int: """ Get total count of unique messages for a given workspace_id, considering trigger_category. """ - sql = text( - """ + base_query = """ SELECT COUNT(DISTINCT p.id) FROM prompts p LEFT JOIN alerts a ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id + {filter_conditions} """ - ) conditions = {"workspace_id": workspace_id} + filter_conditions = [] - if trigger_category: - sql = text(sql.text + " AND a.trigger_category = :trigger_category") - conditions["trigger_category"] = trigger_category + if filter_by_alert_trigger_categories: + filter_conditions.append( + "AND a.trigger_category IN :filter_by_alert_trigger_categories" + ) + conditions["filter_by_alert_trigger_categories"] = filter_by_alert_trigger_categories + + if filter_by_alert_trigger_types: + filter_conditions.append( + "AND EXISTS (SELECT 1 FROM alerts a2 WHERE " + "a2.prompt_id = p.id AND a2.trigger_type IN :filter_by_alert_trigger_types)" + ) + conditions["filter_by_alert_trigger_types"] = filter_by_alert_trigger_types + + if filter_by_ids: + filter_conditions.append("AND p.id IN :filter_by_ids") + conditions["filter_by_ids"] = filter_by_ids + + filter_clause = " ".join(filter_conditions) + query = base_query.format(filter_conditions=filter_clause) + sql = text(query) + + # Bind optional params + if filter_by_alert_trigger_categories: + sql = sql.bindparams(bindparam("filter_by_alert_trigger_categories", expanding=True)) + if filter_by_alert_trigger_types: + sql = sql.bindparams(bindparam("filter_by_alert_trigger_types", expanding=True)) + if filter_by_ids: + sql = sql.bindparams(bindparam("filter_by_ids", expanding=True)) async with self._async_db_engine.begin() as conn: try: From 0403f14c414cb544a5a682ace1cd6b1f406fff06 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Mon, 10 Mar 2025 15:11:05 +0100 Subject: [PATCH 7/9] decouple alerts from question/answer --- src/codegate/api/v1.py | 15 +++++++++++++-- src/codegate/api/v1_models.py | 3 +-- src/codegate/api/v1_processing.py | 10 +--------- src/codegate/db/connection.py | 11 +++++++++-- src/codegate/db/models.py | 1 - 5 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index a23f53839..f1742f421 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -420,7 +420,9 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.A raise HTTPException(status_code=500, detail="Internal server error") try: - alerts = await dbreader.get_alerts_by_workspace(ws.id, AlertSeverity.CRITICAL.value) + alerts = await dbreader.get_alerts_by_workspace_or_prompt_id( + workspace_id=ws.id, trigger_category=AlertSeverity.CRITICAL.value + ) prompts_outputs = await dbreader.get_prompts_with_output(ws.id) return await v1_processing.parse_get_alert_conversation(alerts, prompts_outputs) except Exception: @@ -576,10 +578,19 @@ async def get_messages_by_prompt_id( prompts_outputs = await dbreader.get_prompts_with_output( workspace_id=ws.id, prompt_id=prompt_id ) + + # get all alerts for the prompt + alerts = await dbreader.get_alerts_by_workspace_or_prompt_id( + workspace_id=ws.id, prompt_id=prompt_id, trigger_category=AlertSeverity.CRITICAL.value + ) + deduped_alerts = await v1_processing.remove_duplicate_alerts(alerts) conversations, _ = await v1_processing.parse_messages_in_conversations(prompts_outputs) if not conversations: raise HTTPException(status_code=404, detail="Conversation not found") - return conversations[0] + + conversation = conversations[0] + conversation.alerts = deduped_alerts + return conversation @v1.get( diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index 7eb594c20..6489f96d6 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -202,7 +202,6 @@ class PartialQuestionAnswer(pydantic.BaseModel): partial_questions: PartialQuestions answer: Optional[ChatMessage] model_token_usage: TokenUsageByModel - alerts: List[Alert] = [] class Conversation(pydantic.BaseModel): @@ -216,7 +215,7 @@ class Conversation(pydantic.BaseModel): chat_id: str conversation_timestamp: datetime.datetime token_usage_agg: Optional[TokenUsageAggregate] - alerts: List[Alert] = [] + alerts: Optional[List[Alert]] = [] class ConversationSummary(pydantic.BaseModel): diff --git a/src/codegate/api/v1_processing.py b/src/codegate/api/v1_processing.py index 10f42075b..07e5acaa6 100644 --- a/src/codegate/api/v1_processing.py +++ b/src/codegate/api/v1_processing.py @@ -202,15 +202,10 @@ async def _get_partial_question_answer( model=model, token_usage=token_usage, provider_type=provider ) - alerts: List[v1_models.Alert] = [ - v1_models.Alert.from_db_model(db_alert) for db_alert in row.alerts - ] - return PartialQuestionAnswer( partial_questions=request_message, answer=output_message, model_token_usage=model_token_usage, - alerts=alerts, ) @@ -374,7 +369,7 @@ async def match_conversations( for group in grouped_partial_questions: questions_answers: List[QuestionAnswer] = [] token_usage_agg = TokenUsageAggregate(tokens_by_model={}, token_usage=TokenUsage()) - alerts: List[v1_models.Alert] = [] + first_partial_qa = None for partial_question in sorted(group, key=lambda x: x.timestamp): # Partial questions don't contain the answer, so we need to find the corresponding @@ -398,8 +393,6 @@ async def match_conversations( qa = _get_question_answer_from_partial(selected_partial_qa) qa.question.message = parse_question_answer(qa.question.message) questions_answers.append(qa) - deduped_alerts = await remove_duplicate_alerts(selected_partial_qa.alerts) - alerts.extend(deduped_alerts) token_usage_agg.add_model_token_usage(selected_partial_qa.model_token_usage) # if we have a conversation with at least one question and answer @@ -413,7 +406,6 @@ async def match_conversations( chat_id=first_partial_qa.partial_questions.message_id, conversation_timestamp=first_partial_qa.partial_questions.timestamp, token_usage_agg=token_usage_agg, - alerts=alerts, ) for qa in questions_answers: map_q_id_to_conversation[qa.question.message_id] = conversation diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index a24866683..fb102766f 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -854,8 +854,11 @@ async def get_total_messages_count_by_workspace_id( logger.error(f"Failed to fetch message count. Error: {e}") return 0 # Return 0 in case of failure - async def get_alerts_by_workspace( - self, workspace_id: str, trigger_category: Optional[str] = None + async def get_alerts_by_workspace_or_prompt_id( + self, + workspace_id: str, + prompt_id: Optional[str] = None, + trigger_category: Optional[str] = None, ) -> List[Alert]: sql = text( """ @@ -874,6 +877,10 @@ async def get_alerts_by_workspace( ) conditions = {"workspace_id": workspace_id} + if prompt_id: + sql = text(sql.text + " AND a.prompt_id = :prompt_id") + conditions["prompt_id"] = prompt_id + if trigger_category: sql = text(sql.text + " AND a.trigger_category = :trigger_category") conditions["trigger_category"] = trigger_category diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index beeef8718..7f8ef4348 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -352,4 +352,3 @@ class GetMessagesRow(BaseModel): output_tokens: Optional[int] input_cost: Optional[float] output_cost: Optional[float] - alerts: List[Alert] = [] From e96638982f4242ab786c071be49ca9fe8f4d95e7 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Tue, 11 Mar 2025 11:12:24 +0100 Subject: [PATCH 8/9] fix querying prompts without alerts --- src/codegate/db/connection.py | 123 ++++++++++++++++++++-------------- 1 file changed, 72 insertions(+), 51 deletions(-) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index fb102766f..02a587e84 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -4,7 +4,7 @@ import sqlite3 import uuid from pathlib import Path -from typing import List, Optional, Type +from typing import List, Optional, Tuple, Type import numpy as np import sqlite_vec_sl_tmp @@ -716,6 +716,61 @@ async def get_prompts_with_output( ) return prompts + def _build_prompt_query( + self, + base_query: str, + workspace_id: str, + filter_by_ids: Optional[List[str]] = None, + filter_by_alert_trigger_categories: Optional[List[str]] = None, + filter_by_alert_trigger_types: Optional[List[str]] = None, + offset: Optional[int] = None, + page_size: Optional[int] = None, + ) -> Tuple[str, dict]: + """ + Helper method to construct SQL query and conditions for prompts based on filters. + + Args: + base_query: The base SQL query string with a placeholder for filter conditions. + workspace_id: The ID of the workspace to fetch prompts from. + filter_by_ids: Optional list of prompt IDs to filter by. + filter_by_alert_trigger_categories: Optional list of alert categories to filter by. + filter_by_alert_trigger_types: Optional list of alert trigger types to filter by. + offset: Number of records to skip (for pagination). + page_size: Number of records per page. + + Returns: + A tuple containing the formatted SQL query string and a dictionary of conditions. + """ + conditions = {"workspace_id": workspace_id} + filter_conditions = [] + + if filter_by_alert_trigger_categories: + filter_conditions.append( + "AND (a.trigger_category IN :filter_by_alert_trigger_categories OR a.trigger_category IS NULL)" + ) + conditions["filter_by_alert_trigger_categories"] = filter_by_alert_trigger_categories + + if filter_by_alert_trigger_types: + filter_conditions.append( + "AND EXISTS (SELECT 1 FROM alerts a2 WHERE a2.prompt_id = p.id AND a2.trigger_type IN :filter_by_alert_trigger_types)" + ) + conditions["filter_by_alert_trigger_types"] = filter_by_alert_trigger_types + + if filter_by_ids: + filter_conditions.append("AND p.id IN :filter_by_ids") + conditions["filter_by_ids"] = filter_by_ids + + if offset is not None: + conditions["offset"] = offset + + if page_size is not None: + conditions["page_size"] = page_size + + filter_clause = " ".join(filter_conditions) + query = base_query.format(filter_conditions=filter_clause) + + return query, conditions + async def get_prompts( self, workspace_id: str, @@ -749,39 +804,19 @@ async def get_prompts( ORDER BY p.timestamp DESC LIMIT :page_size OFFSET :offset """ - # Build conditions and filters - conditions = { - "workspace_id": workspace_id, - "page_size": page_size, - "offset": offset, - } - - # Conditionally add filter clauses and conditions - filter_conditions = [] - - if filter_by_alert_trigger_categories: - filter_conditions.append( - "AND a.trigger_category IN :filter_by_alert_trigger_categories" - ) - conditions["filter_by_alert_trigger_categories"] = filter_by_alert_trigger_categories - - if filter_by_alert_trigger_types: - filter_conditions.append( - "AND EXISTS (SELECT 1 FROM alerts a2 WHERE a2.prompt_id = p.id AND a2.trigger_type IN :filter_by_alert_trigger_types)" # noqa: E501 - ) - conditions["filter_by_alert_trigger_types"] = filter_by_alert_trigger_types - - if filter_by_ids: - filter_conditions.append("AND p.id IN :filter_by_ids") - conditions["filter_by_ids"] = filter_by_ids - - filter_clause = " ".join(filter_conditions) - query = base_query.format(filter_conditions=filter_clause) + query, conditions = self._build_prompt_query( + base_query, + workspace_id, + filter_by_ids, + filter_by_alert_trigger_categories, + filter_by_alert_trigger_types, + offset, + page_size, + ) sql = text(query) # Bind optional params - if filter_by_alert_trigger_categories: sql = sql.bindparams(bindparam("filter_by_alert_trigger_categories", expanding=True)) if filter_by_alert_trigger_types: @@ -813,28 +848,14 @@ async def get_total_messages_count_by_workspace_id( WHERE p.workspace_id = :workspace_id {filter_conditions} """ - conditions = {"workspace_id": workspace_id} - filter_conditions = [] - - if filter_by_alert_trigger_categories: - filter_conditions.append( - "AND a.trigger_category IN :filter_by_alert_trigger_categories" - ) - conditions["filter_by_alert_trigger_categories"] = filter_by_alert_trigger_categories - - if filter_by_alert_trigger_types: - filter_conditions.append( - "AND EXISTS (SELECT 1 FROM alerts a2 WHERE " - "a2.prompt_id = p.id AND a2.trigger_type IN :filter_by_alert_trigger_types)" - ) - conditions["filter_by_alert_trigger_types"] = filter_by_alert_trigger_types - if filter_by_ids: - filter_conditions.append("AND p.id IN :filter_by_ids") - conditions["filter_by_ids"] = filter_by_ids - - filter_clause = " ".join(filter_conditions) - query = base_query.format(filter_conditions=filter_clause) + query, conditions = self._build_prompt_query( + base_query, + workspace_id, + filter_by_ids, + filter_by_alert_trigger_categories, + filter_by_alert_trigger_types, + ) sql = text(query) # Bind optional params From 90bb9e2a954dc7a4a67d2f2b9904560ee02bdf9a Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Tue, 11 Mar 2025 11:37:15 +0100 Subject: [PATCH 9/9] clean message in list --- src/codegate/api/v1.py | 5 +++-- src/codegate/db/connection.py | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index f1742f421..5db3ae549 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -510,9 +510,10 @@ async def get_workspace_messages( logger.warning(f"Skipping prompt {prompt.id}. No messages found") continue - # message is just the first entry in the request + # message is just the first entry in the request, cleaned properly + message = v1_processing.parse_question_answer(messages[0]) message_obj = v1_models.ChatMessage( - message=messages[0], timestamp=prompt.timestamp, message_id=prompt.id + message=message, timestamp=prompt.timestamp, message_id=prompt.id ) # count total alerts for the prompt diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 02a587e84..915c4251c 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -746,13 +746,15 @@ def _build_prompt_query( if filter_by_alert_trigger_categories: filter_conditions.append( - "AND (a.trigger_category IN :filter_by_alert_trigger_categories OR a.trigger_category IS NULL)" + """AND (a.trigger_category IN :filter_by_alert_trigger_categories + OR a.trigger_category IS NULL)""" ) conditions["filter_by_alert_trigger_categories"] = filter_by_alert_trigger_categories if filter_by_alert_trigger_types: filter_conditions.append( - "AND EXISTS (SELECT 1 FROM alerts a2 WHERE a2.prompt_id = p.id AND a2.trigger_type IN :filter_by_alert_trigger_types)" + """AND EXISTS (SELECT 1 FROM alerts a2 WHERE + a2.prompt_id = p.id AND a2.trigger_type IN :filter_by_alert_trigger_types)""" ) conditions["filter_by_alert_trigger_types"] = filter_by_alert_trigger_types