Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit bcd011d

Browse files
Fix FIM on Continue. Have specific formatters for chat and FIM
Until now we had a general formatter on the way out of muxing. This is wrong since sometimes the pipelines respond with different format for chat or FIM. Such is the case for Ollama. This PR separates the formatters and declares them explicitly so that they're easier to adjust in the future.
1 parent 8cb658d commit bcd011d

File tree

4 files changed

+191
-58
lines changed

4 files changed

+191
-58
lines changed

src/codegate/muxing/adapter.py

Lines changed: 128 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import copy
22
import json
33
import uuid
4-
from typing import Union
4+
from abc import ABC, abstractmethod
5+
from typing import Callable, Dict, Union
56

67
import structlog
78
from fastapi.responses import JSONResponse, StreamingResponse
89
from litellm import ModelResponse
910
from litellm.types.utils import Delta, StreamingChoices
10-
from ollama import ChatResponse
11+
from ollama import ChatResponse, GenerateResponse
1112

1213
from codegate.db import models as db_models
1314
from codegate.muxing import rulematcher
@@ -30,12 +31,13 @@ class BodyAdapter:
3031

3132
def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> str:
3233
"""Get the provider formatted URL to use in base_url. Note this value comes from DB"""
34+
base_endpoint = model_route.endpoint.endpoint.rstrip("/")
3335
if model_route.endpoint.provider_type in [
3436
db_models.ProviderType.openai,
3537
db_models.ProviderType.openrouter,
3638
]:
37-
return f"{model_route.endpoint.endpoint}/v1"
38-
return model_route.endpoint.endpoint
39+
return f"{base_endpoint}/v1"
40+
return base_endpoint
3941

4042
def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict) -> dict:
4143
"""Set the destination provider info."""
@@ -45,15 +47,91 @@ def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict)
4547
return new_data
4648

4749

48-
class StreamChunkFormatter:
50+
class OutputFormatter(ABC):
51+
52+
@property
53+
@abstractmethod
54+
def provider_format_funcs(self) -> Dict[str, Callable]:
55+
"""
56+
Return the provider specific format functions. All providers format functions should
57+
return the chunk in OpenAI format.
58+
"""
59+
pass
60+
61+
@abstractmethod
62+
def format(
63+
self, response: Union[StreamingResponse, JSONResponse], dest_prov: db_models.ProviderType
64+
) -> Union[StreamingResponse, JSONResponse]:
65+
"""Format the response to the client."""
66+
pass
67+
68+
69+
class StreamChunkFormatter(OutputFormatter):
4970
"""
5071
Format a single chunk from a stream to OpenAI format.
5172
We need to configure the client to expect the OpenAI format.
5273
In Continue this means setting "provider": "openai" in the config json file.
5374
"""
5475

55-
def __init__(self):
56-
self.provider_to_func = {
76+
@property
77+
@abstractmethod
78+
def provider_format_funcs(self) -> Dict[str, Callable]:
79+
"""
80+
Return the provider specific format functions. All providers format functions should
81+
return the chunk in OpenAI format.
82+
"""
83+
pass
84+
85+
def _format_openai(self, chunk: str) -> str:
86+
"""
87+
The chunk is already in OpenAI format. To standarize remove the "data:" prefix.
88+
89+
This function is used by both chat and FIM formatters
90+
"""
91+
cleaned_chunk = chunk.split("data:")[1].strip()
92+
return cleaned_chunk
93+
94+
def _format_as_openai_chunk(self, formatted_chunk: str) -> str:
95+
"""Format the chunk as OpenAI chunk. This is the format how the clients expect the data."""
96+
return f"data:{formatted_chunk}\n\n"
97+
98+
async def _format_streaming_response(
99+
self, response: StreamingResponse, dest_prov: db_models.ProviderType
100+
):
101+
format_func = self.provider_format_funcs.get(dest_prov)
102+
"""Format the streaming response to OpenAI format."""
103+
async for chunk in response.body_iterator:
104+
openai_chunk = format_func(chunk)
105+
# Sometimes for Anthropic we couldn't get content from the chunk. Skip it.
106+
if not openai_chunk:
107+
continue
108+
yield self._format_as_openai_chunk(openai_chunk)
109+
110+
def format(
111+
self, response: StreamingResponse, dest_prov: db_models.ProviderType
112+
) -> StreamingResponse:
113+
"""Format the response to the client."""
114+
return StreamingResponse(
115+
self._format_streaming_response(response, dest_prov),
116+
status_code=response.status_code,
117+
headers=response.headers,
118+
background=response.background,
119+
media_type=response.media_type,
120+
)
121+
122+
123+
class ChatStreamChunkFormatter(StreamChunkFormatter):
124+
"""
125+
Format a single chunk from a stream to OpenAI format given that the request was a chat.
126+
"""
127+
128+
@property
129+
def provider_format_funcs(self) -> Dict[str, Callable]:
130+
"""
131+
Return the provider specific format functions. All providers format functions should
132+
return the chunk in OpenAI format.
133+
"""
134+
return {
57135
db_models.ProviderType.ollama: self._format_ollama,
58136
db_models.ProviderType.openai: self._format_openai,
59137
db_models.ProviderType.anthropic: self._format_antropic,
@@ -68,7 +146,7 @@ def _format_ollama(self, chunk: str) -> str:
68146
try:
69147
chunk_dict = json.loads(chunk)
70148
ollama_chunk = ChatResponse(**chunk_dict)
71-
open_ai_chunk = OLlamaToModel.normalize_chunk(ollama_chunk)
149+
open_ai_chunk = OLlamaToModel.normalize_chat_chunk(ollama_chunk)
72150
return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True)
73151
except Exception:
74152
return chunk
@@ -119,46 +197,54 @@ def _format_antropic(self, chunk: str) -> str:
119197
except Exception:
120198
return cleaned_chunk.strip()
121199

122-
def format(self, chunk: str, dest_prov: db_models.ProviderType) -> ModelResponse:
123-
"""Format the chunk to OpenAI format."""
124-
# Get the format function
125-
format_func = self.provider_to_func.get(dest_prov)
126-
if format_func is None:
127-
raise MuxingAdapterError(f"Provider {dest_prov} not supported.")
128-
return format_func(chunk)
129200

201+
class FimStreamChunkFormatter(StreamChunkFormatter):
130202

131-
class ResponseAdapter:
203+
@property
204+
def provider_format_funcs(self) -> Dict[str, Callable]:
205+
"""
206+
Return the provider specific format functions. All providers format functions should
207+
return the chunk in OpenAI format.
208+
"""
209+
return {
210+
db_models.ProviderType.ollama: self._format_ollama,
211+
db_models.ProviderType.openai: self._format_openai,
212+
# Our Lllamacpp provider emits OpenAI chunks
213+
db_models.ProviderType.llamacpp: self._format_openai,
214+
# OpenRouter is a dialect of OpenAI
215+
db_models.ProviderType.openrouter: self._format_openai,
216+
}
132217

133-
def __init__(self):
134-
self.stream_formatter = StreamChunkFormatter()
218+
def _format_ollama(self, chunk: str) -> str:
219+
"""Format the Ollama chunk to OpenAI format."""
220+
try:
221+
chunk_dict = json.loads(chunk)
222+
ollama_chunk = GenerateResponse(**chunk_dict)
223+
open_ai_chunk = OLlamaToModel.normalize_fim_chunk(ollama_chunk)
224+
ai_chunk = open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True)
225+
return ai_chunk
226+
except Exception:
227+
return chunk
135228

136-
def _format_as_openai_chunk(self, formatted_chunk: str) -> str:
137-
"""Format the chunk as OpenAI chunk. This is the format how the clients expect the data."""
138-
return f"data:{formatted_chunk}\n\n"
139229

140-
async def _format_streaming_response(
141-
self, response: StreamingResponse, dest_prov: db_models.ProviderType
142-
):
143-
"""Format the streaming response to OpenAI format."""
144-
async for chunk in response.body_iterator:
145-
openai_chunk = self.stream_formatter.format(chunk, dest_prov)
146-
# Sometimes for Anthropic we couldn't get content from the chunk. Skip it.
147-
if not openai_chunk:
148-
continue
149-
yield self._format_as_openai_chunk(openai_chunk)
230+
class ResponseAdapter:
231+
232+
def _get_formatter(
233+
self, response: Union[StreamingResponse, JSONResponse], is_fim_request: bool
234+
) -> OutputFormatter:
235+
"""Get the formatter based on the request type."""
236+
if isinstance(response, StreamingResponse):
237+
if is_fim_request:
238+
return FimStreamChunkFormatter()
239+
return ChatStreamChunkFormatter()
240+
raise MuxingAdapterError("Only streaming responses are supported.")
150241

151242
def format_response_to_client(
152-
self, response: Union[StreamingResponse, JSONResponse], dest_prov: db_models.ProviderType
243+
self,
244+
response: Union[StreamingResponse, JSONResponse],
245+
dest_prov: db_models.ProviderType,
246+
is_fim_request: bool,
153247
) -> Union[StreamingResponse, JSONResponse]:
154248
"""Format the response to the client."""
155-
if isinstance(response, StreamingResponse):
156-
return StreamingResponse(
157-
self._format_streaming_response(response, dest_prov),
158-
status_code=response.status_code,
159-
headers=response.headers,
160-
background=response.background,
161-
media_type=response.media_type,
162-
)
163-
else:
164-
raise MuxingAdapterError("Only streaming responses are supported.")
249+
stream_formatter = self._get_formatter(response, is_fim_request)
250+
return stream_formatter.format(response, dest_prov)

src/codegate/muxing/router.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ async def route_to_dest_provider(
9393
model=model_route.model.name,
9494
provider_type=model_route.endpoint.provider_type,
9595
provider_name=model_route.endpoint.name,
96+
is_fim_request=is_fim_request,
9697
)
9798

9899
# 2. Map the request body to the destination provider format.
@@ -108,5 +109,5 @@ async def route_to_dest_provider(
108109

109110
# 4. Transmit the response back to the client in OpenAI format.
110111
return self._response_adapter.format_response_to_client(
111-
response, model_route.endpoint.provider_type
112+
response, model_route.endpoint.provider_type, is_fim_request=is_fim_request
112113
)

src/codegate/providers/normalizer/completion.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,17 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
1919
if "prompt" in data:
2020
data["messages"] = [{"content": data.pop("prompt"), "role": "user"}]
2121
# We can add as many parameters as we like to data. ChatCompletionRequest is not strict.
22-
data["had_prompt_before"] = True
22+
23+
# NOTE: Not adding the flag. This will also skip the denormalize step.
24+
# LiteLLM seems to not support anymore having "prompt" as key and only "message" is
25+
# supported
26+
# data["had_prompt_before"] = True
27+
28+
# Litelllm says the we need to have max a list of length 4 in stop.
29+
stop_list = data.get("stop", [])
30+
trimmed_stop_list = stop_list[:4]
31+
data["stop"] = trimmed_stop_list
32+
2333
try:
2434
normalized_data = ChatCompletionRequest(**data)
2535
if normalized_data.get("stream", False):

src/codegate/providers/ollama/adapter.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import uuid
22
from datetime import datetime, timezone
3-
from typing import Any, AsyncIterator, Dict, Union
3+
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union
44

55
from litellm import ChatCompletionRequest, ModelResponse
66
from litellm.types.utils import Delta, StreamingChoices
7-
from ollama import ChatResponse, Message
7+
from ollama import ChatResponse, GenerateResponse, Message
88

99
from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer
1010

@@ -47,18 +47,27 @@ def __init__(self, ollama_response: AsyncIterator[ChatResponse]):
4747
self.ollama_response = ollama_response
4848
self._aiter = ollama_response.__aiter__()
4949

50-
@staticmethod
51-
def normalize_chunk(chunk: ChatResponse) -> ModelResponse:
52-
finish_reason = None
53-
role = "assistant"
54-
50+
@classmethod
51+
def _transform_to_int_secs(cls, chunk_created_at) -> int:
5552
# Convert the datetime object to a timestamp in seconds
56-
datetime_obj = datetime.fromisoformat(chunk.created_at)
57-
timestamp_seconds = int(datetime_obj.timestamp())
53+
datetime_obj = datetime.fromisoformat(chunk_created_at)
54+
return int(datetime_obj.timestamp())
5855

59-
if chunk.done:
56+
@classmethod
57+
def _get_finish_reason_assistant(cls, is_chunk_done: bool) -> Tuple[str, Optional[str]]:
58+
finish_reason = None
59+
role = "assistant"
60+
if is_chunk_done:
6061
finish_reason = "stop"
6162
role = None
63+
return role, finish_reason
64+
65+
@classmethod
66+
def normalize_chat_chunk(cls, chunk: ChatResponse) -> ModelResponse:
67+
# Convert the datetime object to a timestamp in seconds
68+
timestamp_seconds = cls._transform_to_int_secs(chunk.created_at)
69+
# Get role and finish reason
70+
role, finish_reason = cls._get_finish_reason_assistant(chunk.done)
6271

6372
model_response = ModelResponse(
6473
id=f"ollama-chat-{str(uuid.uuid4())}",
@@ -76,16 +85,43 @@ def normalize_chunk(chunk: ChatResponse) -> ModelResponse:
7685
)
7786
return model_response
7887

88+
@classmethod
89+
def normalize_fim_chunk(cls, chunk: GenerateResponse) -> ModelResponse:
90+
"""
91+
Transform an ollama generation chunk to an OpenAI one
92+
"""
93+
# Convert the datetime object to a timestamp in seconds
94+
timestamp_seconds = cls._transform_to_int_secs(chunk.created_at)
95+
# Get role and finish reason
96+
_, finish_reason = cls._get_finish_reason_assistant(chunk.done)
97+
98+
model_response = ModelResponse(
99+
id=f"ollama-chat-{str(uuid.uuid4())}",
100+
created=timestamp_seconds,
101+
model=chunk.model,
102+
object="chat.completion.chunk",
103+
choices=[
104+
StreamingChoices(
105+
finish_reason=finish_reason,
106+
index=0,
107+
delta=Delta(content=chunk.response),
108+
logprobs=None,
109+
)
110+
],
111+
)
112+
return model_response
113+
79114
def __aiter__(self):
80115
return self
81116

82117
async def __anext__(self):
83118
try:
84119
chunk = await self._aiter.__anext__()
85-
if not isinstance(chunk, ChatResponse):
86-
return chunk
87-
88-
return self.normalize_chunk(chunk)
120+
if isinstance(chunk, ChatResponse):
121+
return self.normalize_chat_chunk(chunk)
122+
if isinstance(chunk, GenerateResponse):
123+
return self.normalize_fim_chunk(chunk)
124+
return chunk
89125
except StopAsyncIteration:
90126
raise StopAsyncIteration
91127

0 commit comments

Comments
 (0)