diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 9820e292..15305790 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -119,6 +119,10 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option """ INSERT INTO prompts (id, timestamp, provider, request, type, workspace_id) VALUES (:id, :timestamp, :provider, :request, :type, :workspace_id) + ON CONFLICT(id) DO UPDATE SET + timestamp = excluded.timestamp, provider = excluded.provider, + request = excluded.request, type = excluded.type, + workspace_id = excluded.workspace_id RETURNING * """ ) @@ -173,6 +177,9 @@ async def record_outputs( """ INSERT INTO outputs (id, prompt_id, timestamp, output) VALUES (:id, :prompt_id, :timestamp, :output) + ON CONFLICT (id) DO UPDATE SET + timestamp = excluded.timestamp, output = excluded.output, + prompt_id = excluded.prompt_id RETURNING * """ ) @@ -192,6 +199,10 @@ async def record_alerts(self, alerts: List[Alert], initial_id: Optional[str]) -> ) VALUES (:id, :prompt_id, :code_snippet, :trigger_string, :trigger_type, :trigger_category, :timestamp) + ON CONFLICT (id) DO UPDATE SET + code_snippet = excluded.code_snippet, trigger_string = excluded.trigger_string, + trigger_type = excluded.trigger_type, trigger_category = excluded.trigger_category, + timestamp = excluded.timestamp, prompt_id = excluded.prompt_id RETURNING * """ ) @@ -219,9 +230,6 @@ async def record_alerts(self, alerts: List[Alert], initial_id: Optional[str]) -> def _should_record_context(self, context: Optional[PipelineContext]) -> tuple: """Check if the context should be recorded in DB and determine the action.""" - if context is None or context.metadata.get("stored_in_db", False): - return False, None, None - if not context.input_request: logger.warning("No input request found. Skipping recording context.") return False, None, None @@ -245,7 +253,6 @@ async def record_context(self, context: Optional[PipelineContext]) -> None: await self.record_request(context.input_request) await self.record_outputs(context.output_responses, None) await self.record_alerts(context.alerts_raised, None) - context.metadata["stored_in_db"] = True logger.info( f"Recorded context in DB. Output chunks: {len(context.output_responses)}. " f"Alerts: {len(context.alerts_raised)}." @@ -255,7 +262,6 @@ async def record_context(self, context: Optional[PipelineContext]) -> None: await self.update_request(initial_id, context.input_request) await self.record_outputs(context.output_responses, initial_id) await self.record_alerts(context.alerts_raised, initial_id) - context.metadata["stored_in_db"] = True logger.info( f"Recorded context in DB. Output chunks: {len(context.output_responses)}. " f"Alerts: {len(context.alerts_raised)}." diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index 6937ba50..70fb4b7a 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -812,7 +812,7 @@ def _ensure_output_processor(self) -> None: input_context=self.proxy.context_tracking, ) - async def _process_stream(self): + async def _process_stream(self): # noqa: C901 try: async def stream_iterator():