Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 29f143b

Browse files
committed
Use the client type when streaming the data to the client, not when executing completion
We used to special-case ollama stream generation by passing the client type to the execute_completion. Instead, let's pass the client type to the place that needs special casing using the recently introduce client type enum. Related: #830
1 parent 963348e commit 29f143b

File tree

11 files changed

+128
-37
lines changed

11 files changed

+128
-37
lines changed

src/codegate/providers/anthropic/completion_handler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ async def execute_completion(
1616
api_key: Optional[str],
1717
stream: bool = False,
1818
is_fim_request: bool = False,
19-
base_tool: Optional[str] = "",
2019
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
2120
"""
2221
Ensures the model name is prefixed with 'anthropic/' to explicitly route to Anthropic's API.

src/codegate/providers/anthropic/provider.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import structlog
66
from fastapi import Header, HTTPException, Request
77

8+
from codegate.clients.clients import ClientType
9+
from codegate.clients.detector import DetectClient
810
from codegate.pipeline.factory import PipelineFactory
911
from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
1012
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
@@ -51,7 +53,13 @@ def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
5153

5254
return [model["id"] for model in respjson.get("data", [])]
5355

54-
async def process_request(self, data: dict, api_key: str, request_url_path: str):
56+
async def process_request(
57+
self,
58+
data: dict,
59+
api_key: str,
60+
request_url_path: str,
61+
client_type: ClientType,
62+
):
5563
is_fim_request = self._is_fim_request(request_url_path, data)
5664
try:
5765
stream = await self.complete(data, api_key, is_fim_request)
@@ -65,7 +73,7 @@ async def process_request(self, data: dict, api_key: str, request_url_path: str)
6573
else:
6674
# just continue raising the exception
6775
raise e
68-
return self._completion_handler.create_response(stream)
76+
return self._completion_handler.create_response(stream, client_type)
6977

7078
def _setup_routes(self):
7179
"""
@@ -80,6 +88,7 @@ def _setup_routes(self):
8088

8189
@self.router.post(f"/{self.provider_route_name}/messages")
8290
@self.router.post(f"/{self.provider_route_name}/v1/messages")
91+
@DetectClient()
8392
async def create_message(
8493
request: Request,
8594
x_api_key: str = Header(None),
@@ -90,4 +99,9 @@ async def create_message(
9099
body = await request.body()
91100
data = json.loads(body)
92101

93-
return await self.process_request(data, x_api_key, request.url.path)
102+
return await self.process_request(
103+
data,
104+
x_api_key,
105+
request.url.path,
106+
request.state.detected_client,
107+
)

src/codegate/providers/base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from litellm import ModelResponse
1111
from litellm.types.llms.openai import ChatCompletionRequest
1212

13+
from codegate.clients.clients import ClientType
1314
from codegate.codegate_logging import setup_logging
1415
from codegate.db.connection import DbRecorder
1516
from codegate.pipeline.base import (
@@ -22,7 +23,6 @@
2223
from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter
2324
from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer
2425
from codegate.providers.normalizer.completion import CompletionNormalizer
25-
from codegate.utils.utils import get_tool_name_from_messages
2626

2727
setup_logging()
2828
logger = structlog.get_logger("codegate")
@@ -74,7 +74,13 @@ def models(self, endpoint, str=None, api_key: str = None) -> List[str]:
7474
pass
7575

7676
@abstractmethod
77-
async def process_request(self, data: dict, api_key: str, request_url_path: str):
77+
async def process_request(
78+
self,
79+
data: dict,
80+
api_key: str,
81+
request_url_path: str,
82+
client_type: ClientType,
83+
):
7884
pass
7985

8086
@property
@@ -287,14 +293,11 @@ async def complete(
287293
# Execute the completion and translate the response
288294
# This gives us either a single response or a stream of responses
289295
# based on the streaming flag
290-
base_tool = get_tool_name_from_messages(data)
291-
292296
model_response = await self._completion_handler.execute_completion(
293297
provider_request,
294298
api_key=api_key,
295299
stream=streaming,
296300
is_fim_request=is_fim_request,
297-
base_tool=base_tool,
298301
)
299302
if not streaming:
300303
normalized_response = self._output_normalizer.normalize(model_response)

src/codegate/providers/completion/base.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from fastapi.responses import JSONResponse, StreamingResponse
77
from litellm import ChatCompletionRequest, ModelResponse
88

9+
from codegate.clients.clients import ClientType
10+
911

1012
class BaseCompletionHandler(ABC):
1113
"""
@@ -20,20 +22,27 @@ async def execute_completion(
2022
api_key: Optional[str],
2123
stream: bool = False, # TODO: remove this param?
2224
is_fim_request: bool = False,
23-
base_tool: Optional[str] = "",
2425
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
2526
"""Execute the completion request"""
2627
pass
2728

2829
@abstractmethod
29-
def _create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
30+
def _create_streaming_response(
31+
self,
32+
stream: AsyncIterator[Any],
33+
client_type: ClientType = ClientType.GENERIC,
34+
) -> StreamingResponse:
3035
pass
3136

3237
@abstractmethod
3338
def _create_json_response(self, response: Any) -> JSONResponse:
3439
pass
3540

36-
def create_response(self, response: Any) -> Union[JSONResponse, StreamingResponse]:
41+
def create_response(
42+
self,
43+
response: Any,
44+
client_type: ClientType,
45+
) -> Union[JSONResponse, StreamingResponse]:
3746
"""
3847
Create a FastAPI response from the completion response.
3948
"""
@@ -42,5 +51,5 @@ def create_response(self, response: Any) -> Union[JSONResponse, StreamingRespons
4251
or isinstance(response, AsyncIterator)
4352
or inspect.isasyncgen(response)
4453
):
45-
return self._create_streaming_response(response)
54+
return self._create_streaming_response(response, client_type)
4655
return self._create_json_response(response)

src/codegate/providers/litellmshim/litellmshim.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
acompletion,
1010
)
1111

12+
from codegate.clients.clients import ClientType
1213
from codegate.providers.base import BaseCompletionHandler, StreamGenerator
1314

1415
logger = structlog.get_logger("codegate")
@@ -43,7 +44,6 @@ async def execute_completion(
4344
api_key: Optional[str],
4445
stream: bool = False,
4546
is_fim_request: bool = False,
46-
base_tool: Optional[str] = "",
4747
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
4848
"""
4949
Execute the completion request with LiteLLM's API
@@ -53,7 +53,11 @@ async def execute_completion(
5353
return await self._fim_completion_func(**request)
5454
return await self._completion_func(**request)
5555

56-
def _create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
56+
def _create_streaming_response(
57+
self,
58+
stream: AsyncIterator[Any],
59+
_: ClientType = ClientType.GENERIC,
60+
) -> StreamingResponse:
5761
"""
5862
Create a streaming response from a stream generator. The StreamingResponse
5963
is the format that FastAPI expects for streaming responses.

src/codegate/providers/llamacpp/completion_handler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
CreateChatCompletionStreamResponse,
99
)
1010

11+
from codegate.clients.clients import ClientType
1112
from codegate.config import Config
1213
from codegate.inference.inference_engine import LlamaCppInferenceEngine
1314
from codegate.providers.base import BaseCompletionHandler
@@ -52,7 +53,6 @@ async def execute_completion(
5253
api_key: Optional[str],
5354
stream: bool = False,
5455
is_fim_request: bool = False,
55-
base_tool: Optional[str] = "",
5656
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
5757
"""
5858
Execute the completion request with inference engine API
@@ -82,7 +82,11 @@ async def execute_completion(
8282

8383
return convert_to_async_iterator(response) if stream else response
8484

85-
def _create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
85+
def _create_streaming_response(
86+
self,
87+
stream: AsyncIterator[Any],
88+
client_type: ClientType = ClientType.GENERIC,
89+
) -> StreamingResponse:
8690
"""
8791
Create a streaming response from a stream generator. The StreamingResponse
8892
is the format that FastAPI expects for streaming responses.

src/codegate/providers/llamacpp/provider.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import structlog
55
from fastapi import HTTPException, Request
66

7+
from codegate.clients.clients import ClientType
8+
from codegate.clients.detector import DetectClient
79
from codegate.pipeline.factory import PipelineFactory
810
from codegate.providers.base import BaseProvider
911
from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler
@@ -33,7 +35,13 @@ def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
3335
# TODO: Implement file fetching
3436
return []
3537

36-
async def process_request(self, data: dict, api_key: str, request_url_path: str):
38+
async def process_request(
39+
self,
40+
data: dict,
41+
api_key: str,
42+
request_url_path: str,
43+
client_type: ClientType,
44+
):
3745
is_fim_request = self._is_fim_request(request_url_path, data)
3846
try:
3947
stream = await self.complete(data, None, is_fim_request=is_fim_request)
@@ -51,7 +59,7 @@ async def process_request(self, data: dict, api_key: str, request_url_path: str)
5159
else:
5260
# just continue raising the exception
5361
raise e
54-
return self._completion_handler.create_response(stream)
62+
return self._completion_handler.create_response(stream, client_type)
5563

5664
def _setup_routes(self):
5765
"""
@@ -61,10 +69,15 @@ def _setup_routes(self):
6169

6270
@self.router.post(f"/{self.provider_route_name}/completions")
6371
@self.router.post(f"/{self.provider_route_name}/chat/completions")
72+
@DetectClient()
6473
async def create_completion(
6574
request: Request,
6675
):
6776
body = await request.body()
6877
data = json.loads(body)
69-
70-
return await self.process_request(data, None, request.url.path)
78+
return await self.process_request(
79+
data,
80+
None,
81+
request.url.path,
82+
request.state.detected_client,
83+
)

src/codegate/providers/ollama/completion_handler.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
from litellm import ChatCompletionRequest
77
from ollama import AsyncClient, ChatResponse, GenerateResponse
88

9+
from codegate.clients.clients import ClientType
910
from codegate.providers.base import BaseCompletionHandler
1011

1112
logger = structlog.get_logger("codegate")
1213

1314

1415
async def ollama_stream_generator( # noqa: C901
15-
stream: AsyncIterator[ChatResponse], base_tool: str
16+
stream: AsyncIterator[ChatResponse],
17+
client_type: ClientType,
1618
) -> AsyncIterator[str]:
1719
"""OpenAI-style SSE format"""
1820
try:
@@ -21,7 +23,7 @@ async def ollama_stream_generator( # noqa: C901
2123
# TODO We should wire in the client info so we can respond with
2224
# the correct format and start to handle multiple clients
2325
# in a more robust way.
24-
if base_tool in ["cline", "kodu"]:
26+
if client_type in [ClientType.CLINE, ClientType.KODU]:
2527
# First get the raw dict from the chunk
2628
chunk_dict = chunk.model_dump()
2729
# Create response dictionary in OpenAI-like format
@@ -82,18 +84,15 @@ class OllamaShim(BaseCompletionHandler):
8284

8385
def __init__(self, base_url):
8486
self.client = AsyncClient(host=base_url, timeout=300)
85-
self.base_tool = ""
8687

8788
async def execute_completion(
8889
self,
8990
request: ChatCompletionRequest,
9091
api_key: Optional[str],
9192
stream: bool = False,
9293
is_fim_request: bool = False,
93-
base_tool: Optional[str] = "",
9494
) -> Union[ChatResponse, GenerateResponse]:
9595
"""Stream response directly from Ollama API."""
96-
self.base_tool = base_tool
9796
if is_fim_request:
9897
prompt = ""
9998
for i in reversed(range(len(request["messages"]))):
@@ -120,13 +119,17 @@ async def execute_completion(
120119
) # type: ignore
121120
return response
122121

123-
def _create_streaming_response(self, stream: AsyncIterator[ChatResponse]) -> StreamingResponse:
122+
def _create_streaming_response(
123+
self,
124+
stream: AsyncIterator[ChatResponse],
125+
client_type: ClientType,
126+
) -> StreamingResponse:
124127
"""
125128
Create a streaming response from a stream generator. The StreamingResponse
126129
is the format that FastAPI expects for streaming responses.
127130
"""
128131
return StreamingResponse(
129-
ollama_stream_generator(stream, self.base_tool or ""),
132+
ollama_stream_generator(stream, client_type),
130133
media_type="application/x-ndjson; charset=utf-8",
131134
headers={
132135
"Cache-Control": "no-cache",

src/codegate/providers/ollama/provider.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import structlog
66
from fastapi import HTTPException, Request
77

8+
from codegate.clients.clients import ClientType
9+
from codegate.clients.detector import DetectClient
810
from codegate.config import Config
911
from codegate.pipeline.factory import PipelineFactory
1012
from codegate.providers.base import BaseProvider, ModelFetchError
@@ -55,7 +57,13 @@ def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
5557

5658
return [model["name"] for model in jsonresp.get("models", [])]
5759

58-
async def process_request(self, data: dict, api_key: str, request_url_path: str):
60+
async def process_request(
61+
self,
62+
data: dict,
63+
api_key: str,
64+
request_url_path: str,
65+
client_type: ClientType,
66+
):
5967
is_fim_request = self._is_fim_request(request_url_path, data)
6068
try:
6169
stream = await self.complete(data, api_key=None, is_fim_request=is_fim_request)
@@ -71,7 +79,7 @@ async def process_request(self, data: dict, api_key: str, request_url_path: str)
7179
else:
7280
# just continue raising the exception
7381
raise e
74-
return self._completion_handler.create_response(stream)
82+
return self._completion_handler.create_response(stream, client_type)
7583

7684
def _setup_routes(self):
7785
"""
@@ -117,6 +125,7 @@ async def show_model(request: Request):
117125
# Cline API routes
118126
@self.router.post(f"/{self.provider_route_name}/v1/chat/completions")
119127
@self.router.post(f"/{self.provider_route_name}/v1/generate")
128+
@DetectClient()
120129
async def create_completion(request: Request):
121130
body = await request.body()
122131
data = json.loads(body)
@@ -125,4 +134,9 @@ async def create_completion(request: Request):
125134
# Force it to be the one that comes in the configuration.
126135
data["base_url"] = self.base_url
127136

128-
return await self.process_request(data, None, request.url.path)
137+
return await self.process_request(
138+
data,
139+
None,
140+
request.url.path,
141+
request.state.detected_client,
142+
)

0 commit comments

Comments
 (0)