From 538492a89b0190df07a083e768abf3991f012ca4 Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Mon, 27 Jan 2025 10:43:06 +0200 Subject: [PATCH] Kick off provider endpoint CRUD structure and registration This structure will handle all the database operations and turn that into the right models. Note that for provider endpoints we already have a way of setting these via configuration, so this is taken into account to output some sample objects that users can leverage. Each provider will need to implement a `models` function which allows us to auto-discover models for a provider. Signed-off-by: Juan Antonio Osorio --- src/codegate/api/v1.py | 180 ++++++++++----- src/codegate/api/v1_models.py | 36 ++- src/codegate/cli.py | 4 + src/codegate/db/connection.py | 132 +++++++++++ src/codegate/db/models.py | 21 ++ src/codegate/providers/anthropic/provider.py | 17 ++ src/codegate/providers/base.py | 6 +- src/codegate/providers/crud/__init__.py | 3 + src/codegate/providers/crud/crud.py | 229 +++++++++++++++++++ src/codegate/providers/llamacpp/provider.py | 8 + src/codegate/providers/ollama/provider.py | 6 + src/codegate/providers/openai/provider.py | 9 + src/codegate/providers/vllm/provider.py | 6 + src/codegate/server.py | 12 +- tests/providers/test_registry.py | 3 + tests/test_provider.py | 3 + 16 files changed, 606 insertions(+), 69 deletions(-) create mode 100644 src/codegate/providers/crud/__init__.py create mode 100644 src/codegate/providers/crud/crud.py diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 2a160b9e..d4695ffd 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -1,21 +1,24 @@ from typing import List, Optional +from uuid import UUID import requests import structlog -from fastapi import APIRouter, HTTPException, Response +from fastapi import APIRouter, Depends, HTTPException, Response from fastapi.responses import StreamingResponse from fastapi.routing import APIRoute -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from codegate import __version__ from codegate.api import v1_models, v1_processing from codegate.db.connection import AlreadyExistsError, DbReader +from codegate.providers import crud as provendcrud from codegate.workspaces import crud logger = structlog.get_logger("codegate") v1 = APIRouter() wscrud = crud.WorkspaceCrud() +pcrud = provendcrud.ProviderCrud() # This is a singleton object dbreader = DbReader() @@ -25,38 +28,78 @@ def uniq_name(route: APIRoute): return f"v1_{route.name}" +class FilterByNameParams(BaseModel): + name: Optional[str] = None + + @v1.get("/provider-endpoints", tags=["Providers"], generate_unique_id_function=uniq_name) -async def list_provider_endpoints(name: Optional[str] = None) -> List[v1_models.ProviderEndpoint]: +async def list_provider_endpoints( + filter_query: FilterByNameParams = Depends(), +) -> List[v1_models.ProviderEndpoint]: """List all provider endpoints.""" - # NOTE: This is a dummy implementation. In the future, we should have a proper - # implementation that fetches the provider endpoints from the database. - return [ - v1_models.ProviderEndpoint( - id=1, - name="dummy", - description="Dummy provider endpoint", - endpoint="http://example.com", - provider_type=v1_models.ProviderType.openai, - auth_type=v1_models.ProviderAuthType.none, - ) - ] + if filter_query.name is None: + try: + return await pcrud.list_endpoints() + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + try: + provend = await pcrud.get_endpoint_by_name(filter_query.name) + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + if provend is None: + raise HTTPException(status_code=404, detail="Provider endpoint not found") + return [provend] + + +# This needs to be above /provider-endpoints/{provider_id} to avoid conflict +@v1.get( + "/provider-endpoints/models", + tags=["Providers"], + generate_unique_id_function=uniq_name, +) +async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]: + """List all models for all providers.""" + try: + return await pcrud.get_all_models() + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + +@v1.get( + "/provider-endpoints/{provider_id}/models", + tags=["Providers"], + generate_unique_id_function=uniq_name, +) +async def list_models_by_provider( + provider_id: UUID, +) -> List[v1_models.ModelByProvider]: + """List models by provider.""" + + try: + return await pcrud.models_by_provider(provider_id) + except provendcrud.ProviderNotFoundError: + raise HTTPException(status_code=404, detail="Provider not found") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) @v1.get( "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name ) -async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint: +async def get_provider_endpoint( + provider_id: UUID, +) -> v1_models.ProviderEndpoint: """Get a provider endpoint by ID.""" - # NOTE: This is a dummy implementation. In the future, we should have a proper - # implementation that fetches the provider endpoint from the database. - return v1_models.ProviderEndpoint( - id=provider_id, - name="dummy", - description="Dummy provider endpoint", - endpoint="http://example.com", - provider_type=v1_models.ProviderType.openai, - auth_type=v1_models.ProviderAuthType.none, - ) + try: + provend = await pcrud.get_endpoint_by_id(provider_id) + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + if provend is None: + raise HTTPException(status_code=404, detail="Provider endpoint not found") + return provend @v1.post( @@ -65,59 +108,65 @@ async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint: generate_unique_id_function=uniq_name, status_code=201, ) -async def add_provider_endpoint(request: v1_models.ProviderEndpoint) -> v1_models.ProviderEndpoint: +async def add_provider_endpoint( + request: v1_models.ProviderEndpoint, +) -> v1_models.ProviderEndpoint: """Add a provider endpoint.""" - # NOTE: This is a dummy implementation. In the future, we should have a proper - # implementation that adds the provider endpoint to the database. - return request + try: + provend = await pcrud.add_endpoint(request) + except AlreadyExistsError: + raise HTTPException(status_code=409, detail="Provider endpoint already exists") + except ValidationError as e: + # TODO: This should be more specific + raise HTTPException( + status_code=400, + detail=str(e), + ) + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + return provend @v1.put( "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name ) async def update_provider_endpoint( - provider_id: int, request: v1_models.ProviderEndpoint + provider_id: UUID, + request: v1_models.ProviderEndpoint, ) -> v1_models.ProviderEndpoint: """Update a provider endpoint by ID.""" - # NOTE: This is a dummy implementation. In the future, we should have a proper - # implementation that updates the provider endpoint in the database. - return request + try: + request.id = provider_id + provend = await pcrud.update_endpoint(request) + except ValidationError as e: + # TODO: This should be more specific + raise HTTPException( + status_code=400, + detail=str(e), + ) + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + return provend @v1.delete( "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name ) -async def delete_provider_endpoint(provider_id: int): +async def delete_provider_endpoint( + provider_id: UUID, +): """Delete a provider endpoint by id.""" - # NOTE: This is a dummy implementation. In the future, we should have a proper - # implementation that deletes the provider endpoint from the database. + try: + await pcrud.delete_endpoint(provider_id) + except provendcrud.ProviderNotFoundError: + raise HTTPException(status_code=404, detail="Provider endpoint not found") + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") return Response(status_code=204) -@v1.get( - "/provider-endpoints/{provider_name}/models", - tags=["Providers"], - generate_unique_id_function=uniq_name, -) -async def list_models_by_provider(provider_name: str) -> List[v1_models.ModelByProvider]: - """List models by provider.""" - # NOTE: This is a dummy implementation. In the future, we should have a proper - # implementation that fetches the models by provider from the database. - return [v1_models.ModelByProvider(name="dummy", provider="dummy")] - - -@v1.get( - "/provider-endpoints/models", - tags=["Providers"], - generate_unique_id_function=uniq_name, -) -async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]: - """List all models for all providers.""" - # NOTE: This is a dummy implementation. In the future, we should have a proper - # implementation that fetches all the models for all providers from the database. - return [v1_models.ModelByProvider(name="dummy", provider="dummy")] - - @v1.get("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name) async def list_workspaces() -> v1_models.ListWorkspacesResponse: """List all workspaces.""" @@ -394,7 +443,9 @@ async def delete_workspace_custom_instructions(workspace_name: str): tags=["Workspaces", "Muxes"], generate_unique_id_function=uniq_name, ) -async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]: +async def get_workspace_muxes( + workspace_name: str, +) -> List[v1_models.MuxRule]: """Get the mux rules of a workspace. The list is ordered in order of priority. That is, the first rule in the list @@ -422,7 +473,10 @@ async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]: generate_unique_id_function=uniq_name, status_code=204, ) -async def set_workspace_muxes(workspace_name: str, request: List[v1_models.MuxRule]): +async def set_workspace_muxes( + workspace_name: str, + request: List[v1_models.MuxRule], +): """Set the mux rules of a workspace.""" # TODO: This is a dummy implementation. In the future, we should have a proper # implementation that sets the mux rules in the database. diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index fb4e90d3..3f1f37a6 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -6,6 +6,8 @@ from codegate.db import models as db_models from codegate.pipeline.base import CodeSnippet +from codegate.providers.base import BaseProvider +from codegate.providers.registry import ProviderRegistry class Workspace(pydantic.BaseModel): @@ -122,6 +124,8 @@ class ProviderType(str, Enum): openai = "openai" anthropic = "anthropic" vllm = "vllm" + ollama = "ollama" + lm_studio = "lm_studio" class TokenUsageByModel(pydantic.BaseModel): @@ -191,13 +195,38 @@ class ProviderEndpoint(pydantic.BaseModel): so we can use this for muxing messages. """ - id: int + # This will be set on creation + id: Optional[str] = "" name: str description: str = "" provider_type: ProviderType endpoint: str auth_type: ProviderAuthType + @staticmethod + def from_db_model(db_model: db_models.ProviderEndpoint) -> "ProviderEndpoint": + return ProviderEndpoint( + id=db_model.id, + name=db_model.name, + description=db_model.description, + provider_type=db_model.provider_type, + endpoint=db_model.endpoint, + auth_type=db_model.auth_type, + ) + + def to_db_model(self) -> db_models.ProviderEndpoint: + return db_models.ProviderEndpoint( + id=self.id, + name=self.name, + description=self.description, + provider_type=self.provider_type, + endpoint=self.endpoint, + auth_type=self.auth_type, + ) + + def get_from_registry(self, registry: ProviderRegistry) -> Optional[BaseProvider]: + return registry.get_provider(self.provider_type) + class ModelByProvider(pydantic.BaseModel): """ @@ -207,10 +236,11 @@ class ModelByProvider(pydantic.BaseModel): """ name: str - provider: str + provider_id: str + provider_name: str def __str__(self): - return f"{self.provider}/{self.name}" + return f"{self.provider_name} / {self.name}" class MuxMatcherType(str, Enum): diff --git a/src/codegate/cli.py b/src/codegate/cli.py index dc05ed25..ba3016eb 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -17,6 +17,7 @@ from codegate.db.connection import init_db_sync, init_session_if_not_exists from codegate.pipeline.factory import PipelineFactory from codegate.pipeline.secrets.manager import SecretsManager +from codegate.providers import crud as provendcrud from codegate.providers.copilot.provider import CopilotProvider from codegate.server import init_app from codegate.storage.utils import restore_storage_backup @@ -338,6 +339,9 @@ def serve( # noqa: C901 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + registry = app.provider_registry + loop.run_until_complete(provendcrud.initialize_provider_endpoints(registry)) + # Run the server try: loop.run_until_complete(run_servers(cfg, app)) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 15305790..10c1c81f 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -21,6 +21,9 @@ GetWorkspaceByNameConditions, Output, Prompt, + ProviderAuthMaterial, + ProviderEndpoint, + ProviderModel, Session, WorkspaceRow, WorkspaceWithSessionInfo, @@ -368,6 +371,72 @@ async def recover_workspace(self, workspace: WorkspaceRow) -> Optional[Workspace ) return recovered_workspace + async def add_provider_endpoint(self, provider: ProviderEndpoint) -> ProviderEndpoint: + sql = text( + """ + INSERT INTO provider_endpoints ( + id, name, description, provider_type, endpoint, auth_type, auth_blob + ) + VALUES (:id, :name, :description, :provider_type, :endpoint, :auth_type, "") + RETURNING * + """ + ) + added_provider = await self._execute_update_pydantic_model(provider, sql, should_raise=True) + return added_provider + + async def update_provider_endpoint(self, provider: ProviderEndpoint) -> ProviderEndpoint: + sql = text( + """ + UPDATE provider_endpoints + SET name = :name, description = :description, provider_type = :provider_type, + endpoint = :endpoint, auth_type = :auth_type + WHERE id = :id + RETURNING * + """ + ) + updated_provider = await self._execute_update_pydantic_model( + provider, sql, should_raise=True + ) + return updated_provider + + async def delete_provider_endpoint( + self, + provider: ProviderEndpoint, + ) -> Optional[ProviderEndpoint]: + sql = text( + """ + DELETE FROM provider_endpoints + WHERE id = :id + RETURNING * + """ + ) + deleted_provider = await self._execute_update_pydantic_model( + provider, sql, should_raise=True + ) + return deleted_provider + + async def push_provider_auth_material(self, auth_material: ProviderAuthMaterial): + sql = text( + """ + UPDATE provider_endpoints + SET auth_type = :auth_type, auth_blob = :auth_blob + WHERE id = :provider_endpoint_id + """ + ) + _ = await self._execute_update_pydantic_model(auth_material, sql, should_raise=True) + return + + async def add_provider_model(self, model: ProviderModel) -> ProviderModel: + sql = text( + """ + INSERT INTO provider_models (provider_endpoint_id, name) + VALUES (:provider_endpoint_id, :name) + RETURNING * + """ + ) + added_model = await self._execute_update_pydantic_model(model, sql, should_raise=True) + return added_model + class DbReader(DbCodeGate): @@ -537,6 +606,69 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]: active_workspace = await self._execute_select_pydantic_model(ActiveWorkspace, sql) return active_workspace[0] if active_workspace else None + async def get_provider_endpoint_by_name(self, provider_name: str) -> Optional[ProviderEndpoint]: + sql = text( + """ + SELECT id, name, description, provider_type, endpoint, auth_type, created_at, updated_at + FROM provider_endpoints + WHERE name = :name + """ + ) + conditions = {"name": provider_name} + provider = await self._exec_select_conditions_to_pydantic( + ProviderEndpoint, sql, conditions, should_raise=True + ) + return provider[0] if provider else None + + async def get_provider_endpoint_by_id(self, provider_id: str) -> Optional[ProviderEndpoint]: + sql = text( + """ + SELECT id, name, description, provider_type, endpoint, auth_type, created_at, updated_at + FROM provider_endpoints + WHERE id = :id + """ + ) + conditions = {"id": provider_id} + provider = await self._exec_select_conditions_to_pydantic( + ProviderEndpoint, sql, conditions, should_raise=True + ) + return provider[0] if provider else None + + async def get_provider_endpoints(self) -> List[ProviderEndpoint]: + sql = text( + """ + SELECT id, name, description, provider_type, endpoint, auth_type, created_at, updated_at + FROM provider_endpoints + """ + ) + providers = await self._execute_select_pydantic_model(ProviderEndpoint, sql) + return providers + + async def get_provider_models_by_provider_id(self, provider_id: str) -> List[ProviderModel]: + sql = text( + """ + SELECT provider_endpoint_id, name + FROM provider_models + WHERE provider_endpoint_id = :provider_endpoint_id + """ + ) + conditions = {"provider_endpoint_id": provider_id} + models = await self._exec_select_conditions_to_pydantic( + ProviderModel, sql, conditions, should_raise=True + ) + return models + + async def get_all_provider_models(self) -> List[ProviderModel]: + sql = text( + """ + SELECT pm.provider_endpoint_id, pm.name, pe.name as provider_endpoint_name + FROM provider_models pm + INNER JOIN provider_endpoints pe ON pm.provider_endpoint_id = pe.id + """ + ) + models = await self._execute_select_pydantic_model(ProviderModel, sql) + return models + def init_db_sync(db_path: Optional[str] = None): """DB will be initialized in the constructor in case it doesn't exist.""" diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 23cbea5d..2a6434ef 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -99,3 +99,24 @@ class ActiveWorkspace(BaseModel): custom_instructions: Optional[str] session_id: str last_update: datetime.datetime + + +class ProviderEndpoint(BaseModel): + id: str + name: str + description: str + provider_type: str + endpoint: str + auth_type: str + + +class ProviderAuthMaterial(BaseModel): + provider_endpoint_id: str + auth_type: str + auth_blob: str + + +class ProviderModel(BaseModel): + provider_endpoint_id: str + provider_endpoint_name: Optional[str] = None + name: str diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index 10215c9e..48821de0 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -1,5 +1,7 @@ import json +from typing import List +import httpx import structlog from fastapi import Header, HTTPException, Request @@ -27,6 +29,21 @@ def __init__( def provider_route_name(self) -> str: return "anthropic" + def models(self) -> List[str]: + # TODO: This won't work since we need an API Key being set. + resp = httpx.get("https://api.anthropic.com/models") + # If Anthropic returned 404, it means it's not accepting our + # requests. We should throw an error. + if resp.status_code == 404: + raise HTTPException( + status_code=404, + detail="The Anthropic API is not accepting requests. Please check your API key.", + ) + + respjson = resp.json() + + return [model["id"] for model in respjson.get("data", [])] + def _setup_routes(self): """ Sets up the /messages route for the provider as expected by the Anthropic diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 515be531..1ab055ea 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, AsyncIterator, Callable, Dict, Optional, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union import structlog from fastapi import APIRouter, Request @@ -54,6 +54,10 @@ def __init__( def _setup_routes(self) -> None: pass + @abstractmethod + def models(self) -> List[str]: + pass + @property @abstractmethod def provider_route_name(self) -> str: diff --git a/src/codegate/providers/crud/__init__.py b/src/codegate/providers/crud/__init__.py new file mode 100644 index 00000000..58adb943 --- /dev/null +++ b/src/codegate/providers/crud/__init__.py @@ -0,0 +1,3 @@ +from .crud import ProviderCrud, ProviderNotFoundError, initialize_provider_endpoints + +__all__ = ["ProviderCrud", "initialize_provider_endpoints", "ProviderNotFoundError"] diff --git a/src/codegate/providers/crud/crud.py b/src/codegate/providers/crud/crud.py new file mode 100644 index 00000000..637375e8 --- /dev/null +++ b/src/codegate/providers/crud/crud.py @@ -0,0 +1,229 @@ +import asyncio +from typing import List, Optional +from urllib.parse import urlparse +from uuid import UUID, uuid4 + +import structlog +from pydantic import ValidationError + +from codegate.api import v1_models as apimodelsv1 +from codegate.config import Config +from codegate.db import models as dbmodels +from codegate.db.connection import DbReader, DbRecorder +from codegate.providers.base import BaseProvider +from codegate.providers.registry import ProviderRegistry + +logger = structlog.get_logger("codegate") + + +class ProviderNotFoundError(Exception): + pass + + +class ProviderCrud: + """The CRUD operations for the provider endpoint references within + Codegate. + + This is meant to handle all the transformations in between the + database and the API, as well as other sources of information. All + operations should result in the API models being returned. + """ + + def __init__(self): + self._db_reader = DbReader() + self._db_writer = DbRecorder() + + async def list_endpoints(self) -> List[apimodelsv1.ProviderEndpoint]: + """List all the endpoints.""" + + outendpoints = [] + dbendpoints = await self._db_reader.get_provider_endpoints() + for dbendpoint in dbendpoints: + outendpoints.append(apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)) + + return outendpoints + + async def get_endpoint_by_id(self, id: UUID) -> Optional[apimodelsv1.ProviderEndpoint]: + """Get an endpoint by ID.""" + + dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(id)) + if dbendpoint is None: + return None + + return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) + + async def get_endpoint_by_name(self, name: str) -> Optional[apimodelsv1.ProviderEndpoint]: + """Get an endpoint by name.""" + + dbendpoint = await self._db_reader.get_provider_endpoint_by_name(name) + if dbendpoint is None: + return None + + return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) + + async def add_endpoint( + self, endpoint: apimodelsv1.ProviderEndpoint + ) -> apimodelsv1.ProviderEndpoint: + """Add an endpoint.""" + dbend = endpoint.to_db_model() + + # We override the ID here, as we want to generate it. + dbend.id = str(uuid4()) + + dbendpoint = await self._db_writer.add_provider_endpoint() + return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) + + async def update_endpoint( + self, endpoint: apimodelsv1.ProviderEndpoint + ) -> apimodelsv1.ProviderEndpoint: + """Update an endpoint.""" + + dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model()) + return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) + + async def delete_endpoint(self, provider_id: UUID): + """Delete an endpoint.""" + + dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(provider_id)) + if dbendpoint is None: + raise ProviderNotFoundError("Provider not found") + + await self._db_writer.delete_provider_endpoint(dbendpoint) + + async def models_by_provider(self, provider_id: UUID) -> List[apimodelsv1.ModelByProvider]: + """Get the models by provider.""" + + # First we try to get the provider + dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(provider_id)) + if dbendpoint is None: + raise ProviderNotFoundError("Provider not found") + + outmodels = [] + dbmodels = await self._db_reader.get_provider_models_by_provider_id(str(provider_id)) + for dbmodel in dbmodels: + outmodels.append( + apimodelsv1.ModelByProvider( + name=dbmodel.name, + provider_id=dbmodel.provider_endpoint_id, + provider_name=dbendpoint.name, + ) + ) + + return outmodels + + async def get_all_models(self) -> List[apimodelsv1.ModelByProvider]: + """Get all the models.""" + + outmodels = [] + dbmodels = await self._db_reader.get_all_provider_models() + for dbmodel in dbmodels: + ename = dbmodel.provider_endpoint_name if dbmodel.provider_endpoint_name else "" + outmodels.append( + apimodelsv1.ModelByProvider( + name=dbmodel.name, + provider_id=dbmodel.provider_endpoint_id, + provider_name=ename, + ) + ) + + return outmodels + + +async def initialize_provider_endpoints(preg: ProviderRegistry): + db_writer = DbRecorder() + db_reader = DbReader() + config = Config.get_config() + if config is None: + provided_urls = {} + else: + provided_urls = config.provider_urls + + for provider_name, provider_url in provided_urls.items(): + provend = __provider_endpoint_from_cfg(provider_name, provider_url) + if provend is None: + continue + + # Check if the provider is already in the db + dbprovend = await db_reader.get_provider_endpoint_by_name(provend.name) + if dbprovend is not None: + logger.debug( + "Provider already in DB. Not re-adding.", + provider=provend.name, + endpoint=provend.endpoint, + ) + continue + + pimpl = provend.get_from_registry(preg) + await try_initialize_provider_endpoints(provend, pimpl, db_writer) + + +async def try_initialize_provider_endpoints( + provend: apimodelsv1.ProviderEndpoint, + pimpl: BaseProvider, + db_writer: DbRecorder, +): + try: + models = pimpl.models() + except Exception as err: + logger.debug( + "Unable to get models from provider", + provider=provend.name, + err=str(err), + ) + return + + logger.info( + "initializing provider to DB", + provider=provend.name, + endpoint=provend.endpoint, + models=models, + ) + # We only try to add the provider if we have models + await db_writer.add_provider_endpoint(provend.to_db_model()) + + tasks = set() + for model in models: + tasks.add( + db_writer.add_provider_model( + dbmodels.ProviderModel( + provider_endpoint_id=provend.id, + name=model, + ) + ) + ) + + await asyncio.gather(*tasks) + + +def __provider_endpoint_from_cfg( + provider_name: str, provider_url: str +) -> Optional[apimodelsv1.ProviderEndpoint]: + """Create a provider endpoint from the config entry.""" + + try: + _ = urlparse(provider_url) + except Exception: + logger.warning( + "Invalid provider URL", provider_name=provider_name, provider_url=provider_url + ) + return None + + try: + return apimodelsv1.ProviderEndpoint( + id=str(uuid4()), + name=provider_name, + endpoint=provider_url, + description=("Endpoint for the {} provided via the CodeGate configuration.").format( + provider_name + ), + provider_type=provider_name, + auth_type=apimodelsv1.ProviderAuthType.passthrough, + ) + except ValidationError as err: + logger.warning( + "Invalid provider name", + provider_name=provider_name, + provider_url=provider_url, + err=str(err), + ) + return None diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index 7f90619e..4478d137 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -1,5 +1,6 @@ import json +import httpx import structlog from fastapi import HTTPException, Request @@ -26,6 +27,13 @@ def __init__( def provider_route_name(self) -> str: return "llamacpp" + def models(self): + # HACK: This is using OpenAI's /v1/models endpoint to get the list of models + resp = httpx.get(f"{self.base_url}/v1/models") + jsonresp = resp.json() + + return [model["id"] for model in jsonresp.get("data", [])] + def _setup_routes(self): """ Sets up the /completions and /chat/completions routes for the diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index ac8013b9..b8e0477b 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -34,6 +34,12 @@ def __init__( def provider_route_name(self) -> str: return "ollama" + def models(self): + resp = httpx.get(f"{self.base_url}/api/tags") + jsonresp = resp.json() + + return [model["name"] for model in jsonresp.get("models", [])] + def _setup_routes(self): """ Sets up Ollama API routes. diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 8a00c68c..87588265 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -1,5 +1,7 @@ import json +from typing import List +import httpx import structlog from fastapi import Header, HTTPException, Request from fastapi.responses import JSONResponse @@ -33,6 +35,13 @@ def __init__( def provider_route_name(self) -> str: return "openai" + def models(self) -> List[str]: + # NOTE: This won't work since we need an API Key being set. + resp = httpx.get(f"{self.lm_studio_url}/v1/models") + jsonresp = resp.json() + + return [model["id"] for model in jsonresp.get("data", [])] + def _setup_routes(self): """ Sets up the /chat/completions route for the provider as expected by the diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py index f39ed8d6..303b907b 100644 --- a/src/codegate/providers/vllm/provider.py +++ b/src/codegate/providers/vllm/provider.py @@ -31,6 +31,12 @@ def __init__( def provider_route_name(self) -> str: return "vllm" + def models(self): + resp = httpx.get(f"{self.base_url}/v1/models") + jsonresp = resp.json() + + return [model["id"] for model in jsonresp.get("data", [])] + def _setup_routes(self): """ Sets up the /chat/completions route for the provider as expected by the diff --git a/src/codegate/server.py b/src/codegate/server.py index 857fb064..ece60c0c 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -30,9 +30,16 @@ async def custom_error_handler(request, exc: Exception): return JSONResponse({"error": str(exc)}, status_code=500) -def init_app(pipeline_factory: PipelineFactory) -> FastAPI: +class CodeGateServer(FastAPI): + provider_registry: ProviderRegistry = None + + def set_provider_registry(self, registry: ProviderRegistry): + self.provider_registry = registry + + +def init_app(pipeline_factory: PipelineFactory) -> CodeGateServer: """Create the FastAPI application.""" - app = FastAPI( + app = CodeGateServer( title="CodeGate", description=__description__, version=__version__, @@ -58,6 +65,7 @@ async def log_user_agent(request: Request, call_next): # Create provider registry registry = ProviderRegistry(app) + app.set_provider_registry(registry) # Register all known providers registry.add_provider( diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py index d7c97da9..4922a5ef 100644 --- a/tests/providers/test_registry.py +++ b/tests/providers/test_registry.py @@ -93,6 +93,9 @@ def __init__( def provider_route_name(self) -> str: return "mock_provider" + def models(self): + return [] + def _setup_routes(self) -> None: @self.router.get(f"/{self.provider_route_name}/test") def test_route(): diff --git a/tests/test_provider.py b/tests/test_provider.py index 95361c97..3539b942 100644 --- a/tests/test_provider.py +++ b/tests/test_provider.py @@ -19,6 +19,9 @@ def __init__(self): mocked_factory, ) + def models(self): + return [] + def _setup_routes(self) -> None: pass