diff --git a/src/codegate/pipeline/cli/cli.py b/src/codegate/pipeline/cli/cli.py
index d94af629..a18d32ad 100644
--- a/src/codegate/pipeline/cli/cli.py
+++ b/src/codegate/pipeline/cli/cli.py
@@ -10,6 +10,7 @@
PipelineStep,
)
from codegate.pipeline.cli.commands import CustomInstructions, Version, Workspace
+from codegate.utils.utils import get_tool_name_from_messages
HELP_TEXT = """
## CodeGate CLI\n
@@ -77,27 +78,52 @@ async def process(
if last_user_message is not None:
last_user_message_str, _ = last_user_message
- cleaned_message_str = re.sub(r"<.*?>", "", last_user_message_str).strip()
- splitted_message = cleaned_message_str.lower().split(" ")
- # We expect codegate as the first word in the message
- if splitted_message[0] == "codegate":
- context.shortcut_response = True
- args = shlex.split(cleaned_message_str)
- cmd_out = await codegate_cli(args[1:])
-
- if cleaned_message_str != last_user_message_str:
- # it came from Cline, need to wrap into tags
- cmd_out = (
- f"{cmd_out}\n"
- )
- return PipelineResult(
- response=PipelineResponse(
- step_name=self.name,
- content=cmd_out,
- model=request["model"],
- ),
- context=context,
+ last_user_message_str = last_user_message_str.strip()
+ base_tool = get_tool_name_from_messages(request)
+ codegate_regex = re.compile(r"^codegate(?:\s+(.*))?", re.IGNORECASE)
+
+ if base_tool and base_tool == "cline":
+ # Check if there are or tags
+ tag_match = re.search(
+ r"<(task|feedback)>(.*?)\1>", last_user_message_str, re.DOTALL
)
+ if tag_match:
+ # Extract the content between the tags
+ stripped_message = tag_match.group(2).strip()
+ else:
+ # If no or tags, use the entire message
+ stripped_message = last_user_message_str.strip()
+
+ # Remove all other XML tags and trim whitespace
+ stripped_message = re.sub(r"<[^>]+>", "", stripped_message).strip()
+
+ # Check if "codegate" is the first word
+ match = codegate_regex.match(stripped_message)
+ else:
+ # Check if "codegate" is the first word in the message
+ match = codegate_regex.match(last_user_message_str)
+ if match:
+ command = match.group(1) or ""
+ command = command.strip()
+
+ # Process the command
+ args = shlex.split(f"codegate {command}")
+ if args:
+ context.shortcut_response = True
+ cmd_out = await codegate_cli(args[1:])
+ if base_tool and base_tool == "cline":
+ cmd_out = (
+ f"{cmd_out}\n"
+ )
+
+ return PipelineResult(
+ response=PipelineResponse(
+ step_name=self.name,
+ content=cmd_out,
+ model=request["model"],
+ ),
+ context=context,
+ )
# Fall through
return PipelineResult(request=request, context=context)
diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py
index d1ade810..6937ba50 100644
--- a/src/codegate/providers/copilot/provider.py
+++ b/src/codegate/providers/copilot/provider.py
@@ -291,12 +291,15 @@ async def _forward_data_through_pipeline(self, data: bytes) -> Union[HttpRequest
# we couldn't parse this into an HTTP request, so we just pass through
return data
- body, context = await self._body_through_pipeline(
+ result = await self._body_through_pipeline(
http_request.method,
http_request.path,
http_request.headers,
http_request.body,
)
+ if not result:
+ return data
+ body, context = result
# TODO: it's weird that we're overwriting the context.
# Should we set the context once? Maybe when
# creating the pipeline instance?
diff --git a/src/codegate/utils/utils.py b/src/codegate/utils/utils.py
index e04ad714..6096e051 100644
--- a/src/codegate/utils/utils.py
+++ b/src/codegate/utils/utils.py
@@ -29,3 +29,25 @@ def generate_vector_string(package) -> str:
# add description
vector_str += f" - Package offers this functionality: {package['description']}"
return vector_str
+
+
+def get_tool_name_from_messages(data):
+ """
+ Identifies the tool name based on the content of the messages.
+
+ Args:
+ request (dict): The request object containing messages.
+ tools (list): A list of tool names to search for.
+
+ Returns:
+ str: The name of the tool found in the messages, or None if no match is found.
+ """
+ tools = [
+ "Cline",
+ ]
+ for message in data.get("messages", []):
+ message_content = str(message.get("content", ""))
+ for tool in tools:
+ if tool in message_content:
+ return tool.lower()
+ return None