diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index d5b33faf..291b8790 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -73,17 +73,21 @@ async def process( # Extract any code snippets snippets = extract_snippets(user_messages) - # Collect all packages referenced in the snippets - snippet_packages = [] - for snippet in snippets: - snippet_packages.extend( - PackageExtractor.extract_packages(snippet.code, snippet.language) + bad_snippet_packages = [] + if len(snippets) > 0: + # Collect all packages referenced in the snippets + snippet_packages = [] + for snippet in snippets: + snippet_packages.extend( + PackageExtractor.extract_packages(snippet.code, snippet.language) + ) + logger.info(f"Found {len(snippet_packages)} packages in code snippets.") + + # Find bad packages in the snippets + bad_snippet_packages = await storage_engine.search( + language=snippets[0].language, packages=snippet_packages ) - logger.info(f"Found {len(snippet_packages)} packages in code snippets.") - - # Find bad packages in the snippets - bad_snippet_packages = await storage_engine.search_by_property("name", snippet_packages) - logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.") + logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.") # Remove code snippets from the user messages and search for bad packages # in the rest of the user query/messsages diff --git a/src/codegate/pipeline/extract_snippets/output.py b/src/codegate/pipeline/extract_snippets/output.py index e580c7d8..f385b71d 100644 --- a/src/codegate/pipeline/extract_snippets/output.py +++ b/src/codegate/pipeline/extract_snippets/output.py @@ -52,7 +52,9 @@ async def _snippet_comment(self, snippet: CodeSnippet, context: PipelineContext) # Check if any of the snippet libraries is a bad package storage_engine = StorageEngine() - libobjects = await storage_engine.search_by_property("name", snippet.libraries) + libobjects = await storage_engine.search( + language=snippet.language, packages=snippet.libraries + ) logger.info(f"Found {len(libobjects)} libraries in the storage engine") # If no bad packages are found, just return empty comment diff --git a/src/codegate/storage/storage_engine.py b/src/codegate/storage/storage_engine.py index 7f4f22b7..afd5cbd7 100644 --- a/src/codegate/storage/storage_engine.py +++ b/src/codegate/storage/storage_engine.py @@ -12,6 +12,13 @@ logger = structlog.get_logger("codegate") VALID_ECOSYSTEMS = ["npm", "pypi", "crates", "maven", "go"] +LANGUAGE_TO_ECOSYSTEM = { + "javascript": "npm", + "go": "go", + "python": "pypi", + "java": "maven", + "rust": "crates", +} class StorageEngine: @@ -125,6 +132,7 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[dic async def search( self, query: str = None, + language: str = None, ecosystem: str = None, packages: List[str] = None, limit: int = 50, @@ -136,6 +144,9 @@ async def search( try: cursor = self.conn.cursor() + if language and language in LANGUAGE_TO_ECOSYSTEM.keys(): + ecosystem = LANGUAGE_TO_ECOSYSTEM[language] + if packages and ecosystem and ecosystem in VALID_ECOSYSTEMS: placeholders = ",".join("?" * len(packages)) query_sql = f"""