Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit ec96a31

Browse files
committed
Use the client type when instantiating and running provider pipelines
Instead of detecting the client type again when the pipeline is being processed, let's pass the client type on instantiating the pipeline instance as a constant and replace the hardcoded client strings by just using the constants. Related: #830
1 parent 29f143b commit ec96a31

File tree

14 files changed

+94
-35
lines changed

14 files changed

+94
-35
lines changed

src/codegate/pipeline/base.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from litellm import ChatCompletionRequest, ModelResponse
1212
from pydantic import BaseModel
1313

14+
from codegate.clients.clients import ClientType
1415
from codegate.db.models import Alert, Output, Prompt
1516
from codegate.pipeline.secrets.manager import SecretsManager
16-
from codegate.utils.utils import get_tool_name_from_messages
1717

1818
logger = structlog.get_logger("codegate")
1919

@@ -81,6 +81,7 @@ class PipelineContext:
8181
shortcut_response: bool = False
8282
bad_packages_found: bool = False
8383
secrets_found: bool = False
84+
client: ClientType = ClientType.GENERIC
8485

8586
def add_code_snippet(self, snippet: CodeSnippet):
8687
self.code_snippets.append(snippet)
@@ -241,12 +242,14 @@ def get_last_user_message(
241242
@staticmethod
242243
def get_last_user_message_block(
243244
request: ChatCompletionRequest,
245+
client: ClientType = ClientType.GENERIC,
244246
) -> Optional[tuple[str, int]]:
245247
"""
246248
Get the last block of consecutive 'user' messages from the request.
247249
248250
Args:
249251
request (ChatCompletionRequest): The chat completion request to process
252+
client (ClientType): The client type to consider when processing the request
250253
251254
Returns:
252255
Optional[str, int]: A string containing all consecutive user messages in the
@@ -261,9 +264,8 @@ def get_last_user_message_block(
261264
messages = request["messages"]
262265
block_start_index = None
263266

264-
base_tool = get_tool_name_from_messages(request)
265267
accepted_roles = ["user", "assistant"]
266-
if base_tool == "open interpreter":
268+
if client == ClientType.OPEN_INTERPRETER:
267269
# open interpreter also uses the role "tool"
268270
accepted_roles.append("tool")
269271

@@ -328,12 +330,16 @@ async def process(
328330

329331
class InputPipelineInstance:
330332
def __init__(
331-
self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager, is_fim: bool
333+
self,
334+
pipeline_steps: List[PipelineStep],
335+
secret_manager: SecretsManager,
336+
is_fim: bool,
337+
client: ClientType = ClientType.GENERIC,
332338
):
333339
self.pipeline_steps = pipeline_steps
334340
self.secret_manager = secret_manager
335341
self.is_fim = is_fim
336-
self.context = PipelineContext()
342+
self.context = PipelineContext(client=client)
337343

338344
# we create the sesitive context here so that it is not shared between individual requests
339345
# TODO: could we get away with just generating the session ID for an instance?
@@ -392,16 +398,25 @@ async def process_request(
392398

393399
class SequentialPipelineProcessor:
394400
def __init__(
395-
self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager, is_fim: bool
401+
self,
402+
pipeline_steps: List[PipelineStep],
403+
secret_manager: SecretsManager,
404+
client_type: ClientType,
405+
is_fim: bool,
396406
):
397407
self.pipeline_steps = pipeline_steps
398408
self.secret_manager = secret_manager
399409
self.is_fim = is_fim
400-
self.instance = self._create_instance()
410+
self.instance = self._create_instance(client_type)
401411

402-
def _create_instance(self) -> InputPipelineInstance:
412+
def _create_instance(self, client_type: ClientType) -> InputPipelineInstance:
403413
"""Create a new pipeline instance for processing a request"""
404-
return InputPipelineInstance(self.pipeline_steps, self.secret_manager, self.is_fim)
414+
return InputPipelineInstance(
415+
self.pipeline_steps,
416+
self.secret_manager,
417+
self.is_fim,
418+
client_type,
419+
)
405420

406421
async def process_request(
407422
self,

src/codegate/pipeline/cli/cli.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
from litellm import ChatCompletionRequest
66

7+
from codegate.clients.clients import ClientType
78
from codegate.pipeline.base import (
89
PipelineContext,
910
PipelineResponse,
1011
PipelineResult,
1112
PipelineStep,
1213
)
1314
from codegate.pipeline.cli.commands import CustomInstructions, Version, Workspace
14-
from codegate.utils.utils import get_tool_name_from_messages
1515

1616
HELP_TEXT = """
1717
## CodeGate CLI\n
@@ -110,12 +110,11 @@ async def process(
110110
if last_user_message is not None:
111111
last_user_message_str, _ = last_user_message
112112
last_user_message_str = last_user_message_str.strip()
113-
base_tool = get_tool_name_from_messages(request)
114113
codegate_regex = re.compile(r"^codegate(?:\s+(.*))?", re.IGNORECASE)
115114

116-
if base_tool and base_tool in ["cline", "kodu"]:
115+
if context.client in [ClientType.CLINE, ClientType.KODU]:
117116
match = _get_cli_from_cline(codegate_regex, last_user_message_str)
118-
elif base_tool == "open interpreter":
117+
elif context.client in [ClientType.OPEN_INTERPRETER]:
119118
match = _get_cli_from_open_interpreter(last_user_message_str)
120119
else:
121120
# Check if "codegate" is the first word in the message
@@ -130,7 +129,7 @@ async def process(
130129
if args:
131130
context.shortcut_response = True
132131
cmd_out = await codegate_cli(args[1:])
133-
if base_tool in ["cline", "kodu"]:
132+
if context.client in [ClientType.CLINE, ClientType.KODU]:
134133
cmd_out = (
135134
f"<attempt_completion><result>{cmd_out}</result></attempt_completion>\n"
136135
)

src/codegate/pipeline/codegate_context_retriever/codegate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async def process( # noqa: C901
6060
Use RAG DB to add context to the user request
6161
"""
6262
# Get the latest user message
63-
last_message = self.get_last_user_message_block(request)
63+
last_message = self.get_last_user_message_block(request, context.client)
6464
if not last_message:
6565
return PipelineResult(request=request)
6666
user_message, last_user_idx = last_message

src/codegate/pipeline/extract_snippets/extract_snippets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ async def process(
150150
request: ChatCompletionRequest,
151151
context: PipelineContext,
152152
) -> PipelineResult:
153-
last_message = self.get_last_user_message_block(request)
153+
last_message = self.get_last_user_message_block(request, context.client)
154154
if not last_message:
155155
return PipelineResult(request=request, context=context)
156156
msg_content, _ = last_message

src/codegate/pipeline/factory.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List
22

3+
from codegate.clients.clients import ClientType
34
from codegate.config import Config
45
from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor
56
from codegate.pipeline.cli.cli import CodegateCli
@@ -20,7 +21,7 @@ class PipelineFactory:
2021
def __init__(self, secrets_manager: SecretsManager):
2122
self.secrets_manager = secrets_manager
2223

23-
def create_input_pipeline(self) -> SequentialPipelineProcessor:
24+
def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelineProcessor:
2425
input_steps: List[PipelineStep] = [
2526
# make sure that this step is always first in the pipeline
2627
# the other steps might send the request to a LLM for it to be analyzed
@@ -32,13 +33,23 @@ def create_input_pipeline(self) -> SequentialPipelineProcessor:
3233
CodegateContextRetriever(),
3334
SystemPrompt(Config.get_config().prompts.default_chat),
3435
]
35-
return SequentialPipelineProcessor(input_steps, self.secrets_manager, is_fim=False)
36+
return SequentialPipelineProcessor(
37+
input_steps,
38+
self.secrets_manager,
39+
client_type,
40+
is_fim=False,
41+
)
3642

37-
def create_fim_pipeline(self) -> SequentialPipelineProcessor:
43+
def create_fim_pipeline(self, client_type: ClientType) -> SequentialPipelineProcessor:
3844
fim_steps: List[PipelineStep] = [
3945
CodegateSecrets(),
4046
]
41-
return SequentialPipelineProcessor(fim_steps, self.secrets_manager, is_fim=True)
47+
return SequentialPipelineProcessor(
48+
fim_steps,
49+
self.secrets_manager,
50+
client_type,
51+
is_fim=True,
52+
)
4253

4354
def create_output_pipeline(self) -> OutputPipelineProcessor:
4455
output_steps: List[OutputPipelineStep] = [

src/codegate/pipeline/secrets/secrets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ async def process(
272272
total_matches = []
273273

274274
# get last user message block to get index for the first relevant user message
275-
last_user_message = self.get_last_user_message_block(new_request)
275+
last_user_message = self.get_last_user_message_block(new_request, context.client)
276276
last_assistant_idx = -1
277277
if last_user_message:
278278
_, user_idx = last_user_message

src/codegate/providers/anthropic/provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def process_request(
6262
):
6363
is_fim_request = self._is_fim_request(request_url_path, data)
6464
try:
65-
stream = await self.complete(data, api_key, is_fim_request)
65+
stream = await self.complete(data, api_key, is_fim_request, client_type)
6666
except Exception as e:
6767
#  check if we have an status code there
6868
if hasattr(e, "status_code"):

src/codegate/providers/base.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,16 @@ async def _run_input_pipeline(
137137
normalized_request: ChatCompletionRequest,
138138
api_key: Optional[str],
139139
api_base: Optional[str],
140+
client_type: ClientType,
140141
is_fim_request: bool,
141142
) -> PipelineResult:
142143
# Decide which pipeline processor to use
143144
if is_fim_request:
144-
pipeline_processor = self._pipeline_factory.create_fim_pipeline()
145+
pipeline_processor = self._pipeline_factory.create_fim_pipeline(client_type)
145146
logger.info("FIM pipeline selected for execution.")
146147
normalized_request = self._fim_normalizer.normalize(normalized_request)
147148
else:
148-
pipeline_processor = self._pipeline_factory.create_input_pipeline()
149+
pipeline_processor = self._pipeline_factory.create_input_pipeline(client_type)
149150
logger.info("Chat completion pipeline selected for execution.")
150151
if pipeline_processor is None:
151152
return PipelineResult(request=normalized_request)
@@ -253,7 +254,11 @@ def _dump_request_response(self, prefix: str, data: Any) -> None:
253254
f.write(str(data))
254255

255256
async def complete(
256-
self, data: Dict, api_key: Optional[str], is_fim_request: bool
257+
self,
258+
data: Dict,
259+
api_key: Optional[str],
260+
is_fim_request: bool,
261+
client_type: ClientType,
257262
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
258263
"""
259264
Main completion flow with pipeline integration
@@ -272,12 +277,16 @@ async def complete(
272277
# Dump the normalized request
273278
self._dump_request_response("normalized-request", normalized_request)
274279
streaming = normalized_request.get("stream", False)
280+
281+
# Get detected client if available
275282
input_pipeline_result = await self._run_input_pipeline(
276283
normalized_request,
277284
api_key,
278285
data.get("base_url"),
286+
client_type,
279287
is_fim_request,
280288
)
289+
281290
if input_pipeline_result.response and input_pipeline_result.context:
282291
return await self._pipeline_response_formatter.handle_pipeline_response(
283292
input_pipeline_result.response, streaming, context=input_pipeline_result.context

src/codegate/providers/copilot/pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from litellm.types.llms.openai import ChatCompletionRequest
99
from litellm.types.utils import Delta, StreamingChoices
1010

11+
from codegate.clients.clients import ClientType
1112
from codegate.pipeline.base import PipelineContext, PipelineResult, SequentialPipelineProcessor
1213
from codegate.pipeline.factory import PipelineFactory
1314
from codegate.providers.normalizer.completion import CompletionNormalizer
@@ -200,7 +201,7 @@ def _create_normalizer(self):
200201
return CopilotFimNormalizer()
201202

202203
def _create_pipeline(self) -> SequentialPipelineProcessor:
203-
return self.pipeline_factory.create_fim_pipeline()
204+
return self.pipeline_factory.create_fim_pipeline(ClientType.COPILOT)
204205

205206

206207
class CopilotChatPipeline(CopilotPipeline):
@@ -216,4 +217,4 @@ def _create_normalizer(self):
216217
return CopilotChatNormalizer()
217218

218219
def _create_pipeline(self) -> SequentialPipelineProcessor:
219-
return self.pipeline_factory.create_input_pipeline()
220+
return self.pipeline_factory.create_input_pipeline(ClientType.COPILOT)

src/codegate/providers/llamacpp/provider.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ async def process_request(
4444
):
4545
is_fim_request = self._is_fim_request(request_url_path, data)
4646
try:
47-
stream = await self.complete(data, None, is_fim_request=is_fim_request)
47+
stream = await self.complete(
48+
data, None, is_fim_request=is_fim_request, client_type=client_type
49+
)
4850
except RuntimeError as e:
4951
# propagate as error 500
5052
logger.error("Error in LlamaCppProvider completion", error=str(e))

src/codegate/providers/ollama/provider.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,12 @@ async def process_request(
6666
):
6767
is_fim_request = self._is_fim_request(request_url_path, data)
6868
try:
69-
stream = await self.complete(data, api_key=None, is_fim_request=is_fim_request)
69+
stream = await self.complete(
70+
data,
71+
api_key=None,
72+
is_fim_request=is_fim_request,
73+
client_type=client_type,
74+
)
7075
except httpx.ConnectError as e:
7176
logger.error("Error in OllamaProvider completion", error=str(e))
7277
raise HTTPException(status_code=503, detail="Ollama service is unavailable")

src/codegate/providers/openai/provider.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,12 @@ async def process_request(
5454
is_fim_request = self._is_fim_request(request_url_path, data)
5555

5656
try:
57-
stream = await self.complete(data, api_key, is_fim_request=is_fim_request)
57+
stream = await self.complete(
58+
data,
59+
api_key,
60+
is_fim_request=is_fim_request,
61+
client_type=client_type,
62+
)
5863
except Exception as e:
5964
#  check if we have an status code there
6065
if hasattr(e, "status_code"):

src/codegate/providers/vllm/provider.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,12 @@ async def process_request(
7777
is_fim_request = self._is_fim_request(request_url_path, data)
7878
try:
7979
# Pass the potentially None api_key to complete
80-
stream = await self.complete(data, api_key, is_fim_request=is_fim_request)
80+
stream = await self.complete(
81+
data,
82+
api_key,
83+
is_fim_request=is_fim_request,
84+
client_type=client_type,
85+
)
8186
except Exception as e:
8287
# Check if we have a status code there
8388
if hasattr(e, "status_code"):

0 commit comments

Comments
 (0)