From 21e8edbdaabaf9a69042db1181a30e144c4f85a2 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Thu, 30 Jan 2025 09:58:48 +0200 Subject: [PATCH 1/2] Add LM Studio provider We were using OpenAI provider to interface with LM Studio since both of them were very similar. For muxing we need to clearly distinguish to which providers we need to route the request. Hence it will be easier to disambiguate the providers. --- src/codegate/providers/lm_studio/provider.py | 58 ++++++++++++++++++++ src/codegate/providers/openai/provider.py | 52 +++++++----------- src/codegate/server.py | 7 +++ tests/test_server.py | 4 +- 4 files changed, 87 insertions(+), 34 deletions(-) create mode 100644 src/codegate/providers/lm_studio/provider.py diff --git a/src/codegate/providers/lm_studio/provider.py b/src/codegate/providers/lm_studio/provider.py new file mode 100644 index 00000000..be083ad0 --- /dev/null +++ b/src/codegate/providers/lm_studio/provider.py @@ -0,0 +1,58 @@ +import json + +from fastapi import Header, HTTPException, Request +from fastapi.responses import JSONResponse + +from codegate.config import Config +from codegate.pipeline.factory import PipelineFactory +from codegate.providers.openai.provider import OpenAIProvider + + +class LmStudioProvider(OpenAIProvider): + def __init__( + self, + pipeline_factory: PipelineFactory, + ): + config = Config.get_config() + if config is not None: + provided_urls = config.provider_urls + self.lm_studio_url = provided_urls.get("lm_studio", "http://localhost:11434/") + + super().__init__(pipeline_factory) + + @property + def provider_route_name(self) -> str: + return "lm_studio" + + def _setup_routes(self): + """ + Sets up the /chat/completions route for the provider as expected by the + LM Studio API. Extracts the API key from the "Authorization" header and + passes it to the completion handler. + """ + + @self.router.get(f"/{self.provider_route_name}/models") + @self.router.get(f"/{self.provider_route_name}/v1/models") + async def get_models(): + # dummy method for lm studio + return JSONResponse(status_code=200, content=[]) + + @self.router.post(f"/{self.provider_route_name}/chat/completions") + @self.router.post(f"/{self.provider_route_name}/completions") + @self.router.post(f"/{self.provider_route_name}/v1/chat/completions") + async def create_completion( + request: Request, + authorization: str = Header(..., description="Bearer token"), + ): + if not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Invalid authorization header") + + api_key = authorization.split(" ")[1] + body = await request.body() + data = json.loads(body) + + # if model starts with lm_studio, propagate it + if data.get("model", "").startswith("lm_studio"): + data["base_url"] = self.lm_studio_url + "/v1/" + + return await self.process_request(data, api_key, request) diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index be9ddabe..95518cf3 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -4,9 +4,7 @@ import httpx import structlog from fastapi import Header, HTTPException, Request -from fastapi.responses import JSONResponse -from codegate.config import Config from codegate.pipeline.factory import PipelineFactory from codegate.providers.base import BaseProvider, ModelFetchError from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator @@ -19,11 +17,6 @@ def __init__( pipeline_factory: PipelineFactory, ): completion_handler = LiteLLmShim(stream_generator=sse_stream_generator) - config = Config.get_config() - if config is not None: - provided_urls = config.provider_urls - self.lm_studio_url = provided_urls.get("lm_studio", "http://localhost:11434/") - super().__init__( OpenAIInputNormalizer(), OpenAIOutputNormalizer(), @@ -39,8 +32,6 @@ def models(self, endpoint: str = None, api_key: str = None) -> List[str]: headers = {} if api_key: headers["Authorization"] = f"Bearer {api_key}" - if not endpoint: - endpoint = "https://api.openai.com" resp = httpx.get(f"{endpoint}/v1/models", headers=headers) @@ -51,6 +42,25 @@ def models(self, endpoint: str = None, api_key: str = None) -> List[str]: return [model["id"] for model in jsonresp.get("data", [])] + async def process_request(self, data: dict, api_key: str, request: Request): + """ + Process the request and return the completion stream + """ + is_fim_request = self._is_fim_request(request, data) + try: + stream = await self.complete(data, api_key, is_fim_request=is_fim_request) + except Exception as e: + #  check if we have an status code there + if hasattr(e, "status_code"): + logger = structlog.get_logger("codegate") + logger.error("Error in OpenAIProvider completion", error=str(e)) + + raise HTTPException(status_code=e.status_code, detail=str(e)) # type: ignore + else: + # just continue raising the exception + raise e + return self._completion_handler.create_response(stream) + def _setup_routes(self): """ Sets up the /chat/completions route for the provider as expected by the @@ -58,12 +68,6 @@ def _setup_routes(self): passes it to the completion handler. """ - @self.router.get(f"/{self.provider_route_name}/models") - @self.router.get(f"/{self.provider_route_name}/v1/models") - async def get_models(): - # dummy method for lm studio - return JSONResponse(status_code=200, content=[]) - @self.router.post(f"/{self.provider_route_name}/chat/completions") @self.router.post(f"/{self.provider_route_name}/completions") @self.router.post(f"/{self.provider_route_name}/v1/chat/completions") @@ -78,20 +82,4 @@ async def create_completion( body = await request.body() data = json.loads(body) - # if model starts with lm_studio, propagate it - if data.get("model", "").startswith("lm_studio"): - data["base_url"] = self.lm_studio_url + "/v1/" - is_fim_request = self._is_fim_request(request, data) - try: - stream = await self.complete(data, api_key, is_fim_request=is_fim_request) - except Exception as e: - #  check if we have an status code there - if hasattr(e, "status_code"): - logger = structlog.get_logger("codegate") - logger.error("Error in OpenAIProvider completion", error=str(e)) - - raise HTTPException(status_code=e.status_code, detail=str(e)) # type: ignore - else: - # just continue raising the exception - raise e - return self._completion_handler.create_response(stream) + return await self.process_request(data, api_key, request) diff --git a/src/codegate/server.py b/src/codegate/server.py index 216ba95e..e0216e8f 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -13,6 +13,7 @@ from codegate.pipeline.factory import PipelineFactory from codegate.providers.anthropic.provider import AnthropicProvider from codegate.providers.llamacpp.provider import LlamaCppProvider +from codegate.providers.lm_studio.provider import LmStudioProvider from codegate.providers.ollama.provider import OllamaProvider from codegate.providers.openai.provider import OpenAIProvider from codegate.providers.registry import ProviderRegistry, get_provider_registry @@ -96,6 +97,12 @@ async def log_user_agent(request: Request, call_next): pipeline_factory, ), ) + registry.add_provider( + "lm_studio", + LmStudioProvider( + pipeline_factory, + ), + ) # Create and add system routes system_router = APIRouter(tags=["System"]) diff --git a/tests/test_server.py b/tests/test_server.py index 80bb7cb0..46e2f867 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -108,8 +108,8 @@ def test_provider_registration(mock_registry, mock_secrets_mgr, mock_pipeline_fa # Verify all providers were registered registry_instance = mock_registry.return_value assert ( - registry_instance.add_provider.call_count == 5 - ) # openai, anthropic, llamacpp, vllm, ollama + registry_instance.add_provider.call_count == 6 + ) # openai, anthropic, llamacpp, vllm, ollama, lm_studio # Verify specific providers were registered provider_names = [call.args[0] for call in registry_instance.add_provider.call_args_list] From ce2d8624d46f6230aec52f0b49a904856babc45f Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Thu, 30 Jan 2025 10:38:50 +0200 Subject: [PATCH 2/2] Delete conditional to add lm studio URL --- src/codegate/providers/lm_studio/provider.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/codegate/providers/lm_studio/provider.py b/src/codegate/providers/lm_studio/provider.py index be083ad0..9fc88a3b 100644 --- a/src/codegate/providers/lm_studio/provider.py +++ b/src/codegate/providers/lm_studio/provider.py @@ -51,8 +51,6 @@ async def create_completion( body = await request.body() data = json.loads(body) - # if model starts with lm_studio, propagate it - if data.get("model", "").startswith("lm_studio"): - data["base_url"] = self.lm_studio_url + "/v1/" + data["base_url"] = self.lm_studio_url + "/v1/" return await self.process_request(data, api_key, request)