diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 31a4a22f..29f2d4d1 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -25,6 +25,99 @@ def uniq_name(route: APIRoute): return f"v1_{route.name}" +@v1.get("/provider-endpoints", tags=["Providers"], generate_unique_id_function=uniq_name) +async def list_provider_endpoints(name: Optional[str] = None) -> List[v1_models.ProviderEndpoint]: + """List all provider endpoints.""" + # NOTE: This is a dummy implementation. In the future, we should have a proper + # implementation that fetches the provider endpoints from the database. + return [ + v1_models.ProviderEndpoint( + id=1, + name="dummy", + description="Dummy provider endpoint", + endpoint="http://example.com", + provider_type=v1_models.ProviderType.openai, + auth_type=v1_models.ProviderAuthType.none, + ) + ] + + +@v1.get( + "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name +) +async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint: + """Get a provider endpoint by ID.""" + # NOTE: This is a dummy implementation. In the future, we should have a proper + # implementation that fetches the provider endpoint from the database. + return v1_models.ProviderEndpoint( + id=provider_id, + name="dummy", + description="Dummy provider endpoint", + endpoint="http://example.com", + provider_type=v1_models.ProviderType.openai, + auth_type=v1_models.ProviderAuthType.none, + ) + + +@v1.post( + "/provider-endpoints", + tags=["Providers"], + generate_unique_id_function=uniq_name, + status_code=201, +) +async def add_provider_endpoint(request: v1_models.ProviderEndpoint) -> v1_models.ProviderEndpoint: + """Add a provider endpoint.""" + # NOTE: This is a dummy implementation. In the future, we should have a proper + # implementation that adds the provider endpoint to the database. + return request + + +@v1.put( + "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name +) +async def update_provider_endpoint( + provider_id: int, request: v1_models.ProviderEndpoint +) -> v1_models.ProviderEndpoint: + """Update a provider endpoint by ID.""" + # NOTE: This is a dummy implementation. In the future, we should have a proper + # implementation that updates the provider endpoint in the database. + return request + + +@v1.delete( + "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name +) +async def delete_provider_endpoint(provider_id: int): + """Delete a provider endpoint by id.""" + # NOTE: This is a dummy implementation. In the future, we should have a proper + # implementation that deletes the provider endpoint from the database. + return Response(status_code=204) + + +@v1.get( + "/provider-endpoints/{provider_name}/models", + tags=["Providers"], + generate_unique_id_function=uniq_name, +) +async def list_models_by_provider(provider_name: str) -> List[v1_models.ModelByProvider]: + """List models by provider.""" + # NOTE: This is a dummy implementation. In the future, we should have a proper + # implementation that fetches the models by provider from the database. + return [v1_models.ModelByProvider(name="dummy", provider="dummy")] + + +@v1.get( + "/provider-endpoints/models", + tags=["Providers"], + generate_unique_id_function=uniq_name, +) +async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]: + """List all models for all providers.""" + # NOTE: This is a dummy implementation. In the future, we should have a proper + # implementation that fetches all the models for all providers from the database. + return [v1_models.ModelByProvider(name="dummy", provider="dummy")] + + @v1.get("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name) async def list_workspaces() -> v1_models.ListWorkspacesResponse: """List all workspaces.""" @@ -296,6 +389,46 @@ async def delete_workspace_custom_instructions(workspace_name: str): return Response(status_code=204) +@v1.get( + "/workspaces/{workspace_name}/muxes", + tags=["Workspaces", "Muxes"], + generate_unique_id_function=uniq_name, +) +async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]: + """Get the mux rules of a workspace. + + The list is ordered in order of priority. That is, the first rule in the list + has the highest priority.""" + # TODO: This is a dummy implementation. In the future, we should have a proper + # implementation that fetches the mux rules from the database. + return [ + v1_models.MuxRule( + provider="openai", + model="gpt-3.5-turbo", + matcher_type=v1_models.MuxMatcherType.file_regex, + matcher=".*\\.txt", + ), + v1_models.MuxRule( + provider="anthropic", + model="davinci", + matcher_type=v1_models.MuxMatcherType.catch_all, + ), + ] + + +@v1.put( + "/workspaces/{workspace_name}/muxes", + tags=["Workspaces", "Muxes"], + generate_unique_id_function=uniq_name, + status_code=204, +) +async def set_workspace_muxes(workspace_name: str, request: List[v1_models.MuxRule]): + """Set the mux rules of a workspace.""" + # TODO: This is a dummy implementation. In the future, we should have a proper + # implementation that sets the mux rules in the database. + return Response(status_code=204) + + @v1.get("/alerts_notification", tags=["Dashboard"], generate_unique_id_function=uniq_name) async def stream_sse(): """ diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index 2eba1d87..7bf2d3db 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -138,3 +138,83 @@ class AlertConversation(pydantic.BaseModel): trigger_type: str trigger_category: Optional[str] timestamp: datetime.datetime + + +class ProviderType(str, Enum): + """ + Represents the different types of providers we support. + """ + + openai = "openai" + anthropic = "anthropic" + vllm = "vllm" + + +class ProviderAuthType(str, Enum): + """ + Represents the different types of auth we support for providers. + """ + + # No auth required + none = "none" + # Whatever the user provides is passed through + passthrough = "passthrough" + # API key is required + api_key = "api_key" + + +class ProviderEndpoint(pydantic.BaseModel): + """ + Represents a provider's endpoint configuration. This + allows us to persist the configuration for each provider, + so we can use this for muxing messages. + """ + + id: int + name: str + description: str = "" + provider_type: ProviderType + endpoint: str + auth_type: ProviderAuthType + + +class ModelByProvider(pydantic.BaseModel): + """ + Represents a model supported by a provider. + + Note that these are auto-discovered by the provider. + """ + + name: str + provider: str + + def __str__(self): + return f"{self.provider}/{self.name}" + + +class MuxMatcherType(str, Enum): + """ + Represents the different types of matchers we support. + """ + + # Match a regular expression for a file path + # in the prompt. Note that if no file is found, + # the prompt will be passed through. + file_regex = "file_regex" + + # Always match this prompt + catch_all = "catch_all" + + +class MuxRule(pydantic.BaseModel): + """ + Represents a mux rule for a provider. + """ + + provider: str + model: str + # The type of matcher to use + matcher_type: MuxMatcherType + # The actual matcher to use. Note that + # this depends on the matcher type. + matcher: Optional[str]