Skip to content

Kick off provider endpoint CRUD structure and registration #790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 117 additions & 63 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from typing import List, Optional
from uuid import UUID

import requests
import structlog
from fastapi import APIRouter, HTTPException, Response
from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi.responses import StreamingResponse
from fastapi.routing import APIRoute
from pydantic import ValidationError
from pydantic import BaseModel, ValidationError

from codegate import __version__
from codegate.api import v1_models, v1_processing
from codegate.db.connection import AlreadyExistsError, DbReader
from codegate.providers import crud as provendcrud
from codegate.workspaces import crud

logger = structlog.get_logger("codegate")

v1 = APIRouter()
wscrud = crud.WorkspaceCrud()
pcrud = provendcrud.ProviderCrud()

# This is a singleton object
dbreader = DbReader()
Expand All @@ -25,38 +28,78 @@ def uniq_name(route: APIRoute):
return f"v1_{route.name}"


class FilterByNameParams(BaseModel):
name: Optional[str] = None


@v1.get("/provider-endpoints", tags=["Providers"], generate_unique_id_function=uniq_name)
async def list_provider_endpoints(name: Optional[str] = None) -> List[v1_models.ProviderEndpoint]:
async def list_provider_endpoints(
filter_query: FilterByNameParams = Depends(),
) -> List[v1_models.ProviderEndpoint]:
"""List all provider endpoints."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that fetches the provider endpoints from the database.
return [
v1_models.ProviderEndpoint(
id=1,
name="dummy",
description="Dummy provider endpoint",
endpoint="http://example.com",
provider_type=v1_models.ProviderType.openai,
auth_type=v1_models.ProviderAuthType.none,
)
]
if filter_query.name is None:
try:
return await pcrud.list_endpoints()
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

try:
provend = await pcrud.get_endpoint_by_name(filter_query.name)
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

if provend is None:
raise HTTPException(status_code=404, detail="Provider endpoint not found")
return [provend]


# This needs to be above /provider-endpoints/{provider_id} to avoid conflict
@v1.get(
"/provider-endpoints/models",
tags=["Providers"],
generate_unique_id_function=uniq_name,
)
async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]:
"""List all models for all providers."""
try:
return await pcrud.get_all_models()
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")


@v1.get(
"/provider-endpoints/{provider_id}/models",
tags=["Providers"],
generate_unique_id_function=uniq_name,
)
async def list_models_by_provider(
provider_id: UUID,
) -> List[v1_models.ModelByProvider]:
"""List models by provider."""

try:
return await pcrud.models_by_provider(provider_id)
except provendcrud.ProviderNotFoundError:
raise HTTPException(status_code=404, detail="Provider not found")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@v1.get(
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
)
async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
async def get_provider_endpoint(
provider_id: UUID,
) -> v1_models.ProviderEndpoint:
"""Get a provider endpoint by ID."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that fetches the provider endpoint from the database.
return v1_models.ProviderEndpoint(
id=provider_id,
name="dummy",
description="Dummy provider endpoint",
endpoint="http://example.com",
provider_type=v1_models.ProviderType.openai,
auth_type=v1_models.ProviderAuthType.none,
)
try:
provend = await pcrud.get_endpoint_by_id(provider_id)
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

if provend is None:
raise HTTPException(status_code=404, detail="Provider endpoint not found")
return provend


@v1.post(
Expand All @@ -65,59 +108,65 @@ async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
generate_unique_id_function=uniq_name,
status_code=201,
)
async def add_provider_endpoint(request: v1_models.ProviderEndpoint) -> v1_models.ProviderEndpoint:
async def add_provider_endpoint(
request: v1_models.ProviderEndpoint,
) -> v1_models.ProviderEndpoint:
"""Add a provider endpoint."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that adds the provider endpoint to the database.
return request
try:
provend = await pcrud.add_endpoint(request)
except AlreadyExistsError:
raise HTTPException(status_code=409, detail="Provider endpoint already exists")
except ValidationError as e:
# TODO: This should be more specific
raise HTTPException(
status_code=400,
detail=str(e),
)
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

return provend


@v1.put(
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
)
async def update_provider_endpoint(
provider_id: int, request: v1_models.ProviderEndpoint
provider_id: UUID,
request: v1_models.ProviderEndpoint,
) -> v1_models.ProviderEndpoint:
"""Update a provider endpoint by ID."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that updates the provider endpoint in the database.
return request
try:
request.id = provider_id
provend = await pcrud.update_endpoint(request)
except ValidationError as e:
# TODO: This should be more specific
raise HTTPException(
status_code=400,
detail=str(e),
)
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

return provend


@v1.delete(
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
)
async def delete_provider_endpoint(provider_id: int):
async def delete_provider_endpoint(
provider_id: UUID,
):
"""Delete a provider endpoint by id."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that deletes the provider endpoint from the database.
try:
await pcrud.delete_endpoint(provider_id)
except provendcrud.ProviderNotFoundError:
raise HTTPException(status_code=404, detail="Provider endpoint not found")
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")
return Response(status_code=204)


@v1.get(
"/provider-endpoints/{provider_name}/models",
tags=["Providers"],
generate_unique_id_function=uniq_name,
)
async def list_models_by_provider(provider_name: str) -> List[v1_models.ModelByProvider]:
"""List models by provider."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that fetches the models by provider from the database.
return [v1_models.ModelByProvider(name="dummy", provider="dummy")]


@v1.get(
"/provider-endpoints/models",
tags=["Providers"],
generate_unique_id_function=uniq_name,
)
async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]:
"""List all models for all providers."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that fetches all the models for all providers from the database.
return [v1_models.ModelByProvider(name="dummy", provider="dummy")]


@v1.get("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name)
async def list_workspaces() -> v1_models.ListWorkspacesResponse:
"""List all workspaces."""
Expand Down Expand Up @@ -394,7 +443,9 @@ async def delete_workspace_custom_instructions(workspace_name: str):
tags=["Workspaces", "Muxes"],
generate_unique_id_function=uniq_name,
)
async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
async def get_workspace_muxes(
workspace_name: str,
) -> List[v1_models.MuxRule]:
"""Get the mux rules of a workspace.

The list is ordered in order of priority. That is, the first rule in the list
Expand Down Expand Up @@ -422,7 +473,10 @@ async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
generate_unique_id_function=uniq_name,
status_code=204,
)
async def set_workspace_muxes(workspace_name: str, request: List[v1_models.MuxRule]):
async def set_workspace_muxes(
workspace_name: str,
request: List[v1_models.MuxRule],
):
"""Set the mux rules of a workspace."""
# TODO: This is a dummy implementation. In the future, we should have a proper
# implementation that sets the mux rules in the database.
Expand Down
36 changes: 33 additions & 3 deletions src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from codegate.db import models as db_models
from codegate.pipeline.base import CodeSnippet
from codegate.providers.base import BaseProvider
from codegate.providers.registry import ProviderRegistry


class Workspace(pydantic.BaseModel):
Expand Down Expand Up @@ -122,6 +124,8 @@ class ProviderType(str, Enum):
openai = "openai"
anthropic = "anthropic"
vllm = "vllm"
ollama = "ollama"
lm_studio = "lm_studio"


class TokenUsageByModel(pydantic.BaseModel):
Expand Down Expand Up @@ -191,13 +195,38 @@ class ProviderEndpoint(pydantic.BaseModel):
so we can use this for muxing messages.
"""

id: int
# This will be set on creation
id: Optional[str] = ""
name: str
description: str = ""
provider_type: ProviderType
endpoint: str
auth_type: ProviderAuthType

@staticmethod
def from_db_model(db_model: db_models.ProviderEndpoint) -> "ProviderEndpoint":
return ProviderEndpoint(
id=db_model.id,
name=db_model.name,
description=db_model.description,
provider_type=db_model.provider_type,
endpoint=db_model.endpoint,
auth_type=db_model.auth_type,
)

def to_db_model(self) -> db_models.ProviderEndpoint:
return db_models.ProviderEndpoint(
id=self.id,
name=self.name,
description=self.description,
provider_type=self.provider_type,
endpoint=self.endpoint,
auth_type=self.auth_type,
)

def get_from_registry(self, registry: ProviderRegistry) -> Optional[BaseProvider]:
return registry.get_provider(self.provider_type)


class ModelByProvider(pydantic.BaseModel):
"""
Expand All @@ -207,10 +236,11 @@ class ModelByProvider(pydantic.BaseModel):
"""

name: str
provider: str
provider_id: str
provider_name: str

def __str__(self):
return f"{self.provider}/{self.name}"
return f"{self.provider_name} / {self.name}"


class MuxMatcherType(str, Enum):
Expand Down
4 changes: 4 additions & 0 deletions src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from codegate.db.connection import init_db_sync, init_session_if_not_exists
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.providers import crud as provendcrud
from codegate.providers.copilot.provider import CopilotProvider
from codegate.server import init_app
from codegate.storage.utils import restore_storage_backup
Expand Down Expand Up @@ -338,6 +339,9 @@ def serve( # noqa: C901
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

registry = app.provider_registry
loop.run_until_complete(provendcrud.initialize_provider_endpoints(registry))

# Run the server
try:
loop.run_until_complete(run_servers(cfg, app))
Expand Down
Loading