1919 Alert ,
2020 GetPromptWithOutputsRow ,
2121 GetWorkspaceByNameConditions ,
22+ MuxRule ,
2223 Output ,
2324 Prompt ,
2425 ProviderAuthMaterial ,
@@ -112,6 +113,15 @@ async def _execute_update_pydantic_model(
112113 raise e
113114 return None
114115
116+ async def _execute_with_no_return (self , sql_command : TextClause , conditions : dict ):
117+ """Execute a command that doesn't return anything."""
118+ try :
119+ async with self ._async_db_engine .begin () as conn :
120+ await conn .execute (sql_command , conditions )
121+ except Exception as e :
122+ logger .error (f"Failed to execute command: { sql_command } ." , error = str (e ))
123+ raise e
124+
115125 async def record_request (self , prompt_params : Optional [Prompt ] = None ) -> Optional [Prompt ]:
116126 if prompt_params is None :
117127 return None
@@ -459,6 +469,45 @@ async def add_provider_model(self, model: ProviderModel) -> ProviderModel:
459469 added_model = await self ._execute_update_pydantic_model (model , sql , should_raise = True )
460470 return added_model
461471
472+ async def delete_provider_models (self , provider_id : str ):
473+ sql = text (
474+ """
475+ DELETE FROM provider_models
476+ WHERE provider_endpoint_id = :provider_endpoint_id
477+ """
478+ )
479+ conditions = {"provider_endpoint_id" : provider_id }
480+ await self ._execute_with_no_return (sql , conditions )
481+
482+ async def delete_muxes_by_workspace (self , workspace_id : str ):
483+ sql = text (
484+ """
485+ DELETE FROM muxes
486+ WHERE workspace_id = :workspace_id
487+ RETURNING *
488+ """
489+ )
490+
491+ conditions = {"workspace_id" : workspace_id }
492+ await self ._execute_with_no_return (sql , conditions )
493+
494+ async def add_mux (self , mux : MuxRule ) -> MuxRule :
495+ sql = text (
496+ """
497+ INSERT INTO muxes (
498+ id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type,
499+ matcher_blob, priority, created_at, updated_at
500+ )
501+ VALUES (
502+ :id, :provider_endpoint_id, :provider_model_name, :workspace_id,
503+ :matcher_type, :matcher_blob, :priority, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP
504+ )
505+ RETURNING *
506+ """
507+ )
508+ added_mux = await self ._execute_update_pydantic_model (mux , sql , should_raise = True )
509+ return added_mux
510+
462511
463512class DbReader (DbCodeGate ):
464513
@@ -684,6 +733,22 @@ async def get_provider_models_by_provider_id(self, provider_id: str) -> List[Pro
684733 )
685734 return models
686735
736+ async def get_provider_model_by_provider_id_and_name (
737+ self , provider_id : str , model_name : str
738+ ) -> Optional [ProviderModel ]:
739+ sql = text (
740+ """
741+ SELECT provider_endpoint_id, name
742+ FROM provider_models
743+ WHERE provider_endpoint_id = :provider_endpoint_id AND name = :name
744+ """
745+ )
746+ conditions = {"provider_endpoint_id" : provider_id , "name" : model_name }
747+ models = await self ._exec_select_conditions_to_pydantic (
748+ ProviderModel , sql , conditions , should_raise = True
749+ )
750+ return models [0 ] if models else None
751+
687752 async def get_all_provider_models (self ) -> List [ProviderModel ]:
688753 sql = text (
689754 """
@@ -695,6 +760,22 @@ async def get_all_provider_models(self) -> List[ProviderModel]:
695760 models = await self ._execute_select_pydantic_model (ProviderModel , sql )
696761 return models
697762
763+ async def get_muxes_by_workspace (self , workspace_id : str ) -> List [MuxRule ]:
764+ sql = text (
765+ """
766+ SELECT id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type,
767+ matcher_blob, priority, created_at, updated_at
768+ FROM muxes
769+ WHERE workspace_id = :workspace_id
770+ ORDER BY priority ASC
771+ """
772+ )
773+ conditions = {"workspace_id" : workspace_id }
774+ muxes = await self ._exec_select_conditions_to_pydantic (
775+ MuxRule , sql , conditions , should_raise = True
776+ )
777+ return muxes
778+
698779
699780def init_db_sync (db_path : Optional [str ] = None ):
700781 """DB will be initialized in the constructor in case it doesn't exist."""
0 commit comments