diff --git a/api/openapi.json b/api/openapi.json index 8b613d21f..759231de2 100644 --- a/api/openapi.json +++ b/api/openapi.json @@ -148,7 +148,7 @@ } } }, - "/api/v1/provider-endpoints/{provider_id}/models": { + "/api/v1/provider-endpoints/{provider_name}/models": { "get": { "tags": [ "CodeGate API", @@ -159,13 +159,12 @@ "operationId": "v1_list_models_by_provider", "parameters": [ { - "name": "provider_id", + "name": "provider_name", "in": "path", "required": true, "schema": { "type": "string", - "format": "uuid", - "title": "Provider Id" + "title": "Provider Name" } } ], @@ -197,24 +196,23 @@ } } }, - "/api/v1/provider-endpoints/{provider_id}": { + "/api/v1/provider-endpoints/{provider_name}": { "get": { "tags": [ "CodeGate API", "Providers" ], "summary": "Get Provider Endpoint", - "description": "Get a provider endpoint by ID.", + "description": "Get a provider endpoint by name.", "operationId": "v1_get_provider_endpoint", "parameters": [ { - "name": "provider_id", + "name": "provider_name", "in": "path", "required": true, "schema": { "type": "string", - "format": "uuid", - "title": "Provider Id" + "title": "Provider Name" } } ], @@ -247,17 +245,16 @@ "Providers" ], "summary": "Update Provider Endpoint", - "description": "Update a provider endpoint by ID.", + "description": "Update a provider endpoint by name.", "operationId": "v1_update_provider_endpoint", "parameters": [ { - "name": "provider_id", + "name": "provider_name", "in": "path", "required": true, "schema": { "type": "string", - "format": "uuid", - "title": "Provider Id" + "title": "Provider Name" } } ], @@ -300,17 +297,16 @@ "Providers" ], "summary": "Delete Provider Endpoint", - "description": "Delete a provider endpoint by id.", + "description": "Delete a provider endpoint by name.", "operationId": "v1_delete_provider_endpoint", "parameters": [ { - "name": "provider_id", + "name": "provider_name", "in": "path", "required": true, "schema": { "type": "string", - "format": "uuid", - "title": "Provider Id" + "title": "Provider Name" } } ], @@ -336,7 +332,7 @@ } } }, - "/api/v1/provider-endpoints/{provider_id}/auth-material": { + "/api/v1/provider-endpoints/{provider_name}/auth-material": { "put": { "tags": [ "CodeGate API", @@ -347,13 +343,12 @@ "operationId": "v1_configure_auth_material", "parameters": [ { - "name": "provider_id", + "name": "provider_name", "in": "path", "required": true, "schema": { "type": "string", - "format": "uuid", - "title": "Provider Id" + "title": "Provider Name" } } ], @@ -391,8 +386,26 @@ "Workspaces" ], "summary": "List Workspaces", - "description": "List all workspaces.", + "description": "List all workspaces.\n\nArgs:\n provider_name (Optional[str]): Filter workspaces by provider name. If provided,\n will return workspaces where models from the specified provider (e.g., OpenAI,\n Anthropic) have been used in workspace muxing rules.\n\nReturns:\n ListWorkspacesResponse: A response object containing the list of workspaces.", "operationId": "v1_list_workspaces", + "parameters": [ + { + "name": "provider_name", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Provider Name" + } + } + ], "responses": { "200": { "description": "Successful Response", @@ -403,6 +416,16 @@ } } } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } } } }, @@ -415,14 +438,14 @@ "description": "Create a new workspace.", "operationId": "v1_create_workspace", "requestBody": { + "required": true, "content": { "application/json": { "schema": { "$ref": "#/components/schemas/FullWorkspace-Input" } } - }, - "required": true + } }, "responses": { "201": { @@ -552,7 +575,7 @@ } }, "responses": { - "201": { + "200": { "description": "Successful Response", "content": { "application/json": { @@ -613,6 +636,48 @@ } } } + }, + "get": { + "tags": [ + "CodeGate API", + "Workspaces" + ], + "summary": "Get Workspace By Name", + "description": "List workspaces by provider ID.", + "operationId": "v1_get_workspace_by_name", + "parameters": [ + { + "name": "workspace_name", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workspace Name" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/FullWorkspace-Output" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } } }, "/api/v1/workspaces/archive": { @@ -1195,55 +1260,6 @@ } } }, - "/api/v1/workspaces/{provider_id}": { - "get": { - "tags": [ - "CodeGate API", - "Workspaces" - ], - "summary": "List Workspaces By Provider", - "description": "List workspaces by provider ID.", - "operationId": "v1_list_workspaces_by_provider", - "parameters": [ - { - "name": "provider_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "format": "uuid", - "title": "Provider Id" - } - } - ], - "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": { - "type": "array", - "items": { - "$ref": "#/components/schemas/WorkspaceWithModel" - }, - "title": "Response V1 List Workspaces By Provider" - } - } - } - }, - "422": { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - } - } - } - } - }, "/api/v1/alerts_notification": { "get": { "tags": [ @@ -2136,9 +2152,8 @@ "type": "string", "title": "Name" }, - "provider_id": { - "type": "string", - "title": "Provider Id" + "provider_type": { + "$ref": "#/components/schemas/ProviderType" }, "provider_name": { "type": "string", @@ -2148,7 +2163,7 @@ "type": "object", "required": [ "name", - "provider_id", + "provider_type", "provider_name" ], "title": "ModelByProvider", @@ -2168,19 +2183,11 @@ "MuxRule": { "properties": { "provider_name": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], + "type": "string", "title": "Provider Name" }, - "provider_id": { - "type": "string", - "title": "Provider Id" + "provider_type": { + "$ref": "#/components/schemas/ProviderType" }, "model": { "type": "string", @@ -2203,7 +2210,8 @@ }, "type": "object", "required": [ - "provider_id", + "provider_name", + "provider_type", "model", "matcher_type" ], @@ -2565,31 +2573,6 @@ "muxing_rules" ], "title": "WorkspaceConfig" - }, - "WorkspaceWithModel": { - "properties": { - "id": { - "type": "string", - "title": "Id" - }, - "name": { - "type": "string", - "pattern": "^[a-zA-Z0-9_-]+$", - "title": "Name" - }, - "provider_model_name": { - "type": "string", - "title": "Provider Model Name" - } - }, - "type": "object", - "required": [ - "id", - "name", - "provider_model_name" - ], - "title": "WorkspaceWithModel", - "description": "Returns a workspace ID with model name" } } } diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index edd6d0a06..c085c4e2b 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -1,5 +1,4 @@ from typing import List, Optional -from uuid import UUID import cachetools.func import requests @@ -14,7 +13,7 @@ from codegate.api import v1_models, v1_processing from codegate.config import API_DEFAULT_PAGE_SIZE, API_MAX_PAGE_SIZE from codegate.db.connection import AlreadyExistsError, DbReader -from codegate.db.models import AlertSeverity, AlertTriggerType, Persona, WorkspaceWithModel +from codegate.db.models import AlertSeverity, AlertTriggerType, Persona from codegate.muxing.persona import ( PersonaDoesNotExistError, PersonaManager, @@ -56,15 +55,14 @@ async def list_provider_endpoints( try: provend = await pcrud.get_endpoint_by_name(filter_query.name) + return [provend] + except provendcrud.ProviderNotFoundError: + raise HTTPException(status_code=404, detail="Provider endpoint not found") except Exception: raise HTTPException(status_code=500, detail="Internal server error") - if provend is None: - raise HTTPException(status_code=404, detail="Provider endpoint not found") - return [provend] - -# This needs to be above /provider-endpoints/{provider_id} to avoid conflict +# This needs to be above /provider-endpoints/{provider_name} to avoid conflict @v1.get( "/provider-endpoints/models", tags=["Providers"], @@ -79,37 +77,38 @@ async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider] @v1.get( - "/provider-endpoints/{provider_id}/models", + "/provider-endpoints/{provider_name}/models", tags=["Providers"], generate_unique_id_function=uniq_name, ) async def list_models_by_provider( - provider_id: UUID, + provider_name: str, ) -> List[v1_models.ModelByProvider]: """List models by provider.""" try: - return await pcrud.models_by_provider(provider_id) + provider = await pcrud.get_endpoint_by_name(provider_name) + return await pcrud.models_by_provider(provider.id) except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider not found") except Exception as e: + logger.exception("Error while listing models by provider") raise HTTPException(status_code=500, detail=str(e)) @v1.get( - "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name + "/provider-endpoints/{provider_name}", tags=["Providers"], generate_unique_id_function=uniq_name ) async def get_provider_endpoint( - provider_id: UUID, + provider_name: str, ) -> v1_models.ProviderEndpoint: - """Get a provider endpoint by ID.""" + """Get a provider endpoint by name.""" try: - provend = await pcrud.get_endpoint_by_id(provider_id) + provend = await pcrud.get_endpoint_by_name(provider_name) + except provendcrud.ProviderNotFoundError: + raise HTTPException(status_code=404, detail="Provider endpoint not found") except Exception: raise HTTPException(status_code=500, detail="Internal server error") - - if provend is None: - raise HTTPException(status_code=404, detail="Provider endpoint not found") return provend @@ -150,18 +149,19 @@ async def add_provider_endpoint( @v1.put( - "/provider-endpoints/{provider_id}/auth-material", + "/provider-endpoints/{provider_name}/auth-material", tags=["Providers"], generate_unique_id_function=uniq_name, status_code=204, ) async def configure_auth_material( - provider_id: UUID, + provider_name: str, request: v1_models.ConfigureAuthMaterial, ): """Configure auth material for a provider.""" try: - await pcrud.configure_auth_material(provider_id, request) + provider = await pcrud.get_endpoint_by_name(provider_name) + await pcrud.configure_auth_material(provider.id, request) except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider endpoint not found") except provendcrud.ProviderModelsNotFoundError: @@ -175,15 +175,16 @@ async def configure_auth_material( @v1.put( - "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name + "/provider-endpoints/{provider_name}", tags=["Providers"], generate_unique_id_function=uniq_name ) async def update_provider_endpoint( - provider_id: UUID, + provider_name: str, request: v1_models.ProviderEndpoint, ) -> v1_models.ProviderEndpoint: - """Update a provider endpoint by ID.""" + """Update a provider endpoint by name.""" try: - request.id = str(provider_id) + provider = await pcrud.get_endpoint_by_name(provider_name) + request.id = str(provider.id) provend = await pcrud.update_endpoint(request) except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider endpoint not found") @@ -196,20 +197,22 @@ async def update_provider_endpoint( detail=str(e), ) except Exception as e: + logger.exception("Error while updating provider endpoint") raise HTTPException(status_code=500, detail=str(e)) return provend @v1.delete( - "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name + "/provider-endpoints/{provider_name}", tags=["Providers"], generate_unique_id_function=uniq_name ) async def delete_provider_endpoint( - provider_id: UUID, + provider_name: str, ): - """Delete a provider endpoint by id.""" + """Delete a provider endpoint by name.""" try: - await pcrud.delete_endpoint(provider_id) + provider = await pcrud.get_endpoint_by_name(provider_name) + await pcrud.delete_endpoint(provider.id) except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider endpoint not found") except Exception: @@ -218,13 +221,34 @@ async def delete_provider_endpoint( @v1.get("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name) -async def list_workspaces() -> v1_models.ListWorkspacesResponse: - """List all workspaces.""" - wslist = await wscrud.get_workspaces() +async def list_workspaces( + provider_name: Optional[str] = Query(None), +) -> v1_models.ListWorkspacesResponse: + """ + List all workspaces. - resp = v1_models.ListWorkspacesResponse.from_db_workspaces_with_sessioninfo(wslist) + Args: + provider_name (Optional[str]): Filter workspaces by provider name. If provided, + will return workspaces where models from the specified provider (e.g., OpenAI, + Anthropic) have been used in workspace muxing rules. - return resp + Returns: + ListWorkspacesResponse: A response object containing the list of workspaces. + """ + try: + if provider_name: + provider = await pcrud.get_endpoint_by_name(provider_name) + wslist = await wscrud.workspaces_by_provider(provider.id) + resp = v1_models.ListWorkspacesResponse.from_db_workspaces(wslist) + return resp + else: + wslist = await wscrud.get_workspaces() + resp = v1_models.ListWorkspacesResponse.from_db_workspaces_with_sessioninfo(wslist) + return resp + except provendcrud.ProviderNotFoundError: + return v1_models.ListWorkspacesResponse(workspaces=[]) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) @v1.get("/workspaces/active", tags=["Workspaces"], generate_unique_id_function=uniq_name) @@ -262,11 +286,20 @@ async def create_workspace( """Create a new workspace.""" try: custom_instructions = request.config.custom_instructions if request.config else None - muxing_rules = request.config.muxing_rules if request.config else None + mux_rules = [] + if request.config and request.config.muxing_rules: + mux_rules = await pcrud.add_provider_ids_to_mux_rule_list(request.config.muxing_rules) - workspace_row, mux_rules = await wscrud.add_workspace( - request.name, custom_instructions, muxing_rules + workspace_row, created_mux_rules = await wscrud.add_workspace( + request.name, custom_instructions, mux_rules ) + + created_muxes_with_name_type = [ + mux_models.MuxRule.from_db_models( + mux_rule, await pcrud.get_endpoint_by_id(mux_rule.provider_endpoint_id) + ) + for mux_rule in created_mux_rules + ] except crud.WorkspaceNameAlreadyInUseError: raise HTTPException(status_code=409, detail="Workspace name already in use") except ValidationError: @@ -277,16 +310,21 @@ async def create_workspace( "Please use only alphanumeric characters, hyphens, or underscores." ), ) + except provendcrud.ProviderNotFoundError as e: + logger.exception("Error matching a provider for a mux rule while creating a workspace") + raise HTTPException(status_code=400, detail=str(e)) except crud.WorkspaceCrudError as e: + logger.exception("Error while create a workspace") raise HTTPException(status_code=400, detail=str(e)) except Exception: + logger.exception("Error while creating workspace") raise HTTPException(status_code=500, detail="Internal server error") return v1_models.FullWorkspace( name=workspace_row.name, config=v1_models.WorkspaceConfig( custom_instructions=workspace_row.custom_instructions or "", - muxing_rules=[mux_models.MuxRule.from_db_mux_rule(mux_rule) for mux_rule in mux_rules], + muxing_rules=created_muxes_with_name_type, ), ) @@ -295,7 +333,7 @@ async def create_workspace( "/workspaces/{workspace_name}", tags=["Workspaces"], generate_unique_id_function=uniq_name, - status_code=201, + status_code=200, ) async def update_workspace( workspace_name: str, @@ -304,14 +342,26 @@ async def update_workspace( """Update a workspace.""" try: custom_instructions = request.config.custom_instructions if request.config else None - muxing_rules = request.config.muxing_rules if request.config else None + mux_rules = [] + if request.config and request.config.muxing_rules: + mux_rules = await pcrud.add_provider_ids_to_mux_rule_list(request.config.muxing_rules) - workspace_row, mux_rules = await wscrud.update_workspace( + workspace_row, updated_muxes = await wscrud.update_workspace( workspace_name, request.name, custom_instructions, - muxing_rules, + mux_rules, ) + + updated_muxes_with_name_type = [ + mux_models.MuxRule.from_db_models( + mux_rule, await pcrud.get_endpoint_by_id(mux_rule.provider_endpoint_id) + ) + for mux_rule in updated_muxes + ] + + except provendcrud.ProviderNotFoundError as e: + raise HTTPException(status_code=400, detail=str(e)) except crud.WorkspaceDoesNotExistError: raise HTTPException(status_code=404, detail="Workspace does not exist") except crud.WorkspaceNameAlreadyInUseError: @@ -325,6 +375,7 @@ async def update_workspace( ), ) except crud.WorkspaceCrudError as e: + logger.exception("Error while updating workspace") raise HTTPException(status_code=400, detail=str(e)) except Exception: raise HTTPException(status_code=500, detail="Internal server error") @@ -333,7 +384,7 @@ async def update_workspace( name=workspace_row.name, config=v1_models.WorkspaceConfig( custom_instructions=workspace_row.custom_instructions or "", - muxing_rules=[mux_models.MuxRule.from_db_mux_rule(mux_rule) for mux_rule in mux_rules], + muxing_rules=updated_muxes_with_name_type, ), ) @@ -351,7 +402,11 @@ async def delete_workspace(workspace_name: str): raise HTTPException(status_code=404, detail="Workspace does not exist") except crud.WorkspaceCrudError as e: raise HTTPException(status_code=400, detail=str(e)) + except crud.DeleteMuxesFromRegistryError: + logger.exception("Error deleting muxes while deleting workspace") + raise HTTPException(status_code=500, detail="Internal server error") except Exception: + logger.exception("Error while deleting workspace") raise HTTPException(status_code=500, detail="Internal server error") return Response(status_code=204) @@ -667,14 +722,20 @@ async def get_workspace_muxes( The list is ordered in order of priority. That is, the first rule in the list has the highest priority.""" try: - muxes = await wscrud.get_muxes(workspace_name) + db_muxes = await wscrud.get_muxes(workspace_name) + + muxes = [] + for db_mux in db_muxes: + db_endpoint = await pcrud.get_endpoint_by_id(db_mux.provider_endpoint_id) + mux_rule = mux_models.MuxRule.from_db_models(db_mux, db_endpoint) + muxes.append(mux_rule) except crud.WorkspaceDoesNotExistError: raise HTTPException(status_code=404, detail="Workspace does not exist") except Exception: logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") - return muxes + return [mux_models.MuxRule.from_mux_rule_with_provider_id(mux) for mux in muxes] @v1.put( @@ -689,31 +750,52 @@ async def set_workspace_muxes( ): """Set the mux rules of a workspace.""" try: - await wscrud.set_muxes(workspace_name, request) + mux_rules = await pcrud.add_provider_ids_to_mux_rule_list(request) + await wscrud.set_muxes(workspace_name, mux_rules) + except provendcrud.ProviderNotFoundError as e: + raise HTTPException(status_code=400, detail=str(e)) except crud.WorkspaceDoesNotExistError: raise HTTPException(status_code=404, detail="Workspace does not exist") except crud.WorkspaceCrudError as e: raise HTTPException(status_code=400, detail=str(e)) - except Exception: - logger.exception("Error while setting muxes") + except Exception as e: + logger.exception(f"Error while setting muxes {e}") raise HTTPException(status_code=500, detail="Internal server error") return Response(status_code=204) @v1.get( - "/workspaces/{provider_id}", + "/workspaces/{workspace_name}", tags=["Workspaces"], generate_unique_id_function=uniq_name, ) -async def list_workspaces_by_provider( - provider_id: UUID, -) -> List[WorkspaceWithModel]: +async def get_workspace_by_name( + workspace_name: str, +) -> v1_models.FullWorkspace: """List workspaces by provider ID.""" try: - return await wscrud.workspaces_by_provider(provider_id) + ws = await wscrud.get_workspace_by_name(workspace_name) + db_muxes = await wscrud.get_muxes(workspace_name) + + muxes = [] + for db_mux in db_muxes: + db_endpoint = await pcrud.get_endpoint_by_id(db_mux.provider_endpoint_id) + mux_rule = mux_models.MuxRule.from_db_models(db_mux, db_endpoint) + muxes.append(mux_rule) + + return v1_models.FullWorkspace( + name=ws.name, + config=v1_models.WorkspaceConfig( + custom_instructions=ws.custom_instructions or "", + muxing_rules=muxes, + ), + ) + except crud.WorkspaceDoesNotExistError: + raise HTTPException(status_code=404, detail="Workspace does not exist") except Exception as e: + logger.exception(f"Error while getting workspace {e}") raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index 6489f96d6..97ece660e 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -276,13 +276,18 @@ class ProviderEndpoint(pydantic.BaseModel): @staticmethod def from_db_model(db_model: db_models.ProviderEndpoint) -> "ProviderEndpoint": + auth_type = ( + ProviderAuthType.none + if not db_model.auth_type + else ProviderAuthType(db_model.auth_type) + ) return ProviderEndpoint( id=db_model.id, name=db_model.name, description=db_model.description, provider_type=db_model.provider_type, endpoint=db_model.endpoint, - auth_type=db_model.auth_type, + auth_type=auth_type, ) def to_db_model(self) -> db_models.ProviderEndpoint: @@ -324,7 +329,7 @@ class ModelByProvider(pydantic.BaseModel): """ name: str - provider_id: str + provider_type: db_models.ProviderType provider_name: str def __str__(self): diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index a828b6a37..973a4a1b3 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -37,9 +37,9 @@ ProviderAuthMaterial, ProviderEndpoint, ProviderModel, + ProviderModelIntermediate, Session, WorkspaceRow, - WorkspaceWithModel, WorkspaceWithSessionInfo, ) from codegate.db.token_usage import TokenUsageParser @@ -468,6 +468,7 @@ async def update_provider_endpoint(self, provider: ProviderEndpoint) -> Provider updated_provider = await self._execute_update_pydantic_model( provider, sql, should_raise=True ) + return updated_provider async def delete_provider_endpoint( @@ -499,7 +500,9 @@ async def push_provider_auth_material(self, auth_material: ProviderAuthMaterial) _ = await self._execute_update_pydantic_model(auth_material, sql, should_raise=True) return - async def add_provider_model(self, model: ProviderModel) -> ProviderModel: + async def add_provider_model( + self, model: ProviderModelIntermediate + ) -> ProviderModelIntermediate: sql = text( """ INSERT INTO provider_models (provider_endpoint_id, name) @@ -1006,11 +1009,13 @@ async def get_workspace_by_name(self, name: str) -> Optional[WorkspaceRow]: ) return workspaces[0] if workspaces else None - async def get_workspaces_by_provider(self, provider_id: str) -> List[WorkspaceWithModel]: + async def get_workspaces_by_provider(self, provider_id: str) -> List[WorkspaceRow]: sql = text( """ - SELECT - w.id, w.name, m.provider_model_name + SELECT DISTINCT + w.id, + w.name, + w.custom_instructions FROM workspaces w JOIN muxes m ON w.id = m.workspace_id WHERE m.provider_endpoint_id = :provider_id @@ -1019,7 +1024,7 @@ async def get_workspaces_by_provider(self, provider_id: str) -> List[WorkspaceWi ) conditions = {"provider_id": provider_id} workspaces = await self._exec_select_conditions_to_pydantic( - WorkspaceWithModel, sql, conditions, should_raise=True + WorkspaceRow, sql, conditions, should_raise=True ) return workspaces @@ -1075,11 +1080,63 @@ async def get_provider_endpoint_by_name(self, provider_name: str) -> Optional[Pr ) return provider[0] if provider else None - async def get_provider_endpoint_by_id(self, provider_id: str) -> Optional[ProviderEndpoint]: + async def try_get_provider_endpoint_by_name_and_type( + self, provider_name: str, provider_type: Optional[str] + ) -> Optional[ProviderEndpoint]: + """ + Best effort attempt to find a provider endpoint matching name and type. + + With shareable workspaces, a user may share a workspace with mux rules + that refer to a provider name & type. + + Another user may want to consume those rules, but may not have the exact + same provider names configured. + + This makes the shareable workspace feature a little more robust. + """ + # First try exact match on both name and type sql = text( """ SELECT id, name, description, provider_type, endpoint, auth_type, created_at, updated_at FROM provider_endpoints + WHERE name = :name AND provider_type = :provider_type + """ + ) + conditions = {"name": provider_name, "provider_type": provider_type} + provider = await self._exec_select_conditions_to_pydantic( + ProviderEndpoint, sql, conditions, should_raise=True + ) + if provider: + logger.debug( + f'Found provider "{provider[0].name}" by name "{provider_name}" and type "{provider_type}"' # noqa: E501 + ) + return provider[0] + + # If no exact match, try matching just provider_type + sql = text( + """ + SELECT id, name, description, provider_type, endpoint, auth_type, created_at, updated_at + FROM provider_endpoints + WHERE provider_type = :provider_type + LIMIT 1 + """ + ) + conditions = {"provider_type": provider_type} + provider = await self._exec_select_conditions_to_pydantic( + ProviderEndpoint, sql, conditions, should_raise=True + ) + if provider: + logger.debug( + f'Found provider "{provider[0].name}" by type {provider_type}. Name "{provider_name}" did not match any providers.' # noqa: E501 + ) + return provider[0] + return None + + async def get_provider_endpoint_by_id(self, provider_id: str) -> Optional[ProviderEndpoint]: + sql = text( + """ + SELECT id, name, description, provider_type, endpoint, auth_type + FROM provider_endpoints WHERE id = :id """ ) @@ -1118,10 +1175,11 @@ async def get_provider_endpoints(self) -> List[ProviderEndpoint]: async def get_provider_models_by_provider_id(self, provider_id: str) -> List[ProviderModel]: sql = text( """ - SELECT provider_endpoint_id, name - FROM provider_models - WHERE provider_endpoint_id = :provider_endpoint_id - """ + SELECT pm.provider_endpoint_id, pm.name, pe.name as provider_endpoint_name, pe.provider_type as provider_endpoint_type + FROM provider_models pm + INNER JOIN provider_endpoints pe ON pm.provider_endpoint_id = pe.id + WHERE pm.provider_endpoint_id = :provider_endpoint_id + """ # noqa: E501 ) conditions = {"provider_endpoint_id": provider_id} models = await self._exec_select_conditions_to_pydantic( @@ -1134,10 +1192,11 @@ async def get_provider_model_by_provider_id_and_name( ) -> Optional[ProviderModel]: sql = text( """ - SELECT provider_endpoint_id, name - FROM provider_models - WHERE provider_endpoint_id = :provider_endpoint_id AND name = :name - """ + SELECT pm.provider_endpoint_id, pm.name, pe.name as provider_endpoint_name, pe.provider_type as provider_endpoint_type + FROM provider_models pm + INNER JOIN provider_endpoints pe ON pm.provider_endpoint_id = pe.id + WHERE pm.provider_endpoint_id = :provider_endpoint_id AND pm.name = :name + """ # noqa: E501 ) conditions = {"provider_endpoint_id": provider_id, "name": model_name} models = await self._exec_select_conditions_to_pydantic( @@ -1148,7 +1207,8 @@ async def get_provider_model_by_provider_id_and_name( async def get_all_provider_models(self) -> List[ProviderModel]: sql = text( """ - SELECT pm.provider_endpoint_id, pm.name, pe.name as provider_endpoint_name + SELECT pm.provider_endpoint_id, pm.name, pe.name as + provider_endpoint_name, pe.provider_type as provider_endpoint_type FROM provider_models pm INNER JOIN provider_endpoints pe ON pm.provider_endpoint_id = pe.id """ diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 7f8ef4348..5b3b95e2f 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -253,8 +253,14 @@ class ProviderAuthMaterial(BaseModel): auth_blob: str +class ProviderModelIntermediate(BaseModel): + provider_endpoint_id: str + name: str + + class ProviderModel(BaseModel): provider_endpoint_id: str + provider_endpoint_type: str provider_endpoint_name: Optional[str] = None name: str diff --git a/src/codegate/muxing/models.py b/src/codegate/muxing/models.py index 5637c5b8c..5e74db2e2 100644 --- a/src/codegate/muxing/models.py +++ b/src/codegate/muxing/models.py @@ -5,6 +5,8 @@ from codegate.clients.clients import ClientType from codegate.db.models import MuxRule as DBMuxRule +from codegate.db.models import ProviderEndpoint as DBProviderEndpoint +from codegate.db.models import ProviderType class MuxMatcherType(str, Enum): @@ -39,9 +41,8 @@ class MuxRule(pydantic.BaseModel): Represents a mux rule for a provider. """ - # Used for exportable workspaces - provider_name: Optional[str] = None - provider_id: str + provider_name: str + provider_type: ProviderType model: str # The type of matcher to use matcher_type: MuxMatcherType @@ -50,17 +51,54 @@ class MuxRule(pydantic.BaseModel): matcher: Optional[str] = None @classmethod - def from_db_mux_rule(cls, db_mux_rule: DBMuxRule) -> Self: + def from_db_models( + cls, db_mux_rule: DBMuxRule, db_provider_endpoint: DBProviderEndpoint + ) -> Self: """ - Convert a DBMuxRule to a MuxRule. + Convert a DBMuxRule and DBProviderEndpoint to a MuxRule. """ - return MuxRule( - provider_id=db_mux_rule.id, + return cls( + provider_name=db_provider_endpoint.name, + provider_type=db_provider_endpoint.provider_type, model=db_mux_rule.provider_model_name, - matcher_type=db_mux_rule.matcher_type, + matcher_type=MuxMatcherType(db_mux_rule.matcher_type), matcher=db_mux_rule.matcher_blob, ) + @classmethod + def from_mux_rule_with_provider_id(cls, rule: "MuxRuleWithProviderId") -> Self: + """ + Convert a MuxRuleWithProviderId to a MuxRule. + """ + return cls( + provider_name=rule.provider_name, + provider_type=rule.provider_type, + model=rule.model, + matcher_type=rule.matcher_type, + matcher=rule.matcher, + ) + + +class MuxRuleWithProviderId(MuxRule): + """ + Represents a mux rule for a provider with provider ID. + Used internally for referring to a mux rule. + """ + + provider_id: str + + @classmethod + def from_db_models( + cls, db_mux_rule: DBMuxRule, db_provider_endpoint: DBProviderEndpoint + ) -> Self: + """ + Convert a DBMuxRule and DBProviderEndpoint to a MuxRuleWithProviderId. + """ + return cls( + **MuxRule.from_db_models(db_mux_rule, db_provider_endpoint).model_dump(), + provider_id=db_mux_rule.provider_endpoint_id, + ) + class ThingToMatchMux(pydantic.BaseModel): """ diff --git a/src/codegate/muxing/rulematcher.py b/src/codegate/muxing/rulematcher.py index d41eb2ce0..7f154df7a 100644 --- a/src/codegate/muxing/rulematcher.py +++ b/src/codegate/muxing/rulematcher.py @@ -74,7 +74,11 @@ class MuxingMatcherFactory: """Factory for creating muxing matchers.""" @staticmethod - def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher: + def create( + db_mux_rule: db_models.MuxRule, + db_provider_endpoint: db_models.ProviderEndpoint, + route: ModelRoute, + ) -> MuxingRuleMatcher: """Create a muxing matcher for the given endpoint and model.""" factory: Dict[mux_models.MuxMatcherType, MuxingRuleMatcher] = { @@ -86,7 +90,7 @@ def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatch try: # Initialize the MuxingRuleMatcher - mux_rule = mux_models.MuxRule.from_db_mux_rule(db_mux_rule) + mux_rule = mux_models.MuxRule.from_db_models(db_mux_rule, db_provider_endpoint) return factory[mux_rule.matcher_type](route, mux_rule) except KeyError: raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}") @@ -193,7 +197,8 @@ async def set_ws_rules(self, workspace_name: str, rules: List[MuxingRuleMatcher] async def delete_ws_rules(self, workspace_name: str) -> None: """Delete the rules for the given workspace.""" async with self._lock: - del self._ws_rules[workspace_name] + if workspace_name in self._ws_rules: + del self._ws_rules[workspace_name] async def set_active_workspace(self, workspace_name: str) -> None: """Set the active workspace.""" diff --git a/src/codegate/providers/crud/crud.py b/src/codegate/providers/crud/crud.py index 8bba52b87..56ba63089 100644 --- a/src/codegate/providers/crud/crud.py +++ b/src/codegate/providers/crud/crud.py @@ -10,6 +10,7 @@ from codegate.config import Config from codegate.db import models as dbmodels from codegate.db.connection import DbReader, DbRecorder +from codegate.muxing import models as mux_models from codegate.providers.base import BaseProvider from codegate.providers.registry import ProviderRegistry, get_provider_registry from codegate.workspaces import crud as workspace_crud @@ -67,10 +68,47 @@ async def get_endpoint_by_name(self, name: str) -> Optional[apimodelsv1.Provider dbendpoint = await self._db_reader.get_provider_endpoint_by_name(name) if dbendpoint is None: - return None + raise ProviderNotFoundError(f'Provider "{name}" not found') return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) + async def _try_get_endpoint_by_name_and_type( + self, name: str, type: Optional[str] + ) -> Optional[apimodelsv1.ProviderEndpoint]: + """ + Try to get an endpoint by name & type, + falling back to a "best effort" match by type. + """ + + dbendpoint = await self._db_reader.try_get_provider_endpoint_by_name_and_type(name, type) + if dbendpoint is None: + raise ProviderNotFoundError(f'Provider "{name}" not found') + + return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) + + async def add_provider_id_to_mux_rule( + self, rule: mux_models.MuxRule + ) -> mux_models.MuxRuleWithProviderId: + endpoint = await self._try_get_endpoint_by_name_and_type( + rule.provider_name, rule.provider_type + ) + return mux_models.MuxRuleWithProviderId( + model=rule.model, + matcher=rule.matcher, + matcher_type=rule.matcher_type, + provider_name=endpoint.name, + provider_type=endpoint.provider_type, + provider_id=endpoint.id, + ) + + async def add_provider_ids_to_mux_rule_list( + self, rules: List[mux_models.MuxRule] + ) -> List[mux_models.MuxRuleWithProviderId]: + rules_with_ids = [] + for rule in rules: + rules_with_ids.append(await self.add_provider_id_to_mux_rule(rule)) + return rules_with_ids + async def add_endpoint( self, endpoint: apimodelsv1.AddProviderEndpointRequest ) -> apimodelsv1.ProviderEndpoint: @@ -114,9 +152,9 @@ async def add_endpoint( for model in models: await self._db_writer.add_provider_model( - dbmodels.ProviderModel( - provider_endpoint_id=dbendpoint.id, + dbmodels.ProviderModelIntermediate( name=model, + provider_endpoint_id=dbendpoint.id, ) ) return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) @@ -236,9 +274,9 @@ async def _update_models_for_provider( # Add the models that are in the provider but not in the DB for model in models_set - models_in_db_set: await self._db_writer.add_provider_model( - dbmodels.ProviderModel( - provider_endpoint_id=dbendpoint.id, + dbmodels.ProviderModelIntermediate( name=model, + provider_endpoint_id=dbendpoint.id, ) ) @@ -274,8 +312,8 @@ async def models_by_provider(self, provider_id: UUID) -> List[apimodelsv1.ModelB outmodels.append( apimodelsv1.ModelByProvider( name=dbmodel.name, - provider_id=dbmodel.provider_endpoint_id, provider_name=dbendpoint.name, + provider_type=dbendpoint.provider_type, ) ) @@ -291,8 +329,8 @@ async def get_all_models(self) -> List[apimodelsv1.ModelByProvider]: outmodels.append( apimodelsv1.ModelByProvider( name=dbmodel.name, - provider_id=dbmodel.provider_endpoint_id, provider_name=ename, + provider_type=dbmodel.provider_endpoint_type, ) ) @@ -383,6 +421,8 @@ async def try_initialize_provider_endpoints( dbmodels.ProviderModel( provider_endpoint_id=provend.id, name=model, + provider_endpoint_type=provend.provider_type, + provider_endpoint_name=provend.name, ) ) ) @@ -393,7 +433,6 @@ async def try_initialize_provider_endpoints( async def try_update_to_provider( provcrud: ProviderCrud, prov: BaseProvider, dbprovend: dbmodels.ProviderEndpoint ): - authm = await provcrud._db_reader.get_auth_material_by_provider_id(str(dbprovend.id)) try: diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index fbaf5b994..1dba3a871 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -2,11 +2,15 @@ from typing import List, Optional, Tuple from uuid import uuid4 as uuid +import structlog + from codegate.db import models as db_models from codegate.db.connection import AlreadyExistsError, DbReader, DbRecorder, DbTransaction from codegate.muxing import models as mux_models from codegate.muxing import rulematcher +logger = structlog.get_logger("codegate") + class WorkspaceCrudError(Exception): pass @@ -28,6 +32,10 @@ class WorkspaceMuxRuleDoesNotExistError(WorkspaceCrudError): pass +class DeleteMuxesFromRegistryError(WorkspaceCrudError): + pass + + DEFAULT_WORKSPACE_NAME = "default" # These are reserved keywords that cannot be used for workspaces @@ -43,7 +51,7 @@ async def add_workspace( self, new_workspace_name: str, custom_instructions: Optional[str] = None, - muxing_rules: Optional[List[mux_models.MuxRule]] = None, + muxing_rules: Optional[List[mux_models.MuxRuleWithProviderId]] = None, ) -> Tuple[db_models.WorkspaceRow, List[db_models.MuxRule]]: """ Add a workspace @@ -51,8 +59,8 @@ async def add_workspace( Args: new_workspace_name (str): The name of the workspace system_prompt (Optional[str]): The system prompt for the workspace - muxing_rules (Optional[List[mux_models.MuxRule]]): The muxing rules for the workspace - """ + muxing_rules (Optional[List[mux_models.MuxRuleWithProviderId]]): The muxing rules for the workspace + """ # noqa: E501 if new_workspace_name == "": raise WorkspaceCrudError("Workspace name cannot be empty.") if new_workspace_name in RESERVED_WORKSPACE_KEYWORDS: @@ -92,7 +100,7 @@ async def update_workspace( old_workspace_name: str, new_workspace_name: str, custom_instructions: Optional[str] = None, - muxing_rules: Optional[List[mux_models.MuxRule]] = None, + muxing_rules: Optional[List[mux_models.MuxRuleWithProviderId]] = None, ) -> Tuple[db_models.WorkspaceRow, List[db_models.MuxRule]]: """ Update a workspace @@ -101,8 +109,8 @@ async def update_workspace( old_workspace_name (str): The old name of the workspace new_workspace_name (str): The new name of the workspace system_prompt (Optional[str]): The system prompt for the workspace - muxing_rules (Optional[List[mux_models.MuxRule]]): The muxing rules for the workspace - """ + muxing_rules (Optional[List[mux_models.MuxRuleWithProviderId]]): The muxing rules for the workspace + """ # noqa: E501 if new_workspace_name == "": raise WorkspaceCrudError("Workspace name cannot be empty.") if old_workspace_name == "": @@ -111,8 +119,6 @@ async def update_workspace( raise WorkspaceCrudError("Cannot rename default workspace.") if new_workspace_name in RESERVED_WORKSPACE_KEYWORDS: raise WorkspaceCrudError(f"Workspace name {new_workspace_name} is reserved.") - if old_workspace_name == new_workspace_name: - raise WorkspaceCrudError("Old and new workspace names are the same.") async with DbTransaction() as transaction: try: @@ -122,11 +128,12 @@ async def update_workspace( f"Workspace {old_workspace_name} does not exist." ) - existing_ws = await self._db_reader.get_workspace_by_name(new_workspace_name) - if existing_ws: - raise WorkspaceNameAlreadyInUseError( - f"Workspace name {new_workspace_name} is already in use." - ) + if old_workspace_name != new_workspace_name: + existing_ws = await self._db_reader.get_workspace_by_name(new_workspace_name) + if existing_ws: + raise WorkspaceNameAlreadyInUseError( + f"Workspace name {new_workspace_name} is already in use." + ) new_ws = db_models.WorkspaceRow( id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions @@ -143,7 +150,7 @@ async def update_workspace( await transaction.commit() return workspace_renamed, mux_rules - except (WorkspaceNameAlreadyInUseError, WorkspaceDoesNotExistError) as e: + except (WorkspaceDoesNotExistError, WorkspaceNameAlreadyInUseError) as e: raise e except Exception as e: raise WorkspaceCrudError(f"Error updating workspace {old_workspace_name}: {str(e)}") @@ -234,6 +241,7 @@ async def soft_delete_workspace(self, workspace_name: str): """ Soft delete a workspace """ + if workspace_name == "": raise WorkspaceCrudError("Workspace name cannot be empty.") if workspace_name == DEFAULT_WORKSPACE_NAME: @@ -281,14 +289,16 @@ async def get_workspace_by_name(self, workspace_name: str) -> db_models.Workspac raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") return workspace - async def workspaces_by_provider(self, provider_id: uuid) -> List[db_models.WorkspaceWithModel]: + async def workspaces_by_provider( + self, provider_id: uuid + ) -> List[db_models.WorkspaceWithSessionInfo]: """Get the workspaces by provider.""" workspaces = await self._db_reader.get_workspaces_by_provider(str(provider_id)) return workspaces - async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRule]: + async def get_muxes(self, workspace_name: str) -> List[db_models.MuxRule]: # Verify if workspace exists workspace = await self._db_reader.get_workspace_by_name(workspace_name) if not workspace: @@ -296,22 +306,10 @@ async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRule]: dbmuxes = await self._db_reader.get_muxes_by_workspace(workspace.id) - muxes = [] - # These are already sorted by priority - for dbmux in dbmuxes: - muxes.append( - mux_models.MuxRule( - provider_id=dbmux.provider_endpoint_id, - model=dbmux.provider_model_name, - matcher_type=dbmux.matcher_type, - matcher=dbmux.matcher_blob, - ) - ) - - return muxes + return dbmuxes async def set_muxes( - self, workspace_name: str, muxes: List[mux_models.MuxRule] + self, workspace_name: str, muxes: List[mux_models.MuxRuleWithProviderId] ) -> List[db_models.MuxRule]: # Verify if workspace exists workspace = await self._db_reader.get_workspace_by_name(workspace_name) @@ -324,7 +322,9 @@ async def set_muxes( # Add the new muxes priority = 0 - muxes_with_routes: List[Tuple[mux_models.MuxRule, rulematcher.ModelRoute]] = [] + muxes_with_routes: List[Tuple[mux_models.MuxRuleWithProviderId, rulematcher.ModelRoute]] = ( + [] + ) # Verify all models are valid for mux in muxes: @@ -347,7 +347,8 @@ async def set_muxes( dbmux = await self._db_recorder.add_mux(new_mux) dbmuxes.append(dbmux) - matchers.append(rulematcher.MuxingMatcherFactory.create(dbmux, route)) + provider = await self._db_reader.get_provider_endpoint_by_id(mux.provider_id) + matchers.append(rulematcher.MuxingMatcherFactory.create(dbmux, provider, route)) priority += 1 @@ -357,7 +358,9 @@ async def set_muxes( return dbmuxes - async def get_routing_for_mux(self, mux: mux_models.MuxRule) -> rulematcher.ModelRoute: + async def get_routing_for_mux( + self, mux: mux_models.MuxRuleWithProviderId + ) -> rulematcher.ModelRoute: """Get the routing for a mux Note that this particular mux object is the API model, not the database model. @@ -365,7 +368,7 @@ async def get_routing_for_mux(self, mux: mux_models.MuxRule) -> rulematcher.Mode """ dbprov = await self._db_reader.get_provider_endpoint_by_id(mux.provider_id) if not dbprov: - raise WorkspaceCrudError(f"Provider {mux.provider_id} does not exist") + raise WorkspaceCrudError(f'Provider "{mux.provider_name}" does not exist') dbm = await self._db_reader.get_provider_model_by_provider_id_and_name( mux.provider_id, @@ -373,11 +376,13 @@ async def get_routing_for_mux(self, mux: mux_models.MuxRule) -> rulematcher.Mode ) if not dbm: raise WorkspaceCrudError( - f"Model {mux.model} does not exist for provider {mux.provider_id}" + f'Model "{mux.model}" does not exist for provider "{mux.provider_name}"' ) dbauth = await self._db_reader.get_auth_material_by_provider_id(mux.provider_id) if not dbauth: - raise WorkspaceCrudError(f"Auth material for provider {mux.provider_id} does not exist") + raise WorkspaceCrudError( + f'Auth material for provider "{mux.provider_name}" does not exist' + ) return rulematcher.ModelRoute( endpoint=dbprov, @@ -393,7 +398,7 @@ async def get_routing_for_db_mux(self, mux: db_models.MuxRule) -> rulematcher.Mo """ dbprov = await self._db_reader.get_provider_endpoint_by_id(mux.provider_endpoint_id) if not dbprov: - raise WorkspaceCrudError(f"Provider {mux.provider_endpoint_id} does not exist") + raise WorkspaceCrudError(f'Provider "{mux.provider_endpoint_name}" does not exist') dbm = await self._db_reader.get_provider_model_by_provider_id_and_name( mux.provider_endpoint_id, @@ -407,7 +412,7 @@ async def get_routing_for_db_mux(self, mux: db_models.MuxRule) -> rulematcher.Mo dbauth = await self._db_reader.get_auth_material_by_provider_id(mux.provider_endpoint_id) if not dbauth: raise WorkspaceCrudError( - f"Auth material for provider {mux.provider_endpoint_id} does not exist" + f'Auth material for provider "{mux.provider_endpoint_name}" does not exist' ) return rulematcher.ModelRoute( @@ -448,7 +453,10 @@ async def repopulate_mux_cache(self) -> None: matchers: List[rulematcher.MuxingRuleMatcher] = [] for mux in muxes: + provider = await self._db_reader.get_provider_endpoint_by_id( + mux.provider_endpoint_id + ) route = await self.get_routing_for_db_mux(mux) - matchers.append(rulematcher.MuxingMatcherFactory.create(mux, route)) + matchers.append(rulematcher.MuxingMatcherFactory.create(mux, provider, route)) await mux_registry.set_ws_rules(ws.name, matchers) diff --git a/tests/api/test_v1_providers.py b/tests/api/test_v1_providers.py new file mode 100644 index 000000000..fc0ef6ace --- /dev/null +++ b/tests/api/test_v1_providers.py @@ -0,0 +1,535 @@ +from pathlib import Path +from unittest.mock import MagicMock, patch +from uuid import uuid4 as uuid + +import httpx +import pytest +import structlog +from httpx import AsyncClient + +from codegate.db import connection +from codegate.pipeline.factory import PipelineFactory +from codegate.providers.crud.crud import ProviderCrud +from codegate.server import init_app +from codegate.workspaces.crud import WorkspaceCrud + +logger = structlog.get_logger("codegate") + +# TODO: Abstract the mock DB setup + + +@pytest.fixture +def db_path(): + """Creates a temporary database file path.""" + current_test_dir = Path(__file__).parent + db_filepath = current_test_dir / f"codegate_test_{uuid()}.db" + db_fullpath = db_filepath.absolute() + connection.init_db_sync(str(db_fullpath)) + yield db_fullpath + if db_fullpath.is_file(): + db_fullpath.unlink() + + +@pytest.fixture() +def db_recorder(db_path) -> connection.DbRecorder: + """Creates a DbRecorder instance with test database.""" + return connection.DbRecorder(sqlite_path=db_path, _no_singleton=True) + + +@pytest.fixture() +def db_reader(db_path) -> connection.DbReader: + """Creates a DbReader instance with test database.""" + return connection.DbReader(sqlite_path=db_path, _no_singleton=True) + + +@pytest.fixture() +def mock_workspace_crud(db_recorder, db_reader) -> WorkspaceCrud: + """Creates a WorkspaceCrud instance with test database.""" + ws_crud = WorkspaceCrud() + ws_crud._db_reader = db_reader + ws_crud._db_recorder = db_recorder + return ws_crud + + +@pytest.fixture() +def mock_provider_crud(db_recorder, db_reader, mock_workspace_crud) -> ProviderCrud: + """Creates a ProviderCrud instance with test database.""" + p_crud = ProviderCrud() + p_crud._db_reader = db_reader + p_crud._db_writer = db_recorder + p_crud._ws_crud = mock_workspace_crud + return p_crud + + +@pytest.fixture +def mock_pipeline_factory(): + """Create a mock pipeline factory.""" + mock_factory = MagicMock(spec=PipelineFactory) + mock_factory.create_input_pipeline.return_value = MagicMock() + mock_factory.create_fim_pipeline.return_value = MagicMock() + mock_factory.create_output_pipeline.return_value = MagicMock() + mock_factory.create_fim_output_pipeline.return_value = MagicMock() + return mock_factory + + +@pytest.mark.asyncio +async def test_providers_crud( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], + ), + ): + """Test creating multiple providers and listing them.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create first provider (OpenAI) + provider_payload_1 = { + "name": "openai-provider", + "description": "OpenAI provider description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xyz", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + provider1_response = response.json() + assert provider1_response["name"] == provider_payload_1["name"] + assert provider1_response["description"] == provider_payload_1["description"] + assert provider1_response["auth_type"] == provider_payload_1["auth_type"] + assert provider1_response["provider_type"] == provider_payload_1["provider_type"] + assert provider1_response["endpoint"] == provider_payload_1["endpoint"] + assert isinstance(provider1_response.get("id", ""), str) and provider1_response["id"] + + # Create second provider (OpenRouter) + provider_payload_2 = { + "name": "openrouter-provider", + "description": "OpenRouter provider description", + "auth_type": "none", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "api_key": "sk-or-foo-bar-456-xyz", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + provider2_response = response.json() + assert provider2_response["name"] == provider_payload_2["name"] + assert provider2_response["description"] == provider_payload_2["description"] + assert provider2_response["auth_type"] == provider_payload_2["auth_type"] + assert provider2_response["provider_type"] == provider_payload_2["provider_type"] + assert provider2_response["endpoint"] == provider_payload_2["endpoint"] + assert isinstance(provider2_response.get("id", ""), str) and provider2_response["id"] + + # List all providers + response = await ac.get("/api/v1/provider-endpoints") + assert response.status_code == 200 + providers = response.json() + + # Verify both providers exist in the list + assert isinstance(providers, list) + assert len(providers) == 2 + + # Verify fields for first provider + provider1 = next(p for p in providers if p["name"] == "openai-provider") + assert provider1["description"] == provider_payload_1["description"] + assert provider1["auth_type"] == provider_payload_1["auth_type"] + assert provider1["provider_type"] == provider_payload_1["provider_type"] + assert provider1["endpoint"] == provider_payload_1["endpoint"] + assert isinstance(provider1.get("id", ""), str) and provider1["id"] + + # Verify fields for second provider + provider2 = next(p for p in providers if p["name"] == "openrouter-provider") + assert provider2["description"] == provider_payload_2["description"] + assert provider2["auth_type"] == provider_payload_2["auth_type"] + assert provider2["provider_type"] == provider_payload_2["provider_type"] + assert provider2["endpoint"] == provider_payload_2["endpoint"] + assert isinstance(provider2.get("id", ""), str) and provider2["id"] + + # Get OpenAI provider by name + response = await ac.get("/api/v1/provider-endpoints/openai-provider") + assert response.status_code == 200 + provider = response.json() + assert provider["name"] == provider_payload_1["name"] + assert provider["description"] == provider_payload_1["description"] + assert provider["auth_type"] == provider_payload_1["auth_type"] + assert provider["provider_type"] == provider_payload_1["provider_type"] + assert provider["endpoint"] == provider_payload_1["endpoint"] + assert isinstance(provider["id"], str) and provider["id"] + + # Get OpenRouter provider by name + response = await ac.get("/api/v1/provider-endpoints/openrouter-provider") + assert response.status_code == 200 + provider = response.json() + assert provider["name"] == provider_payload_2["name"] + assert provider["description"] == provider_payload_2["description"] + assert provider["auth_type"] == provider_payload_2["auth_type"] + assert provider["provider_type"] == provider_payload_2["provider_type"] + assert provider["endpoint"] == provider_payload_2["endpoint"] + assert isinstance(provider["id"], str) and provider["id"] + + # Test getting non-existent provider + response = await ac.get("/api/v1/provider-endpoints/non-existent") + assert response.status_code == 404 + assert response.json()["detail"] == "Provider endpoint not found" + + # Test deleting providers + response = await ac.delete("/api/v1/provider-endpoints/openai-provider") + assert response.status_code == 204 + + # Verify provider was deleted by trying to get it + response = await ac.get("/api/v1/provider-endpoints/openai-provider") + assert response.status_code == 404 + assert response.json()["detail"] == "Provider endpoint not found" + + # Delete second provider + response = await ac.delete("/api/v1/provider-endpoints/openrouter-provider") + assert response.status_code == 204 + + # Verify second provider was deleted + response = await ac.get("/api/v1/provider-endpoints/openrouter-provider") + assert response.status_code == 404 + assert response.json()["detail"] == "Provider endpoint not found" + + # Test deleting non-existent provider + response = await ac.delete("/api/v1/provider-endpoints/non-existent") + assert response.status_code == 404 + assert response.json()["detail"] == "Provider endpoint not found" + + # Verify providers list is empty + response = await ac.get("/api/v1/provider-endpoints") + assert response.status_code == 200 + providers = response.json() + assert len(providers) == 0 + + +@pytest.mark.asyncio +async def test_update_provider_endpoint( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + ): + """Test updating a provider endpoint.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create initial provider + provider_payload = { + "name": "test-provider", + "description": "Initial description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.initial.com", + "api_key": "initial-key", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) + assert response.status_code == 201 + initial_provider = response.json() + + # Update the provider + updated_payload = { + "name": "test-provider-updated", + "description": "Updated description", + "auth_type": "api_key", + "provider_type": "openai", + "endpoint": "https://api.updated.com", + "api_key": "updated-key", + } + + response = await ac.put( + "/api/v1/provider-endpoints/test-provider", json=updated_payload + ) + assert response.status_code == 200 + updated_provider = response.json() + + # Verify fields were updated + assert updated_provider["name"] == updated_payload["name"] + assert updated_provider["description"] == updated_payload["description"] + assert updated_provider["auth_type"] == updated_payload["auth_type"] + assert updated_provider["provider_type"] == updated_payload["provider_type"] + assert updated_provider["endpoint"] == updated_payload["endpoint"] + assert updated_provider["id"] == initial_provider["id"] + + # Get OpenRouter provider by name + response = await ac.get("/api/v1/provider-endpoints/test-provider-updated") + assert response.status_code == 200 + provider = response.json() + assert provider["name"] == updated_payload["name"] + assert provider["description"] == updated_payload["description"] + assert provider["auth_type"] == updated_payload["auth_type"] + assert provider["provider_type"] == updated_payload["provider_type"] + assert provider["endpoint"] == updated_payload["endpoint"] + assert isinstance(provider["id"], str) and provider["id"] + + # Test updating non-existent provider + response = await ac.put( + "/api/v1/provider-endpoints/fake-provider", json=updated_payload + ) + assert response.status_code == 404 + assert response.json()["detail"] == "Provider endpoint not found" + + +@pytest.mark.asyncio +async def test_list_providers_by_name( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], + ), + ): + """Test creating multiple providers and listing them by name.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create first provider (OpenAI) + provider_payload_1 = { + "name": "openai-provider", + "description": "OpenAI provider description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xyz", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + + # Create second provider (OpenRouter) + provider_payload_2 = { + "name": "openrouter-provider", + "description": "OpenRouter provider description", + "auth_type": "none", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "api_key": "sk-or-foo-bar-456-xyz", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + + # Test querying providers by name + response = await ac.get("/api/v1/provider-endpoints?name=openai-provider") + assert response.status_code == 200 + providers = response.json() + assert len(providers) == 1 + assert providers[0]["name"] == "openai-provider" + assert isinstance(providers[0]["id"], str) and providers[0]["id"] + + response = await ac.get("/api/v1/provider-endpoints?name=openrouter-provider") + assert response.status_code == 200 + providers = response.json() + assert len(providers) == 1 + assert providers[0]["name"] == "openrouter-provider" + assert isinstance(providers[0]["id"], str) and providers[0]["id"] + + +@pytest.mark.asyncio +async def test_list_all_provider_models( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], + ), + ): + """Test listing all models from all providers.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create OpenAI provider + provider_payload_1 = { + "name": "openai-provider", + "description": "OpenAI provider description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xyz", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + + # Create OpenRouter provider + provider_payload_2 = { + "name": "openrouter-provider", + "description": "OpenRouter provider description", + "auth_type": "none", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "api_key": "sk-or-foo-bar-456-xyz", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + + # Get all models + response = await ac.get("/api/v1/provider-endpoints/models") + assert response.status_code == 200 + models = response.json() + + # Verify response structure and content + assert isinstance(models, list) + assert len(models) == 4 + + # Verify models list structure + assert all(isinstance(model, dict) for model in models) + assert all("name" in model for model in models) + assert all("provider_type" in model for model in models) + assert all("provider_name" in model for model in models) + + # Verify OpenAI provider models + openai_models = [m for m in models if m["provider_name"] == "openai-provider"] + assert len(openai_models) == 2 + assert all(m["provider_type"] == "openai" for m in openai_models) + + # Verify OpenRouter provider models + openrouter_models = [m for m in models if m["provider_name"] == "openrouter-provider"] + assert len(openrouter_models) == 2 + assert all(m["provider_type"] == "openrouter" for m in openrouter_models) + + +@pytest.mark.asyncio +async def test_list_models_by_provider( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], + ), + ): + """Test listing models for a specific provider.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create OpenAI provider + provider_payload = { + "name": "openai-provider", + "description": "OpenAI provider description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xyz", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) + assert response.status_code == 201 + provider = response.json() + provider_name = provider["name"] + + # Get models for the provider + response = await ac.get(f"/api/v1/provider-endpoints/{provider_name}/models") + assert response.status_code == 200 + models = response.json() + + # Verify response structure and content + assert isinstance(models, list) + assert len(models) == 2 + assert all(isinstance(model, dict) for model in models) + assert all("name" in model for model in models) + assert all("provider_type" in model for model in models) + assert all("provider_name" in model for model in models) + assert all(model["provider_type"] == "openai" for model in models) + assert all(model["provider_name"] == "openai-provider" for model in models) + + # Test with non-existent provider ID + fake_name = "foo-bar" + response = await ac.get(f"/api/v1/provider-endpoints/{fake_name}/models") + assert response.status_code == 404 + assert response.json()["detail"] == "Provider not found" + + +@pytest.mark.asyncio +async def test_configure_auth_material( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + ): + """Test configuring auth material for a provider.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create provider + provider_payload = { + "name": "test-provider", + "description": "Test provider", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.test.com", + "api_key": "test-key", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) + assert response.status_code == 201 + + # Configure auth material + auth_material = {"api_key": "sk-proj-foo-bar-123-xyz", "auth_type": "api_key"} + + response = await ac.put( + "/api/v1/provider-endpoints/test-provider/auth-material", json=auth_material + ) + assert response.status_code == 204 + + # Test with non-existent provider + response = await ac.put( + "/api/v1/provider-endpoints/fake-provider/auth-material", json=auth_material + ) + assert response.status_code == 404 + assert response.json()["detail"] == "Provider endpoint not found" diff --git a/tests/api/test_v1_workspaces.py b/tests/api/test_v1_workspaces.py index 8bfcbfaf3..24db9f238 100644 --- a/tests/api/test_v1_workspaces.py +++ b/tests/api/test_v1_workspaces.py @@ -1,5 +1,5 @@ from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 as uuid import httpx @@ -8,6 +8,7 @@ from httpx import AsyncClient from codegate.db import connection +from codegate.muxing.rulematcher import MuxingRulesinWorkspaces from codegate.pipeline.factory import PipelineFactory from codegate.providers.crud.crud import ProviderCrud from codegate.server import init_app @@ -70,67 +71,250 @@ def mock_pipeline_factory(): return mock_factory +@pytest.fixture +def mock_muxing_rules_registry(): + """Creates a mock for the muxing rules registry.""" + mock_registry = AsyncMock(spec=MuxingRulesinWorkspaces) + return mock_registry + + @pytest.mark.asyncio -async def test_create_update_workspace_happy_path( +async def test_workspace_crud_name_only( mock_pipeline_factory, mock_workspace_crud, mock_provider_crud ) -> None: with ( patch("codegate.api.v1.wscrud", mock_workspace_crud), patch("codegate.api.v1.pcrud", mock_provider_crud), + ): + """Test creating and deleting a workspace by name only.""" + + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + name: str = str(uuid()) + + # Create workspace + payload_create = {"name": name} + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 201 + + # Verify workspace exists + response = await ac.get(f"/api/v1/workspaces/{name}") + assert response.status_code == 200 + assert response.json()["name"] == name + + # Delete workspace + response = await ac.delete(f"/api/v1/workspaces/{name}") + assert response.status_code == 204 + + # Verify workspace no longer exists + response = await ac.get(f"/api/v1/workspaces/{name}") + assert response.status_code == 404 + + +@pytest.mark.asyncio +async def test_muxes_crud( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader +) -> None: + with ( + patch("codegate.api.v1.dbreader", db_reader), + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), patch( "codegate.providers.openai.provider.OpenAIProvider.models", - return_value=["foo-bar-001", "foo-bar-002"], + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], ), ): - """Test creating & updating a workspace (happy path).""" + """Test creating and validating mux rules on a workspace.""" app = init_app(mock_pipeline_factory) provider_payload_1 = { - "name": "foo", - "description": "", + "name": "openai-provider", + "description": "OpenAI provider description", "auth_type": "none", "provider_type": "openai", "endpoint": "https://api.openai.com", - "api_key": "sk-proj-foo-bar-123-xzy", + "api_key": "sk-proj-foo-bar-123-xyz", } provider_payload_2 = { - "name": "bar", - "description": "", + "name": "openrouter-provider", + "description": "OpenRouter provider description", + "auth_type": "none", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "api_key": "sk-or-foo-bar-456-xyz", + } + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + + # Create workspace + workspace_name: str = str(uuid()) + custom_instructions: str = "Respond to every request in iambic pentameter" + payload_create = { + "name": workspace_name, + "config": { + "custom_instructions": custom_instructions, + "muxing_rules": [], + }, + } + + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 201 + + # Set mux rules + muxing_rules = [ + { + "provider_name": "openai-provider", + "provider_type": "openai", + "model": "gpt-4", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + { + "provider_name": "openrouter-provider", + "provider_type": "openrouter", + "model": "anthropic/claude-2", + "matcher_type": "catch_all", + "matcher": "", + }, + ] + + response = await ac.put(f"/api/v1/workspaces/{workspace_name}/muxes", json=muxing_rules) + assert response.status_code == 204 + + # Verify mux rules + response = await ac.get(f"/api/v1/workspaces/{workspace_name}") + assert response.status_code == 200 + response_body = response.json() + for i, rule in enumerate(response_body["config"]["muxing_rules"]): + assert rule["provider_name"] == muxing_rules[i]["provider_name"] + assert rule["provider_type"] == muxing_rules[i]["provider_type"] + assert rule["model"] == muxing_rules[i]["model"] + assert rule["matcher"] == muxing_rules[i]["matcher"] + assert rule["matcher_type"] == muxing_rules[i]["matcher_type"] + + +@pytest.mark.asyncio +async def test_create_workspace_and_add_custom_instructions( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + ): + """Test creating a workspace, adding custom + instructions, and validating them.""" + + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + name: str = str(uuid()) + + # Create workspace + payload_create = {"name": name} + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 201 + + # Add custom instructions + custom_instructions = "Respond to every request in iambic pentameter" + payload_instructions = {"prompt": custom_instructions} + response = await ac.put( + f"/api/v1/workspaces/{name}/custom-instructions", json=payload_instructions + ) + assert response.status_code == 204 + + # Validate custom instructions by getting the workspace + response = await ac.get(f"/api/v1/workspaces/{name}") + assert response.status_code == 200 + assert response.json()["config"]["custom_instructions"] == custom_instructions + + # Validate custom instructions by getting the custom instructions endpoint + response = await ac.get(f"/api/v1/workspaces/{name}/custom-instructions") + assert response.status_code == 200 + assert response.json()["prompt"] == custom_instructions + + +@pytest.mark.asyncio +async def test_workspace_crud_full_workspace( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader +) -> None: + with ( + patch("codegate.api.v1.dbreader", db_reader), + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], + ), + ): + """Test creating, updating and reading a workspace.""" + + app = init_app(mock_pipeline_factory) + + provider_payload_1 = { + "name": "openai-provider", + "description": "OpenAI provider description", "auth_type": "none", "provider_type": "openai", "endpoint": "https://api.openai.com", - "api_key": "sk-proj-foo-bar-123-xzy", + "api_key": "sk-proj-foo-bar-123-xyz", + } + + provider_payload_2 = { + "name": "openrouter-provider", + "description": "OpenRouter provider description", + "auth_type": "none", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "api_key": "sk-or-foo-bar-456-xyz", } async with AsyncClient( transport=httpx.ASGITransport(app=app), base_url="http://test" ) as ac: - # Create the first provider response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) assert response.status_code == 201 - provider_1 = response.json() - # Create the second provider response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) assert response.status_code == 201 - provider_2 = response.json() + + # Create workspace name_1: str = str(uuid()) custom_instructions_1: str = "Respond to every request in iambic pentameter" muxing_rules_1 = [ { - "provider_name": None, # optional & not implemented yet - "provider_id": provider_1["id"], - "model": "foo-bar-001", + "provider_name": "openai-provider", + "provider_type": "openai", + "model": "gpt-4", "matcher": "*.ts", "matcher_type": "filename_match", }, { - "provider_name": None, # optional & not implemented yet - "provider_id": provider_2["id"], - "model": "foo-bar-002", + "provider_name": "openai-provider", + "provider_type": "openai", + "model": "gpt-3.5-turbo", "matcher_type": "catch_all", "matcher": "", }, @@ -146,11 +330,17 @@ async def test_create_update_workspace_happy_path( response = await ac.post("/api/v1/workspaces", json=payload_create) assert response.status_code == 201 + + # Verify created workspace + response = await ac.get(f"/api/v1/workspaces/{name_1}") + assert response.status_code == 200 response_body = response.json() assert response_body["name"] == name_1 assert response_body["config"]["custom_instructions"] == custom_instructions_1 for i, rule in enumerate(response_body["config"]["muxing_rules"]): + assert rule["provider_name"] == muxing_rules_1[i]["provider_name"] + assert rule["provider_type"] == muxing_rules_1[i]["provider_type"] assert rule["model"] == muxing_rules_1[i]["model"] assert rule["matcher"] == muxing_rules_1[i]["matcher"] assert rule["matcher_type"] == muxing_rules_1[i]["matcher_type"] @@ -159,16 +349,16 @@ async def test_create_update_workspace_happy_path( custom_instructions_2: str = "Respond to every request in cockney rhyming slang" muxing_rules_2 = [ { - "provider_name": None, # optional & not implemented yet - "provider_id": provider_2["id"], - "model": "foo-bar-002", + "provider_name": "openrouter-provider", + "provider_type": "openrouter", + "model": "anthropic/claude-2", "matcher": "*.ts", "matcher_type": "filename_match", }, { - "provider_name": None, # optional & not implemented yet - "provider_id": provider_1["id"], - "model": "foo-bar-001", + "provider_name": "openrouter-provider", + "provider_type": "openrouter", + "model": "deepseek/deepseek-r1", "matcher_type": "catch_all", "matcher": "", }, @@ -183,46 +373,249 @@ async def test_create_update_workspace_happy_path( } response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update) - assert response.status_code == 201 + assert response.status_code == 200 + + # Verify updated workspace + response = await ac.get(f"/api/v1/workspaces/{name_2}") + assert response.status_code == 200 response_body = response.json() assert response_body["name"] == name_2 assert response_body["config"]["custom_instructions"] == custom_instructions_2 for i, rule in enumerate(response_body["config"]["muxing_rules"]): + assert rule["provider_name"] == muxing_rules_2[i]["provider_name"] + assert rule["provider_type"] == muxing_rules_2[i]["provider_type"] assert rule["model"] == muxing_rules_2[i]["model"] assert rule["matcher"] == muxing_rules_2[i]["matcher"] assert rule["matcher_type"] == muxing_rules_2[i]["matcher_type"] @pytest.mark.asyncio -async def test_create_update_workspace_name_only( - mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +async def test_create_workspace_with_mux_different_provider_name( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader ) -> None: with ( + patch("codegate.api.v1.dbreader", db_reader), patch("codegate.api.v1.wscrud", mock_workspace_crud), patch("codegate.api.v1.pcrud", mock_provider_crud), patch( "codegate.providers.openai.provider.OpenAIProvider.models", - return_value=["foo-bar-001", "foo-bar-002"], + return_value=["gpt-4", "gpt-3.5-turbo"], ), ): - """Test creating & updating a workspace (happy path).""" + """ + Test creating a workspace with mux rules, then recreating it after + renaming the provider. + """ + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create initial provider + provider_payload = { + "name": "test-provider-1", + "description": "Test provider", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.test.com", + "api_key": "test-key", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) + assert response.status_code == 201 + + # Create workspace with mux rules + workspace_name = str(uuid()) + muxing_rules = [ + { + "provider_name": "test-provider-1", + "provider_type": "openai", + "model": "gpt-4", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + { + "provider_name": "test-provider-1", + "provider_type": "openai", + "model": "gpt-3.5-turbo", + "matcher_type": "catch_all", + "matcher": "", + }, + ] + + workspace_payload = { + "name": workspace_name, + "config": { + "custom_instructions": "Test instructions", + "muxing_rules": muxing_rules, + }, + } + + response = await ac.post("/api/v1/workspaces", json=workspace_payload) + assert response.status_code == 201 + + # Get workspace config as JSON blob + response = await ac.get(f"/api/v1/workspaces/{workspace_name}") + assert response.status_code == 200 + workspace_blob = response.json() + + # Delete workspace + response = await ac.delete(f"/api/v1/workspaces/{workspace_name}") + assert response.status_code == 204 + response = await ac.delete(f"/api/v1/workspaces/archive/{workspace_name}") + assert response.status_code == 204 + + # Verify workspace is deleted + response = await ac.get(f"/api/v1/workspaces/{workspace_name}") + assert response.status_code == 404 + + # Update provider name + rename_provider_payload = { + "name": "test-provider-2", + "description": "Test provider", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.test.com", + "api_key": "test-key", + } + + response = await ac.put( + "/api/v1/provider-endpoints/test-provider-1", json=rename_provider_payload + ) + assert response.status_code == 200 + + # Verify old provider name no longer exists + response = await ac.get("/api/v1/provider-endpoints/test-provider-1") + assert response.status_code == 404 + + # Verify provider exists under new name + response = await ac.get("/api/v1/provider-endpoints/test-provider-2") + assert response.status_code == 200 + provider = response.json() + assert provider["name"] == "test-provider-2" + assert provider["description"] == "Test provider" + assert provider["auth_type"] == "none" + assert provider["provider_type"] == "openai" + assert provider["endpoint"] == "https://api.test.com" + + # re-upload the workspace that we have previously downloaded + + response = await ac.post("/api/v1/workspaces", json=workspace_blob) + assert response.status_code == 201 + + # Verify new workspace config + response = await ac.get(f"/api/v1/workspaces/{workspace_name}") + assert response.status_code == 200 + new_workspace = response.json() + + assert new_workspace["name"] == workspace_name + assert ( + new_workspace["config"]["custom_instructions"] + == workspace_blob["config"]["custom_instructions"] + ) + + # Verify muxing rules are correct with updated provider name + for i, rule in enumerate(new_workspace["config"]["muxing_rules"]): + assert rule["provider_name"] == "test-provider-2" + assert ( + rule["provider_type"] + == workspace_blob["config"]["muxing_rules"][i]["provider_type"] + ) + assert rule["model"] == workspace_blob["config"]["muxing_rules"][i]["model"] + assert rule["matcher"] == workspace_blob["config"]["muxing_rules"][i]["matcher"] + assert ( + rule["matcher_type"] + == workspace_blob["config"]["muxing_rules"][i]["matcher_type"] + ) + + +@pytest.mark.asyncio +async def test_rename_workspace( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader +) -> None: + with ( + patch("codegate.api.v1.dbreader", db_reader), + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], + ), + ): + """Test renaming a workspace.""" app = init_app(mock_pipeline_factory) + provider_payload_1 = { + "name": "openai-provider", + "description": "OpenAI provider description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xyz", + } + + provider_payload_2 = { + "name": "openrouter-provider", + "description": "OpenRouter provider description", + "auth_type": "none", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "api_key": "sk-or-foo-bar-456-xyz", + } + async with AsyncClient( transport=httpx.ASGITransport(app=app), base_url="http://test" ) as ac: + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + + # Create workspace + name_1: str = str(uuid()) + custom_instructions: str = "Respond to every request in iambic pentameter" + muxing_rules = [ + { + "provider_name": "openai-provider", + "provider_type": "openai", + "model": "gpt-4", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + { + "provider_name": "openai-provider", + "provider_type": "openai", + "model": "gpt-3.5-turbo", + "matcher_type": "catch_all", + "matcher": "", + }, + ] payload_create = { "name": name_1, + "config": { + "custom_instructions": custom_instructions, + "muxing_rules": muxing_rules, + }, } response = await ac.post("/api/v1/workspaces", json=payload_create) assert response.status_code == 201 response_body = response.json() + assert response_body["name"] == name_1 + # Verify created workspace + response = await ac.get(f"/api/v1/workspaces/{name_1}") + assert response.status_code == 200 + response_body = response.json() assert response_body["name"] == name_1 name_2: str = str(uuid()) @@ -232,9 +625,23 @@ async def test_create_update_workspace_name_only( } response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update) - assert response.status_code == 201 + assert response.status_code == 200 response_body = response.json() + assert response_body["name"] == name_2 + # other fields shouldn't have been touched + assert response_body["config"]["custom_instructions"] == custom_instructions + for i, rule in enumerate(response_body["config"]["muxing_rules"]): + assert rule["provider_name"] == muxing_rules[i]["provider_name"] + assert rule["provider_type"] == muxing_rules[i]["provider_type"] + assert rule["model"] == muxing_rules[i]["model"] + assert rule["matcher"] == muxing_rules[i]["matcher"] + assert rule["matcher_type"] == muxing_rules[i]["matcher_type"] + + # Verify updated workspace + response = await ac.get(f"/api/v1/workspaces/{name_2}") + assert response.status_code == 200 + response_body = response.json() assert response_body["name"] == name_2 @@ -247,7 +654,11 @@ async def test_create_workspace_name_already_in_use( patch("codegate.api.v1.pcrud", mock_provider_crud), patch( "codegate.providers.openai.provider.OpenAIProvider.models", - return_value=["foo-bar-001", "foo-bar-002"], + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], ), ): """Test creating a workspace when the name is already in use.""" @@ -282,7 +693,11 @@ async def test_rename_workspace_name_already_in_use( patch("codegate.api.v1.pcrud", mock_provider_crud), patch( "codegate.providers.openai.provider.OpenAIProvider.models", - return_value=["foo-bar-001", "foo-bar-002"], + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], ), ): """Test renaming a workspace when the new name is already in use.""" @@ -322,14 +737,19 @@ async def test_rename_workspace_name_already_in_use( @pytest.mark.asyncio async def test_create_workspace_with_nonexistent_model_in_muxing_rule( - mock_pipeline_factory, mock_workspace_crud, mock_provider_crud + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader ) -> None: with ( + patch("codegate.api.v1.dbreader", db_reader), patch("codegate.api.v1.wscrud", mock_workspace_crud), patch("codegate.api.v1.pcrud", mock_provider_crud), patch( "codegate.providers.openai.provider.OpenAIProvider.models", - return_value=["foo-bar-001", "foo-bar-002"], + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], ), ): """Test creating a workspace with a muxing rule that uses a nonexistent model.""" @@ -337,28 +757,26 @@ async def test_create_workspace_with_nonexistent_model_in_muxing_rule( app = init_app(mock_pipeline_factory) provider_payload = { - "name": "foo", - "description": "", + "name": "openai-provider", + "description": "OpenAI provider description", "auth_type": "none", "provider_type": "openai", "endpoint": "https://api.openai.com", - "api_key": "sk-proj-foo-bar-123-xzy", + "api_key": "sk-proj-foo-bar-123-xyz", } async with AsyncClient( transport=httpx.ASGITransport(app=app), base_url="http://test" ) as ac: - # Create the first provider response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) assert response.status_code == 201 - provider = response.json() name: str = str(uuid()) custom_instructions: str = "Respond to every request in iambic pentameter" muxing_rules = [ { - "provider_name": None, - "provider_id": provider["id"], + "provider_name": "openai-provider", + "provider_type": "openai", "model": "nonexistent-model", "matcher": "*.ts", "matcher_type": "filename_match", @@ -375,4 +793,189 @@ async def test_create_workspace_with_nonexistent_model_in_muxing_rule( response = await ac.post("/api/v1/workspaces", json=payload_create) assert response.status_code == 400 - assert "Model nonexistent-model does not exist" in response.json()["detail"] + assert "does not exist" in response.json()["detail"] + + +@pytest.mark.asyncio +async def test_list_workspaces_by_provider_name( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader +) -> None: + with ( + patch("codegate.api.v1.dbreader", db_reader), + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], + ), + ): + """Test listing workspaces filtered by provider name.""" + + app = init_app(mock_pipeline_factory) + + provider_payload_1 = { + "name": "openai-provider", + "description": "OpenAI provider description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xyz", + } + + provider_payload_2 = { + "name": "openrouter-provider", + "description": "OpenRouter provider description", + "auth_type": "none", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "api_key": "sk-or-foo-bar-456-xyz", + } + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create providers + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + + # Create workspace + + name_1: str = str(uuid()) + custom_instructions_1: str = "Respond to every request in iambic pentameter" + muxing_rules_1 = [ + { + "provider_name": "openai-provider", + "provider_type": "openai", + "model": "gpt-4", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + { + "provider_name": "openai-provider", + "provider_type": "openai", + "model": "gpt-3.5-turbo", + "matcher_type": "catch_all", + "matcher": "", + }, + ] + + payload_create_1 = { + "name": name_1, + "config": { + "custom_instructions": custom_instructions_1, + "muxing_rules": muxing_rules_1, + }, + } + + response = await ac.post("/api/v1/workspaces", json=payload_create_1) + assert response.status_code == 201 + + name_2: str = str(uuid()) + custom_instructions_2: str = "Respond to every request in cockney rhyming slang" + muxing_rules_2 = [ + { + "provider_name": "openrouter-provider", + "provider_type": "openrouter", + "model": "anthropic/claude-2", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + { + "provider_name": "openrouter-provider", + "provider_type": "openrouter", + "model": "deepseek/deepseek-r1", + "matcher_type": "catch_all", + "matcher": "", + }, + ] + + payload_create_2 = { + "name": name_2, + "config": { + "custom_instructions": custom_instructions_2, + "muxing_rules": muxing_rules_2, + }, + } + + response = await ac.post("/api/v1/workspaces", json=payload_create_2) + assert response.status_code == 201 + + # List workspaces filtered by openai provider + response = await ac.get("/api/v1/workspaces?provider_name=openai-provider") + assert response.status_code == 200 + response_body = response.json() + assert len(response_body["workspaces"]) == 1 + assert response_body["workspaces"][0]["name"] == name_1 + + # List workspaces filtered by openrouter provider + response = await ac.get("/api/v1/workspaces?provider_name=openrouter-provider") + assert response.status_code == 200 + response_body = response.json() + assert len(response_body["workspaces"]) == 1 + assert response_body["workspaces"][0]["name"] == name_2 + + # List workspaces filtered by non-existent provider + response = await ac.get("/api/v1/workspaces?provider_name=foo-bar-123") + assert response.status_code == 200 + response_body = response.json() + assert len(response_body["workspaces"]) == 0 + + # List workspaces unfiltered + response = await ac.get("/api/v1/workspaces") + assert response.status_code == 200 + response_body = response.json() + assert len(response_body["workspaces"]) == 3 # 2 created in test + default + + +@pytest.mark.asyncio +async def test_delete_workspace( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, mock_muxing_rules_registry +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.muxing.rulematcher.get_muxing_rules_registry", + return_value=mock_muxing_rules_registry, + ), + ): + """Test deleting a workspace.""" + + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + name: str = str(uuid()) + payload_create = { + "name": name, + } + + # Create workspace + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 201 + + # Verify workspace exists + response = await ac.get(f"/api/v1/workspaces/{name}") + assert response.status_code == 200 + assert response.json()["name"] == name + + # Delete workspace + response = await ac.delete(f"/api/v1/workspaces/{name}") + assert response.status_code == 204 + + # Verify workspace no longer exists + response = await ac.get(f"/api/v1/workspaces/{name}") + assert response.status_code == 404 + + # Try to delete non-existent workspace + response = await ac.delete("/api/v1/workspaces/nonexistent") + assert response.status_code == 404 + assert response.json()["detail"] == "Workspace does not exist" diff --git a/tests/integration/anthropic/testcases.yaml b/tests/integration/anthropic/testcases.yaml index 03f8f6667..1b50ea79d 100644 --- a/tests/integration/anthropic/testcases.yaml +++ b/tests/integration/anthropic/testcases.yaml @@ -24,6 +24,8 @@ muxing: Content-Type: application/json rules: - model: claude-3-5-haiku-20241022 + provider_name: anthropic_muxing + provider_type: anthropic matcher_type: catch_all matcher: "" diff --git a/tests/integration/llamacpp/testcases.yaml b/tests/integration/llamacpp/testcases.yaml index 69ec72df6..f7422991d 100644 --- a/tests/integration/llamacpp/testcases.yaml +++ b/tests/integration/llamacpp/testcases.yaml @@ -23,6 +23,8 @@ muxing: Content-Type: application/json rules: - model: qwen2.5-coder-0.5b-instruct-q5_k_m + provider_name: llamacpp_muxing + provider_type: llamacpp matcher_type: catch_all matcher: "" diff --git a/tests/integration/ollama/testcases.yaml b/tests/integration/ollama/testcases.yaml index 56a13b571..691fe4faf 100644 --- a/tests/integration/ollama/testcases.yaml +++ b/tests/integration/ollama/testcases.yaml @@ -24,6 +24,8 @@ muxing: rules: - model: qwen2.5-coder:1.5b matcher_type: catch_all + provider_name: ollama_muxing + provider_type: ollama matcher: "" testcases: diff --git a/tests/integration/openai/testcases.yaml b/tests/integration/openai/testcases.yaml index 452dcce6f..fb3730798 100644 --- a/tests/integration/openai/testcases.yaml +++ b/tests/integration/openai/testcases.yaml @@ -24,6 +24,8 @@ muxing: Content-Type: application/json rules: - model: gpt-4o-mini + provider_name: openai_muxing + provider_type: openai matcher_type: catch_all matcher: "" diff --git a/tests/integration/openrouter/testcases.yaml b/tests/integration/openrouter/testcases.yaml index d64e0266a..818acd6a5 100644 --- a/tests/integration/openrouter/testcases.yaml +++ b/tests/integration/openrouter/testcases.yaml @@ -24,6 +24,8 @@ muxing: Content-Type: application/json rules: - model: anthropic/claude-3.5-haiku + provider_name: openrouter_muxing + provider_type: openrouter matcher_type: catch_all matcher: "" diff --git a/tests/integration/vllm/testcases.yaml b/tests/integration/vllm/testcases.yaml index 009783e50..eea2c61d6 100644 --- a/tests/integration/vllm/testcases.yaml +++ b/tests/integration/vllm/testcases.yaml @@ -23,6 +23,8 @@ muxing: Content-Type: application/json rules: - model: Qwen/Qwen2.5-Coder-0.5B-Instruct + provider_name: vllm_muxing + provider_type: vllm matcher_type: catch_all matcher: "" diff --git a/tests/muxing/test_rulematcher.py b/tests/muxing/test_rulematcher.py index 7e551525c..2edd1f975 100644 --- a/tests/muxing/test_rulematcher.py +++ b/tests/muxing/test_rulematcher.py @@ -8,7 +8,10 @@ mocked_route_openai = rulematcher.ModelRoute( db_models.ProviderModel( - provider_endpoint_id="1", provider_endpoint_name="fake-openai", name="fake-gpt" + provider_endpoint_id="1", + provider_endpoint_name="fake-openai", + provider_endpoint_type=db_models.ProviderType.openai, + name="fake-gpt", ), db_models.ProviderEndpoint( id="1", @@ -70,6 +73,8 @@ def test_file_matcher( model="fake-gpt", matcher_type="filename_match", matcher=matcher, + provider_name="fake-openai", + provider_type=db_models.ProviderType.openai, ) muxing_rule_matcher = rulematcher.FileMuxingRuleMatcher(mocked_route_openai, mux_rule) # We mock the _extract_request_filenames method to return a list of filenames @@ -120,6 +125,8 @@ def test_request_file_matcher( model="fake-gpt", matcher_type=matcher_type, matcher=matcher, + provider_name="fake-openai", + provider_type=db_models.ProviderType.openai, ) muxing_rule_matcher = rulematcher.RequestTypeAndFileMuxingRuleMatcher( mocked_route_openai, mux_rule @@ -168,10 +175,23 @@ def test_muxing_matcher_factory(matcher_type, expected_class): matcher_blob="fake-matcher", priority=1, ) + provider_endpoint = db_models.ProviderEndpoint( + id="1", + auth_type="none", + description="", + endpoint="http://localhost:11434", + name="fake-openai", + provider_type="openai", + ) if expected_class: assert isinstance( - rulematcher.MuxingMatcherFactory.create(mux_rule, mocked_route_openai), expected_class + rulematcher.MuxingMatcherFactory.create( + mux_rule, provider_endpoint, mocked_route_openai + ), + expected_class, ) else: with pytest.raises(ValueError): - rulematcher.MuxingMatcherFactory.create(mux_rule, mocked_route_openai) + rulematcher.MuxingMatcherFactory.create( + mux_rule, provider_endpoint, mocked_route_openai + )