Skip to content

Commit 8915f94

Browse files
JAORMXlukehinds
authored andcommitted
Bootstrap provider models on addition & implement mux endpoints (#826)
* Bootstrap provider models on addition When adding a provider via the API, this bootstrap the available models for a provider. Signed-off-by: Juan Antonio Osorio <[email protected]> * Implement muxes CRUD Signed-off-by: Juan Antonio Osorio <[email protected]> * Fix linter issues Signed-off-by: Juan Antonio Osorio <[email protected]> * Remove `models` implementation from llama-cpp Signed-off-by: Juan Antonio Osorio <[email protected]> --------- Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent 17e949d commit 8915f94

File tree

15 files changed

+427
-67
lines changed

15 files changed

+427
-67
lines changed

src/codegate/api/v1.py

+34-23
Original file line numberDiff line numberDiff line change
@@ -109,20 +109,26 @@ async def get_provider_endpoint(
109109
status_code=201,
110110
)
111111
async def add_provider_endpoint(
112-
request: v1_models.ProviderEndpoint,
112+
request: v1_models.AddProviderEndpointRequest,
113113
) -> v1_models.ProviderEndpoint:
114114
"""Add a provider endpoint."""
115115
try:
116116
provend = await pcrud.add_endpoint(request)
117117
except AlreadyExistsError:
118118
raise HTTPException(status_code=409, detail="Provider endpoint already exists")
119+
except ValueError as e:
120+
raise HTTPException(
121+
status_code=400,
122+
detail=str(e),
123+
)
119124
except ValidationError as e:
120125
# TODO: This should be more specific
121126
raise HTTPException(
122127
status_code=400,
123128
detail=str(e),
124129
)
125130
except Exception:
131+
logger.exception("Error while adding provider endpoint")
126132
raise HTTPException(status_code=500, detail="Internal server error")
127133

128134
return provend
@@ -154,20 +160,24 @@ async def configure_auth_material(
154160
)
155161
async def update_provider_endpoint(
156162
provider_id: UUID,
157-
request: v1_models.ProviderEndpoint,
163+
request: v1_models.AddProviderEndpointRequest,
158164
) -> v1_models.ProviderEndpoint:
159165
"""Update a provider endpoint by ID."""
160166
try:
161-
request.id = provider_id
167+
request.id = str(provider_id)
162168
provend = await pcrud.update_endpoint(request)
169+
except provendcrud.ProviderNotFoundError:
170+
raise HTTPException(status_code=404, detail="Provider endpoint not found")
171+
except ValueError as e:
172+
raise HTTPException(status_code=400, detail=str(e))
163173
except ValidationError as e:
164174
# TODO: This should be more specific
165175
raise HTTPException(
166176
status_code=400,
167177
detail=str(e),
168178
)
169-
except Exception:
170-
raise HTTPException(status_code=500, detail="Internal server error")
179+
except Exception as e:
180+
raise HTTPException(status_code=500, detail=str(e))
171181

172182
return provend
173183

@@ -471,22 +481,15 @@ async def get_workspace_muxes(
471481
472482
The list is ordered in order of priority. That is, the first rule in the list
473483
has the highest priority."""
474-
# TODO: This is a dummy implementation. In the future, we should have a proper
475-
# implementation that fetches the mux rules from the database.
476-
return [
477-
v1_models.MuxRule(
478-
# Hardcode some UUID just for mocking purposes
479-
provider_id="00000000-0000-0000-0000-000000000001",
480-
model="gpt-3.5-turbo",
481-
matcher_type=v1_models.MuxMatcherType.file_regex,
482-
matcher=".*\\.txt",
483-
),
484-
v1_models.MuxRule(
485-
provider_id="00000000-0000-0000-0000-000000000002",
486-
model="davinci",
487-
matcher_type=v1_models.MuxMatcherType.catch_all,
488-
),
489-
]
484+
try:
485+
muxes = await wscrud.get_muxes(workspace_name)
486+
except crud.WorkspaceDoesNotExistError:
487+
raise HTTPException(status_code=404, detail="Workspace does not exist")
488+
except Exception:
489+
logger.exception("Error while getting workspace")
490+
raise HTTPException(status_code=500, detail="Internal server error")
491+
492+
return muxes
490493

491494

492495
@v1.put(
@@ -500,8 +503,16 @@ async def set_workspace_muxes(
500503
request: List[v1_models.MuxRule],
501504
):
502505
"""Set the mux rules of a workspace."""
503-
# TODO: This is a dummy implementation. In the future, we should have a proper
504-
# implementation that sets the mux rules in the database.
506+
try:
507+
await wscrud.set_muxes(workspace_name, request)
508+
except crud.WorkspaceDoesNotExistError:
509+
raise HTTPException(status_code=404, detail="Workspace does not exist")
510+
except crud.WorkspaceCrudError as e:
511+
raise HTTPException(status_code=400, detail=str(e))
512+
except Exception:
513+
logger.exception("Error while setting muxes")
514+
raise HTTPException(status_code=500, detail="Internal server error")
515+
505516
return Response(status_code=204)
506517

507518

src/codegate/api/v1_models.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ class ProviderEndpoint(pydantic.BaseModel):
222222
name: str
223223
description: str = ""
224224
provider_type: ProviderType
225-
endpoint: str
225+
endpoint: str = "" # Some providers have defaults we can leverage
226226
auth_type: Optional[ProviderAuthType] = ProviderAuthType.none
227227

228228
@staticmethod
@@ -250,6 +250,14 @@ def get_from_registry(self, registry: ProviderRegistry) -> Optional[BaseProvider
250250
return registry.get_provider(self.provider_type)
251251

252252

253+
class AddProviderEndpointRequest(ProviderEndpoint):
254+
"""
255+
Represents a request to add a provider endpoint.
256+
"""
257+
258+
api_key: Optional[str] = None
259+
260+
253261
class ConfigureAuthMaterial(pydantic.BaseModel):
254262
"""
255263
Represents a request to configure auth material for a provider.
@@ -279,11 +287,6 @@ class MuxMatcherType(str, Enum):
279287
Represents the different types of matchers we support.
280288
"""
281289

282-
# Match a regular expression for a file path
283-
# in the prompt. Note that if no file is found,
284-
# the prompt will be passed through.
285-
file_regex = "file_regex"
286-
287290
# Always match this prompt
288291
catch_all = "catch_all"
289292

src/codegate/db/connection.py

+81
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Alert,
2020
GetPromptWithOutputsRow,
2121
GetWorkspaceByNameConditions,
22+
MuxRule,
2223
Output,
2324
Prompt,
2425
ProviderAuthMaterial,
@@ -112,6 +113,15 @@ async def _execute_update_pydantic_model(
112113
raise e
113114
return None
114115

116+
async def _execute_with_no_return(self, sql_command: TextClause, conditions: dict):
117+
"""Execute a command that doesn't return anything."""
118+
try:
119+
async with self._async_db_engine.begin() as conn:
120+
await conn.execute(sql_command, conditions)
121+
except Exception as e:
122+
logger.error(f"Failed to execute command: {sql_command}.", error=str(e))
123+
raise e
124+
115125
async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
116126
if prompt_params is None:
117127
return None
@@ -459,6 +469,45 @@ async def add_provider_model(self, model: ProviderModel) -> ProviderModel:
459469
added_model = await self._execute_update_pydantic_model(model, sql, should_raise=True)
460470
return added_model
461471

472+
async def delete_provider_models(self, provider_id: str):
473+
sql = text(
474+
"""
475+
DELETE FROM provider_models
476+
WHERE provider_endpoint_id = :provider_endpoint_id
477+
"""
478+
)
479+
conditions = {"provider_endpoint_id": provider_id}
480+
await self._execute_with_no_return(sql, conditions)
481+
482+
async def delete_muxes_by_workspace(self, workspace_id: str):
483+
sql = text(
484+
"""
485+
DELETE FROM muxes
486+
WHERE workspace_id = :workspace_id
487+
RETURNING *
488+
"""
489+
)
490+
491+
conditions = {"workspace_id": workspace_id}
492+
await self._execute_with_no_return(sql, conditions)
493+
494+
async def add_mux(self, mux: MuxRule) -> MuxRule:
495+
sql = text(
496+
"""
497+
INSERT INTO muxes (
498+
id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type,
499+
matcher_blob, priority, created_at, updated_at
500+
)
501+
VALUES (
502+
:id, :provider_endpoint_id, :provider_model_name, :workspace_id,
503+
:matcher_type, :matcher_blob, :priority, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP
504+
)
505+
RETURNING *
506+
"""
507+
)
508+
added_mux = await self._execute_update_pydantic_model(mux, sql, should_raise=True)
509+
return added_mux
510+
462511

463512
class DbReader(DbCodeGate):
464513

@@ -684,6 +733,22 @@ async def get_provider_models_by_provider_id(self, provider_id: str) -> List[Pro
684733
)
685734
return models
686735

736+
async def get_provider_model_by_provider_id_and_name(
737+
self, provider_id: str, model_name: str
738+
) -> Optional[ProviderModel]:
739+
sql = text(
740+
"""
741+
SELECT provider_endpoint_id, name
742+
FROM provider_models
743+
WHERE provider_endpoint_id = :provider_endpoint_id AND name = :name
744+
"""
745+
)
746+
conditions = {"provider_endpoint_id": provider_id, "name": model_name}
747+
models = await self._exec_select_conditions_to_pydantic(
748+
ProviderModel, sql, conditions, should_raise=True
749+
)
750+
return models[0] if models else None
751+
687752
async def get_all_provider_models(self) -> List[ProviderModel]:
688753
sql = text(
689754
"""
@@ -695,6 +760,22 @@ async def get_all_provider_models(self) -> List[ProviderModel]:
695760
models = await self._execute_select_pydantic_model(ProviderModel, sql)
696761
return models
697762

763+
async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]:
764+
sql = text(
765+
"""
766+
SELECT id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type,
767+
matcher_blob, priority, created_at, updated_at
768+
FROM muxes
769+
WHERE workspace_id = :workspace_id
770+
ORDER BY priority ASC
771+
"""
772+
)
773+
conditions = {"workspace_id": workspace_id}
774+
muxes = await self._exec_select_conditions_to_pydantic(
775+
MuxRule, sql, conditions, should_raise=True
776+
)
777+
return muxes
778+
698779

699780
def init_db_sync(db_path: Optional[str] = None):
700781
"""DB will be initialized in the constructor in case it doesn't exist."""

src/codegate/db/models.py

+12
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,15 @@ class ProviderModel(BaseModel):
173173
provider_endpoint_id: str
174174
provider_endpoint_name: Optional[str] = None
175175
name: str
176+
177+
178+
class MuxRule(BaseModel):
179+
id: str
180+
provider_endpoint_id: str
181+
provider_model_name: str
182+
workspace_id: str
183+
matcher_type: str
184+
matcher_blob: str
185+
priority: int
186+
created_at: Optional[datetime.datetime] = None
187+
updated_at: Optional[datetime.datetime] = None

src/codegate/providers/anthropic/provider.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from codegate.pipeline.factory import PipelineFactory
99
from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
1010
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
11-
from codegate.providers.base import BaseProvider
11+
from codegate.providers.base import BaseProvider, ModelFetchError
1212
from codegate.providers.litellmshim import anthropic_stream_generator
1313

1414

@@ -29,16 +29,23 @@ def __init__(
2929
def provider_route_name(self) -> str:
3030
return "anthropic"
3131

32-
def models(self) -> List[str]:
33-
# TODO: This won't work since we need an API Key being set.
34-
resp = httpx.get("https://api.anthropic.com/models")
35-
# If Anthropic returned 404, it means it's not accepting our
36-
# requests. We should throw an error.
37-
if resp.status_code == 404:
38-
raise HTTPException(
39-
status_code=404,
40-
detail="The Anthropic API is not accepting requests. Please check your API key.",
41-
)
32+
def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
33+
headers = {
34+
"Content-Type": "application/json",
35+
"anthropic-version": "2023-06-01",
36+
}
37+
if api_key:
38+
headers["x-api-key"] = api_key
39+
if not endpoint:
40+
endpoint = "https://api.anthropic.com"
41+
42+
resp = httpx.get(
43+
f"{endpoint}/v1/models",
44+
headers=headers,
45+
)
46+
47+
if resp.status_code != 200:
48+
raise ModelFetchError(f"Failed to fetch models from Anthropic API: {resp.text}")
4249

4350
respjson = resp.json()
4451

src/codegate/providers/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]]
2525

2626

27+
class ModelFetchError(Exception):
28+
pass
29+
30+
2731
class BaseProvider(ABC):
2832
"""
2933
The provider class is responsible for defining the API routes and
@@ -55,7 +59,7 @@ def _setup_routes(self) -> None:
5559
pass
5660

5761
@abstractmethod
58-
def models(self) -> List[str]:
62+
def models(self, endpoint, str=None, api_key: str = None) -> List[str]:
5963
pass
6064

6165
@property

0 commit comments

Comments
 (0)