diff --git a/api/openapi.json b/api/openapi.json index 24033f88..f924125d 100644 --- a/api/openapi.json +++ b/api/openapi.json @@ -2030,4 +2030,4 @@ } } } -} +} \ No newline at end of file diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index f9d4b00b..bba6ab8e 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -248,22 +248,18 @@ async def activate_workspace(request: v1_models.ActivateWorkspaceRequest, status @v1.post("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name, status_code=201) async def create_workspace( - request: v1_models.CreateOrRenameWorkspaceRequest, -) -> v1_models.Workspace: + request: v1_models.FullWorkspace, +) -> v1_models.FullWorkspace: """Create a new workspace.""" - if request.rename_to is not None: - return await rename_workspace(request) - return await create_new_workspace(request) - - -async def create_new_workspace( - request: v1_models.CreateOrRenameWorkspaceRequest, -) -> v1_models.Workspace: - # Input validation is done in the model try: - _ = await wscrud.add_workspace(request.name) - except AlreadyExistsError: - raise HTTPException(status_code=409, detail="Workspace already exists") + custom_instructions = request.config.custom_instructions if request.config else None + muxing_rules = request.config.muxing_rules if request.config else None + + workspace_row, mux_rules = await wscrud.add_workspace( + request.name, custom_instructions, muxing_rules + ) + except crud.WorkspaceNameAlreadyInUseError: + raise HTTPException(status_code=409, detail="Workspace name already in use") except ValidationError: raise HTTPException( status_code=400, @@ -277,18 +273,40 @@ async def create_new_workspace( except Exception: raise HTTPException(status_code=500, detail="Internal server error") - return v1_models.Workspace(name=request.name, is_active=False) + return v1_models.FullWorkspace( + name=workspace_row.name, + config=v1_models.WorkspaceConfig( + custom_instructions=workspace_row.custom_instructions or "", + muxing_rules=[mux_models.MuxRule.from_db_mux_rule(mux_rule) for mux_rule in mux_rules], + ), + ) -async def rename_workspace( - request: v1_models.CreateOrRenameWorkspaceRequest, -) -> v1_models.Workspace: +@v1.put( + "/workspaces/{workspace_name}", + tags=["Workspaces"], + generate_unique_id_function=uniq_name, + status_code=201, +) +async def update_workspace( + workspace_name: str, + request: v1_models.FullWorkspace, +) -> v1_models.FullWorkspace: + """Update a workspace.""" try: - _ = await wscrud.rename_workspace(request.name, request.rename_to) + custom_instructions = request.config.custom_instructions if request.config else None + muxing_rules = request.config.muxing_rules if request.config else None + + workspace_row, mux_rules = await wscrud.update_workspace( + workspace_name, + request.name, + custom_instructions, + muxing_rules, + ) except crud.WorkspaceDoesNotExistError: raise HTTPException(status_code=404, detail="Workspace does not exist") - except AlreadyExistsError: - raise HTTPException(status_code=409, detail="Workspace already exists") + except crud.WorkspaceNameAlreadyInUseError: + raise HTTPException(status_code=409, detail="Workspace name already in use") except ValidationError: raise HTTPException( status_code=400, @@ -302,7 +320,13 @@ async def rename_workspace( except Exception: raise HTTPException(status_code=500, detail="Internal server error") - return v1_models.Workspace(name=request.rename_to, is_active=False) + return v1_models.FullWorkspace( + name=workspace_row.name, + config=v1_models.WorkspaceConfig( + custom_instructions=workspace_row.custom_instructions or "", + muxing_rules=[mux_models.MuxRule.from_db_mux_rule(mux_rule) for mux_rule in mux_rules], + ), + ) @v1.delete( diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index 51f65ea9..6cbc2be3 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -61,7 +61,7 @@ def from_db_workspaces( class WorkspaceConfig(pydantic.BaseModel): - system_prompt: str + custom_instructions: str muxing_rules: List[mux_models.MuxRule] @@ -72,13 +72,6 @@ class FullWorkspace(pydantic.BaseModel): config: Optional[WorkspaceConfig] = None -class CreateOrRenameWorkspaceRequest(FullWorkspace): - # If set, rename the workspace to this name. Note that - # the 'name' field is still required and the workspace - # workspace must exist. - rename_to: Optional[str] = None - - class ActivateWorkspaceRequest(pydantic.BaseModel): name: str diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 38bf6010..170cb52e 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -14,7 +14,8 @@ from sqlalchemy import CursorResult, TextClause, event, text from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError, OperationalError -from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker from codegate.db.fim_cache import FimCache from codegate.db.models import ( @@ -1025,6 +1026,34 @@ async def get_distance_to_persona( return persona_distance[0] +class DbTransaction: + def __init__(self): + self._session = None + + async def __aenter__(self): + self._session = sessionmaker( + bind=DbCodeGate()._async_db_engine, + class_=AsyncSession, + expire_on_commit=False, + )() + await self._session.begin() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if exc_type: + await self._session.rollback() + raise exc_val + else: + await self._session.commit() + await self._session.close() + + async def commit(self): + await self._session.commit() + + async def rollback(self): + await self._session.rollback() + + def init_db_sync(db_path: Optional[str] = None): """DB will be initialized in the constructor in case it doesn't exist.""" current_dir = Path(__file__).parent diff --git a/src/codegate/pipeline/cli/commands.py b/src/codegate/pipeline/cli/commands.py index 5b101400..c5655ec3 100644 --- a/src/codegate/pipeline/cli/commands.py +++ b/src/codegate/pipeline/cli/commands.py @@ -98,7 +98,6 @@ def help(self) -> str: class CodegateCommandSubcommand(CodegateCommand): - @property @abstractmethod def subcommands(self) -> Dict[str, Callable[[List[str]], Awaitable[str]]]: @@ -174,7 +173,6 @@ async def run(self, args: List[str]) -> str: class Workspace(CodegateCommandSubcommand): - def __init__(self): self.workspace_crud = crud.WorkspaceCrud() @@ -258,7 +256,7 @@ async def _rename_workspace(self, flags: Dict[str, str], args: List[str]) -> str ) try: - await self.workspace_crud.rename_workspace(old_workspace_name, new_workspace_name) + await self.workspace_crud.update_workspace(old_workspace_name, new_workspace_name) except crud.WorkspaceDoesNotExistError: return f"Workspace **{old_workspace_name}** does not exist" except AlreadyExistsError: @@ -410,7 +408,6 @@ def help(self) -> str: class CustomInstructions(CodegateCommandSubcommand): - def __init__(self): self.workspace_crud = crud.WorkspaceCrud() diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index a81426a8..fbaf5b99 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -3,7 +3,7 @@ from uuid import uuid4 as uuid from codegate.db import models as db_models -from codegate.db.connection import DbReader, DbRecorder +from codegate.db.connection import AlreadyExistsError, DbReader, DbRecorder, DbTransaction from codegate.muxing import models as mux_models from codegate.muxing import rulematcher @@ -16,6 +16,10 @@ class WorkspaceDoesNotExistError(WorkspaceCrudError): pass +class WorkspaceNameAlreadyInUseError(WorkspaceCrudError): + pass + + class WorkspaceAlreadyActiveError(WorkspaceCrudError): pass @@ -31,34 +35,73 @@ class WorkspaceMuxRuleDoesNotExistError(WorkspaceCrudError): class WorkspaceCrud: - def __init__(self): self._db_reader = DbReader() - - async def add_workspace(self, new_workspace_name: str) -> db_models.WorkspaceRow: + self._db_recorder = DbRecorder() + + async def add_workspace( + self, + new_workspace_name: str, + custom_instructions: Optional[str] = None, + muxing_rules: Optional[List[mux_models.MuxRule]] = None, + ) -> Tuple[db_models.WorkspaceRow, List[db_models.MuxRule]]: """ Add a workspace Args: - name (str): The name of the workspace + new_workspace_name (str): The name of the workspace + system_prompt (Optional[str]): The system prompt for the workspace + muxing_rules (Optional[List[mux_models.MuxRule]]): The muxing rules for the workspace """ if new_workspace_name == "": raise WorkspaceCrudError("Workspace name cannot be empty.") if new_workspace_name in RESERVED_WORKSPACE_KEYWORDS: raise WorkspaceCrudError(f"Workspace name {new_workspace_name} is reserved.") - db_recorder = DbRecorder() - workspace_created = await db_recorder.add_workspace(new_workspace_name) - return workspace_created - async def rename_workspace( - self, old_workspace_name: str, new_workspace_name: str - ) -> db_models.WorkspaceRow: + async with DbTransaction() as transaction: + try: + existing_ws = await self._db_reader.get_workspace_by_name(new_workspace_name) + if existing_ws: + raise WorkspaceNameAlreadyInUseError( + f"Workspace name {new_workspace_name} is already in use." + ) + + workspace_created = await self._db_recorder.add_workspace(new_workspace_name) + + if custom_instructions: + workspace_created.custom_instructions = custom_instructions + await self._db_recorder.update_workspace(workspace_created) + + mux_rules = [] + if muxing_rules: + mux_rules = await self.set_muxes(new_workspace_name, muxing_rules) + + await transaction.commit() + return workspace_created, mux_rules + except ( + AlreadyExistsError, + WorkspaceDoesNotExistError, + WorkspaceNameAlreadyInUseError, + ) as e: + raise e + except Exception as e: + raise WorkspaceCrudError(f"Error adding workspace {new_workspace_name}: {str(e)}") + + async def update_workspace( + self, + old_workspace_name: str, + new_workspace_name: str, + custom_instructions: Optional[str] = None, + muxing_rules: Optional[List[mux_models.MuxRule]] = None, + ) -> Tuple[db_models.WorkspaceRow, List[db_models.MuxRule]]: """ - Rename a workspace + Update a workspace Args: - old_name (str): The old name of the workspace - new_name (str): The new name of the workspace + old_workspace_name (str): The old name of the workspace + new_workspace_name (str): The new name of the workspace + system_prompt (Optional[str]): The system prompt for the workspace + muxing_rules (Optional[List[mux_models.MuxRule]]): The muxing rules for the workspace """ if new_workspace_name == "": raise WorkspaceCrudError("Workspace name cannot be empty.") @@ -70,15 +113,40 @@ async def rename_workspace( raise WorkspaceCrudError(f"Workspace name {new_workspace_name} is reserved.") if old_workspace_name == new_workspace_name: raise WorkspaceCrudError("Old and new workspace names are the same.") - ws = await self._db_reader.get_workspace_by_name(old_workspace_name) - if not ws: - raise WorkspaceDoesNotExistError(f"Workspace {old_workspace_name} does not exist.") - db_recorder = DbRecorder() - new_ws = db_models.WorkspaceRow( - id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions - ) - workspace_renamed = await db_recorder.update_workspace(new_ws) - return workspace_renamed + + async with DbTransaction() as transaction: + try: + ws = await self._db_reader.get_workspace_by_name(old_workspace_name) + if not ws: + raise WorkspaceDoesNotExistError( + f"Workspace {old_workspace_name} does not exist." + ) + + existing_ws = await self._db_reader.get_workspace_by_name(new_workspace_name) + if existing_ws: + raise WorkspaceNameAlreadyInUseError( + f"Workspace name {new_workspace_name} is already in use." + ) + + new_ws = db_models.WorkspaceRow( + id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions + ) + workspace_renamed = await self._db_recorder.update_workspace(new_ws) + + if custom_instructions: + workspace_renamed.custom_instructions = custom_instructions + await self._db_recorder.update_workspace(workspace_renamed) + + mux_rules = [] + if muxing_rules: + mux_rules = await self.set_muxes(new_workspace_name, muxing_rules) + + await transaction.commit() + return workspace_renamed, mux_rules + except (WorkspaceNameAlreadyInUseError, WorkspaceDoesNotExistError) as e: + raise e + except Exception as e: + raise WorkspaceCrudError(f"Error updating workspace {old_workspace_name}: {str(e)}") async def get_workspaces(self) -> List[db_models.WorkspaceWithSessionInfo]: """ @@ -128,8 +196,7 @@ async def activate_workspace(self, workspace_name: str): session.active_workspace_id = workspace.id session.last_update = datetime.datetime.now(datetime.timezone.utc) - db_recorder = DbRecorder() - await db_recorder.update_session(session) + await self._db_recorder.update_session(session) # Ensure the mux registry is updated mux_registry = await rulematcher.get_muxing_rules_registry() @@ -144,8 +211,7 @@ async def recover_workspace(self, workspace_name: str): if not selected_workspace: raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") - db_recorder = DbRecorder() - await db_recorder.recover_workspace(selected_workspace) + await self._db_recorder.recover_workspace(selected_workspace) return async def update_workspace_custom_instructions( @@ -161,8 +227,7 @@ async def update_workspace_custom_instructions( name=selected_workspace.name, custom_instructions=custom_instructions, ) - db_recorder = DbRecorder() - updated_workspace = await db_recorder.update_workspace(workspace_update) + updated_workspace = await self._db_recorder.update_workspace(workspace_update) return updated_workspace async def soft_delete_workspace(self, workspace_name: str): @@ -183,9 +248,8 @@ async def soft_delete_workspace(self, workspace_name: str): if active_workspace and active_workspace.id == selected_workspace.id: raise WorkspaceCrudError("Cannot archive active workspace.") - db_recorder = DbRecorder() try: - _ = await db_recorder.soft_delete_workspace(selected_workspace) + _ = await self._db_recorder.soft_delete_workspace(selected_workspace) except Exception: raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}") @@ -205,9 +269,8 @@ async def hard_delete_workspace(self, workspace_name: str): if not selected_workspace: raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") - db_recorder = DbRecorder() try: - _ = await db_recorder.hard_delete_workspace(selected_workspace) + _ = await self._db_recorder.hard_delete_workspace(selected_workspace) except Exception: raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}") return @@ -247,15 +310,16 @@ async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRule]: return muxes - async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> None: + async def set_muxes( + self, workspace_name: str, muxes: List[mux_models.MuxRule] + ) -> List[db_models.MuxRule]: # Verify if workspace exists workspace = await self._db_reader.get_workspace_by_name(workspace_name) if not workspace: raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") # Delete all muxes for the workspace - db_recorder = DbRecorder() - await db_recorder.delete_muxes_by_workspace(workspace.id) + await self._db_recorder.delete_muxes_by_workspace(workspace.id) # Add the new muxes priority = 0 @@ -268,6 +332,7 @@ async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> Non muxes_with_routes.append((mux, route)) matchers: List[rulematcher.MuxingRuleMatcher] = [] + dbmuxes: List[db_models.MuxRule] = [] for mux, route in muxes_with_routes: new_mux = db_models.MuxRule( @@ -279,7 +344,8 @@ async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> Non matcher_blob=mux.matcher if mux.matcher else "", priority=priority, ) - dbmux = await db_recorder.add_mux(new_mux) + dbmux = await self._db_recorder.add_mux(new_mux) + dbmuxes.append(dbmux) matchers.append(rulematcher.MuxingMatcherFactory.create(dbmux, route)) @@ -289,6 +355,8 @@ async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> Non mux_registry = await rulematcher.get_muxing_rules_registry() await mux_registry.set_ws_rules(workspace_name, matchers) + return dbmuxes + async def get_routing_for_mux(self, mux: mux_models.MuxRule) -> rulematcher.ModelRoute: """Get the routing for a mux diff --git a/tests/api/test_v1_workspaces.py b/tests/api/test_v1_workspaces.py new file mode 100644 index 00000000..8bfcbfaf --- /dev/null +++ b/tests/api/test_v1_workspaces.py @@ -0,0 +1,378 @@ +from pathlib import Path +from unittest.mock import MagicMock, patch +from uuid import uuid4 as uuid + +import httpx +import pytest +import structlog +from httpx import AsyncClient + +from codegate.db import connection +from codegate.pipeline.factory import PipelineFactory +from codegate.providers.crud.crud import ProviderCrud +from codegate.server import init_app +from codegate.workspaces.crud import WorkspaceCrud + +logger = structlog.get_logger("codegate") + + +@pytest.fixture +def db_path(): + """Creates a temporary database file path.""" + current_test_dir = Path(__file__).parent + db_filepath = current_test_dir / f"codegate_test_{uuid()}.db" + db_fullpath = db_filepath.absolute() + connection.init_db_sync(str(db_fullpath)) + yield db_fullpath + if db_fullpath.is_file(): + db_fullpath.unlink() + + +@pytest.fixture() +def db_recorder(db_path) -> connection.DbRecorder: + """Creates a DbRecorder instance with test database.""" + return connection.DbRecorder(sqlite_path=db_path, _no_singleton=True) + + +@pytest.fixture() +def db_reader(db_path) -> connection.DbReader: + """Creates a DbReader instance with test database.""" + return connection.DbReader(sqlite_path=db_path, _no_singleton=True) + + +@pytest.fixture() +def mock_workspace_crud(db_recorder, db_reader) -> WorkspaceCrud: + """Creates a WorkspaceCrud instance with test database.""" + ws_crud = WorkspaceCrud() + ws_crud._db_reader = db_reader + ws_crud._db_recorder = db_recorder + return ws_crud + + +@pytest.fixture() +def mock_provider_crud(db_recorder, db_reader, mock_workspace_crud) -> ProviderCrud: + """Creates a ProviderCrud instance with test database.""" + p_crud = ProviderCrud() + p_crud._db_reader = db_reader + p_crud._db_writer = db_recorder + p_crud._ws_crud = mock_workspace_crud + return p_crud + + +@pytest.fixture +def mock_pipeline_factory(): + """Create a mock pipeline factory.""" + mock_factory = MagicMock(spec=PipelineFactory) + mock_factory.create_input_pipeline.return_value = MagicMock() + mock_factory.create_fim_pipeline.return_value = MagicMock() + mock_factory.create_output_pipeline.return_value = MagicMock() + mock_factory.create_fim_output_pipeline.return_value = MagicMock() + return mock_factory + + +@pytest.mark.asyncio +async def test_create_update_workspace_happy_path( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["foo-bar-001", "foo-bar-002"], + ), + ): + """Test creating & updating a workspace (happy path).""" + + app = init_app(mock_pipeline_factory) + + provider_payload_1 = { + "name": "foo", + "description": "", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xzy", + } + + provider_payload_2 = { + "name": "bar", + "description": "", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xzy", + } + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create the first provider + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + provider_1 = response.json() + + # Create the second provider + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + provider_2 = response.json() + + name_1: str = str(uuid()) + custom_instructions_1: str = "Respond to every request in iambic pentameter" + muxing_rules_1 = [ + { + "provider_name": None, # optional & not implemented yet + "provider_id": provider_1["id"], + "model": "foo-bar-001", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + { + "provider_name": None, # optional & not implemented yet + "provider_id": provider_2["id"], + "model": "foo-bar-002", + "matcher_type": "catch_all", + "matcher": "", + }, + ] + + payload_create = { + "name": name_1, + "config": { + "custom_instructions": custom_instructions_1, + "muxing_rules": muxing_rules_1, + }, + } + + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 201 + response_body = response.json() + + assert response_body["name"] == name_1 + assert response_body["config"]["custom_instructions"] == custom_instructions_1 + for i, rule in enumerate(response_body["config"]["muxing_rules"]): + assert rule["model"] == muxing_rules_1[i]["model"] + assert rule["matcher"] == muxing_rules_1[i]["matcher"] + assert rule["matcher_type"] == muxing_rules_1[i]["matcher_type"] + + name_2: str = str(uuid()) + custom_instructions_2: str = "Respond to every request in cockney rhyming slang" + muxing_rules_2 = [ + { + "provider_name": None, # optional & not implemented yet + "provider_id": provider_2["id"], + "model": "foo-bar-002", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + { + "provider_name": None, # optional & not implemented yet + "provider_id": provider_1["id"], + "model": "foo-bar-001", + "matcher_type": "catch_all", + "matcher": "", + }, + ] + + payload_update = { + "name": name_2, + "config": { + "custom_instructions": custom_instructions_2, + "muxing_rules": muxing_rules_2, + }, + } + + response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update) + assert response.status_code == 201 + response_body = response.json() + + assert response_body["name"] == name_2 + assert response_body["config"]["custom_instructions"] == custom_instructions_2 + for i, rule in enumerate(response_body["config"]["muxing_rules"]): + assert rule["model"] == muxing_rules_2[i]["model"] + assert rule["matcher"] == muxing_rules_2[i]["matcher"] + assert rule["matcher_type"] == muxing_rules_2[i]["matcher_type"] + + +@pytest.mark.asyncio +async def test_create_update_workspace_name_only( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["foo-bar-001", "foo-bar-002"], + ), + ): + """Test creating & updating a workspace (happy path).""" + + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + name_1: str = str(uuid()) + + payload_create = { + "name": name_1, + } + + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 201 + response_body = response.json() + + assert response_body["name"] == name_1 + + name_2: str = str(uuid()) + + payload_update = { + "name": name_2, + } + + response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update) + assert response.status_code == 201 + response_body = response.json() + + assert response_body["name"] == name_2 + + +@pytest.mark.asyncio +async def test_create_workspace_name_already_in_use( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["foo-bar-001", "foo-bar-002"], + ), + ): + """Test creating a workspace when the name is already in use.""" + + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + name: str = str(uuid()) + + payload_create = { + "name": name, + } + + # Create the workspace for the first time + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 201 + + # Try to create the workspace again with the same name + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 409 + assert response.json()["detail"] == "Workspace name already in use" + + +@pytest.mark.asyncio +async def test_rename_workspace_name_already_in_use( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["foo-bar-001", "foo-bar-002"], + ), + ): + """Test renaming a workspace when the new name is already in use.""" + + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + name_1: str = str(uuid()) + name_2: str = str(uuid()) + + payload_create_1 = { + "name": name_1, + } + + payload_create_2 = { + "name": name_2, + } + + # Create two workspaces + response = await ac.post("/api/v1/workspaces", json=payload_create_1) + assert response.status_code == 201 + + response = await ac.post("/api/v1/workspaces", json=payload_create_2) + assert response.status_code == 201 + + # Try to rename the first workspace to the name of the second workspace + payload_update = { + "name": name_2, + } + + response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update) + assert response.status_code == 409 + assert response.json()["detail"] == "Workspace name already in use" + + +@pytest.mark.asyncio +async def test_create_workspace_with_nonexistent_model_in_muxing_rule( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["foo-bar-001", "foo-bar-002"], + ), + ): + """Test creating a workspace with a muxing rule that uses a nonexistent model.""" + + app = init_app(mock_pipeline_factory) + + provider_payload = { + "name": "foo", + "description": "", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xzy", + } + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create the first provider + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) + assert response.status_code == 201 + provider = response.json() + + name: str = str(uuid()) + custom_instructions: str = "Respond to every request in iambic pentameter" + muxing_rules = [ + { + "provider_name": None, + "provider_id": provider["id"], + "model": "nonexistent-model", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + ] + + payload_create = { + "name": name, + "config": { + "custom_instructions": custom_instructions, + "muxing_rules": muxing_rules, + }, + } + + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 400 + assert "Model nonexistent-model does not exist" in response.json()["detail"]