Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 4e032d9

Browse files
Support llamacpp provider in muxing (#889)
Closes: #883 There were a couple of nits that prevented support for llamacpp: 1. The `models` method was not implemented in the provider 2. A way of specifying the model outside `CompletionHandler` These PR takes care of both
1 parent 06a5bc8 commit 4e032d9

File tree

5 files changed

+22
-4
lines changed

5 files changed

+22
-4
lines changed

src/codegate/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"vllm": "http://localhost:8000", # Base URL without /v1 path
2222
"ollama": "http://localhost:11434", # Default Ollama server URL
2323
"lm_studio": "http://localhost:1234",
24+
"llamacpp": "./codegate_volume/models", # Default LlamaCpp model path
2425
}
2526

2627

src/codegate/muxing/adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def __init__(self):
104104
db_models.ProviderType.ollama: self._format_ollama,
105105
db_models.ProviderType.openai: self._format_openai,
106106
db_models.ProviderType.anthropic: self._format_antropic,
107+
# Our Lllamacpp provider emits OpenAI chunks
108+
db_models.ProviderType.llamacpp: self._format_openai,
107109
}
108110

109111
def _format_ollama(self, chunk: str) -> str:

src/codegate/providers/crud/crud.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def __provider_endpoint_from_cfg(
365365

366366

367367
def provider_default_endpoints(provider_type: str) -> str:
368+
# TODO: These providers default endpoints should come from config.py
368369
defaults = {
369370
"openai": "https://api.openai.com",
370371
"anthropic": "https://api.anthropic.com",

src/codegate/providers/llamacpp/completion_handler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,16 @@ async def execute_completion(
5757
"""
5858
Execute the completion request with inference engine API
5959
"""
60-
model_path = f"{Config.get_config().model_base_path}/{request['model']}.gguf"
60+
model_path = f"{request['base_url']}/{request['model']}.gguf"
6161

6262
# Create a copy of the request dict and remove stream_options
6363
# Reason - Request error as JSON:
6464
# {'error': "Llama.create_completion() got an unexpected keyword argument 'stream_options'"}
6565
request_dict = dict(request)
6666
request_dict.pop("stream_options", None)
67+
# Remove base_url from the request dict. We use this field as a standard across
68+
# all providers to specify the base URL of the model.
69+
request_dict.pop("base_url", None)
6770

6871
if is_fim_request:
6972
response = await self.inference_engine.complete(

src/codegate/providers/llamacpp/provider.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import json
2+
from pathlib import Path
23
from typing import List
34

45
import structlog
56
from fastapi import HTTPException, Request
67

8+
from codegate.config import Config
79
from codegate.pipeline.factory import PipelineFactory
8-
from codegate.providers.base import BaseProvider
10+
from codegate.providers.base import BaseProvider, ModelFetchError
911
from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler
1012
from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer
1113

@@ -30,8 +32,16 @@ def provider_route_name(self) -> str:
3032
return "llamacpp"
3133

3234
def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
33-
# TODO: Implement file fetching
34-
return []
35+
models_path = Path(Config.get_config().model_base_path)
36+
if not models_path.is_dir():
37+
raise ModelFetchError(f"llamacpp model path does not exist: {models_path}")
38+
39+
# return all models except the all-minilm-L6-v2-q5_k_m model which we use for embeddings
40+
return [
41+
model.stem
42+
for model in models_path.glob("*.gguf")
43+
if model.is_file() and model.stem != "all-minilm-L6-v2-q5_k_m"
44+
]
3545

3646
async def process_request(self, data: dict, api_key: str, request_url_path: str):
3747
is_fim_request = self._is_fim_request(request_url_path, data)
@@ -66,5 +76,6 @@ async def create_completion(
6676
):
6777
body = await request.body()
6878
data = json.loads(body)
79+
data["base_url"] = Config.get_config().model_base_path
6980

7081
return await self.process_request(data, None, request.url.path)

0 commit comments

Comments
 (0)