Skip to content

Add scaffolding of API for provider endpoints and mux rules #761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,99 @@ def uniq_name(route: APIRoute):
return f"v1_{route.name}"


@v1.get("/provider-endpoints", tags=["Providers"], generate_unique_id_function=uniq_name)
async def list_provider_endpoints(name: Optional[str] = None) -> List[v1_models.ProviderEndpoint]:
"""List all provider endpoints."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that fetches the provider endpoints from the database.
return [
v1_models.ProviderEndpoint(
id=1,
name="dummy",
description="Dummy provider endpoint",
endpoint="http://example.com",
provider_type=v1_models.ProviderType.openai,
auth_type=v1_models.ProviderAuthType.none,
)
]


@v1.get(
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
)
async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
"""Get a provider endpoint by ID."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that fetches the provider endpoint from the database.
return v1_models.ProviderEndpoint(
id=provider_id,
name="dummy",
description="Dummy provider endpoint",
endpoint="http://example.com",
provider_type=v1_models.ProviderType.openai,
auth_type=v1_models.ProviderAuthType.none,
)


@v1.post(
"/provider-endpoints",
tags=["Providers"],
generate_unique_id_function=uniq_name,
status_code=201,
)
async def add_provider_endpoint(request: v1_models.ProviderEndpoint) -> v1_models.ProviderEndpoint:
"""Add a provider endpoint."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that adds the provider endpoint to the database.
return request


@v1.put(
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
)
async def update_provider_endpoint(
provider_id: int, request: v1_models.ProviderEndpoint
) -> v1_models.ProviderEndpoint:
"""Update a provider endpoint by ID."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that updates the provider endpoint in the database.
return request


@v1.delete(
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
)
async def delete_provider_endpoint(provider_id: int):
"""Delete a provider endpoint by id."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that deletes the provider endpoint from the database.
return Response(status_code=204)


@v1.get(
"/provider-endpoints/{provider_name}/models",
tags=["Providers"],
generate_unique_id_function=uniq_name,
)
async def list_models_by_provider(provider_name: str) -> List[v1_models.ModelByProvider]:
"""List models by provider."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that fetches the models by provider from the database.
return [v1_models.ModelByProvider(name="dummy", provider="dummy")]


@v1.get(
"/provider-endpoints/models",
tags=["Providers"],
generate_unique_id_function=uniq_name,
)
async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]:
"""List all models for all providers."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that fetches all the models for all providers from the database.
return [v1_models.ModelByProvider(name="dummy", provider="dummy")]


@v1.get("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name)
async def list_workspaces() -> v1_models.ListWorkspacesResponse:
"""List all workspaces."""
Expand Down Expand Up @@ -296,6 +389,46 @@ async def delete_workspace_custom_instructions(workspace_name: str):
return Response(status_code=204)


@v1.get(
"/workspaces/{workspace_name}/muxes",
tags=["Workspaces", "Muxes"],
generate_unique_id_function=uniq_name,
)
async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
"""Get the mux rules of a workspace.

The list is ordered in order of priority. That is, the first rule in the list
has the highest priority."""
# TODO: This is a dummy implementation. In the future, we should have a proper
# implementation that fetches the mux rules from the database.
return [
v1_models.MuxRule(
provider="openai",
model="gpt-3.5-turbo",
matcher_type=v1_models.MuxMatcherType.file_regex,
matcher=".*\\.txt",
),
v1_models.MuxRule(
provider="anthropic",
model="davinci",
matcher_type=v1_models.MuxMatcherType.catch_all,
),
]


@v1.put(
"/workspaces/{workspace_name}/muxes",
tags=["Workspaces", "Muxes"],
generate_unique_id_function=uniq_name,
status_code=204,
)
async def set_workspace_muxes(workspace_name: str, request: List[v1_models.MuxRule]):
"""Set the mux rules of a workspace."""
# TODO: This is a dummy implementation. In the future, we should have a proper
# implementation that sets the mux rules in the database.
return Response(status_code=204)


@v1.get("/alerts_notification", tags=["Dashboard"], generate_unique_id_function=uniq_name)
async def stream_sse():
"""
Expand Down
80 changes: 80 additions & 0 deletions src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,83 @@ class AlertConversation(pydantic.BaseModel):
trigger_type: str
trigger_category: Optional[str]
timestamp: datetime.datetime


class ProviderType(str, Enum):
"""
Represents the different types of providers we support.
"""

openai = "openai"
anthropic = "anthropic"
vllm = "vllm"


class ProviderAuthType(str, Enum):
"""
Represents the different types of auth we support for providers.
"""

# No auth required
none = "none"
# Whatever the user provides is passed through
passthrough = "passthrough"
# API key is required
api_key = "api_key"


class ProviderEndpoint(pydantic.BaseModel):
"""
Represents a provider's endpoint configuration. This
allows us to persist the configuration for each provider,
so we can use this for muxing messages.
"""

id: int
name: str
description: str = ""
provider_type: ProviderType
endpoint: str
auth_type: ProviderAuthType


class ModelByProvider(pydantic.BaseModel):
"""
Represents a model supported by a provider.

Note that these are auto-discovered by the provider.
"""

name: str
provider: str

def __str__(self):
return f"{self.provider}/{self.name}"


class MuxMatcherType(str, Enum):
"""
Represents the different types of matchers we support.
"""

# Match a regular expression for a file path
# in the prompt. Note that if no file is found,
# the prompt will be passed through.
file_regex = "file_regex"

# Always match this prompt
catch_all = "catch_all"


class MuxRule(pydantic.BaseModel):
"""
Represents a mux rule for a provider.
"""

provider: str
model: str
# The type of matcher to use
matcher_type: MuxMatcherType
# The actual matcher to use. Note that
# this depends on the matcher type.
matcher: Optional[str]
Loading