Skip to content

Bootstrap provider models on addition & implement mux endpoints #826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,26 @@ async def get_provider_endpoint(
status_code=201,
)
async def add_provider_endpoint(
request: v1_models.ProviderEndpoint,
request: v1_models.AddProviderEndpointRequest,
) -> v1_models.ProviderEndpoint:
"""Add a provider endpoint."""
try:
provend = await pcrud.add_endpoint(request)
except AlreadyExistsError:
raise HTTPException(status_code=409, detail="Provider endpoint already exists")
except ValueError as e:
raise HTTPException(
status_code=400,
detail=str(e),
)
except ValidationError as e:
# TODO: This should be more specific
raise HTTPException(
status_code=400,
detail=str(e),
)
except Exception:
logger.exception("Error while adding provider endpoint")
raise HTTPException(status_code=500, detail="Internal server error")

return provend
Expand Down Expand Up @@ -154,20 +160,24 @@ async def configure_auth_material(
)
async def update_provider_endpoint(
provider_id: UUID,
request: v1_models.ProviderEndpoint,
request: v1_models.AddProviderEndpointRequest,
) -> v1_models.ProviderEndpoint:
"""Update a provider endpoint by ID."""
try:
request.id = provider_id
request.id = str(provider_id)
provend = await pcrud.update_endpoint(request)
except provendcrud.ProviderNotFoundError:
raise HTTPException(status_code=404, detail="Provider endpoint not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except ValidationError as e:
# TODO: This should be more specific
raise HTTPException(
status_code=400,
detail=str(e),
)
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

return provend

Expand Down Expand Up @@ -471,22 +481,15 @@ async def get_workspace_muxes(

The list is ordered in order of priority. That is, the first rule in the list
has the highest priority."""
# TODO: This is a dummy implementation. In the future, we should have a proper
# implementation that fetches the mux rules from the database.
return [
v1_models.MuxRule(
# Hardcode some UUID just for mocking purposes
provider_id="00000000-0000-0000-0000-000000000001",
model="gpt-3.5-turbo",
matcher_type=v1_models.MuxMatcherType.file_regex,
matcher=".*\\.txt",
),
v1_models.MuxRule(
provider_id="00000000-0000-0000-0000-000000000002",
model="davinci",
matcher_type=v1_models.MuxMatcherType.catch_all,
),
]
try:
muxes = await wscrud.get_muxes(workspace_name)
except crud.WorkspaceDoesNotExistError:
raise HTTPException(status_code=404, detail="Workspace does not exist")
except Exception:
logger.exception("Error while getting workspace")
raise HTTPException(status_code=500, detail="Internal server error")

return muxes


@v1.put(
Expand All @@ -500,8 +503,16 @@ async def set_workspace_muxes(
request: List[v1_models.MuxRule],
):
"""Set the mux rules of a workspace."""
# TODO: This is a dummy implementation. In the future, we should have a proper
# implementation that sets the mux rules in the database.
try:
await wscrud.set_muxes(workspace_name, request)
except crud.WorkspaceDoesNotExistError:
raise HTTPException(status_code=404, detail="Workspace does not exist")
except crud.WorkspaceCrudError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception:
logger.exception("Error while setting muxes")
raise HTTPException(status_code=500, detail="Internal server error")

return Response(status_code=204)


Expand Down
15 changes: 9 additions & 6 deletions src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class ProviderEndpoint(pydantic.BaseModel):
name: str
description: str = ""
provider_type: ProviderType
endpoint: str
endpoint: str = "" # Some providers have defaults we can leverage
auth_type: Optional[ProviderAuthType] = ProviderAuthType.none

@staticmethod
Expand Down Expand Up @@ -250,6 +250,14 @@ def get_from_registry(self, registry: ProviderRegistry) -> Optional[BaseProvider
return registry.get_provider(self.provider_type)


class AddProviderEndpointRequest(ProviderEndpoint):
"""
Represents a request to add a provider endpoint.
"""

api_key: Optional[str] = None


class ConfigureAuthMaterial(pydantic.BaseModel):
"""
Represents a request to configure auth material for a provider.
Expand Down Expand Up @@ -279,11 +287,6 @@ class MuxMatcherType(str, Enum):
Represents the different types of matchers we support.
"""

# Match a regular expression for a file path
# in the prompt. Note that if no file is found,
# the prompt will be passed through.
file_regex = "file_regex"

# Always match this prompt
catch_all = "catch_all"

Expand Down
81 changes: 81 additions & 0 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Alert,
GetPromptWithOutputsRow,
GetWorkspaceByNameConditions,
MuxRule,
Output,
Prompt,
ProviderAuthMaterial,
Expand Down Expand Up @@ -112,6 +113,15 @@ async def _execute_update_pydantic_model(
raise e
return None

async def _execute_with_no_return(self, sql_command: TextClause, conditions: dict):
"""Execute a command that doesn't return anything."""
try:
async with self._async_db_engine.begin() as conn:
await conn.execute(sql_command, conditions)
except Exception as e:
logger.error(f"Failed to execute command: {sql_command}.", error=str(e))
raise e

async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
if prompt_params is None:
return None
Expand Down Expand Up @@ -459,6 +469,45 @@ async def add_provider_model(self, model: ProviderModel) -> ProviderModel:
added_model = await self._execute_update_pydantic_model(model, sql, should_raise=True)
return added_model

async def delete_provider_models(self, provider_id: str):
sql = text(
"""
DELETE FROM provider_models
WHERE provider_endpoint_id = :provider_endpoint_id
"""
)
conditions = {"provider_endpoint_id": provider_id}
await self._execute_with_no_return(sql, conditions)

async def delete_muxes_by_workspace(self, workspace_id: str):
sql = text(
"""
DELETE FROM muxes
WHERE workspace_id = :workspace_id
RETURNING *
"""
)

conditions = {"workspace_id": workspace_id}
await self._execute_with_no_return(sql, conditions)

async def add_mux(self, mux: MuxRule) -> MuxRule:
sql = text(
"""
INSERT INTO muxes (
id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type,
matcher_blob, priority, created_at, updated_at
)
VALUES (
:id, :provider_endpoint_id, :provider_model_name, :workspace_id,
:matcher_type, :matcher_blob, :priority, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP
)
RETURNING *
"""
)
added_mux = await self._execute_update_pydantic_model(mux, sql, should_raise=True)
return added_mux


class DbReader(DbCodeGate):

Expand Down Expand Up @@ -684,6 +733,22 @@ async def get_provider_models_by_provider_id(self, provider_id: str) -> List[Pro
)
return models

async def get_provider_model_by_provider_id_and_name(
self, provider_id: str, model_name: str
) -> Optional[ProviderModel]:
sql = text(
"""
SELECT provider_endpoint_id, name
FROM provider_models
WHERE provider_endpoint_id = :provider_endpoint_id AND name = :name
"""
)
conditions = {"provider_endpoint_id": provider_id, "name": model_name}
models = await self._exec_select_conditions_to_pydantic(
ProviderModel, sql, conditions, should_raise=True
)
return models[0] if models else None

async def get_all_provider_models(self) -> List[ProviderModel]:
sql = text(
"""
Expand All @@ -695,6 +760,22 @@ async def get_all_provider_models(self) -> List[ProviderModel]:
models = await self._execute_select_pydantic_model(ProviderModel, sql)
return models

async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]:
sql = text(
"""
SELECT id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type,
matcher_blob, priority, created_at, updated_at
FROM muxes
WHERE workspace_id = :workspace_id
ORDER BY priority ASC
"""
)
conditions = {"workspace_id": workspace_id}
muxes = await self._exec_select_conditions_to_pydantic(
MuxRule, sql, conditions, should_raise=True
)
return muxes


def init_db_sync(db_path: Optional[str] = None):
"""DB will be initialized in the constructor in case it doesn't exist."""
Expand Down
12 changes: 12 additions & 0 deletions src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,15 @@ class ProviderModel(BaseModel):
provider_endpoint_id: str
provider_endpoint_name: Optional[str] = None
name: str


class MuxRule(BaseModel):
id: str
provider_endpoint_id: str
provider_model_name: str
workspace_id: str
matcher_type: str
matcher_blob: str
priority: int
created_at: Optional[datetime.datetime] = None
updated_at: Optional[datetime.datetime] = None
29 changes: 18 additions & 11 deletions src/codegate/providers/anthropic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
from codegate.providers.base import BaseProvider
from codegate.providers.base import BaseProvider, ModelFetchError
from codegate.providers.litellmshim import anthropic_stream_generator


Expand All @@ -29,16 +29,23 @@ def __init__(
def provider_route_name(self) -> str:
return "anthropic"

def models(self) -> List[str]:
# TODO: This won't work since we need an API Key being set.
resp = httpx.get("https://api.anthropic.com/models")
# If Anthropic returned 404, it means it's not accepting our
# requests. We should throw an error.
if resp.status_code == 404:
raise HTTPException(
status_code=404,
detail="The Anthropic API is not accepting requests. Please check your API key.",
)
def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
headers = {
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
}
if api_key:
headers["x-api-key"] = api_key
if not endpoint:
endpoint = "https://api.anthropic.com"

resp = httpx.get(
f"{endpoint}/v1/models",
headers=headers,
)

if resp.status_code != 200:
raise ModelFetchError(f"Failed to fetch models from Anthropic API: {resp.text}")

respjson = resp.json()

Expand Down
6 changes: 5 additions & 1 deletion src/codegate/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]]


class ModelFetchError(Exception):
pass


class BaseProvider(ABC):
"""
The provider class is responsible for defining the API routes and
Expand Down Expand Up @@ -55,7 +59,7 @@ def _setup_routes(self) -> None:
pass

@abstractmethod
def models(self) -> List[str]:
def models(self, endpoint, str=None, api_key: str = None) -> List[str]:
pass

@property
Expand Down
Loading