Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit cfb343c

Browse files
committed
fix: use the latest user messages block instead of single message - WIP
Depending on the tooling, the user request is splitted between several user/assistant blocks. So use this logic instead of just picking the latest user message one, to identify code snippets and secrets Closes: #580
1 parent 38230d1 commit cfb343c

File tree

4 files changed

+61
-15
lines changed

4 files changed

+61
-15
lines changed

src/codegate/pipeline/base.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,45 @@ def get_last_user_message(
231231
return None
232232
for i in reversed(range(len(request["messages"]))):
233233
if request["messages"][i]["role"] == "user":
234-
content = request["messages"][i]["content"]
235-
return content, i
234+
content = request["messages"][i]["content"] # type: ignore
235+
return str(content), i
236+
237+
return None
238+
239+
@staticmethod
240+
def get_last_user_message_block(
241+
request: ChatCompletionRequest,
242+
) -> Optional[str]:
243+
"""
244+
Get the last block of consecutive 'user' messages from the request.
245+
246+
Args:
247+
request (ChatCompletionRequest): The chat completion request to process
248+
249+
Returns:
250+
Optional[str]: A string containing all consecutive user messages in the
251+
last user message block, separated by newlines, or None if
252+
no user message block is found.
253+
"""
254+
if request.get("messages") is None:
255+
return None
256+
257+
user_messages = []
258+
messages = request["messages"]
259+
260+
# Iterate in reverse to find the last block of consecutive 'user' messages
261+
for i in reversed(range(len(messages))):
262+
if messages[i]["role"] == "user" or messages[i]["role"] == "assistant":
263+
if messages[i]["role"] == "user":
264+
user_messages.append(messages[i]["content"]) # type: ignore
265+
else:
266+
# Stop when a message with a different role is encountered
267+
if user_messages:
268+
break
269+
270+
# Reverse the collected user messages to preserve the original order
271+
if user_messages:
272+
return "\n".join(reversed(user_messages))
236273

237274
return None
238275

src/codegate/pipeline/codegate_context_retriever/codegate.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,38 +59,36 @@ async def process(
5959
"""
6060
Use RAG DB to add context to the user request
6161
"""
62-
# Get the latest user messages
63-
user_messages = self.get_latest_user_messages(request)
64-
65-
# Nothing to do if the user_messages string is empty
66-
if len(user_messages) == 0:
62+
# Get the latest user message
63+
user_message = self.get_last_user_message_block(request)
64+
if not user_message:
6765
return PipelineResult(request=request)
6866

6967
# Create storage engine object
7068
storage_engine = StorageEngine()
7169

7270
# Extract any code snippets
73-
snippets = extract_snippets(user_messages)
71+
snippets = extract_snippets(user_message)
7472

7573
bad_snippet_packages = []
7674
if len(snippets) > 0:
7775
# Collect all packages referenced in the snippets
7876
snippet_packages = []
7977
for snippet in snippets:
8078
snippet_packages.extend(
81-
PackageExtractor.extract_packages(snippet.code, snippet.language)
79+
PackageExtractor.extract_packages(snippet.code, snippet.language) # type: ignore
8280
)
8381
logger.info(f"Found {len(snippet_packages)} packages in code snippets.")
8482

8583
# Find bad packages in the snippets
8684
bad_snippet_packages = await storage_engine.search(
87-
language=snippets[0].language, packages=snippet_packages
85+
language=snippets[0].language, packages=snippet_packages # type: ignore
8886
)
8987
logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.")
9088

9189
# Remove code snippets from the user messages and search for bad packages
9290
# in the rest of the user query/messsages
93-
user_messages = re.sub(r"```.*?```", "", user_messages, flags=re.DOTALL)
91+
user_messages = re.sub(r"```.*?```", "", user_message, flags=re.DOTALL)
9492

9593
# Vector search to find bad packages
9694
bad_packages = await storage_engine.search(query=user_messages, distance=0.5, limit=100)
@@ -119,7 +117,7 @@ async def process(
119117
# Add the context to the last user message
120118
# Format: "Context: {context_str} \n Query: {last user message content}"
121119
message = new_request["messages"][last_user_idx]
122-
context_msg = f'Context: {context_str} \n\n Query: {message["content"]}'
120+
context_msg = f'Context: {context_str} \n\n Query: {message["content"]}' # type: ignore
123121
message["content"] = context_msg
124122

125123
logger.debug("Final context message", context_message=context_msg)

src/codegate/pipeline/extract_snippets/extract_snippets.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def extract_snippets(message: str) -> List[CodeSnippet]:
8585

8686
# Find all code block matches
8787
for match in CODE_BLOCK_PATTERN.finditer(message):
88+
print("i try to extract snippet")
8889
matched_language = match.group("language") if match.group("language") else None
8990
filename = match.group("filename") if match.group("filename") else None
9091
content = match.group("content")
@@ -94,18 +95,24 @@ def extract_snippets(message: str) -> List[CodeSnippet]:
9495
# format ` ```python ` in output snippets
9596
if filename and not matched_language and "." not in filename:
9697
lang = filename
98+
print("lang is")
99+
print(lang)
97100
filename = None
98101
else:
99102
# Determine language from the message, either by the short
100103
# language identifier or by the filename
101104
lang = None
102105
if matched_language:
106+
print("i have a matched language")
103107
lang = ecosystem_from_message(matched_language.strip())
104108
if lang is None and filename:
109+
print("I try to get from filename")
105110
filename = filename.strip()
106111
# Determine language from the filename
107112
lang = ecosystem_from_filepath(filename)
108113

114+
print("language is")
115+
print(lang)
109116
snippets.append(CodeSnippet(filepath=filename, code=content, language=lang))
110117

111118
return snippets
@@ -129,10 +136,9 @@ async def process(
129136
request: ChatCompletionRequest,
130137
context: PipelineContext,
131138
) -> PipelineResult:
132-
last_user_message = self.get_last_user_message(request)
133-
if not last_user_message:
139+
msg_content = self.get_last_user_message_block(request)
140+
if not msg_content:
134141
return PipelineResult(request=request, context=context)
135-
msg_content, _ = last_user_message
136142
snippets = extract_snippets(msg_content)
137143

138144
logger.info(f"Extracted {len(snippets)} code snippets from the user message")

src/codegate/utils/package_extractor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,21 @@ class PackageExtractor:
7474

7575
@staticmethod
7676
def extract_packages(code: str, language_name: str) -> list[str]:
77+
print("packages are")
78+
print(code)
79+
print(language_name)
7780
if (code is None) or (language_name is None):
7881
return []
7982

8083
language_name = language_name.lower()
8184

8285
if language_name not in PackageExtractor.__languages.keys():
86+
print("no langauge")
8387
return []
8488

8589
language = PackageExtractor.__languages[language_name]
8690
parser = PackageExtractor.__parsers[language_name]
91+
print("here")
8792

8893
# Create tree
8994
tree = parser.parse(bytes(code, "utf8"))

0 commit comments

Comments
 (0)