diff --git a/src/codegate/pipeline/cli/cli.py b/src/codegate/pipeline/cli/cli.py index d94af629..a18d32ad 100644 --- a/src/codegate/pipeline/cli/cli.py +++ b/src/codegate/pipeline/cli/cli.py @@ -10,6 +10,7 @@ PipelineStep, ) from codegate.pipeline.cli.commands import CustomInstructions, Version, Workspace +from codegate.utils.utils import get_tool_name_from_messages HELP_TEXT = """ ## CodeGate CLI\n @@ -77,27 +78,52 @@ async def process( if last_user_message is not None: last_user_message_str, _ = last_user_message - cleaned_message_str = re.sub(r"<.*?>", "", last_user_message_str).strip() - splitted_message = cleaned_message_str.lower().split(" ") - # We expect codegate as the first word in the message - if splitted_message[0] == "codegate": - context.shortcut_response = True - args = shlex.split(cleaned_message_str) - cmd_out = await codegate_cli(args[1:]) - - if cleaned_message_str != last_user_message_str: - # it came from Cline, need to wrap into tags - cmd_out = ( - f"{cmd_out}\n" - ) - return PipelineResult( - response=PipelineResponse( - step_name=self.name, - content=cmd_out, - model=request["model"], - ), - context=context, + last_user_message_str = last_user_message_str.strip() + base_tool = get_tool_name_from_messages(request) + codegate_regex = re.compile(r"^codegate(?:\s+(.*))?", re.IGNORECASE) + + if base_tool and base_tool == "cline": + # Check if there are or tags + tag_match = re.search( + r"<(task|feedback)>(.*?)", last_user_message_str, re.DOTALL ) + if tag_match: + # Extract the content between the tags + stripped_message = tag_match.group(2).strip() + else: + # If no or tags, use the entire message + stripped_message = last_user_message_str.strip() + + # Remove all other XML tags and trim whitespace + stripped_message = re.sub(r"<[^>]+>", "", stripped_message).strip() + + # Check if "codegate" is the first word + match = codegate_regex.match(stripped_message) + else: + # Check if "codegate" is the first word in the message + match = codegate_regex.match(last_user_message_str) + if match: + command = match.group(1) or "" + command = command.strip() + + # Process the command + args = shlex.split(f"codegate {command}") + if args: + context.shortcut_response = True + cmd_out = await codegate_cli(args[1:]) + if base_tool and base_tool == "cline": + cmd_out = ( + f"{cmd_out}\n" + ) + + return PipelineResult( + response=PipelineResponse( + step_name=self.name, + content=cmd_out, + model=request["model"], + ), + context=context, + ) # Fall through return PipelineResult(request=request, context=context) diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index d1ade810..6937ba50 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -291,12 +291,15 @@ async def _forward_data_through_pipeline(self, data: bytes) -> Union[HttpRequest # we couldn't parse this into an HTTP request, so we just pass through return data - body, context = await self._body_through_pipeline( + result = await self._body_through_pipeline( http_request.method, http_request.path, http_request.headers, http_request.body, ) + if not result: + return data + body, context = result # TODO: it's weird that we're overwriting the context. # Should we set the context once? Maybe when # creating the pipeline instance? diff --git a/src/codegate/utils/utils.py b/src/codegate/utils/utils.py index e04ad714..6096e051 100644 --- a/src/codegate/utils/utils.py +++ b/src/codegate/utils/utils.py @@ -29,3 +29,25 @@ def generate_vector_string(package) -> str: # add description vector_str += f" - Package offers this functionality: {package['description']}" return vector_str + + +def get_tool_name_from_messages(data): + """ + Identifies the tool name based on the content of the messages. + + Args: + request (dict): The request object containing messages. + tools (list): A list of tool names to search for. + + Returns: + str: The name of the tool found in the messages, or None if no match is found. + """ + tools = [ + "Cline", + ] + for message in data.get("messages", []): + message_content = str(message.get("content", "")) + for tool in tools: + if tool in message_content: + return tool.lower() + return None