Skip to content

Commit c59c10e

Browse files
committed
Kick off provider endpoint CRUD structure and registration
This structure will handle all the database operations and turn that into the right models. Note that for provider endpoints we already have a way of setting these via configuration, so this is taken into account to output some sample objects that users can leverage. Each provider will need to implement a `models` function which allows us to auto-discover models for a provider. Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent 1fbaa1a commit c59c10e

File tree

16 files changed

+606
-69
lines changed

16 files changed

+606
-69
lines changed

src/codegate/api/v1.py

Lines changed: 117 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
from typing import List, Optional
2+
from uuid import UUID
23

34
import requests
45
import structlog
5-
from fastapi import APIRouter, HTTPException, Response
6+
from fastapi import APIRouter, Depends, HTTPException, Response
67
from fastapi.responses import StreamingResponse
78
from fastapi.routing import APIRoute
8-
from pydantic import ValidationError
9+
from pydantic import BaseModel, ValidationError
910

1011
from codegate import __version__
1112
from codegate.api import v1_models, v1_processing
1213
from codegate.db.connection import AlreadyExistsError, DbReader
14+
from codegate.providers import crud as provendcrud
1315
from codegate.workspaces import crud
1416

1517
logger = structlog.get_logger("codegate")
1618

1719
v1 = APIRouter()
1820
wscrud = crud.WorkspaceCrud()
21+
pcrud = provendcrud.ProviderCrud()
1922

2023
# This is a singleton object
2124
dbreader = DbReader()
@@ -25,38 +28,78 @@ def uniq_name(route: APIRoute):
2528
return f"v1_{route.name}"
2629

2730

31+
class FilterByNameParams(BaseModel):
32+
name: Optional[str] = None
33+
34+
2835
@v1.get("/provider-endpoints", tags=["Providers"], generate_unique_id_function=uniq_name)
29-
async def list_provider_endpoints(name: Optional[str] = None) -> List[v1_models.ProviderEndpoint]:
36+
async def list_provider_endpoints(
37+
filter_query: FilterByNameParams = Depends(),
38+
) -> List[v1_models.ProviderEndpoint]:
3039
"""List all provider endpoints."""
31-
# NOTE: This is a dummy implementation. In the future, we should have a proper
32-
# implementation that fetches the provider endpoints from the database.
33-
return [
34-
v1_models.ProviderEndpoint(
35-
id=1,
36-
name="dummy",
37-
description="Dummy provider endpoint",
38-
endpoint="http://example.com",
39-
provider_type=v1_models.ProviderType.openai,
40-
auth_type=v1_models.ProviderAuthType.none,
41-
)
42-
]
40+
if filter_query.name is None:
41+
try:
42+
return await pcrud.list_endpoints()
43+
except Exception:
44+
raise HTTPException(status_code=500, detail="Internal server error")
45+
46+
try:
47+
provend = await pcrud.get_endpoint_by_name(filter_query.name)
48+
except Exception:
49+
raise HTTPException(status_code=500, detail="Internal server error")
50+
51+
if provend is None:
52+
raise HTTPException(status_code=404, detail="Provider endpoint not found")
53+
return [provend]
54+
55+
56+
# This needs to be above /provider-endpoints/{provider_id} to avoid conflict
57+
@v1.get(
58+
"/provider-endpoints/models",
59+
tags=["Providers"],
60+
generate_unique_id_function=uniq_name,
61+
)
62+
async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]:
63+
"""List all models for all providers."""
64+
try:
65+
return await pcrud.get_all_models()
66+
except Exception:
67+
raise HTTPException(status_code=500, detail="Internal server error")
68+
69+
70+
@v1.get(
71+
"/provider-endpoints/{provider_id}/models",
72+
tags=["Providers"],
73+
generate_unique_id_function=uniq_name,
74+
)
75+
async def list_models_by_provider(
76+
provider_id: UUID,
77+
) -> List[v1_models.ModelByProvider]:
78+
"""List models by provider."""
79+
80+
try:
81+
return await pcrud.models_by_provider(provider_id)
82+
except provendcrud.ProviderNotFoundError:
83+
raise HTTPException(status_code=404, detail="Provider not found")
84+
except Exception as e:
85+
raise HTTPException(status_code=500, detail=str(e))
4386

4487

4588
@v1.get(
4689
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
4790
)
48-
async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
91+
async def get_provider_endpoint(
92+
provider_id: UUID,
93+
) -> v1_models.ProviderEndpoint:
4994
"""Get a provider endpoint by ID."""
50-
# NOTE: This is a dummy implementation. In the future, we should have a proper
51-
# implementation that fetches the provider endpoint from the database.
52-
return v1_models.ProviderEndpoint(
53-
id=provider_id,
54-
name="dummy",
55-
description="Dummy provider endpoint",
56-
endpoint="http://example.com",
57-
provider_type=v1_models.ProviderType.openai,
58-
auth_type=v1_models.ProviderAuthType.none,
59-
)
95+
try:
96+
provend = await pcrud.get_endpoint_by_id(provider_id)
97+
except Exception:
98+
raise HTTPException(status_code=500, detail="Internal server error")
99+
100+
if provend is None:
101+
raise HTTPException(status_code=404, detail="Provider endpoint not found")
102+
return provend
60103

61104

62105
@v1.post(
@@ -65,59 +108,65 @@ async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
65108
generate_unique_id_function=uniq_name,
66109
status_code=201,
67110
)
68-
async def add_provider_endpoint(request: v1_models.ProviderEndpoint) -> v1_models.ProviderEndpoint:
111+
async def add_provider_endpoint(
112+
request: v1_models.ProviderEndpoint,
113+
) -> v1_models.ProviderEndpoint:
69114
"""Add a provider endpoint."""
70-
# NOTE: This is a dummy implementation. In the future, we should have a proper
71-
# implementation that adds the provider endpoint to the database.
72-
return request
115+
try:
116+
provend = await pcrud.add_endpoint(request)
117+
except AlreadyExistsError:
118+
raise HTTPException(status_code=409, detail="Provider endpoint already exists")
119+
except ValidationError as e:
120+
# TODO: This should be more specific
121+
raise HTTPException(
122+
status_code=400,
123+
detail=str(e),
124+
)
125+
except Exception:
126+
raise HTTPException(status_code=500, detail="Internal server error")
127+
128+
return provend
73129

74130

75131
@v1.put(
76132
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
77133
)
78134
async def update_provider_endpoint(
79-
provider_id: int, request: v1_models.ProviderEndpoint
135+
provider_id: UUID,
136+
request: v1_models.ProviderEndpoint,
80137
) -> v1_models.ProviderEndpoint:
81138
"""Update a provider endpoint by ID."""
82-
# NOTE: This is a dummy implementation. In the future, we should have a proper
83-
# implementation that updates the provider endpoint in the database.
84-
return request
139+
try:
140+
request.id = provider_id
141+
provend = await pcrud.update_endpoint(request)
142+
except ValidationError as e:
143+
# TODO: This should be more specific
144+
raise HTTPException(
145+
status_code=400,
146+
detail=str(e),
147+
)
148+
except Exception:
149+
raise HTTPException(status_code=500, detail="Internal server error")
150+
151+
return provend
85152

86153

87154
@v1.delete(
88155
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
89156
)
90-
async def delete_provider_endpoint(provider_id: int):
157+
async def delete_provider_endpoint(
158+
provider_id: UUID,
159+
):
91160
"""Delete a provider endpoint by id."""
92-
# NOTE: This is a dummy implementation. In the future, we should have a proper
93-
# implementation that deletes the provider endpoint from the database.
161+
try:
162+
await pcrud.delete_endpoint(provider_id)
163+
except provendcrud.ProviderNotFoundError:
164+
raise HTTPException(status_code=404, detail="Provider endpoint not found")
165+
except Exception:
166+
raise HTTPException(status_code=500, detail="Internal server error")
94167
return Response(status_code=204)
95168

96169

97-
@v1.get(
98-
"/provider-endpoints/{provider_name}/models",
99-
tags=["Providers"],
100-
generate_unique_id_function=uniq_name,
101-
)
102-
async def list_models_by_provider(provider_name: str) -> List[v1_models.ModelByProvider]:
103-
"""List models by provider."""
104-
# NOTE: This is a dummy implementation. In the future, we should have a proper
105-
# implementation that fetches the models by provider from the database.
106-
return [v1_models.ModelByProvider(name="dummy", provider="dummy")]
107-
108-
109-
@v1.get(
110-
"/provider-endpoints/models",
111-
tags=["Providers"],
112-
generate_unique_id_function=uniq_name,
113-
)
114-
async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]:
115-
"""List all models for all providers."""
116-
# NOTE: This is a dummy implementation. In the future, we should have a proper
117-
# implementation that fetches all the models for all providers from the database.
118-
return [v1_models.ModelByProvider(name="dummy", provider="dummy")]
119-
120-
121170
@v1.get("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name)
122171
async def list_workspaces() -> v1_models.ListWorkspacesResponse:
123172
"""List all workspaces."""
@@ -394,7 +443,9 @@ async def delete_workspace_custom_instructions(workspace_name: str):
394443
tags=["Workspaces", "Muxes"],
395444
generate_unique_id_function=uniq_name,
396445
)
397-
async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
446+
async def get_workspace_muxes(
447+
workspace_name: str,
448+
) -> List[v1_models.MuxRule]:
398449
"""Get the mux rules of a workspace.
399450
400451
The list is ordered in order of priority. That is, the first rule in the list
@@ -422,7 +473,10 @@ async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
422473
generate_unique_id_function=uniq_name,
423474
status_code=204,
424475
)
425-
async def set_workspace_muxes(workspace_name: str, request: List[v1_models.MuxRule]):
476+
async def set_workspace_muxes(
477+
workspace_name: str,
478+
request: List[v1_models.MuxRule],
479+
):
426480
"""Set the mux rules of a workspace."""
427481
# TODO: This is a dummy implementation. In the future, we should have a proper
428482
# implementation that sets the mux rules in the database.

src/codegate/api/v1_models.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from codegate.db import models as db_models
88
from codegate.pipeline.base import CodeSnippet
9+
from codegate.providers.base import BaseProvider
10+
from codegate.providers.registry import ProviderRegistry
911

1012

1113
class Workspace(pydantic.BaseModel):
@@ -122,6 +124,8 @@ class ProviderType(str, Enum):
122124
openai = "openai"
123125
anthropic = "anthropic"
124126
vllm = "vllm"
127+
ollama = "ollama"
128+
lm_studio = "lm_studio"
125129

126130

127131
class TokenUsageByModel(pydantic.BaseModel):
@@ -191,13 +195,38 @@ class ProviderEndpoint(pydantic.BaseModel):
191195
so we can use this for muxing messages.
192196
"""
193197

194-
id: int
198+
# This will be set on creation
199+
id: Optional[str] = ""
195200
name: str
196201
description: str = ""
197202
provider_type: ProviderType
198203
endpoint: str
199204
auth_type: ProviderAuthType
200205

206+
@staticmethod
207+
def from_db_model(db_model: db_models.ProviderEndpoint) -> "ProviderEndpoint":
208+
return ProviderEndpoint(
209+
id=db_model.id,
210+
name=db_model.name,
211+
description=db_model.description,
212+
provider_type=db_model.provider_type,
213+
endpoint=db_model.endpoint,
214+
auth_type=db_model.auth_type,
215+
)
216+
217+
def to_db_model(self) -> db_models.ProviderEndpoint:
218+
return db_models.ProviderEndpoint(
219+
id=self.id,
220+
name=self.name,
221+
description=self.description,
222+
provider_type=self.provider_type,
223+
endpoint=self.endpoint,
224+
auth_type=self.auth_type,
225+
)
226+
227+
def get_from_registry(self, registry: ProviderRegistry) -> Optional[BaseProvider]:
228+
return registry.get_provider(self.provider_type)
229+
201230

202231
class ModelByProvider(pydantic.BaseModel):
203232
"""
@@ -207,10 +236,11 @@ class ModelByProvider(pydantic.BaseModel):
207236
"""
208237

209238
name: str
210-
provider: str
239+
provider_id: str
240+
provider_name: str
211241

212242
def __str__(self):
213-
return f"{self.provider}/{self.name}"
243+
return f"{self.provider_name} / {self.name}"
214244

215245

216246
class MuxMatcherType(str, Enum):

src/codegate/cli.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from codegate.db.connection import init_db_sync, init_session_if_not_exists
1818
from codegate.pipeline.factory import PipelineFactory
1919
from codegate.pipeline.secrets.manager import SecretsManager
20+
from codegate.providers import crud as provendcrud
2021
from codegate.providers.copilot.provider import CopilotProvider
2122
from codegate.server import init_app
2223
from codegate.storage.utils import restore_storage_backup
@@ -329,6 +330,9 @@ def serve(
329330
loop = asyncio.new_event_loop()
330331
asyncio.set_event_loop(loop)
331332

333+
registry = app.provider_registry
334+
loop.run_until_complete(provendcrud.initialize_provider_endpoints(registry))
335+
332336
# Run the server
333337
try:
334338
loop.run_until_complete(run_servers(cfg, app))

0 commit comments

Comments
 (0)