Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 0599103

Browse files
authored
Add client detection (#832)
* Add client detector interface Adds a decorator that can be added to the FastAPI handlers and detect the client from a fallback mechanism, by the user-agent, by a specific header or by a matching word in the messages. At the moment, the clients are represented as a simple enum, but in follow-up patches they will be represented by classes that can perform the changes by an interface providing callbacks from the pipeline or other places that need client-specific behaviour. Related: #830 * Use the client type when streaming the data to the client, not when executing completion We used to special-case ollama stream generation by passing the client type to the execute_completion. Instead, let's pass the client type to the place that needs special casing using the recently introduce client type enum. Related: #830 * Use the client type when instantiating and running provider pipelines Instead of detecting the client type again when the pipeline is being processed, let's pass the client type on instantiating the pipeline instance as a constant and replace the hardcoded client strings by just using the constants. Related: #830 * Remove get_tool_name_from_messages This was superseded by using the client enum. Related: #830 * Remove the is_copilot flag in favor of using the autodetected client In the copilot provider, we can hardcode the client type to copilot when instantiating the pipelines. Related: #830
1 parent 4e032d9 commit 0599103

File tree

23 files changed

+827
-101
lines changed

23 files changed

+827
-101
lines changed

src/codegate/clients/clients.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from enum import Enum
2+
3+
4+
class ClientType(Enum):
5+
"""
6+
Enum of supported client types
7+
"""
8+
9+
GENERIC = "generic" # Default client type when no specific client is detected
10+
CLINE = "cline" # Cline client
11+
KODU = "kodu" # Kodu client
12+
COPILOT = "copilot" # Copilot client
13+
OPEN_INTERPRETER = "open_interpreter" # Open Interpreter client

src/codegate/clients/detector.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import re
2+
from abc import ABC, abstractmethod
3+
from functools import wraps
4+
from typing import List, Optional
5+
6+
import structlog
7+
from fastapi import Request
8+
9+
from codegate.clients.clients import ClientType
10+
11+
logger = structlog.get_logger("codegate")
12+
13+
14+
class HeaderDetector:
15+
"""
16+
Base utility class for header-based detection
17+
"""
18+
19+
def __init__(self, header_name: str, header_value: Optional[str] = None):
20+
self.header_name = header_name
21+
self.header_value = header_value
22+
23+
def detect(self, request: Request) -> bool:
24+
logger.debug(
25+
"checking header detection",
26+
header_name=self.header_name,
27+
header_value=self.header_value,
28+
request_headers=dict(request.headers),
29+
)
30+
# Check if the header is present, if not we didn't detect the client
31+
if self.header_name not in request.headers:
32+
return False
33+
# now we know that the header is present, if we don't care about the value
34+
# we detected the client
35+
if self.header_value is None:
36+
return True
37+
# finally, if we care about the value, we need to check if it matches
38+
return request.headers[self.header_name] == self.header_value
39+
40+
41+
class UserAgentDetector(HeaderDetector):
42+
"""
43+
A variant of the HeaderDetector that specifically looks for a user-agent pattern
44+
"""
45+
46+
def __init__(self, user_agent_pattern: str):
47+
super().__init__("user-agent")
48+
self.pattern = re.compile(user_agent_pattern, re.IGNORECASE)
49+
50+
def detect(self, request: Request) -> bool:
51+
user_agent = request.headers.get(self.header_name)
52+
if not user_agent:
53+
return False
54+
return bool(self.pattern.search(user_agent))
55+
56+
57+
class ContentDetector:
58+
"""
59+
Detector for message content patterns
60+
"""
61+
62+
def __init__(self, pattern: str):
63+
self.pattern = pattern
64+
65+
async def detect(self, request: Request) -> bool:
66+
try:
67+
data = await request.json()
68+
for message in data.get("messages", []):
69+
message_content = str(message.get("content", ""))
70+
if self.pattern in message_content:
71+
return True
72+
# This is clearly a hack and won't be needed when we get rid of the normalizers and will
73+
# be able to access the system message directly from the on-wire format
74+
system_content = str(data.get("system", ""))
75+
if self.pattern in system_content:
76+
return True
77+
return False
78+
except Exception as e:
79+
logger.error(f"Error in content detection: {str(e)}")
80+
return False
81+
82+
83+
class BaseClientDetector(ABC):
84+
"""
85+
Base class for all client detectors using composition of detection methods
86+
"""
87+
88+
def __init__(self):
89+
self.header_detector: Optional[HeaderDetector] = None
90+
self.user_agent_detector: Optional[UserAgentDetector] = None
91+
self.content_detector: Optional[ContentDetector] = None
92+
93+
@property
94+
@abstractmethod
95+
def client_name(self) -> ClientType:
96+
"""
97+
Returns the name of the client
98+
"""
99+
pass
100+
101+
async def detect(self, request: Request) -> bool:
102+
"""
103+
Tries each configured detection method in sequence
104+
"""
105+
# Try user agent first if configured
106+
if self.user_agent_detector and self.user_agent_detector.detect(request):
107+
return True
108+
109+
# Then try header if configured
110+
if self.header_detector and self.header_detector.detect(request):
111+
return True
112+
113+
# Finally try content if configured
114+
if self.content_detector:
115+
return await self.content_detector.detect(request)
116+
117+
return False
118+
119+
120+
class ClineDetector(BaseClientDetector):
121+
"""
122+
Detector for Cline client based on message content
123+
"""
124+
125+
def __init__(self):
126+
super().__init__()
127+
self.content_detector = ContentDetector("Cline")
128+
129+
@property
130+
def client_name(self) -> ClientType:
131+
return ClientType.CLINE
132+
133+
134+
class KoduDetector(BaseClientDetector):
135+
"""
136+
Detector for Kodu client based on message content
137+
"""
138+
139+
def __init__(self):
140+
super().__init__()
141+
self.user_agent_detector = UserAgentDetector("Kodu")
142+
self.content_detector = ContentDetector("Kodu")
143+
144+
@property
145+
def client_name(self) -> ClientType:
146+
return ClientType.KODU
147+
148+
149+
class OpenInterpreter(BaseClientDetector):
150+
"""
151+
Detector for Kodu client based on message content
152+
"""
153+
154+
def __init__(self):
155+
super().__init__()
156+
self.content_detector = ContentDetector("Open Interpreter")
157+
158+
@property
159+
def client_name(self) -> ClientType:
160+
return ClientType.OPEN_INTERPRETER
161+
162+
163+
class CopilotDetector(HeaderDetector):
164+
"""
165+
Detector for Copilot client based on user agent
166+
"""
167+
168+
def __init__(self):
169+
super().__init__("user-agent", "Copilot")
170+
171+
@property
172+
def client_name(self) -> ClientType:
173+
return ClientType.COPILOT
174+
175+
176+
class DetectClient:
177+
"""
178+
Decorator class for detecting clients from request system messages
179+
180+
Usage:
181+
@app.post("/v1/chat/completions")
182+
@DetectClient()
183+
async def chat_completions(request: Request):
184+
client = request.state.detected_client
185+
"""
186+
187+
def __init__(self):
188+
self.detectors: List[BaseClientDetector] = [
189+
ClineDetector(),
190+
KoduDetector(),
191+
OpenInterpreter(),
192+
CopilotDetector(),
193+
]
194+
195+
def __call__(self, func):
196+
@wraps(func)
197+
async def wrapper(request: Request, *args, **kwargs):
198+
try:
199+
client = await self.detect(request)
200+
request.state.detected_client = client
201+
except Exception as e:
202+
logger.error(f"Error in client detection: {str(e)}")
203+
request.state.detected_client = ClientType.GENERIC
204+
205+
return await func(request, *args, **kwargs)
206+
207+
return wrapper
208+
209+
async def detect(self, request: Request) -> ClientType:
210+
"""
211+
Detects the client from the request by trying each detector in sequence.
212+
Returns the name of the first detected client, or GENERIC if no specific client is detected.
213+
"""
214+
for detector in self.detectors:
215+
try:
216+
if await detector.detect(request):
217+
client_name = detector.client_name
218+
logger.info(f"{client_name} client detected")
219+
return client_name
220+
except Exception as e:
221+
logger.error(f"Error in {detector.client_name} detection: {str(e)}")
222+
continue
223+
logger.info("No particilar client detected, using generic client")
224+
return ClientType.GENERIC

src/codegate/pipeline/base.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from litellm import ChatCompletionRequest, ModelResponse
1212
from pydantic import BaseModel
1313

14+
from codegate.clients.clients import ClientType
1415
from codegate.db.models import Alert, Output, Prompt
1516
from codegate.pipeline.secrets.manager import SecretsManager
16-
from codegate.utils.utils import get_tool_name_from_messages
1717

1818
logger = structlog.get_logger("codegate")
1919

@@ -81,6 +81,7 @@ class PipelineContext:
8181
shortcut_response: bool = False
8282
bad_packages_found: bool = False
8383
secrets_found: bool = False
84+
client: ClientType = ClientType.GENERIC
8485

8586
def add_code_snippet(self, snippet: CodeSnippet):
8687
self.code_snippets.append(snippet)
@@ -241,12 +242,14 @@ def get_last_user_message(
241242
@staticmethod
242243
def get_last_user_message_block(
243244
request: ChatCompletionRequest,
245+
client: ClientType = ClientType.GENERIC,
244246
) -> Optional[tuple[str, int]]:
245247
"""
246248
Get the last block of consecutive 'user' messages from the request.
247249
248250
Args:
249251
request (ChatCompletionRequest): The chat completion request to process
252+
client (ClientType): The client type to consider when processing the request
250253
251254
Returns:
252255
Optional[str, int]: A string containing all consecutive user messages in the
@@ -261,9 +264,8 @@ def get_last_user_message_block(
261264
messages = request["messages"]
262265
block_start_index = None
263266

264-
base_tool = get_tool_name_from_messages(request)
265267
accepted_roles = ["user", "assistant"]
266-
if base_tool == "open interpreter":
268+
if client == ClientType.OPEN_INTERPRETER:
267269
# open interpreter also uses the role "tool"
268270
accepted_roles.append("tool")
269271

@@ -303,12 +305,16 @@ async def process(
303305

304306
class InputPipelineInstance:
305307
def __init__(
306-
self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager, is_fim: bool
308+
self,
309+
pipeline_steps: List[PipelineStep],
310+
secret_manager: SecretsManager,
311+
is_fim: bool,
312+
client: ClientType = ClientType.GENERIC,
307313
):
308314
self.pipeline_steps = pipeline_steps
309315
self.secret_manager = secret_manager
310316
self.is_fim = is_fim
311-
self.context = PipelineContext()
317+
self.context = PipelineContext(client=client)
312318

313319
# we create the sesitive context here so that it is not shared between individual requests
314320
# TODO: could we get away with just generating the session ID for an instance?
@@ -326,7 +332,6 @@ async def process_request(
326332
api_key: Optional[str] = None,
327333
api_base: Optional[str] = None,
328334
extra_headers: Optional[Dict[str, str]] = None,
329-
is_copilot: bool = False,
330335
) -> PipelineResult:
331336
"""Process a request through all pipeline steps"""
332337
self.context.metadata["extra_headers"] = extra_headers
@@ -338,7 +343,9 @@ async def process_request(
338343
self.context.sensitive.api_base = api_base
339344

340345
# For Copilot provider=openai. Use a flag to not clash with other places that may use that.
341-
provider_db = "copilot" if is_copilot else provider
346+
provider_db = provider
347+
if self.context.client == ClientType.COPILOT:
348+
provider_db = "copilot"
342349

343350
for step in self.pipeline_steps:
344351
result = await step.process(current_request, self.context)
@@ -367,16 +374,25 @@ async def process_request(
367374

368375
class SequentialPipelineProcessor:
369376
def __init__(
370-
self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager, is_fim: bool
377+
self,
378+
pipeline_steps: List[PipelineStep],
379+
secret_manager: SecretsManager,
380+
client_type: ClientType,
381+
is_fim: bool,
371382
):
372383
self.pipeline_steps = pipeline_steps
373384
self.secret_manager = secret_manager
374385
self.is_fim = is_fim
375-
self.instance = self._create_instance()
386+
self.instance = self._create_instance(client_type)
376387

377-
def _create_instance(self) -> InputPipelineInstance:
388+
def _create_instance(self, client_type: ClientType) -> InputPipelineInstance:
378389
"""Create a new pipeline instance for processing a request"""
379-
return InputPipelineInstance(self.pipeline_steps, self.secret_manager, self.is_fim)
390+
return InputPipelineInstance(
391+
self.pipeline_steps,
392+
self.secret_manager,
393+
self.is_fim,
394+
client_type,
395+
)
380396

381397
async def process_request(
382398
self,
@@ -386,9 +402,13 @@ async def process_request(
386402
api_key: Optional[str] = None,
387403
api_base: Optional[str] = None,
388404
extra_headers: Optional[Dict[str, str]] = None,
389-
is_copilot: bool = False,
390405
) -> PipelineResult:
391406
"""Create a new pipeline instance and process the request"""
392407
return await self.instance.process_request(
393-
request, provider, model, api_key, api_base, extra_headers, is_copilot
408+
request,
409+
provider,
410+
model,
411+
api_key,
412+
api_base,
413+
extra_headers,
394414
)

0 commit comments

Comments
 (0)