diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index b256635f..7900d873 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -109,13 +109,18 @@ async def get_provider_endpoint( status_code=201, ) async def add_provider_endpoint( - request: v1_models.ProviderEndpoint, + request: v1_models.AddProviderEndpointRequest, ) -> v1_models.ProviderEndpoint: """Add a provider endpoint.""" try: provend = await pcrud.add_endpoint(request) except AlreadyExistsError: raise HTTPException(status_code=409, detail="Provider endpoint already exists") + except ValueError as e: + raise HTTPException( + status_code=400, + detail=str(e), + ) except ValidationError as e: # TODO: This should be more specific raise HTTPException( @@ -123,6 +128,7 @@ async def add_provider_endpoint( detail=str(e), ) except Exception: + logger.exception("Error while adding provider endpoint") raise HTTPException(status_code=500, detail="Internal server error") return provend @@ -154,20 +160,24 @@ async def configure_auth_material( ) async def update_provider_endpoint( provider_id: UUID, - request: v1_models.ProviderEndpoint, + request: v1_models.AddProviderEndpointRequest, ) -> v1_models.ProviderEndpoint: """Update a provider endpoint by ID.""" try: - request.id = provider_id + request.id = str(provider_id) provend = await pcrud.update_endpoint(request) + except provendcrud.ProviderNotFoundError: + raise HTTPException(status_code=404, detail="Provider endpoint not found") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) except ValidationError as e: # TODO: This should be more specific raise HTTPException( status_code=400, detail=str(e), ) - except Exception: - raise HTTPException(status_code=500, detail="Internal server error") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) return provend @@ -471,22 +481,15 @@ async def get_workspace_muxes( The list is ordered in order of priority. That is, the first rule in the list has the highest priority.""" - # TODO: This is a dummy implementation. In the future, we should have a proper - # implementation that fetches the mux rules from the database. - return [ - v1_models.MuxRule( - # Hardcode some UUID just for mocking purposes - provider_id="00000000-0000-0000-0000-000000000001", - model="gpt-3.5-turbo", - matcher_type=v1_models.MuxMatcherType.file_regex, - matcher=".*\\.txt", - ), - v1_models.MuxRule( - provider_id="00000000-0000-0000-0000-000000000002", - model="davinci", - matcher_type=v1_models.MuxMatcherType.catch_all, - ), - ] + try: + muxes = await wscrud.get_muxes(workspace_name) + except crud.WorkspaceDoesNotExistError: + raise HTTPException(status_code=404, detail="Workspace does not exist") + except Exception: + logger.exception("Error while getting workspace") + raise HTTPException(status_code=500, detail="Internal server error") + + return muxes @v1.put( @@ -500,8 +503,16 @@ async def set_workspace_muxes( request: List[v1_models.MuxRule], ): """Set the mux rules of a workspace.""" - # TODO: This is a dummy implementation. In the future, we should have a proper - # implementation that sets the mux rules in the database. + try: + await wscrud.set_muxes(workspace_name, request) + except crud.WorkspaceDoesNotExistError: + raise HTTPException(status_code=404, detail="Workspace does not exist") + except crud.WorkspaceCrudError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception: + logger.exception("Error while setting muxes") + raise HTTPException(status_code=500, detail="Internal server error") + return Response(status_code=204) diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index 65c4acc1..be7706cc 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -222,7 +222,7 @@ class ProviderEndpoint(pydantic.BaseModel): name: str description: str = "" provider_type: ProviderType - endpoint: str + endpoint: str = "" # Some providers have defaults we can leverage auth_type: Optional[ProviderAuthType] = ProviderAuthType.none @staticmethod @@ -250,6 +250,14 @@ def get_from_registry(self, registry: ProviderRegistry) -> Optional[BaseProvider return registry.get_provider(self.provider_type) +class AddProviderEndpointRequest(ProviderEndpoint): + """ + Represents a request to add a provider endpoint. + """ + + api_key: Optional[str] = None + + class ConfigureAuthMaterial(pydantic.BaseModel): """ Represents a request to configure auth material for a provider. @@ -279,11 +287,6 @@ class MuxMatcherType(str, Enum): Represents the different types of matchers we support. """ - # Match a regular expression for a file path - # in the prompt. Note that if no file is found, - # the prompt will be passed through. - file_regex = "file_regex" - # Always match this prompt catch_all = "catch_all" diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index caed5276..1ea5d736 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -19,6 +19,7 @@ Alert, GetPromptWithOutputsRow, GetWorkspaceByNameConditions, + MuxRule, Output, Prompt, ProviderAuthMaterial, @@ -112,6 +113,15 @@ async def _execute_update_pydantic_model( raise e return None + async def _execute_with_no_return(self, sql_command: TextClause, conditions: dict): + """Execute a command that doesn't return anything.""" + try: + async with self._async_db_engine.begin() as conn: + await conn.execute(sql_command, conditions) + except Exception as e: + logger.error(f"Failed to execute command: {sql_command}.", error=str(e)) + raise e + async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]: if prompt_params is None: return None @@ -459,6 +469,45 @@ async def add_provider_model(self, model: ProviderModel) -> ProviderModel: added_model = await self._execute_update_pydantic_model(model, sql, should_raise=True) return added_model + async def delete_provider_models(self, provider_id: str): + sql = text( + """ + DELETE FROM provider_models + WHERE provider_endpoint_id = :provider_endpoint_id + """ + ) + conditions = {"provider_endpoint_id": provider_id} + await self._execute_with_no_return(sql, conditions) + + async def delete_muxes_by_workspace(self, workspace_id: str): + sql = text( + """ + DELETE FROM muxes + WHERE workspace_id = :workspace_id + RETURNING * + """ + ) + + conditions = {"workspace_id": workspace_id} + await self._execute_with_no_return(sql, conditions) + + async def add_mux(self, mux: MuxRule) -> MuxRule: + sql = text( + """ + INSERT INTO muxes ( + id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type, + matcher_blob, priority, created_at, updated_at + ) + VALUES ( + :id, :provider_endpoint_id, :provider_model_name, :workspace_id, + :matcher_type, :matcher_blob, :priority, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP + ) + RETURNING * + """ + ) + added_mux = await self._execute_update_pydantic_model(mux, sql, should_raise=True) + return added_mux + class DbReader(DbCodeGate): @@ -684,6 +733,22 @@ async def get_provider_models_by_provider_id(self, provider_id: str) -> List[Pro ) return models + async def get_provider_model_by_provider_id_and_name( + self, provider_id: str, model_name: str + ) -> Optional[ProviderModel]: + sql = text( + """ + SELECT provider_endpoint_id, name + FROM provider_models + WHERE provider_endpoint_id = :provider_endpoint_id AND name = :name + """ + ) + conditions = {"provider_endpoint_id": provider_id, "name": model_name} + models = await self._exec_select_conditions_to_pydantic( + ProviderModel, sql, conditions, should_raise=True + ) + return models[0] if models else None + async def get_all_provider_models(self) -> List[ProviderModel]: sql = text( """ @@ -695,6 +760,22 @@ async def get_all_provider_models(self) -> List[ProviderModel]: models = await self._execute_select_pydantic_model(ProviderModel, sql) return models + async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]: + sql = text( + """ + SELECT id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type, + matcher_blob, priority, created_at, updated_at + FROM muxes + WHERE workspace_id = :workspace_id + ORDER BY priority ASC + """ + ) + conditions = {"workspace_id": workspace_id} + muxes = await self._exec_select_conditions_to_pydantic( + MuxRule, sql, conditions, should_raise=True + ) + return muxes + def init_db_sync(db_path: Optional[str] = None): """DB will be initialized in the constructor in case it doesn't exist.""" diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index c2a5ce8a..06b20b67 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -173,3 +173,15 @@ class ProviderModel(BaseModel): provider_endpoint_id: str provider_endpoint_name: Optional[str] = None name: str + + +class MuxRule(BaseModel): + id: str + provider_endpoint_id: str + provider_model_name: str + workspace_id: str + matcher_type: str + matcher_blob: str + priority: int + created_at: Optional[datetime.datetime] = None + updated_at: Optional[datetime.datetime] = None diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index 48821de0..9c656794 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -8,7 +8,7 @@ from codegate.pipeline.factory import PipelineFactory from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer from codegate.providers.anthropic.completion_handler import AnthropicCompletion -from codegate.providers.base import BaseProvider +from codegate.providers.base import BaseProvider, ModelFetchError from codegate.providers.litellmshim import anthropic_stream_generator @@ -29,16 +29,23 @@ def __init__( def provider_route_name(self) -> str: return "anthropic" - def models(self) -> List[str]: - # TODO: This won't work since we need an API Key being set. - resp = httpx.get("https://api.anthropic.com/models") - # If Anthropic returned 404, it means it's not accepting our - # requests. We should throw an error. - if resp.status_code == 404: - raise HTTPException( - status_code=404, - detail="The Anthropic API is not accepting requests. Please check your API key.", - ) + def models(self, endpoint: str = None, api_key: str = None) -> List[str]: + headers = { + "Content-Type": "application/json", + "anthropic-version": "2023-06-01", + } + if api_key: + headers["x-api-key"] = api_key + if not endpoint: + endpoint = "https://api.anthropic.com" + + resp = httpx.get( + f"{endpoint}/v1/models", + headers=headers, + ) + + if resp.status_code != 200: + raise ModelFetchError(f"Failed to fetch models from Anthropic API: {resp.text}") respjson = resp.json() diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 8e9a4d40..050cea06 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -24,6 +24,10 @@ StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]] +class ModelFetchError(Exception): + pass + + class BaseProvider(ABC): """ The provider class is responsible for defining the API routes and @@ -55,7 +59,7 @@ def _setup_routes(self) -> None: pass @abstractmethod - def models(self) -> List[str]: + def models(self, endpoint, str=None, api_key: str = None) -> List[str]: pass @property diff --git a/src/codegate/providers/crud/crud.py b/src/codegate/providers/crud/crud.py index ebae2b97..5207b7d6 100644 --- a/src/codegate/providers/crud/crud.py +++ b/src/codegate/providers/crud/crud.py @@ -11,7 +11,7 @@ from codegate.db import models as dbmodels from codegate.db.connection import DbReader, DbRecorder from codegate.providers.base import BaseProvider -from codegate.providers.registry import ProviderRegistry +from codegate.providers.registry import ProviderRegistry, get_provider_registry logger = structlog.get_logger("codegate") @@ -62,23 +62,106 @@ async def get_endpoint_by_name(self, name: str) -> Optional[apimodelsv1.Provider return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) async def add_endpoint( - self, endpoint: apimodelsv1.ProviderEndpoint + self, endpoint: apimodelsv1.AddProviderEndpointRequest ) -> apimodelsv1.ProviderEndpoint: """Add an endpoint.""" + + if not endpoint.endpoint: + endpoint.endpoint = provider_default_endpoints(endpoint.provider_type) + + # If we STILL don't have an endpoint, we can't continue + if not endpoint.endpoint: + raise ValueError("No endpoint provided and no default found for provider type") + dbend = endpoint.to_db_model() + provider_registry = get_provider_registry() # We override the ID here, as we want to generate it. dbend.id = str(uuid4()) - dbendpoint = await self._db_writer.add_provider_endpoint() + prov = endpoint.get_from_registry(provider_registry) + if prov is None: + raise ValueError("Unknown provider type: {}".format(endpoint.provider_type)) + + models = [] + if endpoint.auth_type == apimodelsv1.ProviderAuthType.api_key and not endpoint.api_key: + raise ValueError("API key must be provided for API auth type") + if endpoint.auth_type != apimodelsv1.ProviderAuthType.passthrough: + try: + models = prov.models(endpoint=endpoint.endpoint, api_key=endpoint.api_key) + except Exception as err: + raise ValueError("Unable to get models from provider: {}".format(str(err))) + + dbendpoint = await self._db_writer.add_provider_endpoint(dbend) + + await self._db_writer.push_provider_auth_material( + dbmodels.ProviderAuthMaterial( + provider_endpoint_id=dbendpoint.id, + auth_type=endpoint.auth_type, + auth_blob=endpoint.api_key if endpoint.api_key else "", + ) + ) + + for model in models: + await self._db_writer.add_provider_model( + dbmodels.ProviderModel( + provider_endpoint_id=dbendpoint.id, + name=model, + ) + ) return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) async def update_endpoint( - self, endpoint: apimodelsv1.ProviderEndpoint + self, endpoint: apimodelsv1.AddProviderEndpointRequest ) -> apimodelsv1.ProviderEndpoint: """Update an endpoint.""" + if not endpoint.endpoint: + endpoint.endpoint = provider_default_endpoints(endpoint.provider_type) + + # If we STILL don't have an endpoint, we can't continue + if not endpoint.endpoint: + raise ValueError("No endpoint provided and no default found for provider type") + + provider_registry = get_provider_registry() + prov = endpoint.get_from_registry(provider_registry) + if prov is None: + raise ValueError("Unknown provider type: {}".format(endpoint.provider_type)) + + founddbe = await self._db_reader.get_provider_endpoint_by_id(str(endpoint.id)) + if founddbe is None: + raise ProviderNotFoundError("Provider not found") + + models = [] + if endpoint.auth_type == apimodelsv1.ProviderAuthType.api_key and not endpoint.api_key: + raise ValueError("API key must be provided for API auth type") + if endpoint.auth_type != apimodelsv1.ProviderAuthType.passthrough: + try: + models = prov.models(endpoint=endpoint.endpoint, api_key=endpoint.api_key) + except Exception as err: + raise ValueError("Unable to get models from provider: {}".format(str(err))) + + # Reset all provider models. + await self._db_writer.delete_provider_models(str(endpoint.id)) + + for model in models: + await self._db_writer.add_provider_model( + dbmodels.ProviderModel( + provider_endpoint_id=founddbe.id, + name=model, + ) + ) + dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model()) + + await self._db_writer.push_provider_auth_material( + dbmodels.ProviderAuthMaterial( + provider_endpoint_id=dbendpoint.id, + auth_type=endpoint.auth_type, + auth_blob=endpoint.api_key if endpoint.api_key else "", + ) + ) + return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) async def configure_auth_material( @@ -175,6 +258,13 @@ async def initialize_provider_endpoints(preg: ProviderRegistry): continue pimpl = provend.get_from_registry(preg) + if pimpl is None: + logger.warning( + "Provider not found in registry", + provider=provend.name, + endpoint=provend.endpoint, + ) + continue await try_initialize_provider_endpoints(provend, pimpl, db_writer) @@ -240,7 +330,7 @@ def __provider_endpoint_from_cfg( description=("Endpoint for the {} provided via the CodeGate configuration.").format( provider_name ), - provider_type=provider_name, + provider_type=provider_overrides(provider_name), auth_type=apimodelsv1.ProviderAuthType.passthrough, ) except ValidationError as err: @@ -251,3 +341,24 @@ def __provider_endpoint_from_cfg( err=str(err), ) return None + + +def provider_default_endpoints(provider_type: str) -> str: + defaults = { + "openai": "https://api.openai.com", + "anthropic": "https://api.anthropic.com", + } + + # If we have a default, we return it + # Otherwise, we return an empty string + return defaults.get(provider_type, "") + + +def provider_overrides(provider_type: str) -> str: + overrides = { + "lm_studio": "openai", + } + + # If we have an override, we return it + # Otherwise, we return the type + return overrides.get(provider_type, provider_type) diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index 4478d137..1fd6b27e 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -1,6 +1,6 @@ import json +from typing import List -import httpx import structlog from fastapi import HTTPException, Request @@ -27,12 +27,9 @@ def __init__( def provider_route_name(self) -> str: return "llamacpp" - def models(self): - # HACK: This is using OpenAI's /v1/models endpoint to get the list of models - resp = httpx.get(f"{self.base_url}/v1/models") - jsonresp = resp.json() - - return [model["id"] for model in jsonresp.get("data", [])] + def models(self, endpoint: str = None, api_key: str = None) -> List[str]: + # TODO: Implement file fetching + return [] def _setup_routes(self): """ diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index b8e0477b..66ea38ef 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -1,4 +1,5 @@ import json +from typing import List import httpx import structlog @@ -6,7 +7,7 @@ from codegate.config import Config from codegate.pipeline.factory import PipelineFactory -from codegate.providers.base import BaseProvider +from codegate.providers.base import BaseProvider, ModelFetchError from codegate.providers.ollama.adapter import OllamaInputNormalizer, OllamaOutputNormalizer from codegate.providers.ollama.completion_handler import OllamaShim @@ -34,8 +35,20 @@ def __init__( def provider_route_name(self) -> str: return "ollama" - def models(self): - resp = httpx.get(f"{self.base_url}/api/tags") + def models(self, endpoint: str = None, api_key: str = None) -> List[str]: + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + if not endpoint: + endpoint = self.base_url + resp = httpx.get( + f"{endpoint}/api/tags", + headers=headers, + ) + + if resp.status_code != 200: + raise ModelFetchError(f"Failed to fetch models from Ollama API: {resp.text}") + jsonresp = resp.json() return [model["name"] for model in jsonresp.get("models", [])] diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 87588265..be9ddabe 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -8,7 +8,7 @@ from codegate.config import Config from codegate.pipeline.factory import PipelineFactory -from codegate.providers.base import BaseProvider +from codegate.providers.base import BaseProvider, ModelFetchError from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer @@ -35,9 +35,18 @@ def __init__( def provider_route_name(self) -> str: return "openai" - def models(self) -> List[str]: - # NOTE: This won't work since we need an API Key being set. - resp = httpx.get(f"{self.lm_studio_url}/v1/models") + def models(self, endpoint: str = None, api_key: str = None) -> List[str]: + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + if not endpoint: + endpoint = "https://api.openai.com" + + resp = httpx.get(f"{endpoint}/v1/models", headers=headers) + + if resp.status_code != 200: + raise ModelFetchError(f"Failed to fetch models from OpenAI API: {resp.text}") + jsonresp = resp.json() return [model["id"] for model in jsonresp.get("data", [])] diff --git a/src/codegate/providers/registry.py b/src/codegate/providers/registry.py index 7450460f..3def8840 100644 --- a/src/codegate/providers/registry.py +++ b/src/codegate/providers/registry.py @@ -1,9 +1,26 @@ +from threading import Lock from typing import Dict, Optional from fastapi import FastAPI from codegate.providers.base import BaseProvider +_provider_registry_lock = Lock() +_provider_registry_singleton: Optional["ProviderRegistry"] = None + + +def get_provider_registry(app: FastAPI = None) -> "ProviderRegistry": + global _provider_registry_singleton + + if _provider_registry_singleton is None: + if app is None: + raise ValueError("Cannot initialize a ProviderRegistry without an app") + with _provider_registry_lock: + if _provider_registry_singleton is None: + _provider_registry_singleton = ProviderRegistry(app) + + return _provider_registry_singleton + class ProviderRegistry: def __init__(self, app: FastAPI): diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py index 70f768c7..0826e5e2 100644 --- a/src/codegate/providers/vllm/provider.py +++ b/src/codegate/providers/vllm/provider.py @@ -1,4 +1,5 @@ import json +from typing import List from urllib.parse import urljoin import httpx @@ -8,7 +9,7 @@ from codegate.config import Config from codegate.pipeline.factory import PipelineFactory -from codegate.providers.base import BaseProvider +from codegate.providers.base import BaseProvider, ModelFetchError from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.vllm.adapter import VLLMInputNormalizer, VLLMOutputNormalizer @@ -45,8 +46,21 @@ def _get_base_url(self) -> str: base_url = f"{base_url}/v1" return base_url - def models(self): - resp = httpx.get(f"{self.base_url}/v1/models") + def models(self, endpoint: str = None, api_key: str = None) -> List[str]: + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + if not endpoint: + endpoint = self._get_base_url() + + resp = httpx.get( + f"{endpoint}/v1/models", + headers=headers, + ) + + if resp.status_code != 200: + raise ModelFetchError(f"Failed to fetch models from vLLM API: {resp.text}") + jsonresp = resp.json() return [model["id"] for model in jsonresp.get("data", [])] diff --git a/src/codegate/server.py b/src/codegate/server.py index ece60c0c..216ba95e 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -15,7 +15,7 @@ from codegate.providers.llamacpp.provider import LlamaCppProvider from codegate.providers.ollama.provider import OllamaProvider from codegate.providers.openai.provider import OpenAIProvider -from codegate.providers.registry import ProviderRegistry +from codegate.providers.registry import ProviderRegistry, get_provider_registry from codegate.providers.vllm.provider import VLLMProvider logger = structlog.get_logger("codegate") @@ -64,7 +64,7 @@ async def log_user_agent(request: Request, call_next): app.add_middleware(ServerErrorMiddleware, handler=custom_error_handler) # Create provider registry - registry = ProviderRegistry(app) + registry = get_provider_registry(app) app.set_provider_registry(registry) # Register all known providers diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index 77d8212b..39c23e86 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -1,8 +1,16 @@ +import asyncio import datetime from typing import List, Optional, Tuple +from uuid import uuid4 as uuid from codegate.db.connection import DbReader, DbRecorder -from codegate.db.models import ActiveWorkspace, Session, WorkspaceRow, WorkspaceWithSessionInfo +from codegate.db.models import ( + ActiveWorkspace, + MuxRule, + Session, + WorkspaceRow, + WorkspaceWithSessionInfo, +) class WorkspaceCrudError(Exception): @@ -17,6 +25,10 @@ class WorkspaceAlreadyActiveError(WorkspaceCrudError): pass +class WorkspaceMuxRuleDoesNotExistError(WorkspaceCrudError): + pass + + DEFAULT_WORKSPACE_NAME = "default" # These are reserved keywords that cannot be used for workspaces @@ -202,3 +214,72 @@ async def get_workspace_by_name(self, workspace_name: str) -> WorkspaceRow: if not workspace: raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") return workspace + + # Can't use type hints since the models are not yet defined + # Note that I'm explicitly importing the models here to avoid circular imports. + async def get_muxes(self, workspace_name: str): + from codegate.api import v1_models + + # Verify if workspace exists + workspace = await self._db_reader.get_workspace_by_name(workspace_name) + if not workspace: + raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") + + dbmuxes = await self._db_reader.get_muxes_by_workspace(workspace.id) + + muxes = [] + # These are already sorted by priority + for dbmux in dbmuxes: + muxes.append( + v1_models.MuxRule( + provider_id=dbmux.provider_endpoint_id, + model=dbmux.provider_model_name, + matcher_type=dbmux.matcher_type, + matcher=dbmux.matcher_blob, + ) + ) + + return muxes + + # Can't use type hints since the models are not yet defined + async def set_muxes(self, workspace_name: str, muxes): + # Verify if workspace exists + workspace = await self._db_reader.get_workspace_by_name(workspace_name) + if not workspace: + raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") + + # Delete all muxes for the workspace + db_recorder = DbRecorder() + await db_recorder.delete_muxes_by_workspace(workspace.id) + + tasks = set() + + # Add the new muxes + priority = 0 + + # Verify all models are valid + for mux in muxes: + dbm = await self._db_reader.get_provider_model_by_provider_id_and_name( + mux.provider_id, + mux.model, + ) + if not dbm: + raise WorkspaceCrudError( + f"Model {mux.model} does not exist for provider {mux.provider_id}" + ) + + for mux in muxes: + new_mux = MuxRule( + id=str(uuid()), + provider_endpoint_id=mux.provider_id, + provider_model_name=mux.model, + workspace_id=workspace.id, + matcher_type=mux.matcher_type, + matcher_blob=mux.matcher if mux.matcher else "", + priority=priority, + ) + tasks.add(db_recorder.add_mux(new_mux)) + + priority += 1 + + await asyncio.gather(*tasks) diff --git a/tests/test_server.py b/tests/test_server.py index f7b7a12f..80bb7cb0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -97,7 +97,7 @@ def test_version_endpoint(mock_fetch_latest_version, test_client: TestClient) -> @patch("codegate.pipeline.secrets.manager.SecretsManager") -@patch("codegate.server.ProviderRegistry") +@patch("codegate.server.get_provider_registry") def test_provider_registration(mock_registry, mock_secrets_mgr, mock_pipeline_factory) -> None: """Test that all providers are registered correctly.""" init_app(mock_pipeline_factory)