Skip to content

Commit 8f234ae

Browse files
committed
Implement muxes CRUD
Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent 64bcf78 commit 8f234ae

File tree

5 files changed

+185
-32
lines changed

5 files changed

+185
-32
lines changed

src/codegate/api/v1.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -481,22 +481,15 @@ async def get_workspace_muxes(
481481
482482
The list is ordered in order of priority. That is, the first rule in the list
483483
has the highest priority."""
484-
# TODO: This is a dummy implementation. In the future, we should have a proper
485-
# implementation that fetches the mux rules from the database.
486-
return [
487-
v1_models.MuxRule(
488-
# Hardcode some UUID just for mocking purposes
489-
provider_id="00000000-0000-0000-0000-000000000001",
490-
model="gpt-3.5-turbo",
491-
matcher_type=v1_models.MuxMatcherType.file_regex,
492-
matcher=".*\\.txt",
493-
),
494-
v1_models.MuxRule(
495-
provider_id="00000000-0000-0000-0000-000000000002",
496-
model="davinci",
497-
matcher_type=v1_models.MuxMatcherType.catch_all,
498-
),
499-
]
484+
try:
485+
muxes = await wscrud.get_muxes(workspace_name)
486+
except crud.WorkspaceDoesNotExistError:
487+
raise HTTPException(status_code=404, detail="Workspace does not exist")
488+
except Exception:
489+
logger.exception("Error while getting workspace")
490+
raise HTTPException(status_code=500, detail="Internal server error")
491+
492+
return muxes
500493

501494

502495
@v1.put(
@@ -510,8 +503,16 @@ async def set_workspace_muxes(
510503
request: List[v1_models.MuxRule],
511504
):
512505
"""Set the mux rules of a workspace."""
513-
# TODO: This is a dummy implementation. In the future, we should have a proper
514-
# implementation that sets the mux rules in the database.
506+
try:
507+
await wscrud.set_muxes(workspace_name, request)
508+
except crud.WorkspaceDoesNotExistError:
509+
raise HTTPException(status_code=404, detail="Workspace does not exist")
510+
except crud.WorkspaceCrudError as e:
511+
raise HTTPException(status_code=400, detail=str(e))
512+
except Exception:
513+
logger.exception("Error while setting muxes")
514+
raise HTTPException(status_code=500, detail="Internal server error")
515+
515516
return Response(status_code=204)
516517

517518

src/codegate/api/v1_models.py

-5
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,6 @@ class MuxMatcherType(str, Enum):
287287
Represents the different types of matchers we support.
288288
"""
289289

290-
# Match a regular expression for a file path
291-
# in the prompt. Note that if no file is found,
292-
# the prompt will be passed through.
293-
file_regex = "file_regex"
294-
295290
# Always match this prompt
296291
catch_all = "catch_all"
297292

src/codegate/db/connection.py

+72-8
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Alert,
2020
GetPromptWithOutputsRow,
2121
GetWorkspaceByNameConditions,
22+
MuxRule,
2223
Output,
2324
Prompt,
2425
ProviderAuthMaterial,
@@ -112,6 +113,15 @@ async def _execute_update_pydantic_model(
112113
raise e
113114
return None
114115

116+
async def _execute_with_no_return(self, sql_command: TextClause, conditions: dict):
117+
"""Execute a command that doesn't return anything."""
118+
try:
119+
async with self._async_db_engine.begin() as conn:
120+
await conn.execute(sql_command, conditions)
121+
except Exception as e:
122+
logger.error(f"Failed to execute command: {sql_command}.", error=str(e))
123+
raise e
124+
115125
async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
116126
if prompt_params is None:
117127
return None
@@ -459,22 +469,44 @@ async def add_provider_model(self, model: ProviderModel) -> ProviderModel:
459469
added_model = await self._execute_update_pydantic_model(model, sql, should_raise=True)
460470
return added_model
461471

462-
async def delete_provider_models(self, provider_id: str) -> Optional[ProviderModel]:
472+
async def delete_provider_models(self, provider_id: str):
463473
sql = text(
464474
"""
465475
DELETE FROM provider_models
466476
WHERE provider_endpoint_id = :provider_endpoint_id
477+
"""
478+
)
479+
conditions = {"provider_endpoint_id": provider_id}
480+
await self._execute_with_no_return(sql, conditions)
481+
482+
async def delete_muxes_by_workspace(self, workspace_id: str):
483+
sql = text(
484+
"""
485+
DELETE FROM muxes
486+
WHERE workspace_id = :workspace_id
467487
RETURNING *
468488
"""
469489
)
470-
await self._execute_update_pydantic_model(
471-
ProviderModel(
472-
provider_endpoint_id=provider_id,
473-
name="Fake name to respect the signature of the function",
474-
),
475-
sql,
476-
should_raise=True,
490+
491+
conditions = {"workspace_id": workspace_id}
492+
await self._execute_with_no_return(sql, conditions)
493+
494+
async def add_mux(self, mux: MuxRule) -> MuxRule:
495+
sql = text(
496+
"""
497+
INSERT INTO muxes (
498+
id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type,
499+
matcher_blob, priority, created_at, updated_at
500+
)
501+
VALUES (
502+
:id, :provider_endpoint_id, :provider_model_name, :workspace_id,
503+
:matcher_type, :matcher_blob, :priority, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP
504+
)
505+
RETURNING *
506+
"""
477507
)
508+
added_mux = await self._execute_update_pydantic_model(mux, sql, should_raise=True)
509+
return added_mux
478510

479511

480512
class DbReader(DbCodeGate):
@@ -701,6 +733,22 @@ async def get_provider_models_by_provider_id(self, provider_id: str) -> List[Pro
701733
)
702734
return models
703735

736+
async def get_provider_model_by_provider_id_and_name(
737+
self, provider_id: str, model_name: str
738+
) -> Optional[ProviderModel]:
739+
sql = text(
740+
"""
741+
SELECT provider_endpoint_id, name
742+
FROM provider_models
743+
WHERE provider_endpoint_id = :provider_endpoint_id AND name = :name
744+
"""
745+
)
746+
conditions = {"provider_endpoint_id": provider_id, "name": model_name}
747+
models = await self._exec_select_conditions_to_pydantic(
748+
ProviderModel, sql, conditions, should_raise=True
749+
)
750+
return models[0] if models else None
751+
704752
async def get_all_provider_models(self) -> List[ProviderModel]:
705753
sql = text(
706754
"""
@@ -712,6 +760,22 @@ async def get_all_provider_models(self) -> List[ProviderModel]:
712760
models = await self._execute_select_pydantic_model(ProviderModel, sql)
713761
return models
714762

763+
async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]:
764+
sql = text(
765+
"""
766+
SELECT id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type,
767+
matcher_blob, priority, created_at, updated_at
768+
FROM muxes
769+
WHERE workspace_id = :workspace_id
770+
ORDER BY priority ASC
771+
"""
772+
)
773+
conditions = {"workspace_id": workspace_id}
774+
muxes = await self._exec_select_conditions_to_pydantic(
775+
MuxRule, sql, conditions, should_raise=True
776+
)
777+
return muxes
778+
715779

716780
def init_db_sync(db_path: Optional[str] = None):
717781
"""DB will be initialized in the constructor in case it doesn't exist."""

src/codegate/db/models.py

+12
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,15 @@ class ProviderModel(BaseModel):
173173
provider_endpoint_id: str
174174
provider_endpoint_name: Optional[str] = None
175175
name: str
176+
177+
178+
class MuxRule(BaseModel):
179+
id: str
180+
provider_endpoint_id: str
181+
provider_model_name: str
182+
workspace_id: str
183+
matcher_type: str
184+
matcher_blob: str
185+
priority: int
186+
created_at: Optional[datetime.datetime] = None
187+
updated_at: Optional[datetime.datetime] = None

src/codegate/workspaces/crud.py

+82-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1+
import asyncio
12
import datetime
23
from typing import List, Optional, Tuple
4+
from uuid import uuid4 as uuid
35

46
from codegate.db.connection import DbReader, DbRecorder
5-
from codegate.db.models import ActiveWorkspace, Session, WorkspaceRow, WorkspaceWithSessionInfo
7+
from codegate.db.models import (
8+
ActiveWorkspace,
9+
MuxRule,
10+
Session,
11+
WorkspaceRow,
12+
WorkspaceWithSessionInfo,
13+
)
614

715

816
class WorkspaceCrudError(Exception):
@@ -17,6 +25,10 @@ class WorkspaceAlreadyActiveError(WorkspaceCrudError):
1725
pass
1826

1927

28+
class WorkspaceMuxRuleDoesNotExistError(WorkspaceCrudError):
29+
pass
30+
31+
2032
DEFAULT_WORKSPACE_NAME = "default"
2133

2234
# These are reserved keywords that cannot be used for workspaces
@@ -202,3 +214,72 @@ async def get_workspace_by_name(self, workspace_name: str) -> WorkspaceRow:
202214
if not workspace:
203215
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
204216
return workspace
217+
218+
# Can't use type hints since the models are not yet defined
219+
# Note that I'm explicitly importing the models here to avoid circular imports.
220+
async def get_muxes(self, workspace_name: str):
221+
from codegate.api import v1_models
222+
223+
# Verify if workspace exists
224+
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
225+
if not workspace:
226+
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
227+
228+
dbmuxes = await self._db_reader.get_muxes_by_workspace(workspace.id)
229+
230+
muxes = []
231+
# These are already sorted by priority
232+
for dbmux in dbmuxes:
233+
muxes.append(
234+
v1_models.MuxRule(
235+
provider_id=dbmux.provider_endpoint_id,
236+
model=dbmux.provider_model_name,
237+
matcher_type=dbmux.matcher_type,
238+
matcher=dbmux.matcher_blob,
239+
)
240+
)
241+
242+
return muxes
243+
244+
# Can't use type hints since the models are not yet defined
245+
async def set_muxes(self, workspace_name: str, muxes):
246+
# Verify if workspace exists
247+
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
248+
if not workspace:
249+
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
250+
251+
# Delete all muxes for the workspace
252+
db_recorder = DbRecorder()
253+
await db_recorder.delete_muxes_by_workspace(workspace.id)
254+
255+
tasks = set()
256+
257+
# Add the new muxes
258+
priority = 0
259+
260+
# Verify all models are valid
261+
for mux in muxes:
262+
dbm = await self._db_reader.get_provider_model_by_provider_id_and_name(
263+
mux.provider_id,
264+
mux.model,
265+
)
266+
if not dbm:
267+
raise WorkspaceCrudError(
268+
f"Model {mux.model} does not exist for provider {mux.provider_id}"
269+
)
270+
271+
for mux in muxes:
272+
new_mux = MuxRule(
273+
id=str(uuid()),
274+
provider_endpoint_id=mux.provider_id,
275+
provider_model_name=mux.model,
276+
workspace_id=workspace.id,
277+
matcher_type=mux.matcher_type,
278+
matcher_blob=mux.matcher if mux.matcher else "",
279+
priority=priority,
280+
)
281+
tasks.add(db_recorder.add_mux(new_mux))
282+
283+
priority += 1
284+
285+
await asyncio.gather(*tasks)

0 commit comments

Comments
 (0)