Skip to content

Commit 0134cd7

Browse files
committed
Bootstrap provider models on addition
When adding a provider via the API, this bootstrap the available models for a provider. Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent 5526ead commit 0134cd7

File tree

13 files changed

+263
-39
lines changed

13 files changed

+263
-39
lines changed

src/codegate/api/v1.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -109,20 +109,26 @@ async def get_provider_endpoint(
109109
status_code=201,
110110
)
111111
async def add_provider_endpoint(
112-
request: v1_models.ProviderEndpoint,
112+
request: v1_models.AddProviderEndpointRequest,
113113
) -> v1_models.ProviderEndpoint:
114114
"""Add a provider endpoint."""
115115
try:
116116
provend = await pcrud.add_endpoint(request)
117117
except AlreadyExistsError:
118118
raise HTTPException(status_code=409, detail="Provider endpoint already exists")
119+
except ValueError as e:
120+
raise HTTPException(
121+
status_code=400,
122+
detail=str(e),
123+
)
119124
except ValidationError as e:
120125
# TODO: This should be more specific
121126
raise HTTPException(
122127
status_code=400,
123128
detail=str(e),
124129
)
125130
except Exception:
131+
logger.exception("Error while adding provider endpoint")
126132
raise HTTPException(status_code=500, detail="Internal server error")
127133

128134
return provend
@@ -154,20 +160,24 @@ async def configure_auth_material(
154160
)
155161
async def update_provider_endpoint(
156162
provider_id: UUID,
157-
request: v1_models.ProviderEndpoint,
163+
request: v1_models.AddProviderEndpointRequest,
158164
) -> v1_models.ProviderEndpoint:
159165
"""Update a provider endpoint by ID."""
160166
try:
161-
request.id = provider_id
167+
request.id = str(provider_id)
162168
provend = await pcrud.update_endpoint(request)
169+
except provendcrud.ProviderNotFoundError:
170+
raise HTTPException(status_code=404, detail="Provider endpoint not found")
171+
except ValueError as e:
172+
raise HTTPException(status_code=400, detail=str(e))
163173
except ValidationError as e:
164174
# TODO: This should be more specific
165175
raise HTTPException(
166176
status_code=400,
167177
detail=str(e),
168178
)
169-
except Exception:
170-
raise HTTPException(status_code=500, detail="Internal server error")
179+
except Exception as e:
180+
raise HTTPException(status_code=500, detail=str(e))
171181

172182
return provend
173183

src/codegate/api/v1_models.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ class ProviderEndpoint(pydantic.BaseModel):
222222
name: str
223223
description: str = ""
224224
provider_type: ProviderType
225-
endpoint: str
225+
endpoint: str = "" # Some providers have defaults we can leverage
226226
auth_type: Optional[ProviderAuthType] = ProviderAuthType.none
227227

228228
@staticmethod
@@ -250,6 +250,14 @@ def get_from_registry(self, registry: ProviderRegistry) -> Optional[BaseProvider
250250
return registry.get_provider(self.provider_type)
251251

252252

253+
class AddProviderEndpointRequest(ProviderEndpoint):
254+
"""
255+
Represents a request to add a provider endpoint.
256+
"""
257+
258+
api_key: Optional[str] = None
259+
260+
253261
class ConfigureAuthMaterial(pydantic.BaseModel):
254262
"""
255263
Represents a request to configure auth material for a provider.

src/codegate/db/connection.py

+17
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,23 @@ async def add_provider_model(self, model: ProviderModel) -> ProviderModel:
459459
added_model = await self._execute_update_pydantic_model(model, sql, should_raise=True)
460460
return added_model
461461

462+
async def delete_provider_models(self, provider_id: str) -> Optional[ProviderModel]:
463+
sql = text(
464+
"""
465+
DELETE FROM provider_models
466+
WHERE provider_endpoint_id = :provider_endpoint_id
467+
RETURNING *
468+
"""
469+
)
470+
await self._execute_update_pydantic_model(
471+
ProviderModel(
472+
provider_endpoint_id=provider_id,
473+
name="Fake name to respect the signature of the function",
474+
),
475+
sql,
476+
should_raise=True,
477+
)
478+
462479

463480
class DbReader(DbCodeGate):
464481

src/codegate/providers/anthropic/provider.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from codegate.pipeline.factory import PipelineFactory
99
from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
1010
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
11-
from codegate.providers.base import BaseProvider
11+
from codegate.providers.base import BaseProvider, ModelFetchError
1212
from codegate.providers.litellmshim import anthropic_stream_generator
1313

1414

@@ -29,16 +29,23 @@ def __init__(
2929
def provider_route_name(self) -> str:
3030
return "anthropic"
3131

32-
def models(self) -> List[str]:
33-
# TODO: This won't work since we need an API Key being set.
34-
resp = httpx.get("https://api.anthropic.com/models")
35-
# If Anthropic returned 404, it means it's not accepting our
36-
# requests. We should throw an error.
37-
if resp.status_code == 404:
38-
raise HTTPException(
39-
status_code=404,
40-
detail="The Anthropic API is not accepting requests. Please check your API key.",
41-
)
32+
def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
33+
headers = {
34+
"Content-Type": "application/json",
35+
"anthropic-version": "2023-06-01",
36+
}
37+
if api_key:
38+
headers["x-api-key"] = api_key
39+
if not endpoint:
40+
endpoint = "https://api.anthropic.com"
41+
42+
resp = httpx.get(
43+
f"{endpoint}/v1/models",
44+
headers=headers,
45+
)
46+
47+
if resp.status_code != 200:
48+
raise ModelFetchError(f"Failed to fetch models from Anthropic API: {resp.text}")
4249

4350
respjson = resp.json()
4451

src/codegate/providers/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]]
2525

2626

27+
class ModelFetchError(Exception):
28+
pass
29+
30+
2731
class BaseProvider(ABC):
2832
"""
2933
The provider class is responsible for defining the API routes and
@@ -55,7 +59,7 @@ def _setup_routes(self) -> None:
5559
pass
5660

5761
@abstractmethod
58-
def models(self) -> List[str]:
62+
def models(self, endpoint, str=None, api_key: str = None) -> List[str]:
5963
pass
6064

6165
@property

src/codegate/providers/crud/crud.py

+116-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from codegate.db import models as dbmodels
1212
from codegate.db.connection import DbReader, DbRecorder
1313
from codegate.providers.base import BaseProvider
14-
from codegate.providers.registry import ProviderRegistry
14+
from codegate.providers.registry import ProviderRegistry, get_provider_registry
1515

1616
logger = structlog.get_logger("codegate")
1717

@@ -62,23 +62,106 @@ async def get_endpoint_by_name(self, name: str) -> Optional[apimodelsv1.Provider
6262
return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
6363

6464
async def add_endpoint(
65-
self, endpoint: apimodelsv1.ProviderEndpoint
65+
self, endpoint: apimodelsv1.AddProviderEndpointRequest
6666
) -> apimodelsv1.ProviderEndpoint:
6767
"""Add an endpoint."""
68+
69+
if not endpoint.endpoint:
70+
endpoint.endpoint = provider_default_endpoints(endpoint.provider_type)
71+
72+
# If we STILL don't have an endpoint, we can't continue
73+
if not endpoint.endpoint:
74+
raise ValueError("No endpoint provided and no default found for provider type")
75+
6876
dbend = endpoint.to_db_model()
77+
provider_registry = get_provider_registry()
6978

7079
# We override the ID here, as we want to generate it.
7180
dbend.id = str(uuid4())
7281

73-
dbendpoint = await self._db_writer.add_provider_endpoint()
82+
prov = endpoint.get_from_registry(provider_registry)
83+
if prov is None:
84+
raise ValueError("Unknown provider type: {}".format(endpoint.provider_type))
85+
86+
models = []
87+
if endpoint.auth_type == apimodelsv1.ProviderAuthType.api_key and not endpoint.api_key:
88+
raise ValueError("API key must be provided for API auth type")
89+
if endpoint.auth_type != apimodelsv1.ProviderAuthType.passthrough:
90+
try:
91+
models = prov.models(endpoint=endpoint.endpoint, api_key=endpoint.api_key)
92+
except Exception as err:
93+
raise ValueError("Unable to get models from provider: {}".format(str(err)))
94+
95+
dbendpoint = await self._db_writer.add_provider_endpoint(dbend)
96+
97+
await self._db_writer.push_provider_auth_material(
98+
dbmodels.ProviderAuthMaterial(
99+
provider_endpoint_id=dbendpoint.id,
100+
auth_type=endpoint.auth_type,
101+
auth_blob=endpoint.api_key if endpoint.api_key else "",
102+
)
103+
)
104+
105+
for model in models:
106+
await self._db_writer.add_provider_model(
107+
dbmodels.ProviderModel(
108+
provider_endpoint_id=dbendpoint.id,
109+
name=model,
110+
)
111+
)
74112
return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
75113

76114
async def update_endpoint(
77-
self, endpoint: apimodelsv1.ProviderEndpoint
115+
self, endpoint: apimodelsv1.AddProviderEndpointRequest
78116
) -> apimodelsv1.ProviderEndpoint:
79117
"""Update an endpoint."""
80118

119+
if not endpoint.endpoint:
120+
endpoint.endpoint = provider_default_endpoints(endpoint.provider_type)
121+
122+
# If we STILL don't have an endpoint, we can't continue
123+
if not endpoint.endpoint:
124+
raise ValueError("No endpoint provided and no default found for provider type")
125+
126+
provider_registry = get_provider_registry()
127+
prov = endpoint.get_from_registry(provider_registry)
128+
if prov is None:
129+
raise ValueError("Unknown provider type: {}".format(endpoint.provider_type))
130+
131+
founddbe = await self._db_reader.get_provider_endpoint_by_id(str(endpoint.id))
132+
if founddbe is None:
133+
raise ProviderNotFoundError("Provider not found")
134+
135+
models = []
136+
if endpoint.auth_type == apimodelsv1.ProviderAuthType.api_key and not endpoint.api_key:
137+
raise ValueError("API key must be provided for API auth type")
138+
if endpoint.auth_type != apimodelsv1.ProviderAuthType.passthrough:
139+
try:
140+
models = prov.models(endpoint=endpoint.endpoint, api_key=endpoint.api_key)
141+
except Exception as err:
142+
raise ValueError("Unable to get models from provider: {}".format(str(err)))
143+
144+
# Reset all provider models.
145+
await self._db_writer.delete_provider_models(str(endpoint.id))
146+
147+
for model in models:
148+
await self._db_writer.add_provider_model(
149+
dbmodels.ProviderModel(
150+
provider_endpoint_id=founddbe.id,
151+
name=model,
152+
)
153+
)
154+
81155
dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model())
156+
157+
await self._db_writer.push_provider_auth_material(
158+
dbmodels.ProviderAuthMaterial(
159+
provider_endpoint_id=dbendpoint.id,
160+
auth_type=endpoint.auth_type,
161+
auth_blob=endpoint.api_key if endpoint.api_key else "",
162+
)
163+
)
164+
82165
return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
83166

84167
async def configure_auth_material(
@@ -175,6 +258,13 @@ async def initialize_provider_endpoints(preg: ProviderRegistry):
175258
continue
176259

177260
pimpl = provend.get_from_registry(preg)
261+
if pimpl is None:
262+
logger.warning(
263+
"Provider not found in registry",
264+
provider=provend.name,
265+
endpoint=provend.endpoint,
266+
)
267+
continue
178268
await try_initialize_provider_endpoints(provend, pimpl, db_writer)
179269

180270

@@ -240,7 +330,7 @@ def __provider_endpoint_from_cfg(
240330
description=("Endpoint for the {} provided via the CodeGate configuration.").format(
241331
provider_name
242332
),
243-
provider_type=provider_name,
333+
provider_type=provider_overrides(provider_name),
244334
auth_type=apimodelsv1.ProviderAuthType.passthrough,
245335
)
246336
except ValidationError as err:
@@ -251,3 +341,24 @@ def __provider_endpoint_from_cfg(
251341
err=str(err),
252342
)
253343
return None
344+
345+
346+
def provider_default_endpoints(provider_type: str) -> str:
347+
defaults = {
348+
"openai": "https://api.openai.com",
349+
"anthropic": "https://api.anthropic.com",
350+
}
351+
352+
# If we have a default, we return it
353+
# Otherwise, we return an empty string
354+
return defaults.get(provider_type, "")
355+
356+
357+
def provider_overrides(provider_type: str) -> str:
358+
overrides = {
359+
"lm_studio": "openai",
360+
}
361+
362+
# If we have an override, we return it
363+
# Otherwise, we return the type
364+
return overrides.get(provider_type, provider_type)

src/codegate/providers/llamacpp/provider.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import json
2+
from typing import List
23

34
import httpx
45
import structlog
56
from fastapi import HTTPException, Request
67

78
from codegate.pipeline.factory import PipelineFactory
8-
from codegate.providers.base import BaseProvider
9+
from codegate.providers.base import BaseProvider, ModelFetchError
910
from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler
1011
from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer
1112

@@ -27,9 +28,22 @@ def __init__(
2728
def provider_route_name(self) -> str:
2829
return "llamacpp"
2930

30-
def models(self):
31+
def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
32+
headers = {}
33+
if api_key:
34+
headers["Authorization"] = f"Bearer {api_key}"
35+
if not endpoint:
36+
endpoint = self.base_url
37+
3138
# HACK: This is using OpenAI's /v1/models endpoint to get the list of models
32-
resp = httpx.get(f"{self.base_url}/v1/models")
39+
resp = httpx.get(
40+
f"{endpoint}/v1/models",
41+
headers=headers,
42+
)
43+
44+
if resp.status_code != 200:
45+
raise ModelFetchError(f"Failed to fetch models from Llama API: {resp.text}")
46+
3347
jsonresp = resp.json()
3448

3549
return [model["id"] for model in jsonresp.get("data", [])]

0 commit comments

Comments
 (0)