From 74ca157b5a3f2abd960ce5cf51607ba6a5898a60 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Mon, 3 Mar 2025 11:24:14 +0100 Subject: [PATCH 01/14] feat: add pagination to alerts and messages endpoints Modify the API to add pagination to those endpoints, to be able to render faster in the browser Closes: #1020 --- src/codegate/api/v1.py | 47 +++++++++++++++++++++++++++-------- src/codegate/config.py | 3 +++ src/codegate/db/connection.py | 42 ++++++++++++++++++++++++------- 3 files changed, 73 insertions(+), 19 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index ebd9be79..aae8c41f 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -1,13 +1,14 @@ -from typing import List, Optional +from typing import Any, Dict, List, Optional from uuid import UUID import requests import structlog -from fastapi import APIRouter, Depends, HTTPException, Response +from fastapi import APIRouter, Depends, HTTPException, Query, Response from fastapi.responses import StreamingResponse from fastapi.routing import APIRoute from pydantic import BaseModel, ValidationError +from codegate.config import API_DEFAULT_PAGE_SIZE, API_MAX_PAGE_SIZE import codegate.muxing.models as mux_models from codegate import __version__ from codegate.api import v1_models, v1_processing @@ -378,7 +379,11 @@ async def hard_delete_workspace(workspace_name: str): tags=["Workspaces"], generate_unique_id_function=uniq_name, ) -async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.AlertConversation]]: +async def get_workspace_alerts( + workspace_name: str, + page: int = Query(1, ge=1), + page_size: int = Query(API_DEFAULT_PAGE_SIZE, get=1, le=API_MAX_PAGE_SIZE), +) -> Dict[str, Any]: """Get alerts for a workspace.""" try: ws = await wscrud.get_workspace_by_name(workspace_name) @@ -388,13 +393,35 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.A logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") - try: - alerts = await dbreader.get_alerts_by_workspace(ws.id, AlertSeverity.CRITICAL.value) - prompts_outputs = await dbreader.get_prompts_with_output(ws.id) - return await v1_processing.parse_get_alert_conversation(alerts, prompts_outputs) - except Exception: - logger.exception("Error while getting alerts and messages") - raise HTTPException(status_code=500, detail="Internal server error") + total_alerts = 0 + fetched_alerts = [] + offset = (page - 1) * page_size + batch_size = page_size * 2 # fetch more alerts per batch to allow deduplication + + while len(fetched_alerts) < page_size: + alerts_batch, total_alerts = await dbreader.get_alerts_by_workspace( + ws.id, AlertSeverity.CRITICAL.value, page_size, offset + ) + if not alerts_batch: + break + + dedup_alerts = await v1_processing.remove_duplicate_alerts(alerts_batch) + fetched_alerts.extend(dedup_alerts) + offset += batch_size + + final_alerts = fetched_alerts[:page_size] + prompt_ids = list({alert.prompt_id for alert in final_alerts if alert.prompt_id}) + prompts_outputs = await dbreader.get_prompts_with_output(prompt_ids) + alert_conversations = await v1_processing.parse_get_alert_conversation( + final_alerts, prompts_outputs + ) + return { + "page": page, + "page_size": page_size, + "total_alerts": total_alerts, + "total_pages": (total_alerts + page_size - 1) // page_size, + "alerts": alert_conversations, + } @v1.get( diff --git a/src/codegate/config.py b/src/codegate/config.py index 11cd96bf..8f9a15c5 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -25,6 +25,9 @@ "llamacpp": "./codegate_volume/models", # Default LlamaCpp model path } +API_DEFAULT_PAGE_SIZE = 50 +API_MAX_PAGE_SIZE = 100 + @dataclass class Config: diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 2d56fccd..a6a43418 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -2,7 +2,7 @@ import json import uuid from pathlib import Path -from typing import Dict, List, Optional, Type +from typing import Dict, List, Optional, Tuple, Type import structlog from alembic import command as alembic_command @@ -13,6 +13,7 @@ from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import create_async_engine +from codegate.config import API_DEFAULT_PAGE_SIZE from codegate.db.fim_cache import FimCache from codegate.db.models import ( ActiveWorkspace, @@ -569,7 +570,10 @@ async def _exec_select_conditions_to_pydantic( raise e return None - async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithOutputsRow]: + async def get_prompts_with_output(self, prompt_ids: List[str]) -> List[GetPromptWithOutputsRow]: + if not prompt_ids: + return [] + sql = text( """ SELECT @@ -583,11 +587,11 @@ async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithO o.output_cost FROM prompts p LEFT JOIN outputs o ON p.id = o.prompt_id - WHERE p.workspace_id = :workspace_id + WHERE p.id IN :prompt_ids ORDER BY o.timestamp DESC """ ) - conditions = {"workspace_id": workpace_id} + conditions = {"prompt_ids": tuple(prompt_ids)} prompts = await self._exec_select_conditions_to_pydantic( GetPromptWithOutputsRow, sql, conditions, should_raise=True ) @@ -656,8 +660,12 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( return list(prompts_dict.values()) async def get_alerts_by_workspace( - self, workspace_id: str, trigger_category: Optional[str] = None - ) -> List[Alert]: + self, + workspace_id: str, + trigger_category: Optional[str] = None, + limit: int = API_DEFAULT_PAGE_SIZE, + offset: int = 0, + ) -> Tuple[List[Alert], int]: sql = text( """ SELECT @@ -679,12 +687,28 @@ async def get_alerts_by_workspace( sql = text(sql.text + " AND a.trigger_category = :trigger_category") conditions["trigger_category"] = trigger_category - sql = text(sql.text + " ORDER BY a.timestamp DESC") + sql = text(sql.text + " ORDER BY a.timestamp DESC LIMIT :limit OFFSET :offset") + conditions["limit"] = limit + conditions["offset"] = offset - prompts = await self._exec_select_conditions_to_pydantic( + alerts = await self._exec_select_conditions_to_pydantic( Alert, sql, conditions, should_raise=True ) - return prompts + + # Count total alerts for pagination + count_sql = text( + """ + SELECT COUNT(*) + FROM alerts a + INNER JOIN prompts p ON p.id = a.prompt_id + WHERE p.workspace_id = :workspace_id + """ + ) + if trigger_category: + count_sql = text(count_sql.text + " AND a.trigger_category = :trigger_category") + + total_alerts = await self._exec_select_count(count_sql, conditions) + return alerts, total_alerts async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]: sql = text( From bba154c4b5312758da3c9df50e03882eecb50d39 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Mon, 3 Mar 2025 12:07:26 +0100 Subject: [PATCH 02/14] add tests --- src/codegate/workspaces/crud.py | 4 + tests/api/test_v1_api.py | 173 ++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+) create mode 100644 tests/api/test_v1_api.py diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index a81426a8..0d0248ef 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -213,8 +213,12 @@ async def hard_delete_workspace(self, workspace_name: str): return async def get_workspace_by_name(self, workspace_name: str) -> db_models.WorkspaceRow: + print("i get by name") workspace = await self._db_reader.get_workspace_by_name(workspace_name) + print("workspace is") + print(workspace) if not workspace: + print("in not exist") raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") return workspace diff --git a/tests/api/test_v1_api.py b/tests/api/test_v1_api.py new file mode 100644 index 00000000..93a6fa74 --- /dev/null +++ b/tests/api/test_v1_api.py @@ -0,0 +1,173 @@ +import pytest +from unittest.mock import AsyncMock, patch +from fastapi.testclient import TestClient +from fastapi import FastAPI +from codegate.api.v1 import v1 +from codegate.db.models import Alert, AlertSeverity, GetPromptWithOutputsRow +from codegate.workspaces.crud import WorkspaceDoesNotExistError # Import the APIRouter instance + +# Create a FastAPI test app and include the APIRouter +app = FastAPI() +app.include_router(v1) +client = TestClient(app) + + +@pytest.fixture +def mock_ws(): + """Mock workspace object""" + ws = AsyncMock() + ws.id = "test_workspace_id" + return ws + + +@pytest.fixture +def mock_alerts(): + """Mock alerts list""" + return [ + Alert( + id="1", + prompt_id="p1", + code_snippet="code", + trigger_string="error", + trigger_type="type", + trigger_category=AlertSeverity.CRITICAL.value, + timestamp="2024-03-03T12:34:56Z", + ), + Alert( + id="2", + prompt_id="p2", + code_snippet="code2", + trigger_string="error2", + trigger_type="type2", + trigger_category=AlertSeverity.CRITICAL.value, + timestamp="2024-03-03T12:35:56Z", + ), + ] + + +@pytest.fixture +def mock_prompts(): + """Mock prompts output list""" + return [ + GetPromptWithOutputsRow( + id="p1", + timestamp="2024-03-03T12:34:56Z", + provider="provider", + request="req", + type="type", + output_id="o1", + output="output", + output_timestamp="2024-03-03T12:35:56Z", + input_tokens=10, + output_tokens=15, + input_cost=0.01, + output_cost=0.02, + ), + GetPromptWithOutputsRow( + id="p2", + timestamp="2024-03-03T12:36:56Z", + provider="provider2", + request="req2", + type="type2", + output_id="o2", + output="output2", + output_timestamp="2024-03-03T12:37:56Z", + input_tokens=20, + output_tokens=25, + input_cost=0.02, + output_cost=0.03, + ), + ] + + +@pytest.mark.asyncio +async def test_get_workspace_alerts_not_found(): + """Test when workspace does not exist (404 error)""" + with patch( + "codegate.workspaces.crud.WorkspaceCrud.get_workspace_by_name", + side_effect=WorkspaceDoesNotExistError("Workspace does not exist"), + ): + response = client.get("/workspaces/non_existent_workspace/alerts") + assert response.status_code == 404 + assert response.json()["detail"] == "Workspace does not exist" + + +@pytest.mark.asyncio +async def test_get_workspace_alerts_internal_server_error(): + """Test when an internal error occurs (500 error)""" + with patch( + "codegate.workspaces.crud.WorkspaceCrud.get_workspace_by_name", + side_effect=Exception("Unexpected error"), + ): + response = client.get("/workspaces/test_workspace/alerts") + assert response.status_code == 500 + assert response.json()["detail"] == "Internal server error" + + +@pytest.mark.asyncio +async def test_get_workspace_alerts_empty(mock_ws): + """Test when no alerts are found (empty list)""" + with ( + patch("codegate.workspaces.crud.WorkspaceCrud.get_workspace_by_name", return_value=mock_ws), + patch("codegate.db.connection.DbReader.get_alerts_by_workspace", return_value=([], 0)), + ): + + response = client.get("/workspaces/test_workspace/alerts?page=1&page_size=10") + assert response.status_code == 200 + assert response.json() == { + "page": 1, + "page_size": 10, + "total_alerts": 0, + "total_pages": 0, + "alerts": [], + } + + +@pytest.mark.asyncio +async def test_get_workspace_alerts_with_results(mock_ws, mock_alerts, mock_prompts): + """Test when valid alerts are retrieved with pagination""" + with ( + patch("codegate.workspaces.crud.WorkspaceCrud.get_workspace_by_name", return_value=mock_ws), + patch( + "codegate.db.connection.DbReader.get_alerts_by_workspace", + return_value=(mock_alerts, len(mock_alerts)), + ), + patch("codegate.db.connection.DbReader.get_prompts_with_output", return_value=mock_prompts), + patch("codegate.api.v1_processing.remove_duplicate_alerts", return_value=mock_alerts), + patch("codegate.api.v1_processing.parse_get_alert_conversation", return_value=mock_alerts), + ): + + response = client.get("/workspaces/test_workspace/alerts?page=1&page_size=2") + assert response.status_code == 200 + data = response.json() + assert data["page"] == 1 + assert data["page_size"] == 2 + assert data["total_alerts"] == 2 + assert data["total_pages"] == 1 + assert len(data["alerts"]) == 2 + + +@pytest.mark.asyncio +async def test_get_workspace_alerts_deduplication(mock_ws, mock_alerts, mock_prompts): + """Test that alerts are fetched iteratively when deduplication reduces results""" + dedup_alerts = [mock_alerts[0]] # Simulate deduplication removing one alert + + with ( + patch("codegate.workspaces.crud.WorkspaceCrud.get_workspace_by_name", return_value=mock_ws), + patch( + "codegate.db.connection.DbReader.get_alerts_by_workspace", + side_effect=[(mock_alerts, 2), (mock_alerts, 2)], + ), + patch("codegate.db.connection.DbReader.get_prompts_with_output", return_value=mock_prompts), + patch("codegate.api.v1_processing.remove_duplicate_alerts", return_value=dedup_alerts), + patch("codegate.api.v1_processing.parse_get_alert_conversation", return_value=dedup_alerts), + ): + + response = client.get("/workspaces/test_workspace/alerts?page=1&page_size=2") + assert response.status_code == 200 + data = response.json() + assert data["page"] == 1 + assert data["page_size"] == 2 + assert data["total_alerts"] == 2 # Total alerts remain the same + assert data["total_pages"] == 1 + assert len(data["alerts"]) == 1 # Only one alert left after deduplication From 3c3eaa8dc012ca49ce61a635f89e6315c7bac55f Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Mon, 3 Mar 2025 12:58:10 +0100 Subject: [PATCH 03/14] fix database queries --- src/codegate/api/v1.py | 15 ++++++------- src/codegate/db/connection.py | 38 +++++++++++++++------------------ src/codegate/workspaces/crud.py | 4 ---- tests/api/test_v1_api.py | 11 +--------- 4 files changed, 24 insertions(+), 44 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index aae8c41f..c83e4c42 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -392,14 +392,12 @@ async def get_workspace_alerts( except Exception: logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") - - total_alerts = 0 - fetched_alerts = [] + offset = (page - 1) * page_size - batch_size = page_size * 2 # fetch more alerts per batch to allow deduplication + fetched_alerts = [] while len(fetched_alerts) < page_size: - alerts_batch, total_alerts = await dbreader.get_alerts_by_workspace( + alerts_batch = await dbreader.get_alerts_by_workspace( ws.id, AlertSeverity.CRITICAL.value, page_size, offset ) if not alerts_batch: @@ -407,9 +405,11 @@ async def get_workspace_alerts( dedup_alerts = await v1_processing.remove_duplicate_alerts(alerts_batch) fetched_alerts.extend(dedup_alerts) - offset += batch_size + offset += page_size final_alerts = fetched_alerts[:page_size] + total_alerts = len(fetched_alerts) + prompt_ids = list({alert.prompt_id for alert in final_alerts if alert.prompt_id}) prompts_outputs = await dbreader.get_prompts_with_output(prompt_ids) alert_conversations = await v1_processing.parse_get_alert_conversation( @@ -417,9 +417,6 @@ async def get_workspace_alerts( ) return { "page": page, - "page_size": page_size, - "total_alerts": total_alerts, - "total_pages": (total_alerts + page_size - 1) // page_size, "alerts": alert_conversations, } diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index a6a43418..8c853872 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -8,7 +8,7 @@ from alembic import command as alembic_command from alembic.config import Config as AlembicConfig from pydantic import BaseModel -from sqlalchemy import CursorResult, TextClause, event, text +from sqlalchemy import CursorResult, TextClause, bindparam, event, text from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import create_async_engine @@ -587,11 +587,12 @@ async def get_prompts_with_output(self, prompt_ids: List[str]) -> List[GetPrompt o.output_cost FROM prompts p LEFT JOIN outputs o ON p.id = o.prompt_id - WHERE p.id IN :prompt_ids + WHERE (p.id IN :prompt_ids) ORDER BY o.timestamp DESC """ - ) - conditions = {"prompt_ids": tuple(prompt_ids)} + ).bindparams(bindparam("prompt_ids", expanding=True)) + + conditions = {"prompt_ids": prompt_ids if prompt_ids else None} prompts = await self._exec_select_conditions_to_pydantic( GetPromptWithOutputsRow, sql, conditions, should_raise=True ) @@ -659,13 +660,23 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( return list(prompts_dict.values()) + async def _exec_select_count(self, sql_command: str, conditions: dict) -> int: + """Executes a COUNT SQL command and returns an integer result.""" + async with self._async_db_engine.begin() as conn: + try: + result = await conn.execute(text(sql_command), conditions) + return result.scalar_one() # Ensures it returns exactly one integer value + except Exception as e: + logger.error(f"Failed to execute COUNT query.", error=str(e)) + return 0 # Return 0 in case of failure to avoid crashes + async def get_alerts_by_workspace( self, workspace_id: str, trigger_category: Optional[str] = None, limit: int = API_DEFAULT_PAGE_SIZE, offset: int = 0, - ) -> Tuple[List[Alert], int]: + ) -> List[Alert]: sql = text( """ SELECT @@ -691,25 +702,10 @@ async def get_alerts_by_workspace( conditions["limit"] = limit conditions["offset"] = offset - alerts = await self._exec_select_conditions_to_pydantic( + return await self._exec_select_conditions_to_pydantic( Alert, sql, conditions, should_raise=True ) - # Count total alerts for pagination - count_sql = text( - """ - SELECT COUNT(*) - FROM alerts a - INNER JOIN prompts p ON p.id = a.prompt_id - WHERE p.workspace_id = :workspace_id - """ - ) - if trigger_category: - count_sql = text(count_sql.text + " AND a.trigger_category = :trigger_category") - - total_alerts = await self._exec_select_count(count_sql, conditions) - return alerts, total_alerts - async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]: sql = text( """ diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index 0d0248ef..a81426a8 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -213,12 +213,8 @@ async def hard_delete_workspace(self, workspace_name: str): return async def get_workspace_by_name(self, workspace_name: str) -> db_models.WorkspaceRow: - print("i get by name") workspace = await self._db_reader.get_workspace_by_name(workspace_name) - print("workspace is") - print(workspace) if not workspace: - print("in not exist") raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") return workspace diff --git a/tests/api/test_v1_api.py b/tests/api/test_v1_api.py index 93a6fa74..dab1fffe 100644 --- a/tests/api/test_v1_api.py +++ b/tests/api/test_v1_api.py @@ -109,16 +109,13 @@ async def test_get_workspace_alerts_empty(mock_ws): """Test when no alerts are found (empty list)""" with ( patch("codegate.workspaces.crud.WorkspaceCrud.get_workspace_by_name", return_value=mock_ws), - patch("codegate.db.connection.DbReader.get_alerts_by_workspace", return_value=([], 0)), + patch("codegate.db.connection.DbReader.get_alerts_by_workspace", return_value=[]), ): response = client.get("/workspaces/test_workspace/alerts?page=1&page_size=10") assert response.status_code == 200 assert response.json() == { "page": 1, - "page_size": 10, - "total_alerts": 0, - "total_pages": 0, "alerts": [], } @@ -141,9 +138,6 @@ async def test_get_workspace_alerts_with_results(mock_ws, mock_alerts, mock_prom assert response.status_code == 200 data = response.json() assert data["page"] == 1 - assert data["page_size"] == 2 - assert data["total_alerts"] == 2 - assert data["total_pages"] == 1 assert len(data["alerts"]) == 2 @@ -167,7 +161,4 @@ async def test_get_workspace_alerts_deduplication(mock_ws, mock_alerts, mock_pro assert response.status_code == 200 data = response.json() assert data["page"] == 1 - assert data["page_size"] == 2 - assert data["total_alerts"] == 2 # Total alerts remain the same - assert data["total_pages"] == 1 assert len(data["alerts"]) == 1 # Only one alert left after deduplication From 887eb6c304fe02c1dd8485c0f65e70a45a71521e Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Mon, 3 Mar 2025 17:00:17 +0100 Subject: [PATCH 04/14] add pagination for messages --- src/codegate/api/v1.py | 44 ++++----- src/codegate/db/connection.py | 31 ++++--- tests/api/test_v1_api.py | 164 ---------------------------------- 3 files changed, 44 insertions(+), 195 deletions(-) delete mode 100644 tests/api/test_v1_api.py diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index c83e4c42..eaec1735 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -383,7 +383,7 @@ async def get_workspace_alerts( workspace_name: str, page: int = Query(1, ge=1), page_size: int = Query(API_DEFAULT_PAGE_SIZE, get=1, le=API_MAX_PAGE_SIZE), -) -> Dict[str, Any]: +) -> List[v1_models.AlertConversation]: """Get alerts for a workspace.""" try: ws = await wscrud.get_workspace_by_name(workspace_name) @@ -392,12 +392,12 @@ async def get_workspace_alerts( except Exception: logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") - + offset = (page - 1) * page_size fetched_alerts = [] while len(fetched_alerts) < page_size: - alerts_batch = await dbreader.get_alerts_by_workspace( + alerts_batch = await dbreader.get_alerts_by_workspace( ws.id, AlertSeverity.CRITICAL.value, page_size, offset ) if not alerts_batch: @@ -408,17 +408,13 @@ async def get_workspace_alerts( offset += page_size final_alerts = fetched_alerts[:page_size] - total_alerts = len(fetched_alerts) prompt_ids = list({alert.prompt_id for alert in final_alerts if alert.prompt_id}) prompts_outputs = await dbreader.get_prompts_with_output(prompt_ids) alert_conversations = await v1_processing.parse_get_alert_conversation( final_alerts, prompts_outputs ) - return { - "page": page, - "alerts": alert_conversations, - } + return alert_conversations @v1.get( @@ -426,7 +422,11 @@ async def get_workspace_alerts( tags=["Workspaces"], generate_unique_id_function=uniq_name, ) -async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversation]: +async def get_workspace_messages( + workspace_name: str, + page: int = Query(1, ge=1), + page_size: int = Query(API_DEFAULT_PAGE_SIZE, ge=1, le=API_MAX_PAGE_SIZE), +) -> List[v1_models.Conversation]: """Get messages for a workspace.""" try: ws = await wscrud.get_workspace_by_name(workspace_name) @@ -436,19 +436,23 @@ async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversa logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") - try: - prompts_with_output_alerts_usage = ( - await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id( - ws.id, AlertSeverity.CRITICAL.value - ) + offset = (page - 1) * page_size + fetched_messages = [] + + while len(fetched_messages) < page_size: + messages_batch = await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id( + ws.id, AlertSeverity.CRITICAL.value, page_size, offset ) - conversations, _ = await v1_processing.parse_messages_in_conversations( - prompts_with_output_alerts_usage + if not messages_batch: + break + parsed_conversations, _ = await v1_processing.parse_messages_in_conversations( + messages_batch ) - return conversations - except Exception: - logger.exception("Error while getting messages") - raise HTTPException(status_code=500, detail="Internal server error") + fetched_messages.extend(parsed_conversations) + offset += page_size + + final_messages = fetched_messages[:page_size] + return final_messages @v1.get( diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 8c853872..ddc61a6c 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -590,7 +590,7 @@ async def get_prompts_with_output(self, prompt_ids: List[str]) -> List[GetPrompt WHERE (p.id IN :prompt_ids) ORDER BY o.timestamp DESC """ - ).bindparams(bindparam("prompt_ids", expanding=True)) + ).bindparams(bindparam("prompt_ids", expanding=True)) conditions = {"prompt_ids": prompt_ids if prompt_ids else None} prompts = await self._exec_select_conditions_to_pydantic( @@ -599,12 +599,15 @@ async def get_prompts_with_output(self, prompt_ids: List[str]) -> List[GetPrompt return prompts async def get_prompts_with_output_alerts_usage_by_workspace_id( - self, workspace_id: str, trigger_category: Optional[str] = None + self, + workspace_id: str, + trigger_category: Optional[str] = None, + limit: int = API_DEFAULT_PAGE_SIZE, + offset: int = 0, ) -> List[GetPromptWithOutputsRow]: """ Get all prompts with their outputs, alerts and token usage by workspace_id. """ - sql = text( """ SELECT @@ -615,20 +618,26 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( LEFT JOIN outputs o ON p.id = o.prompt_id LEFT JOIN alerts a ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id - AND (a.trigger_category = :trigger_category OR a.trigger_category is NULL) - ORDER BY o.timestamp DESC, a.timestamp DESC """ # noqa: E501 ) - # If trigger category is None we want to get all alerts - trigger_category = trigger_category if trigger_category else "%" - conditions = {"workspace_id": workspace_id, "trigger_category": trigger_category} - rows: List[IntermediatePromptWithOutputUsageAlerts] = ( + conditions = {"workspace_id": workspace_id} + if trigger_category: + sql = text(sql.text + " AND a.trigger_category = :trigger_category") + conditions["trigger_category"] = trigger_category + + sql = text( + sql.text + " ORDER BY o.timestamp DESC, a.timestamp DESC LIMIT :limit OFFSET :offset" + ) + conditions["limit"] = limit + conditions["offset"] = offset + + fetched_rows: List[IntermediatePromptWithOutputUsageAlerts] = ( await self._exec_select_conditions_to_pydantic( IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True ) ) prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} - for row in rows: + for row in fetched_rows: prompt_id = row.prompt_id if prompt_id not in prompts_dict: prompts_dict[prompt_id] = GetPromptWithOutputsRow( @@ -669,7 +678,7 @@ async def _exec_select_count(self, sql_command: str, conditions: dict) -> int: except Exception as e: logger.error(f"Failed to execute COUNT query.", error=str(e)) return 0 # Return 0 in case of failure to avoid crashes - + async def get_alerts_by_workspace( self, workspace_id: str, diff --git a/tests/api/test_v1_api.py b/tests/api/test_v1_api.py deleted file mode 100644 index dab1fffe..00000000 --- a/tests/api/test_v1_api.py +++ /dev/null @@ -1,164 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch -from fastapi.testclient import TestClient -from fastapi import FastAPI -from codegate.api.v1 import v1 -from codegate.db.models import Alert, AlertSeverity, GetPromptWithOutputsRow -from codegate.workspaces.crud import WorkspaceDoesNotExistError # Import the APIRouter instance - -# Create a FastAPI test app and include the APIRouter -app = FastAPI() -app.include_router(v1) -client = TestClient(app) - - -@pytest.fixture -def mock_ws(): - """Mock workspace object""" - ws = AsyncMock() - ws.id = "test_workspace_id" - return ws - - -@pytest.fixture -def mock_alerts(): - """Mock alerts list""" - return [ - Alert( - id="1", - prompt_id="p1", - code_snippet="code", - trigger_string="error", - trigger_type="type", - trigger_category=AlertSeverity.CRITICAL.value, - timestamp="2024-03-03T12:34:56Z", - ), - Alert( - id="2", - prompt_id="p2", - code_snippet="code2", - trigger_string="error2", - trigger_type="type2", - trigger_category=AlertSeverity.CRITICAL.value, - timestamp="2024-03-03T12:35:56Z", - ), - ] - - -@pytest.fixture -def mock_prompts(): - """Mock prompts output list""" - return [ - GetPromptWithOutputsRow( - id="p1", - timestamp="2024-03-03T12:34:56Z", - provider="provider", - request="req", - type="type", - output_id="o1", - output="output", - output_timestamp="2024-03-03T12:35:56Z", - input_tokens=10, - output_tokens=15, - input_cost=0.01, - output_cost=0.02, - ), - GetPromptWithOutputsRow( - id="p2", - timestamp="2024-03-03T12:36:56Z", - provider="provider2", - request="req2", - type="type2", - output_id="o2", - output="output2", - output_timestamp="2024-03-03T12:37:56Z", - input_tokens=20, - output_tokens=25, - input_cost=0.02, - output_cost=0.03, - ), - ] - - -@pytest.mark.asyncio -async def test_get_workspace_alerts_not_found(): - """Test when workspace does not exist (404 error)""" - with patch( - "codegate.workspaces.crud.WorkspaceCrud.get_workspace_by_name", - side_effect=WorkspaceDoesNotExistError("Workspace does not exist"), - ): - response = client.get("/workspaces/non_existent_workspace/alerts") - assert response.status_code == 404 - assert response.json()["detail"] == "Workspace does not exist" - - -@pytest.mark.asyncio -async def test_get_workspace_alerts_internal_server_error(): - """Test when an internal error occurs (500 error)""" - with patch( - "codegate.workspaces.crud.WorkspaceCrud.get_workspace_by_name", - side_effect=Exception("Unexpected error"), - ): - response = client.get("/workspaces/test_workspace/alerts") - assert response.status_code == 500 - assert response.json()["detail"] == "Internal server error" - - -@pytest.mark.asyncio -async def test_get_workspace_alerts_empty(mock_ws): - """Test when no alerts are found (empty list)""" - with ( - patch("codegate.workspaces.crud.WorkspaceCrud.get_workspace_by_name", return_value=mock_ws), - patch("codegate.db.connection.DbReader.get_alerts_by_workspace", return_value=[]), - ): - - response = client.get("/workspaces/test_workspace/alerts?page=1&page_size=10") - assert response.status_code == 200 - assert response.json() == { - "page": 1, - "alerts": [], - } - - -@pytest.mark.asyncio -async def test_get_workspace_alerts_with_results(mock_ws, mock_alerts, mock_prompts): - """Test when valid alerts are retrieved with pagination""" - with ( - patch("codegate.workspaces.crud.WorkspaceCrud.get_workspace_by_name", return_value=mock_ws), - patch( - "codegate.db.connection.DbReader.get_alerts_by_workspace", - return_value=(mock_alerts, len(mock_alerts)), - ), - patch("codegate.db.connection.DbReader.get_prompts_with_output", return_value=mock_prompts), - patch("codegate.api.v1_processing.remove_duplicate_alerts", return_value=mock_alerts), - patch("codegate.api.v1_processing.parse_get_alert_conversation", return_value=mock_alerts), - ): - - response = client.get("/workspaces/test_workspace/alerts?page=1&page_size=2") - assert response.status_code == 200 - data = response.json() - assert data["page"] == 1 - assert len(data["alerts"]) == 2 - - -@pytest.mark.asyncio -async def test_get_workspace_alerts_deduplication(mock_ws, mock_alerts, mock_prompts): - """Test that alerts are fetched iteratively when deduplication reduces results""" - dedup_alerts = [mock_alerts[0]] # Simulate deduplication removing one alert - - with ( - patch("codegate.workspaces.crud.WorkspaceCrud.get_workspace_by_name", return_value=mock_ws), - patch( - "codegate.db.connection.DbReader.get_alerts_by_workspace", - side_effect=[(mock_alerts, 2), (mock_alerts, 2)], - ), - patch("codegate.db.connection.DbReader.get_prompts_with_output", return_value=mock_prompts), - patch("codegate.api.v1_processing.remove_duplicate_alerts", return_value=dedup_alerts), - patch("codegate.api.v1_processing.parse_get_alert_conversation", return_value=dedup_alerts), - ): - - response = client.get("/workspaces/test_workspace/alerts?page=1&page_size=2") - assert response.status_code == 200 - data = response.json() - assert data["page"] == 1 - assert len(data["alerts"]) == 1 # Only one alert left after deduplication From 4627bca557a481ed36fb904d11beedf0de2c5b24 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Tue, 4 Mar 2025 10:10:38 +0100 Subject: [PATCH 05/14] revert changes in alerts --- src/codegate/api/v1.py | 35 ++++++++--------------------------- src/codegate/db/connection.py | 25 ++++++++----------------- 2 files changed, 16 insertions(+), 44 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index eaec1735..08825354 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -379,11 +379,7 @@ async def hard_delete_workspace(workspace_name: str): tags=["Workspaces"], generate_unique_id_function=uniq_name, ) -async def get_workspace_alerts( - workspace_name: str, - page: int = Query(1, ge=1), - page_size: int = Query(API_DEFAULT_PAGE_SIZE, get=1, le=API_MAX_PAGE_SIZE), -) -> List[v1_models.AlertConversation]: +async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.AlertConversation]]: """Get alerts for a workspace.""" try: ws = await wscrud.get_workspace_by_name(workspace_name) @@ -393,28 +389,13 @@ async def get_workspace_alerts( logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") - offset = (page - 1) * page_size - fetched_alerts = [] - - while len(fetched_alerts) < page_size: - alerts_batch = await dbreader.get_alerts_by_workspace( - ws.id, AlertSeverity.CRITICAL.value, page_size, offset - ) - if not alerts_batch: - break - - dedup_alerts = await v1_processing.remove_duplicate_alerts(alerts_batch) - fetched_alerts.extend(dedup_alerts) - offset += page_size - - final_alerts = fetched_alerts[:page_size] - - prompt_ids = list({alert.prompt_id for alert in final_alerts if alert.prompt_id}) - prompts_outputs = await dbreader.get_prompts_with_output(prompt_ids) - alert_conversations = await v1_processing.parse_get_alert_conversation( - final_alerts, prompts_outputs - ) - return alert_conversations + try: + alerts = await dbreader.get_alerts_by_workspace(ws.id, AlertSeverity.CRITICAL.value) + prompts_outputs = await dbreader.get_prompts_with_output(ws.id) + return await v1_processing.parse_get_alert_conversation(alerts, prompts_outputs) + except Exception: + logger.exception("Error while getting alerts and messages") + raise HTTPException(status_code=500, detail="Internal server error") @v1.get( diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index ddc61a6c..e23e7ffc 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -570,10 +570,7 @@ async def _exec_select_conditions_to_pydantic( raise e return None - async def get_prompts_with_output(self, prompt_ids: List[str]) -> List[GetPromptWithOutputsRow]: - if not prompt_ids: - return [] - + async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithOutputsRow]: sql = text( """ SELECT @@ -587,12 +584,11 @@ async def get_prompts_with_output(self, prompt_ids: List[str]) -> List[GetPrompt o.output_cost FROM prompts p LEFT JOIN outputs o ON p.id = o.prompt_id - WHERE (p.id IN :prompt_ids) + WHERE p.workspace_id = :workspace_id ORDER BY o.timestamp DESC """ - ).bindparams(bindparam("prompt_ids", expanding=True)) - - conditions = {"prompt_ids": prompt_ids if prompt_ids else None} + ) + conditions = {"workspace_id": workpace_id} prompts = await self._exec_select_conditions_to_pydantic( GetPromptWithOutputsRow, sql, conditions, should_raise=True ) @@ -680,11 +676,7 @@ async def _exec_select_count(self, sql_command: str, conditions: dict) -> int: return 0 # Return 0 in case of failure to avoid crashes async def get_alerts_by_workspace( - self, - workspace_id: str, - trigger_category: Optional[str] = None, - limit: int = API_DEFAULT_PAGE_SIZE, - offset: int = 0, + self, workspace_id: str, trigger_category: Optional[str] = None ) -> List[Alert]: sql = text( """ @@ -707,13 +699,12 @@ async def get_alerts_by_workspace( sql = text(sql.text + " AND a.trigger_category = :trigger_category") conditions["trigger_category"] = trigger_category - sql = text(sql.text + " ORDER BY a.timestamp DESC LIMIT :limit OFFSET :offset") - conditions["limit"] = limit - conditions["offset"] = offset + sql = text(sql.text + " ORDER BY a.timestamp DESC") - return await self._exec_select_conditions_to_pydantic( + prompts = await self._exec_select_conditions_to_pydantic( Alert, sql, conditions, should_raise=True ) + return prompts async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]: sql = text( From 0fd96dc75fc6711ffa94533b25137cc2eab27b41 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Tue, 4 Mar 2025 11:12:10 +0100 Subject: [PATCH 06/14] return object in messages endpoint --- src/codegate/api/v1.py | 14 ++++++++++++-- src/codegate/api/v1_models.py | 7 +++++++ src/codegate/db/connection.py | 29 +++++++++++++++++++++++------ 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 08825354..43ec24fb 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -407,7 +407,7 @@ async def get_workspace_messages( workspace_name: str, page: int = Query(1, ge=1), page_size: int = Query(API_DEFAULT_PAGE_SIZE, ge=1, le=API_MAX_PAGE_SIZE), -) -> List[v1_models.Conversation]: +) -> v1_models.PaginatedMessagesResponse: """Get messages for a workspace.""" try: ws = await wscrud.get_workspace_by_name(workspace_name) @@ -433,7 +433,17 @@ async def get_workspace_messages( offset += page_size final_messages = fetched_messages[:page_size] - return final_messages + + # Fetch total message count + total_count = await dbreader.get_total_messages_count_by_workspace_id( + ws.id, AlertSeverity.CRITICAL.value + ) + return v1_models.PaginatedMessagesResponse( + data=final_messages, + limit=page_size, + offset=(page - 1) * page_size, + total=total_count, + ) @v1.get( diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index c608484c..8ce9e2bc 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -312,3 +312,10 @@ class ModelByProvider(pydantic.BaseModel): def __str__(self): return f"{self.provider_name} / {self.name}" + + +class PaginatedMessagesResponse(pydantic.BaseModel): + data: List[Conversation] + limit: int + offset: int + total: int diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index e23e7ffc..2584bc9b 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -665,15 +665,32 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( return list(prompts_dict.values()) - async def _exec_select_count(self, sql_command: str, conditions: dict) -> int: - """Executes a COUNT SQL command and returns an integer result.""" + async def get_total_messages_count_by_workspace_id( + self, workspace_id: str, trigger_category: Optional[str] = None + ) -> int: + """Get total count of messages for a given workspace_id, considering trigger_category if provided.""" + sql = text( + """ + SELECT COUNT(*) + FROM prompts p + LEFT JOIN alerts a ON p.id = a.prompt_id + WHERE p.workspace_id = :workspace_id + """ + ) + conditions = {"workspace_id": workspace_id} + + if trigger_category: + sql = text(sql.text + " AND a.trigger_category = :trigger_category") + conditions["trigger_category"] = trigger_category + async with self._async_db_engine.begin() as conn: try: - result = await conn.execute(text(sql_command), conditions) - return result.scalar_one() # Ensures it returns exactly one integer value + result = await conn.execute(sql, conditions) + count = result.scalar() # Fetches the integer result directly + return count or 0 # Ensure it returns an integer except Exception as e: - logger.error(f"Failed to execute COUNT query.", error=str(e)) - return 0 # Return 0 in case of failure to avoid crashes + logger.error(f"Failed to fetch message count. Error: {e}") + return 0 # Return 0 in case of failure async def get_alerts_by_workspace( self, workspace_id: str, trigger_category: Optional[str] = None From c41b4ad447159a9407c663613a57c588607fa834 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Tue, 4 Mar 2025 11:22:08 +0100 Subject: [PATCH 07/14] fix lint --- src/codegate/db/connection.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 087c9766..32d6e24a 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -627,10 +627,10 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( conditions["limit"] = limit conditions["offset"] = offset - fetched_rows: List[ - IntermediatePromptWithOutputUsageAlerts - ] = await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + fetched_rows: List[IntermediatePromptWithOutputUsageAlerts] = ( + await self._exec_select_conditions_to_pydantic( + IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + ) ) prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} for row in fetched_rows: @@ -671,7 +671,7 @@ async def get_total_messages_count_by_workspace_id( """Get total count of messages for a given workspace_id, considering trigger_category.""" sql = text( """ - SELECT COUNT(*) + SELECT COUNT(*) FROM prompts p LEFT JOIN alerts a ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id From a6c9e76bb8b4c77852ee2f8f105ddd4056e59dda Mon Sep 17 00:00:00 2001 From: Alex McGovern <58784948+alex-mcgovern@users.noreply.github.com> Date: Tue, 4 Mar 2025 15:51:00 +0100 Subject: [PATCH 08/14] feat: filter messages by ID (#1206) * feat: filter messages by ID * lint fix * fix: use `.bindparams` for `filter_by_ids` --- src/codegate/api/v1.py | 3 +- src/codegate/db/connection.py | 55 ++++++++++++++++++++++------------- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 96672615..00fe45c2 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -407,6 +407,7 @@ async def get_workspace_messages( workspace_name: str, page: int = Query(1, ge=1), page_size: int = Query(API_DEFAULT_PAGE_SIZE, ge=1, le=API_MAX_PAGE_SIZE), + filter_by_ids: Optional[List[str]] = Query(None), ) -> v1_models.PaginatedMessagesResponse: """Get messages for a workspace.""" try: @@ -422,7 +423,7 @@ async def get_workspace_messages( while len(fetched_messages) < page_size: messages_batch = await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id( - ws.id, AlertSeverity.CRITICAL.value, page_size, offset + ws.id, AlertSeverity.CRITICAL.value, page_size, offset, filter_by_ids ) if not messages_batch: break diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 32d6e24a..9dc67b8b 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -8,7 +8,7 @@ from alembic import command as alembic_command from alembic.config import Config as AlembicConfig from pydantic import BaseModel -from sqlalchemy import CursorResult, TextClause, event, text +from sqlalchemy import CursorResult, TextClause, bindparam, event, text from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import create_async_engine @@ -600,38 +600,51 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( trigger_category: Optional[str] = None, limit: int = API_DEFAULT_PAGE_SIZE, offset: int = 0, + filter_by_ids: Optional[List[str]] = None, ) -> List[GetPromptWithOutputsRow]: """ Get all prompts with their outputs, alerts and token usage by workspace_id. """ - sql = text( - """ + + base_query = """ SELECT - p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type, - o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost, - a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp + p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type, + o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost, + a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp FROM prompts p LEFT JOIN outputs o ON p.id = o.prompt_id LEFT JOIN alerts a ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id - """ # noqa: E501 - ) - conditions = {"workspace_id": workspace_id} - if trigger_category: - sql = text(sql.text + " AND a.trigger_category = :trigger_category") - conditions["trigger_category"] = trigger_category + AND (:trigger_category IS NULL OR a.trigger_category = :trigger_category) + """ # noqa: E501 - sql = text( - sql.text + " ORDER BY o.timestamp DESC, a.timestamp DESC LIMIT :limit OFFSET :offset" - ) - conditions["limit"] = limit - conditions["offset"] = offset + if filter_by_ids: + base_query += " AND p.id IN :filter_ids" - fetched_rows: List[IntermediatePromptWithOutputUsageAlerts] = ( - await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True - ) + base_query += """ + ORDER BY o.timestamp DESC, a.timestamp DESC + LIMIT :limit OFFSET :offset + """ + + sql = text(base_query) + + conditions = { + "workspace_id": workspace_id, + "trigger_category": trigger_category, + "limit": limit, + "offset": offset, + } + + if filter_by_ids: + sql = sql.bindparams(bindparam("filter_ids", expanding=True)) + conditions["filter_ids"] = filter_by_ids + + fetched_rows: List[ + IntermediatePromptWithOutputUsageAlerts + ] = await self._exec_select_conditions_to_pydantic( + IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True ) + prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} for row in fetched_rows: prompt_id = row.prompt_id From 204c0a4b74448fb0b3e91a248894bf4fb6266566 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Tue, 4 Mar 2025 15:21:35 +0000 Subject: [PATCH 09/14] fix: deduplicate response --- src/codegate/api/v1.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 00fe45c2..ef019473 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -419,19 +419,36 @@ async def get_workspace_messages( raise HTTPException(status_code=500, detail="Internal server error") offset = (page - 1) * page_size - fetched_messages = [] + fetched_messages: List[v1_models.Conversation] = [] while len(fetched_messages) < page_size: messages_batch = await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id( - ws.id, AlertSeverity.CRITICAL.value, page_size, offset, filter_by_ids + ws.id, + AlertSeverity.CRITICAL.value, + page_size - len(fetched_messages), + offset, + filter_by_ids, ) if not messages_batch: break parsed_conversations, _ = await v1_processing.parse_messages_in_conversations( messages_batch ) - fetched_messages.extend(parsed_conversations) - offset += page_size + + for conversation in parsed_conversations: + existing_conversation = next( + (msg for msg in fetched_messages if msg.chat_id == conversation.chat_id), None + ) + if existing_conversation: + existing_conversation.alerts.extend( + alert + for alert in conversation.alerts + if alert not in existing_conversation.alerts + ) + else: + fetched_messages.append(conversation) + + offset += len(messages_batch) final_messages = fetched_messages[:page_size] From 96f6eb15c505ffadf240e2f75443d27f125e9518 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Tue, 4 Mar 2025 15:26:13 +0000 Subject: [PATCH 10/14] lint fix --- src/codegate/db/connection.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 9dc67b8b..8b301404 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -639,10 +639,10 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( sql = sql.bindparams(bindparam("filter_ids", expanding=True)) conditions["filter_ids"] = filter_by_ids - fetched_rows: List[ - IntermediatePromptWithOutputUsageAlerts - ] = await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + fetched_rows: List[IntermediatePromptWithOutputUsageAlerts] = ( + await self._exec_select_conditions_to_pydantic( + IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + ) ) prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} From 8d58b65f945f2b8fc4b5c4f2bd2a4c6eb1568def Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Tue, 4 Mar 2025 18:20:56 +0000 Subject: [PATCH 11/14] feat(list messages): refactor pagination & add filtering --- src/codegate/api/v1.py | 74 ++++----- src/codegate/api/v1_processing.py | 13 +- src/codegate/db/connection.py | 157 ++++++++++++------ src/codegate/db/models.py | 10 +- .../codegate_context_retriever/codegate.py | 4 +- src/codegate/pipeline/pii/analyzer.py | 6 +- src/codegate/pipeline/pii/pii.py | 7 +- src/codegate/pipeline/secrets/secrets.py | 10 +- .../pipeline/sensitive_data/manager.py | 3 +- .../pipeline/sensitive_data/session_store.py | 2 +- tests/api/test_v1_processing.py | 16 +- tests/pipeline/pii/test_analyzer.py | 1 - tests/pipeline/pii/test_pi.py | 3 +- tests/pipeline/sensitive_data/test_manager.py | 3 +- .../sensitive_data/test_session_store.py | 2 +- tests/test_server.py | 1 - 16 files changed, 184 insertions(+), 128 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 5942610c..ad349e56 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -13,7 +13,7 @@ from codegate.api import v1_models, v1_processing from codegate.config import API_DEFAULT_PAGE_SIZE, API_MAX_PAGE_SIZE from codegate.db.connection import AlreadyExistsError, DbReader -from codegate.db.models import AlertSeverity, WorkspaceWithModel +from codegate.db.models import AlertSeverity, AlertTriggerType, WorkspaceWithModel from codegate.providers import crud as provendcrud from codegate.workspaces import crud @@ -435,6 +435,7 @@ async def get_workspace_messages( page: int = Query(1, ge=1), page_size: int = Query(API_DEFAULT_PAGE_SIZE, ge=1, le=API_MAX_PAGE_SIZE), filter_by_ids: Optional[List[str]] = Query(None), + filter_by_alert_trigger_types: Optional[List[AlertTriggerType]] = Query(None), ) -> v1_models.PaginatedMessagesResponse: """Get messages for a workspace.""" try: @@ -448,47 +449,40 @@ async def get_workspace_messages( offset = (page - 1) * page_size fetched_messages: List[v1_models.Conversation] = [] - while len(fetched_messages) < page_size: - messages_batch = await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id( - ws.id, - AlertSeverity.CRITICAL.value, - page_size - len(fetched_messages), - offset, - filter_by_ids, + try: + while len(fetched_messages) < page_size: + messages_batch = await dbreader.get_messages( + ws.id, + offset, + page_size, + filter_by_ids, + list([AlertSeverity.CRITICAL.value]), # TODO: Configurable severity + filter_by_alert_trigger_types, + ) + if not messages_batch: + break + parsed_conversations, _ = await v1_processing.parse_messages_in_conversations( + messages_batch + ) + fetched_messages.extend(parsed_conversations) + + offset += len(messages_batch) + + final_messages = fetched_messages[:page_size] + + # Fetch total message count + total_count = await dbreader.get_total_messages_count_by_workspace_id( + ws.id, AlertSeverity.CRITICAL.value ) - if not messages_batch: - break - parsed_conversations, _ = await v1_processing.parse_messages_in_conversations( - messages_batch + return v1_models.PaginatedMessagesResponse( + data=final_messages, + limit=page_size, + offset=(page - 1) * page_size, + total=total_count, ) - - for conversation in parsed_conversations: - existing_conversation = next( - (msg for msg in fetched_messages if msg.chat_id == conversation.chat_id), None - ) - if existing_conversation: - existing_conversation.alerts.extend( - alert - for alert in conversation.alerts - if alert not in existing_conversation.alerts - ) - else: - fetched_messages.append(conversation) - - offset += len(messages_batch) - - final_messages = fetched_messages[:page_size] - - # Fetch total message count - total_count = await dbreader.get_total_messages_count_by_workspace_id( - ws.id, AlertSeverity.CRITICAL.value - ) - return v1_models.PaginatedMessagesResponse( - data=final_messages, - limit=page_size, - offset=(page - 1) * page_size, - total=total_count, - ) + except Exception: + logger.exception("Error while getting messages") + raise HTTPException(status_code=500, detail="Internal server error") @v1.get( diff --git a/src/codegate/api/v1_processing.py b/src/codegate/api/v1_processing.py index 10f42075..0f5f013d 100644 --- a/src/codegate/api/v1_processing.py +++ b/src/codegate/api/v1_processing.py @@ -21,7 +21,7 @@ TokenUsageByModel, ) from codegate.db.connection import alert_queue -from codegate.db.models import Alert, GetPromptWithOutputsRow, TokenUsage +from codegate.db.models import Alert, GetMessagesRow, TokenUsage logger = structlog.get_logger("codegate") @@ -152,7 +152,7 @@ def _parse_single_output(single_output: dict) -> str: async def _get_partial_question_answer( - row: GetPromptWithOutputsRow, + row: GetMessagesRow, ) -> Optional[PartialQuestionAnswer]: """ Parse a row from the get_prompt_with_outputs query and return a PartialConversation @@ -423,7 +423,7 @@ async def match_conversations( async def _process_prompt_output_to_partial_qa( - prompts_outputs: List[GetPromptWithOutputsRow], + prompts_outputs: List[GetMessagesRow], ) -> List[PartialQuestionAnswer]: """ Process the prompts and outputs to PartialQuestionAnswer objects. @@ -435,7 +435,7 @@ async def _process_prompt_output_to_partial_qa( async def parse_messages_in_conversations( - prompts_outputs: List[GetPromptWithOutputsRow], + prompts_outputs: List[GetMessagesRow], ) -> Tuple[List[Conversation], Dict[str, Conversation]]: """ Get all the messages from the database and return them as a list of conversations. @@ -477,7 +477,7 @@ async def parse_row_alert_conversation( async def parse_get_alert_conversation( alerts: List[Alert], - prompts_outputs: List[GetPromptWithOutputsRow], + prompts_outputs: List[GetMessagesRow], ) -> List[AlertConversation]: """ Parse a list of rows from the get_alerts_with_prompt_and_output query and return a list of @@ -496,7 +496,7 @@ async def parse_get_alert_conversation( async def parse_workspace_token_usage( - prompts_outputs: List[GetPromptWithOutputsRow], + prompts_outputs: List[GetMessagesRow], ) -> TokenUsageAggregate: """ Parse the token usage from the workspace. @@ -515,7 +515,6 @@ async def remove_duplicate_alerts(alerts: List[v1_models.Alert]) -> List[v1_mode for alert in sorted( alerts, key=lambda x: x.timestamp, reverse=True ): # Sort alerts by timestamp descending - # Handle trigger_string based on its type trigger_string_content = "" if isinstance(alert.trigger_string, dict): diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 7d6e3e1d..21394d21 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -16,14 +16,14 @@ from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import create_async_engine -from codegate.config import API_DEFAULT_PAGE_SIZE from codegate.db.fim_cache import FimCache from codegate.db.models import ( ActiveWorkspace, Alert, - GetPromptWithOutputsRow, + AlertTriggerType, + GetMessagesRow, GetWorkspaceByNameConditions, - IntermediatePromptWithOutputUsageAlerts, + IntermediateMessagesRow, MuxRule, Output, Persona, @@ -630,7 +630,7 @@ async def _exec_vec_db_query_to_pydantic( conn.close() return results - async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithOutputsRow]: + async def get_prompts_with_output(self, workpace_id: str) -> List[GetMessagesRow]: sql = text( """ SELECT @@ -650,66 +650,121 @@ async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithO ) conditions = {"workspace_id": workpace_id} prompts = await self._exec_select_conditions_to_pydantic( - GetPromptWithOutputsRow, sql, conditions, should_raise=True + GetMessagesRow, sql, conditions, should_raise=True ) return prompts - async def get_prompts_with_output_alerts_usage_by_workspace_id( + async def get_messages( self, workspace_id: str, - trigger_category: Optional[str] = None, - limit: int = API_DEFAULT_PAGE_SIZE, offset: int = 0, + page_size: int = 20, filter_by_ids: Optional[List[str]] = None, - ) -> List[GetPromptWithOutputsRow]: + filter_by_alert_trigger_categories: Optional[List[str]] = None, + filter_by_alert_trigger_types: Optional[List[str]] = None, + ) -> List[GetMessagesRow]: """ - Get all prompts with their outputs, alerts and token usage by workspace_id. + Retrieve prompts with their associated outputs and alerts, with filtering and pagination. + + Args: + workspace_id: The ID of the workspace to fetch prompts from + offset: Number of records to skip (for pagination) + page_size: Number of records per page + filter_by_ids: Optional list of prompt IDs to filter by + filter_by_alert_trigger_categories: Optional list of alert categories to filter by + filter_by_alert_trigger_types: Optional list of alert trigger types to filter by + + Returns: + List of GetPromptWithOutputsRow containing prompts, outputs, and alerts """ - + # Build base query base_query = """ + WITH filtered_prompts AS ( + SELECT DISTINCT p.* + FROM prompts p + LEFT JOIN alerts a ON p.id = a.prompt_id + WHERE p.workspace_id = :workspace_id + {filter_conditions} + ORDER BY p.timestamp DESC + LIMIT :page_size OFFSET :offset + ) SELECT - p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type, - o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost, - a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp - FROM prompts p + p.id as prompt_id, + p.timestamp as prompt_timestamp, + p.provider, + p.request, + p.type, + o.id as output_id, + o.output, + o.timestamp as output_timestamp, + o.input_tokens, + o.output_tokens, + o.input_cost, + o.output_cost, + a.id as alert_id, + a.code_snippet, + a.trigger_string, + a.trigger_type, + a.trigger_category, + a.timestamp as alert_timestamp + FROM filtered_prompts p LEFT JOIN outputs o ON p.id = o.prompt_id LEFT JOIN alerts a ON p.id = a.prompt_id - WHERE p.workspace_id = :workspace_id - AND (:trigger_category IS NULL OR a.trigger_category = :trigger_category) - """ # noqa: E501 - - if filter_by_ids: - base_query += " AND p.id IN :filter_ids" - - base_query += """ - ORDER BY o.timestamp DESC, a.timestamp DESC - LIMIT :limit OFFSET :offset + ORDER BY p.timestamp DESC, o.timestamp ASC, a.timestamp ASC """ - sql = text(base_query) - + # Build conditions and filters conditions = { "workspace_id": workspace_id, - "trigger_category": trigger_category, - "limit": limit, + "page_size": page_size, "offset": offset, } - if filter_by_ids: - sql = sql.bindparams(bindparam("filter_ids", expanding=True)) - conditions["filter_ids"] = filter_by_ids + # Conditionally add filter clauses and conditions + + filter_conditions = [] + + if filter_by_alert_trigger_categories: + filter_conditions.append( + "AND a.trigger_category IN :filter_by_alert_trigger_categories" + ) + conditions["filter_by_alert_trigger_categories"] = filter_by_alert_trigger_categories - fetched_rows: List[IntermediatePromptWithOutputUsageAlerts] = ( - await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + if filter_by_alert_trigger_types: + filter_conditions.append( + "AND EXISTS (SELECT 1 FROM alerts a2 WHERE a2.prompt_id = p.id AND a2.trigger_type IN :filter_by_alert_trigger_types)" # noqa: E501 ) + conditions["filter_by_alert_trigger_types"] = filter_by_alert_trigger_types + + if filter_by_ids: + filter_conditions.append("AND p.id IN :filter_by_ids") + conditions["filter_by_ids"] = filter_by_ids + + filter_clause = " ".join(filter_conditions) + query = base_query.format(filter_conditions=filter_clause) + + sql = text(query) + + # Bind optional params + + if filter_by_alert_trigger_categories: + sql = sql.bindparams(bindparam("filter_by_alert_trigger_categories", expanding=True)) + if filter_by_alert_trigger_types: + sql = sql.bindparams(bindparam("filter_by_alert_trigger_types", expanding=True)) + if filter_by_ids: + sql = sql.bindparams(bindparam("filter_by_ids", expanding=True)) + + # Execute query + rows = await self._exec_select_conditions_to_pydantic( + IntermediateMessagesRow, sql, conditions, should_raise=True ) - prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} - for row in fetched_rows: - prompt_id = row.prompt_id - if prompt_id not in prompts_dict: - prompts_dict[prompt_id] = GetPromptWithOutputsRow( + # Process results + prompts_dict: Dict[str, GetMessagesRow] = {} + + for row in rows: + if row.prompt_id not in prompts_dict: + prompts_dict[row.prompt_id] = GetMessagesRow( id=row.prompt_id, timestamp=row.prompt_timestamp, provider=row.provider, @@ -724,6 +779,7 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( output_cost=row.output_cost, alerts=[], ) + if row.alert_id: alert = Alert( id=row.alert_id, @@ -734,7 +790,8 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( trigger_category=row.trigger_category, timestamp=row.alert_timestamp, ) - prompts_dict[prompt_id].alerts.append(alert) + if alert not in prompts_dict[row.prompt_id].alerts: + prompts_dict[row.prompt_id].alerts.append(alert) return list(prompts_dict.values()) @@ -799,19 +856,19 @@ async def get_alerts_by_workspace( async def get_alerts_summary_by_workspace(self, workspace_id: str) -> dict: """Get aggregated alert summary counts for a given workspace_id.""" sql = text( - """ + f""" SELECT - COUNT(*) AS total_alerts, - SUM(CASE WHEN a.trigger_type = 'codegate-secrets' THEN 1 ELSE 0 END) - AS codegate_secrets_count, - SUM(CASE WHEN a.trigger_type = 'codegate-context-retriever' THEN 1 ELSE 0 END) - AS codegate_context_retriever_count, - SUM(CASE WHEN a.trigger_type = 'codegate-pii' THEN 1 ELSE 0 END) - AS codegate_pii_count + COUNT(*) AS total_alerts, + SUM(CASE WHEN a.trigger_type = '{AlertTriggerType.CODEGATE_SECRETS.value}' THEN 1 ELSE 0 END) + AS codegate_secrets_count, + SUM(CASE WHEN a.trigger_type = '{AlertTriggerType.CODEGATE_CONTEXT_RETRIEVER.value}' THEN 1 ELSE 0 END) + AS codegate_context_retriever_count, + SUM(CASE WHEN a.trigger_type = '{AlertTriggerType.CODEGATE_PII.value}' THEN 1 ELSE 0 END) + AS codegate_pii_count FROM alerts a INNER JOIN prompts p ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id - """ + """ # noqa: E501 # nosec ) conditions = {"workspace_id": workspace_id} diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index a5941e96..63422d96 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -11,6 +11,12 @@ class AlertSeverity(str, Enum): CRITICAL = "critical" +class AlertTriggerType(str, Enum): + CODEGATE_PII = "codegate-pii" + CODEGATE_CONTEXT_RETRIEVER = "codegate-context-retriever" + CODEGATE_SECRETS = "codegate-secrets" + + class Alert(BaseModel): id: str prompt_id: str @@ -137,7 +143,7 @@ class ProviderType(str, Enum): openrouter = "openrouter" -class IntermediatePromptWithOutputUsageAlerts(BaseModel): +class IntermediateMessagesRow(BaseModel): """ An intermediate model to represent the result of a query for a prompt and related outputs, usage stats & alerts. @@ -163,7 +169,7 @@ class IntermediatePromptWithOutputUsageAlerts(BaseModel): alert_timestamp: Optional[Any] -class GetPromptWithOutputsRow(BaseModel): +class GetMessagesRow(BaseModel): id: Any timestamp: Any provider: Optional[Any] diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index e22874a6..96dce865 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -5,7 +5,7 @@ from litellm import ChatCompletionRequest from codegate.clients.clients import ClientType -from codegate.db.models import AlertSeverity +from codegate.db.models import AlertSeverity, AlertTriggerType from codegate.extract_snippets.factory import MessageCodeExtractorFactory from codegate.pipeline.base import ( PipelineContext, @@ -36,7 +36,7 @@ def name(self) -> str: """ Returns the name of this pipeline step. """ - return "codegate-context-retriever" + return AlertTriggerType.CODEGATE_CONTEXT_RETRIEVER.value def generate_context_str( self, objects: list[object], context: PipelineContext, snippet_map: dict diff --git a/src/codegate/pipeline/pii/analyzer.py b/src/codegate/pipeline/pii/analyzer.py index 96442824..bed67b17 100644 --- a/src/codegate/pipeline/pii/analyzer.py +++ b/src/codegate/pipeline/pii/analyzer.py @@ -1,10 +1,10 @@ -from typing import Any, List, Optional +from typing import List, Optional import structlog from presidio_analyzer import AnalyzerEngine from presidio_anonymizer import AnonymizerEngine -from codegate.db.models import AlertSeverity +from codegate.db.models import AlertTriggerType from codegate.pipeline.base import PipelineContext from codegate.pipeline.sensitive_data.session_store import SessionStore @@ -31,7 +31,7 @@ class PiiAnalyzer: """ _instance: Optional["PiiAnalyzer"] = None - _name = "codegate-pii" + _name = AlertTriggerType.CODEGATE_PII.value @classmethod def get_instance(cls) -> "PiiAnalyzer": diff --git a/src/codegate/pipeline/pii/pii.py b/src/codegate/pipeline/pii/pii.py index fde89428..c2d8b0ae 100644 --- a/src/codegate/pipeline/pii/pii.py +++ b/src/codegate/pipeline/pii/pii.py @@ -1,5 +1,4 @@ from typing import Any, Dict, List, Optional, Tuple -import uuid import regex as re import structlog @@ -7,7 +6,7 @@ from litellm.types.utils import Delta, StreamingChoices from codegate.config import Config -from codegate.db.models import AlertSeverity +from codegate.db.models import AlertSeverity, AlertTriggerType from codegate.pipeline.base import ( PipelineContext, PipelineResult, @@ -52,7 +51,7 @@ def __init__(self, sensitive_data_manager: SensitiveDataManager): @property def name(self) -> str: - return "codegate-pii" + return AlertTriggerType.CODEGATE_PII.value def _get_redacted_snippet(self, message: str, pii_details: List[Dict[str, Any]]) -> str: # If no PII found, return empty string @@ -420,7 +419,7 @@ async def process_chunk( # TODO: Might want to check these with James! notification_text = ( f"🛡️ [CodeGate protected {redacted_count} instances of PII, including {pii_summary}]" - f"(http://localhost:9090/?search=codegate-pii) from being leaked " + f"(http://localhost:9090/?search={AlertTriggerType.CODEGATE_PII.value}) from being leaked " # noqa: E501 f"by redacting them.\n\n" ) diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index 527c817f..535aea4b 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -7,7 +7,7 @@ from litellm.types.utils import Delta, StreamingChoices from codegate.config import Config -from codegate.db.models import AlertSeverity +from codegate.db.models import AlertSeverity, AlertTriggerType from codegate.extract_snippets.factory import MessageCodeExtractorFactory from codegate.pipeline.base import ( CodeSnippet, @@ -178,7 +178,7 @@ def __init__( self._sensitive_data_manager = sensitive_data_manager self._session_id = session_id self._context = context - self._name = "codegate-secrets" + self._name = AlertTriggerType.CODEGATE_SECRETS.value super().__init__() @@ -255,7 +255,7 @@ def name(self) -> str: Returns: str: The identifier 'codegate-secrets'. """ - return "codegate-secrets" + return AlertTriggerType.CODEGATE_SECRETS.value def _redact_text( self, @@ -547,7 +547,7 @@ async def process_chunk( notification_chunk = self._create_chunk( chunk, f"\n🛡️ [CodeGate prevented {redacted_count} {secret_text}]" - f"(http://localhost:9090/?search=codegate-secrets) from being leaked " + f"(http://localhost:9090/?search={AlertTriggerType.CODEGATE_SECRETS.value}) from being leaked " # noqa: E501 f"by redacting them.\n\n", ) notification_chunk.choices[0].delta.role = "assistant" @@ -555,7 +555,7 @@ async def process_chunk( notification_chunk = self._create_chunk( chunk, f"\n🛡️ [CodeGate prevented {redacted_count} {secret_text}]" - f"(http://localhost:9090/?search=codegate-secrets) from being leaked " + f"(http://localhost:9090/?search={AlertTriggerType.CODEGATE_SECRETS.value}) from being leaked " # noqa: E501 f"by redacting them.\n\n", ) diff --git a/src/codegate/pipeline/sensitive_data/manager.py b/src/codegate/pipeline/sensitive_data/manager.py index 89506d15..bf467878 100644 --- a/src/codegate/pipeline/sensitive_data/manager.py +++ b/src/codegate/pipeline/sensitive_data/manager.py @@ -1,7 +1,8 @@ -import json from typing import Dict, Optional + import pydantic import structlog + from codegate.pipeline.sensitive_data.session_store import SessionStore logger = structlog.get_logger("codegate") diff --git a/src/codegate/pipeline/sensitive_data/session_store.py b/src/codegate/pipeline/sensitive_data/session_store.py index 5e508847..7a33abd2 100644 --- a/src/codegate/pipeline/sensitive_data/session_store.py +++ b/src/codegate/pipeline/sensitive_data/session_store.py @@ -1,5 +1,5 @@ -from typing import Dict, Optional import uuid +from typing import Dict, Optional class SessionStore: diff --git a/tests/api/test_v1_processing.py b/tests/api/test_v1_processing.py index ad8ffcbd..fc3ee363 100644 --- a/tests/api/test_v1_processing.py +++ b/tests/api/test_v1_processing.py @@ -13,7 +13,7 @@ parse_request, remove_duplicate_alerts, ) -from codegate.db.models import GetPromptWithOutputsRow +from codegate.db.models import AlertTriggerType, GetMessagesRow @pytest.mark.asyncio @@ -147,7 +147,7 @@ async def test_parse_output(output_dict, expected_str): @pytest.mark.parametrize( "row", [ - GetPromptWithOutputsRow( + GetMessagesRow( id="1", timestamp=timestamp_now, provider="openai", @@ -446,7 +446,7 @@ def test_group_partial_messages(pq_list, expected_group_ids): prompt_id="p1", code_snippet=None, trigger_string="secret1 Context xyz", - trigger_type="codegate-secrets", + trigger_type=AlertTriggerType.CODEGATE_SECRETS.value, trigger_category="critical", timestamp=datetime.datetime(2023, 1, 1, 12, 0, 0), ), @@ -455,7 +455,7 @@ def test_group_partial_messages(pq_list, expected_group_ids): prompt_id="p2", code_snippet=None, trigger_string="secret1 Context abc", - trigger_type="codegate-secrets", + trigger_type=AlertTriggerType.CODEGATE_SECRETS.value, trigger_category="critical", timestamp=datetime.datetime(2023, 1, 1, 12, 0, 3), ), @@ -471,7 +471,7 @@ def test_group_partial_messages(pq_list, expected_group_ids): prompt_id="p1", code_snippet=None, trigger_string="secret1 Context xyz", - trigger_type="codegate-secrets", + trigger_type=AlertTriggerType.CODEGATE_SECRETS.value, trigger_category="critical", timestamp=datetime.datetime(2023, 1, 1, 12, 0, 0), ), @@ -480,7 +480,7 @@ def test_group_partial_messages(pq_list, expected_group_ids): prompt_id="p2", code_snippet=None, trigger_string="secret1 Context abc", - trigger_type="codegate-secrets", + trigger_type=AlertTriggerType.CODEGATE_SECRETS.value, trigger_category="critical", timestamp=datetime.datetime(2023, 1, 1, 12, 0, 6), ), @@ -496,7 +496,7 @@ def test_group_partial_messages(pq_list, expected_group_ids): prompt_id="p1", code_snippet=None, trigger_string="secret1 Context xyz", - trigger_type="codegate-secrets", + trigger_type=AlertTriggerType.CODEGATE_SECRETS.value, trigger_category="critical", timestamp=datetime.datetime(2023, 1, 1, 12, 0, 0), ), @@ -514,7 +514,7 @@ def test_group_partial_messages(pq_list, expected_group_ids): prompt_id="p3", code_snippet=None, trigger_string="secret1 Context abc", - trigger_type="codegate-secrets", + trigger_type=AlertTriggerType.CODEGATE_SECRETS.value, trigger_category="critical", timestamp=datetime.datetime(2023, 1, 1, 12, 0, 3), ), diff --git a/tests/pipeline/pii/test_analyzer.py b/tests/pipeline/pii/test_analyzer.py index d626b8cf..e856653c 100644 --- a/tests/pipeline/pii/test_analyzer.py +++ b/tests/pipeline/pii/test_analyzer.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from presidio_analyzer import RecognizerResult from codegate.pipeline.pii.analyzer import PiiAnalyzer diff --git a/tests/pipeline/pii/test_pi.py b/tests/pipeline/pii/test_pi.py index 06d2881f..4f6fadf5 100644 --- a/tests/pipeline/pii/test_pi.py +++ b/tests/pipeline/pii/test_pi.py @@ -4,6 +4,7 @@ from litellm import ChatCompletionRequest, ModelResponse from litellm.types.utils import Delta, StreamingChoices +from codegate.db.models import AlertTriggerType from codegate.pipeline.base import PipelineContext, PipelineSensitiveData from codegate.pipeline.output import OutputPipelineContext from codegate.pipeline.pii.pii import CodegatePii, PiiRedactionNotifier, PiiUnRedactionStep @@ -25,7 +26,7 @@ def pii_step(self): return CodegatePii(mock_sensitive_data_manager) def test_name(self, pii_step): - assert pii_step.name == "codegate-pii" + assert pii_step.name == AlertTriggerType.CODEGATE_PII.value def test_get_redacted_snippet_no_pii(self, pii_step): message = "Hello world" diff --git a/tests/pipeline/sensitive_data/test_manager.py b/tests/pipeline/sensitive_data/test_manager.py index 6115ad14..66305388 100644 --- a/tests/pipeline/sensitive_data/test_manager.py +++ b/tests/pipeline/sensitive_data/test_manager.py @@ -1,6 +1,7 @@ -import json from unittest.mock import MagicMock, patch + import pytest + from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager from codegate.pipeline.sensitive_data.session_store import SessionStore diff --git a/tests/pipeline/sensitive_data/test_session_store.py b/tests/pipeline/sensitive_data/test_session_store.py index b9ab64fe..e90b953e 100644 --- a/tests/pipeline/sensitive_data/test_session_store.py +++ b/tests/pipeline/sensitive_data/test_session_store.py @@ -1,5 +1,5 @@ -import uuid import pytest + from codegate.pipeline.sensitive_data.session_store import SessionStore diff --git a/tests/test_server.py b/tests/test_server.py index aa549810..bcf55e7e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -14,7 +14,6 @@ from codegate import __version__ from codegate.pipeline.factory import PipelineFactory -from codegate.pipeline.sensitive_data.manager import SensitiveDataManager from codegate.providers.registry import ProviderRegistry from codegate.server import init_app from src.codegate.cli import UvicornServer, cli From f4305613f89f19eaffaa597e4af395e8498f12f8 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Tue, 4 Mar 2025 23:22:33 +0000 Subject: [PATCH 12/14] fix: duplicate messages in total field in paginated response --- src/codegate/db/connection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 21394d21..2b3ffaf9 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -798,10 +798,10 @@ async def get_messages( async def get_total_messages_count_by_workspace_id( self, workspace_id: str, trigger_category: Optional[str] = None ) -> int: - """Get total count of messages for a given workspace_id, considering trigger_category.""" + """Get total count of unique messages for a given workspace_id, considering trigger_category.""" sql = text( """ - SELECT COUNT(*) + SELECT COUNT(DISTINCT p.id) FROM prompts p LEFT JOIN alerts a ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id From 0a9a81688d4715506d0f7649bf7e18377b3073c5 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Tue, 4 Mar 2025 23:23:27 +0000 Subject: [PATCH 13/14] lint fix --- src/codegate/db/connection.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 2b3ffaf9..bc1930f6 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -798,7 +798,10 @@ async def get_messages( async def get_total_messages_count_by_workspace_id( self, workspace_id: str, trigger_category: Optional[str] = None ) -> int: - """Get total count of unique messages for a given workspace_id, considering trigger_category.""" + """ + Get total count of unique messages for a given workspace_id, + considering trigger_category. + """ sql = text( """ SELECT COUNT(DISTINCT p.id) From 581d3b352a0e07b4d25b9e57450e07f4387b2019 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Wed, 5 Mar 2025 09:18:35 +0000 Subject: [PATCH 14/14] slight query optimizations --- src/codegate/db/connection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index bc1930f6..56ba16bb 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -680,7 +680,7 @@ async def get_messages( # Build base query base_query = """ WITH filtered_prompts AS ( - SELECT DISTINCT p.* + SELECT distinct p.id, p.timestamp, p.provider, p.request, p.type FROM prompts p LEFT JOIN alerts a ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id @@ -710,7 +710,6 @@ async def get_messages( FROM filtered_prompts p LEFT JOIN outputs o ON p.id = o.prompt_id LEFT JOIN alerts a ON p.id = a.prompt_id - ORDER BY p.timestamp DESC, o.timestamp ASC, a.timestamp ASC """ # Build conditions and filters