diff --git a/src/codegate/muxing/adapter.py b/src/codegate/muxing/adapter.py index f076df5e..e4ac3dc2 100644 --- a/src/codegate/muxing/adapter.py +++ b/src/codegate/muxing/adapter.py @@ -1,13 +1,15 @@ import copy import json import uuid -from typing import Union +from abc import ABC, abstractmethod +from typing import Callable, Dict, Union +from urllib.parse import urljoin import structlog from fastapi.responses import JSONResponse, StreamingResponse from litellm import ModelResponse from litellm.types.utils import Delta, StreamingChoices -from ollama import ChatResponse +from ollama import ChatResponse, GenerateResponse from codegate.db import models as db_models from codegate.muxing import rulematcher @@ -34,7 +36,7 @@ def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> st db_models.ProviderType.openai, db_models.ProviderType.openrouter, ]: - return f"{model_route.endpoint.endpoint}/v1" + return urljoin(model_route.endpoint.endpoint, "/v1") return model_route.endpoint.endpoint def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict) -> dict: @@ -45,15 +47,101 @@ def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict) return new_data -class StreamChunkFormatter: +class OutputFormatter(ABC): + + @property + @abstractmethod + def provider_format_funcs(self) -> Dict[str, Callable]: + """ + Return the provider specific format functions. All providers format functions should + return the chunk in OpenAI format. + """ + pass + + @abstractmethod + def format( + self, response: Union[StreamingResponse, JSONResponse], dest_prov: db_models.ProviderType + ) -> Union[StreamingResponse, JSONResponse]: + """Format the response to the client.""" + pass + + +class StreamChunkFormatter(OutputFormatter): """ Format a single chunk from a stream to OpenAI format. We need to configure the client to expect the OpenAI format. In Continue this means setting "provider": "openai" in the config json file. """ - def __init__(self): - self.provider_to_func = { + @property + @abstractmethod + def provider_format_funcs(self) -> Dict[str, Callable]: + """ + Return the provider specific format functions. All providers format functions should + return the chunk in OpenAI format. + """ + pass + + def _format_openai(self, chunk: str) -> str: + """ + The chunk is already in OpenAI format. To standarize remove the "data:" prefix. + + This function is used by both chat and FIM formatters + """ + cleaned_chunk = chunk.split("data:")[1].strip() + return cleaned_chunk + + def _format_as_openai_chunk(self, formatted_chunk: str) -> str: + """Format the chunk as OpenAI chunk. This is the format how the clients expect the data.""" + chunk_to_send = f"data:{formatted_chunk}\n\n" + return chunk_to_send + + async def _format_streaming_response( + self, response: StreamingResponse, dest_prov: db_models.ProviderType + ): + """Format the streaming response to OpenAI format.""" + format_func = self.provider_format_funcs.get(dest_prov) + openai_chunk = None + try: + async for chunk in response.body_iterator: + openai_chunk = format_func(chunk) + # Sometimes for Anthropic we couldn't get content from the chunk. Skip it. + if not openai_chunk: + continue + yield self._format_as_openai_chunk(openai_chunk) + except Exception as e: + logger.error(f"Error sending chunk in muxing: {e}") + yield self._format_as_openai_chunk(str(e)) + finally: + # Make sure the last chunk is always [DONE] + if openai_chunk and "[DONE]" not in openai_chunk: + yield self._format_as_openai_chunk("[DONE]") + + def format( + self, response: StreamingResponse, dest_prov: db_models.ProviderType + ) -> StreamingResponse: + """Format the response to the client.""" + return StreamingResponse( + self._format_streaming_response(response, dest_prov), + status_code=response.status_code, + headers=response.headers, + background=response.background, + media_type=response.media_type, + ) + + +class ChatStreamChunkFormatter(StreamChunkFormatter): + """ + Format a single chunk from a stream to OpenAI format given that the request was a chat. + """ + + @property + def provider_format_funcs(self) -> Dict[str, Callable]: + """ + Return the provider specific format functions. All providers format functions should + return the chunk in OpenAI format. + """ + return { db_models.ProviderType.ollama: self._format_ollama, db_models.ProviderType.openai: self._format_openai, db_models.ProviderType.anthropic: self._format_antropic, @@ -68,21 +156,11 @@ def _format_ollama(self, chunk: str) -> str: try: chunk_dict = json.loads(chunk) ollama_chunk = ChatResponse(**chunk_dict) - open_ai_chunk = OLlamaToModel.normalize_chunk(ollama_chunk) + open_ai_chunk = OLlamaToModel.normalize_chat_chunk(ollama_chunk) return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True) except Exception: return chunk - def _format_openai(self, chunk: str) -> str: - """The chunk is already in OpenAI format. To standarize remove the "data:" prefix.""" - cleaned_chunk = chunk.split("data:")[1].strip() - try: - chunk_dict = json.loads(cleaned_chunk) - open_ai_chunk = ModelResponse(**chunk_dict) - return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True) - except Exception: - return cleaned_chunk - def _format_antropic(self, chunk: str) -> str: """Format the Anthropic chunk to OpenAI format.""" cleaned_chunk = chunk.split("data:")[1].strip() @@ -119,46 +197,53 @@ def _format_antropic(self, chunk: str) -> str: except Exception: return cleaned_chunk.strip() - def format(self, chunk: str, dest_prov: db_models.ProviderType) -> ModelResponse: - """Format the chunk to OpenAI format.""" - # Get the format function - format_func = self.provider_to_func.get(dest_prov) - if format_func is None: - raise MuxingAdapterError(f"Provider {dest_prov} not supported.") - return format_func(chunk) +class FimStreamChunkFormatter(StreamChunkFormatter): -class ResponseAdapter: + @property + def provider_format_funcs(self) -> Dict[str, Callable]: + """ + Return the provider specific format functions. All providers format functions should + return the chunk in OpenAI format. + """ + return { + db_models.ProviderType.ollama: self._format_ollama, + db_models.ProviderType.openai: self._format_openai, + # Our Lllamacpp provider emits OpenAI chunks + db_models.ProviderType.llamacpp: self._format_openai, + # OpenRouter is a dialect of OpenAI + db_models.ProviderType.openrouter: self._format_openai, + } + + def _format_ollama(self, chunk: str) -> str: + """Format the Ollama chunk to OpenAI format.""" + try: + chunk_dict = json.loads(chunk) + ollama_chunk = GenerateResponse(**chunk_dict) + open_ai_chunk = OLlamaToModel.normalize_fim_chunk(ollama_chunk) + return json.dumps(open_ai_chunk, separators=(",", ":"), indent=None) + except Exception: + return chunk - def __init__(self): - self.stream_formatter = StreamChunkFormatter() - def _format_as_openai_chunk(self, formatted_chunk: str) -> str: - """Format the chunk as OpenAI chunk. This is the format how the clients expect the data.""" - return f"data:{formatted_chunk}\n\n" +class ResponseAdapter: - async def _format_streaming_response( - self, response: StreamingResponse, dest_prov: db_models.ProviderType - ): - """Format the streaming response to OpenAI format.""" - async for chunk in response.body_iterator: - openai_chunk = self.stream_formatter.format(chunk, dest_prov) - # Sometimes for Anthropic we couldn't get content from the chunk. Skip it. - if not openai_chunk: - continue - yield self._format_as_openai_chunk(openai_chunk) + def _get_formatter( + self, response: Union[StreamingResponse, JSONResponse], is_fim_request: bool + ) -> OutputFormatter: + """Get the formatter based on the request type.""" + if isinstance(response, StreamingResponse): + if is_fim_request: + return FimStreamChunkFormatter() + return ChatStreamChunkFormatter() + raise MuxingAdapterError("Only streaming responses are supported.") def format_response_to_client( - self, response: Union[StreamingResponse, JSONResponse], dest_prov: db_models.ProviderType + self, + response: Union[StreamingResponse, JSONResponse], + dest_prov: db_models.ProviderType, + is_fim_request: bool, ) -> Union[StreamingResponse, JSONResponse]: """Format the response to the client.""" - if isinstance(response, StreamingResponse): - return StreamingResponse( - self._format_streaming_response(response, dest_prov), - status_code=response.status_code, - headers=response.headers, - background=response.background, - media_type=response.media_type, - ) - else: - raise MuxingAdapterError("Only streaming responses are supported.") + stream_formatter = self._get_formatter(response, is_fim_request) + return stream_formatter.format(response, dest_prov) diff --git a/src/codegate/muxing/router.py b/src/codegate/muxing/router.py index df3a9d39..4231e8e7 100644 --- a/src/codegate/muxing/router.py +++ b/src/codegate/muxing/router.py @@ -93,6 +93,7 @@ async def route_to_dest_provider( model=model_route.model.name, provider_type=model_route.endpoint.provider_type, provider_name=model_route.endpoint.name, + is_fim_request=is_fim_request, ) # 2. Map the request body to the destination provider format. @@ -108,5 +109,5 @@ async def route_to_dest_provider( # 4. Transmit the response back to the client in OpenAI format. return self._response_adapter.format_response_to_client( - response, model_route.endpoint.provider_type + response, model_route.endpoint.provider_type, is_fim_request=is_fim_request ) diff --git a/src/codegate/providers/litellmshim/litellmshim.py b/src/codegate/providers/litellmshim/litellmshim.py index 37693f1d..eab6fc54 100644 --- a/src/codegate/providers/litellmshim/litellmshim.py +++ b/src/codegate/providers/litellmshim/litellmshim.py @@ -3,11 +3,7 @@ import litellm import structlog from fastapi.responses import JSONResponse, StreamingResponse -from litellm import ( - ChatCompletionRequest, - ModelResponse, - acompletion, -) +from litellm import ChatCompletionRequest, ModelResponse, acompletion, atext_completion from codegate.clients.clients import ClientType from codegate.providers.base import BaseCompletionHandler, StreamGenerator @@ -52,6 +48,11 @@ async def execute_completion( request["api_key"] = api_key request["base_url"] = base_url if is_fim_request: + # We need to force atext_completion if there is "prompt" in the request. + # The default function acompletion can only handle "messages" in the request. + if "prompt" in request: + logger.debug("Forcing atext_completion in FIM") + return await atext_completion(**request) return await self._fim_completion_func(**request) return await self._completion_func(**request) diff --git a/src/codegate/providers/normalizer/completion.py b/src/codegate/providers/normalizer/completion.py index c4cc6306..04227bbd 100644 --- a/src/codegate/providers/normalizer/completion.py +++ b/src/codegate/providers/normalizer/completion.py @@ -20,6 +20,12 @@ def normalize(self, data: Dict) -> ChatCompletionRequest: data["messages"] = [{"content": data.pop("prompt"), "role": "user"}] # We can add as many parameters as we like to data. ChatCompletionRequest is not strict. data["had_prompt_before"] = True + + # Litelllm says the we need to have max a list of length 4 in stop. Forcing it. + stop_list = data.get("stop", []) + trimmed_stop_list = stop_list[:4] + data["stop"] = trimmed_stop_list + try: normalized_data = ChatCompletionRequest(**data) if normalized_data.get("stream", False): diff --git a/src/codegate/providers/ollama/adapter.py b/src/codegate/providers/ollama/adapter.py index e64ec81b..46fc13d1 100644 --- a/src/codegate/providers/ollama/adapter.py +++ b/src/codegate/providers/ollama/adapter.py @@ -1,10 +1,9 @@ -import uuid from datetime import datetime, timezone -from typing import Any, AsyncIterator, Dict, Union +from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union from litellm import ChatCompletionRequest, ModelResponse from litellm.types.utils import Delta, StreamingChoices -from ollama import ChatResponse, Message +from ollama import ChatResponse, GenerateResponse, Message from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer @@ -47,21 +46,47 @@ def __init__(self, ollama_response: AsyncIterator[ChatResponse]): self.ollama_response = ollama_response self._aiter = ollama_response.__aiter__() - @staticmethod - def normalize_chunk(chunk: ChatResponse) -> ModelResponse: + @classmethod + def _transform_to_int_secs(cls, chunk_created_at: str) -> int: + """ + Convert the datetime to a timestamp in seconds. + """ + datetime_obj = datetime.fromisoformat(chunk_created_at) + return int(datetime_obj.timestamp()) + + @classmethod + def _get_finish_reason_assistant(cls, is_chunk_done: bool) -> Tuple[str, Optional[str]]: + """ + Get the role and finish reason for the assistant based on the chunk done status. + """ finish_reason = None role = "assistant" - - # Convert the datetime object to a timestamp in seconds - datetime_obj = datetime.fromisoformat(chunk.created_at) - timestamp_seconds = int(datetime_obj.timestamp()) - - if chunk.done: + if is_chunk_done: finish_reason = "stop" role = None + return role, finish_reason + + @classmethod + def _get_chat_id_from_timestamp(cls, timestamp_seconds: int) -> str: + """ + Getting a string representation of the timestamp in seconds used as the chat id. + + This needs to be done so that all chunks of a chat have the same id. + """ + timestamp_str = str(timestamp_seconds) + return timestamp_str[:9] + + @classmethod + def normalize_chat_chunk(cls, chunk: ChatResponse) -> ModelResponse: + """ + Transform an ollama chat chunk to an OpenAI one + """ + timestamp_seconds = cls._transform_to_int_secs(chunk.created_at) + role, finish_reason = cls._get_finish_reason_assistant(chunk.done) + chat_id = cls._get_chat_id_from_timestamp(timestamp_seconds) model_response = ModelResponse( - id=f"ollama-chat-{str(uuid.uuid4())}", + id=f"ollama-chat-{chat_id}", created=timestamp_seconds, model=chunk.model, object="chat.completion.chunk", @@ -76,16 +101,37 @@ def normalize_chunk(chunk: ChatResponse) -> ModelResponse: ) return model_response + @classmethod + def normalize_fim_chunk(cls, chunk: GenerateResponse) -> Dict: + """ + Transform an ollama generation chunk to an OpenAI one + """ + timestamp_seconds = cls._transform_to_int_secs(chunk.created_at) + _, finish_reason = cls._get_finish_reason_assistant(chunk.done) + chat_id = cls._get_chat_id_from_timestamp(timestamp_seconds) + + model_response = { + "id": f"chatcmpl-{chat_id}", + "object": "text_completion", + "created": timestamp_seconds, + "model": chunk.model, + "choices": [{"index": 0, "text": chunk.response}], + "usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}, + } + if finish_reason: + model_response["choices"][0]["finish_reason"] = finish_reason + del model_response["choices"][0]["text"] + return model_response + def __aiter__(self): return self async def __anext__(self): try: chunk = await self._aiter.__anext__() - if not isinstance(chunk, ChatResponse): - return chunk - - return self.normalize_chunk(chunk) + if isinstance(chunk, ChatResponse): + return self.normalize_chat_chunk(chunk) + return chunk except StopAsyncIteration: raise StopAsyncIteration