diff --git a/src/api/generated/@tanstack/react-query.gen.ts b/src/api/generated/@tanstack/react-query.gen.ts index 2ca69ce6..28e8ffbb 100644 --- a/src/api/generated/@tanstack/react-query.gen.ts +++ b/src/api/generated/@tanstack/react-query.gen.ts @@ -7,11 +7,12 @@ import { healthCheckHealthGet, v1ListProviderEndpoints, v1AddProviderEndpoint, + v1ListAllModelsForAllProviders, + v1ListModelsByProvider, v1GetProviderEndpoint, v1UpdateProviderEndpoint, v1DeleteProviderEndpoint, - v1ListModelsByProvider, - v1ListAllModelsForAllProviders, + v1ConfigureAuthMaterial, v1ListWorkspaces, v1CreateWorkspace, v1ListActiveWorkspaces, @@ -36,6 +37,7 @@ import type { V1AddProviderEndpointData, V1AddProviderEndpointError, V1AddProviderEndpointResponse, + V1ListModelsByProviderData, V1GetProviderEndpointData, V1UpdateProviderEndpointData, V1UpdateProviderEndpointError, @@ -43,7 +45,9 @@ import type { V1DeleteProviderEndpointData, V1DeleteProviderEndpointError, V1DeleteProviderEndpointResponse, - V1ListModelsByProviderData, + V1ConfigureAuthMaterialData, + V1ConfigureAuthMaterialError, + V1ConfigureAuthMaterialResponse, V1CreateWorkspaceData, V1CreateWorkspaceError, V1CreateWorkspaceResponse, @@ -190,6 +194,48 @@ export const v1AddProviderEndpointMutation = ( return mutationOptions; }; +export const v1ListAllModelsForAllProvidersQueryKey = ( + options?: OptionsLegacyParser, +) => [createQueryKey("v1ListAllModelsForAllProviders", options)]; + +export const v1ListAllModelsForAllProvidersOptions = ( + options?: OptionsLegacyParser, +) => { + return queryOptions({ + queryFn: async ({ queryKey, signal }) => { + const { data } = await v1ListAllModelsForAllProviders({ + ...options, + ...queryKey[0], + signal, + throwOnError: true, + }); + return data; + }, + queryKey: v1ListAllModelsForAllProvidersQueryKey(options), + }); +}; + +export const v1ListModelsByProviderQueryKey = ( + options: OptionsLegacyParser, +) => [createQueryKey("v1ListModelsByProvider", options)]; + +export const v1ListModelsByProviderOptions = ( + options: OptionsLegacyParser, +) => { + return queryOptions({ + queryFn: async ({ queryKey, signal }) => { + const { data } = await v1ListModelsByProvider({ + ...options, + ...queryKey[0], + signal, + throwOnError: true, + }); + return data; + }, + queryKey: v1ListModelsByProviderQueryKey(options), + }); +}; + export const v1GetProviderEndpointQueryKey = ( options: OptionsLegacyParser, ) => [createQueryKey("v1GetProviderEndpoint", options)]; @@ -251,46 +297,24 @@ export const v1DeleteProviderEndpointMutation = ( return mutationOptions; }; -export const v1ListModelsByProviderQueryKey = ( - options: OptionsLegacyParser, -) => [createQueryKey("v1ListModelsByProvider", options)]; - -export const v1ListModelsByProviderOptions = ( - options: OptionsLegacyParser, -) => { - return queryOptions({ - queryFn: async ({ queryKey, signal }) => { - const { data } = await v1ListModelsByProvider({ - ...options, - ...queryKey[0], - signal, - throwOnError: true, - }); - return data; - }, - queryKey: v1ListModelsByProviderQueryKey(options), - }); -}; - -export const v1ListAllModelsForAllProvidersQueryKey = ( - options?: OptionsLegacyParser, -) => [createQueryKey("v1ListAllModelsForAllProviders", options)]; - -export const v1ListAllModelsForAllProvidersOptions = ( - options?: OptionsLegacyParser, +export const v1ConfigureAuthMaterialMutation = ( + options?: Partial>, ) => { - return queryOptions({ - queryFn: async ({ queryKey, signal }) => { - const { data } = await v1ListAllModelsForAllProviders({ + const mutationOptions: UseMutationOptions< + V1ConfigureAuthMaterialResponse, + V1ConfigureAuthMaterialError, + OptionsLegacyParser + > = { + mutationFn: async (localOptions) => { + const { data } = await v1ConfigureAuthMaterial({ ...options, - ...queryKey[0], - signal, + ...localOptions, throwOnError: true, }); return data; }, - queryKey: v1ListAllModelsForAllProvidersQueryKey(options), - }); + }; + return mutationOptions; }; export const v1ListWorkspacesQueryKey = (options?: OptionsLegacyParser) => [ diff --git a/src/api/generated/sdk.gen.ts b/src/api/generated/sdk.gen.ts index 37fe6c10..8069e38d 100644 --- a/src/api/generated/sdk.gen.ts +++ b/src/api/generated/sdk.gen.ts @@ -14,6 +14,11 @@ import type { V1AddProviderEndpointData, V1AddProviderEndpointError, V1AddProviderEndpointResponse, + V1ListAllModelsForAllProvidersError, + V1ListAllModelsForAllProvidersResponse, + V1ListModelsByProviderData, + V1ListModelsByProviderError, + V1ListModelsByProviderResponse, V1GetProviderEndpointData, V1GetProviderEndpointError, V1GetProviderEndpointResponse, @@ -23,11 +28,9 @@ import type { V1DeleteProviderEndpointData, V1DeleteProviderEndpointError, V1DeleteProviderEndpointResponse, - V1ListModelsByProviderData, - V1ListModelsByProviderError, - V1ListModelsByProviderResponse, - V1ListAllModelsForAllProvidersError, - V1ListAllModelsForAllProvidersResponse, + V1ConfigureAuthMaterialData, + V1ConfigureAuthMaterialError, + V1ConfigureAuthMaterialResponse, V1ListWorkspacesError, V1ListWorkspacesResponse, V1CreateWorkspaceData, @@ -131,6 +134,42 @@ export const v1AddProviderEndpoint = ( }); }; +/** + * List All Models For All Providers + * List all models for all providers. + */ +export const v1ListAllModelsForAllProviders = < + ThrowOnError extends boolean = false, +>( + options?: OptionsLegacyParser, +) => { + return (options?.client ?? client).get< + V1ListAllModelsForAllProvidersResponse, + V1ListAllModelsForAllProvidersError, + ThrowOnError + >({ + ...options, + url: "/api/v1/provider-endpoints/models", + }); +}; + +/** + * List Models By Provider + * List models by provider. + */ +export const v1ListModelsByProvider = ( + options: OptionsLegacyParser, +) => { + return (options?.client ?? client).get< + V1ListModelsByProviderResponse, + V1ListModelsByProviderError, + ThrowOnError + >({ + ...options, + url: "/api/v1/provider-endpoints/{provider_id}/models", + }); +}; + /** * Get Provider Endpoint * Get a provider endpoint by ID. @@ -183,38 +222,19 @@ export const v1DeleteProviderEndpoint = ( }; /** - * List Models By Provider - * List models by provider. - */ -export const v1ListModelsByProvider = ( - options: OptionsLegacyParser, -) => { - return (options?.client ?? client).get< - V1ListModelsByProviderResponse, - V1ListModelsByProviderError, - ThrowOnError - >({ - ...options, - url: "/api/v1/provider-endpoints/{provider_name}/models", - }); -}; - -/** - * List All Models For All Providers - * List all models for all providers. + * Configure Auth Material + * Configure auth material for a provider. */ -export const v1ListAllModelsForAllProviders = < - ThrowOnError extends boolean = false, ->( - options?: OptionsLegacyParser, +export const v1ConfigureAuthMaterial = ( + options: OptionsLegacyParser, ) => { - return (options?.client ?? client).get< - V1ListAllModelsForAllProvidersResponse, - V1ListAllModelsForAllProvidersError, + return (options?.client ?? client).put< + V1ConfigureAuthMaterialResponse, + V1ConfigureAuthMaterialError, ThrowOnError >({ ...options, - url: "/api/v1/provider-endpoints/models", + url: "/api/v1/provider-endpoints/{provider_id}/auth-material", }); }; diff --git a/src/api/generated/types.gen.ts b/src/api/generated/types.gen.ts index c45aecb6..ec403510 100644 --- a/src/api/generated/types.gen.ts +++ b/src/api/generated/types.gen.ts @@ -44,6 +44,14 @@ export type CodeSnippet = { libraries?: Array; }; +/** + * Represents a request to configure auth material for a provider. + */ +export type ConfigureAuthMaterial = { + auth_type: ProviderAuthType; + api_key?: string | null; +}; + /** * Represents a conversation. */ @@ -84,7 +92,8 @@ export type ListWorkspacesResponse = { */ export type ModelByProvider = { name: string; - provider: string; + provider_id: string; + provider_name: string; }; /** @@ -99,10 +108,10 @@ export enum MuxMatcherType { * Represents a mux rule for a provider. */ export type MuxRule = { - provider: string; + provider_id: string; model: string; matcher_type: MuxMatcherType; - matcher: string | null; + matcher?: string | null; }; /** @@ -120,12 +129,12 @@ export enum ProviderAuthType { * so we can use this for muxing messages. */ export type ProviderEndpoint = { - id: number; + id?: string | null; name: string; description?: string; provider_type: ProviderType; endpoint: string; - auth_type: ProviderAuthType; + auth_type?: ProviderAuthType | null; }; /** @@ -135,8 +144,9 @@ export enum ProviderType { OPENAI = "openai", ANTHROPIC = "anthropic", VLLM = "vllm", - LLAMACPP = "llamacpp", OLLAMA = "ollama", + LM_STUDIO = "lm_studio", + LLAMACPP = "llamacpp", } /** @@ -216,9 +226,23 @@ export type V1AddProviderEndpointResponse = ProviderEndpoint; export type V1AddProviderEndpointError = HTTPValidationError; +export type V1ListAllModelsForAllProvidersResponse = Array; + +export type V1ListAllModelsForAllProvidersError = unknown; + +export type V1ListModelsByProviderData = { + path: { + provider_id: string; + }; +}; + +export type V1ListModelsByProviderResponse = Array; + +export type V1ListModelsByProviderError = HTTPValidationError; + export type V1GetProviderEndpointData = { path: { - provider_id: number; + provider_id: string; }; }; @@ -229,7 +253,7 @@ export type V1GetProviderEndpointError = HTTPValidationError; export type V1UpdateProviderEndpointData = { body: ProviderEndpoint; path: { - provider_id: number; + provider_id: string; }; }; @@ -239,7 +263,7 @@ export type V1UpdateProviderEndpointError = HTTPValidationError; export type V1DeleteProviderEndpointData = { path: { - provider_id: number; + provider_id: string; }; }; @@ -247,19 +271,16 @@ export type V1DeleteProviderEndpointResponse = unknown; export type V1DeleteProviderEndpointError = HTTPValidationError; -export type V1ListModelsByProviderData = { +export type V1ConfigureAuthMaterialData = { + body: ConfigureAuthMaterial; path: { - provider_name: string; + provider_id: string; }; }; -export type V1ListModelsByProviderResponse = Array; - -export type V1ListModelsByProviderError = HTTPValidationError; +export type V1ConfigureAuthMaterialResponse = void; -export type V1ListAllModelsForAllProvidersResponse = Array; - -export type V1ListAllModelsForAllProvidersError = unknown; +export type V1ConfigureAuthMaterialError = HTTPValidationError; export type V1ListWorkspacesResponse = ListWorkspacesResponse; diff --git a/src/api/openapi.json b/src/api/openapi.json index 7c11f8c8..67d67ecb 100644 --- a/src/api/openapi.json +++ b/src/api/openapi.json @@ -8,7 +8,9 @@ "paths": { "/health": { "get": { - "tags": ["System"], + "tags": [ + "System" + ], "summary": "Health Check", "operationId": "health_check_health_get", "responses": { @@ -25,7 +27,10 @@ }, "/api/v1/provider-endpoints": { "get": { - "tags": ["CodeGate API", "Providers"], + "tags": [ + "CodeGate API", + "Providers" + ], "summary": "List Provider Endpoints", "description": "List all provider endpoints.", "operationId": "v1_list_provider_endpoints", @@ -75,7 +80,10 @@ } }, "post": { - "tags": ["CodeGate API", "Providers"], + "tags": [ + "CodeGate API", + "Providers" + ], "summary": "Add Provider Endpoint", "description": "Add a provider endpoint.", "operationId": "v1_add_provider_endpoint", @@ -113,9 +121,88 @@ } } }, + "/api/v1/provider-endpoints/models": { + "get": { + "tags": [ + "CodeGate API", + "Providers" + ], + "summary": "List All Models For All Providers", + "description": "List all models for all providers.", + "operationId": "v1_list_all_models_for_all_providers", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/ModelByProvider" + }, + "type": "array", + "title": "Response V1 List All Models For All Providers" + } + } + } + } + } + } + }, + "/api/v1/provider-endpoints/{provider_id}/models": { + "get": { + "tags": [ + "CodeGate API", + "Providers" + ], + "summary": "List Models By Provider", + "description": "List models by provider.", + "operationId": "v1_list_models_by_provider", + "parameters": [ + { + "name": "provider_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "format": "uuid", + "title": "Provider Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ModelByProvider" + }, + "title": "Response V1 List Models By Provider" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/api/v1/provider-endpoints/{provider_id}": { "get": { - "tags": ["CodeGate API", "Providers"], + "tags": [ + "CodeGate API", + "Providers" + ], "summary": "Get Provider Endpoint", "description": "Get a provider endpoint by ID.", "operationId": "v1_get_provider_endpoint", @@ -125,7 +212,8 @@ "in": "path", "required": true, "schema": { - "type": "integer", + "type": "string", + "format": "uuid", "title": "Provider Id" } } @@ -154,7 +242,10 @@ } }, "put": { - "tags": ["CodeGate API", "Providers"], + "tags": [ + "CodeGate API", + "Providers" + ], "summary": "Update Provider Endpoint", "description": "Update a provider endpoint by ID.", "operationId": "v1_update_provider_endpoint", @@ -164,7 +255,8 @@ "in": "path", "required": true, "schema": { - "type": "integer", + "type": "string", + "format": "uuid", "title": "Provider Id" } } @@ -203,7 +295,10 @@ } }, "delete": { - "tags": ["CodeGate API", "Providers"], + "tags": [ + "CodeGate API", + "Providers" + ], "summary": "Delete Provider Endpoint", "description": "Delete a provider endpoint by id.", "operationId": "v1_delete_provider_endpoint", @@ -213,7 +308,8 @@ "in": "path", "required": true, "schema": { - "type": "integer", + "type": "string", + "format": "uuid", "title": "Provider Id" } } @@ -240,37 +336,40 @@ } } }, - "/api/v1/provider-endpoints/{provider_name}/models": { - "get": { - "tags": ["CodeGate API", "Providers"], - "summary": "List Models By Provider", - "description": "List models by provider.", - "operationId": "v1_list_models_by_provider", + "/api/v1/provider-endpoints/{provider_id}/auth-material": { + "put": { + "tags": [ + "CodeGate API", + "Providers" + ], + "summary": "Configure Auth Material", + "description": "Configure auth material for a provider.", + "operationId": "v1_configure_auth_material", "parameters": [ { - "name": "provider_name", + "name": "provider_id", "in": "path", "required": true, "schema": { "type": "string", - "title": "Provider Name" + "format": "uuid", + "title": "Provider Id" } } ], - "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ModelByProvider" - }, - "title": "Response V1 List Models By Provider" - } + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ConfigureAuthMaterial" } } + } + }, + "responses": { + "204": { + "description": "Successful Response" }, "422": { "description": "Validation Error", @@ -285,33 +384,12 @@ } } }, - "/api/v1/provider-endpoints/models": { - "get": { - "tags": ["CodeGate API", "Providers"], - "summary": "List All Models For All Providers", - "description": "List all models for all providers.", - "operationId": "v1_list_all_models_for_all_providers", - "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": { - "items": { - "$ref": "#/components/schemas/ModelByProvider" - }, - "type": "array", - "title": "Response V1 List All Models For All Providers" - } - } - } - } - } - } - }, "/api/v1/workspaces": { "get": { - "tags": ["CodeGate API", "Workspaces"], + "tags": [ + "CodeGate API", + "Workspaces" + ], "summary": "List Workspaces", "description": "List all workspaces.", "operationId": "v1_list_workspaces", @@ -329,7 +407,10 @@ } }, "post": { - "tags": ["CodeGate API", "Workspaces"], + "tags": [ + "CodeGate API", + "Workspaces" + ], "summary": "Create Workspace", "description": "Create a new workspace.", "operationId": "v1_create_workspace", @@ -369,7 +450,10 @@ }, "/api/v1/workspaces/active": { "get": { - "tags": ["CodeGate API", "Workspaces"], + "tags": [ + "CodeGate API", + "Workspaces" + ], "summary": "List Active Workspaces", "description": "List all active workspaces.\n\nIn it's current form, this function will only return one workspace. That is,\nthe globally active workspace.", "operationId": "v1_list_active_workspaces", @@ -387,7 +471,10 @@ } }, "post": { - "tags": ["CodeGate API", "Workspaces"], + "tags": [ + "CodeGate API", + "Workspaces" + ], "summary": "Activate Workspace", "description": "Activate a workspace by name.", "operationId": "v1_activate_workspace", @@ -436,7 +523,10 @@ }, "/api/v1/workspaces/{workspace_name}": { "delete": { - "tags": ["CodeGate API", "Workspaces"], + "tags": [ + "CodeGate API", + "Workspaces" + ], "summary": "Delete Workspace", "description": "Delete a workspace by name.", "operationId": "v1_delete_workspace", @@ -475,7 +565,10 @@ }, "/api/v1/workspaces/archive": { "get": { - "tags": ["CodeGate API", "Workspaces"], + "tags": [ + "CodeGate API", + "Workspaces" + ], "summary": "List Archived Workspaces", "description": "List all archived workspaces.", "operationId": "v1_list_archived_workspaces", @@ -495,7 +588,10 @@ }, "/api/v1/workspaces/archive/{workspace_name}/recover": { "post": { - "tags": ["CodeGate API", "Workspaces"], + "tags": [ + "CodeGate API", + "Workspaces" + ], "summary": "Recover Workspace", "description": "Recover an archived workspace by name.", "operationId": "v1_recover_workspace", @@ -529,7 +625,10 @@ }, "/api/v1/workspaces/archive/{workspace_name}": { "delete": { - "tags": ["CodeGate API", "Workspaces"], + "tags": [ + "CodeGate API", + "Workspaces" + ], "summary": "Hard Delete Workspace", "description": "Hard delete an archived workspace by name.", "operationId": "v1_hard_delete_workspace", @@ -568,7 +667,10 @@ }, "/api/v1/workspaces/{workspace_name}/alerts": { "get": { - "tags": ["CodeGate API", "Workspaces"], + "tags": [ + "CodeGate API", + "Workspaces" + ], "summary": "Get Workspace Alerts", "description": "Get alerts for a workspace.", "operationId": "v1_get_workspace_alerts", @@ -620,7 +722,10 @@ }, "/api/v1/workspaces/{workspace_name}/messages": { "get": { - "tags": ["CodeGate API", "Workspaces"], + "tags": [ + "CodeGate API", + "Workspaces" + ], "summary": "Get Workspace Messages", "description": "Get messages for a workspace.", "operationId": "v1_get_workspace_messages", @@ -665,7 +770,10 @@ }, "/api/v1/workspaces/{workspace_name}/custom-instructions": { "get": { - "tags": ["CodeGate API", "Workspaces"], + "tags": [ + "CodeGate API", + "Workspaces" + ], "summary": "Get Workspace Custom Instructions", "description": "Get the custom instructions of a workspace.", "operationId": "v1_get_workspace_custom_instructions", @@ -704,7 +812,10 @@ } }, "put": { - "tags": ["CodeGate API", "Workspaces"], + "tags": [ + "CodeGate API", + "Workspaces" + ], "summary": "Set Workspace Custom Instructions", "operationId": "v1_set_workspace_custom_instructions", "parameters": [ @@ -745,7 +856,10 @@ } }, "delete": { - "tags": ["CodeGate API", "Workspaces"], + "tags": [ + "CodeGate API", + "Workspaces" + ], "summary": "Delete Workspace Custom Instructions", "operationId": "v1_delete_workspace_custom_instructions", "parameters": [ @@ -778,7 +892,11 @@ }, "/api/v1/workspaces/{workspace_name}/muxes": { "get": { - "tags": ["CodeGate API", "Workspaces", "Muxes"], + "tags": [ + "CodeGate API", + "Workspaces", + "Muxes" + ], "summary": "Get Workspace Muxes", "description": "Get the mux rules of a workspace.\n\nThe list is ordered in order of priority. That is, the first rule in the list\nhas the highest priority.", "operationId": "v1_get_workspace_muxes", @@ -821,7 +939,11 @@ } }, "put": { - "tags": ["CodeGate API", "Workspaces", "Muxes"], + "tags": [ + "CodeGate API", + "Workspaces", + "Muxes" + ], "summary": "Set Workspace Muxes", "description": "Set the mux rules of a workspace.", "operationId": "v1_set_workspace_muxes", @@ -869,7 +991,10 @@ }, "/api/v1/alerts_notification": { "get": { - "tags": ["CodeGate API", "Dashboard"], + "tags": [ + "CodeGate API", + "Dashboard" + ], "summary": "Stream Sse", "description": "Send alerts event", "operationId": "v1_stream_sse", @@ -887,7 +1012,10 @@ }, "/api/v1/version": { "get": { - "tags": ["CodeGate API", "Dashboard"], + "tags": [ + "CodeGate API", + "Dashboard" + ], "summary": "Version Check", "operationId": "v1_version_check", "responses": { @@ -904,7 +1032,11 @@ }, "/api/v1/workspaces/{workspace_name}/token-usage": { "get": { - "tags": ["CodeGate API", "Workspaces", "Token Usage"], + "tags": [ + "CodeGate API", + "Workspaces", + "Token Usage" + ], "summary": "Get Workspace Token Usage", "description": "Get the token usage of a workspace.", "operationId": "v1_get_workspace_token_usage", @@ -954,7 +1086,9 @@ } }, "type": "object", - "required": ["name"], + "required": [ + "name" + ], "title": "ActivateWorkspaceRequest" }, "ActiveWorkspace": { @@ -972,7 +1106,11 @@ } }, "type": "object", - "required": ["name", "is_active", "last_updated"], + "required": [ + "name", + "is_active", + "last_updated" + ], "title": "ActiveWorkspace" }, "AlertConversation": { @@ -1059,7 +1197,11 @@ } }, "type": "object", - "required": ["message", "timestamp", "message_id"], + "required": [ + "message", + "timestamp", + "message_id" + ], "title": "ChatMessage", "description": "Represents a chat message." }, @@ -1100,9 +1242,37 @@ } }, "type": "object", - "required": ["code", "language", "filepath"], + "required": [ + "code", + "language", + "filepath" + ], "title": "CodeSnippet" }, + "ConfigureAuthMaterial": { + "properties": { + "auth_type": { + "$ref": "#/components/schemas/ProviderAuthType" + }, + "api_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Api Key" + } + }, + "type": "object", + "required": [ + "auth_type" + ], + "title": "ConfigureAuthMaterial", + "description": "Represents a request to configure auth material for a provider." + }, "Conversation": { "properties": { "question_answers": { @@ -1177,7 +1347,9 @@ } }, "type": "object", - "required": ["name"], + "required": [ + "name" + ], "title": "CreateOrRenameWorkspaceRequest" }, "CustomInstructions": { @@ -1188,7 +1360,9 @@ } }, "type": "object", - "required": ["prompt"], + "required": [ + "prompt" + ], "title": "CustomInstructions" }, "HTTPValidationError": { @@ -1215,7 +1389,9 @@ } }, "type": "object", - "required": ["workspaces"], + "required": [ + "workspaces" + ], "title": "ListActiveWorkspacesResponse" }, "ListWorkspacesResponse": { @@ -1229,7 +1405,9 @@ } }, "type": "object", - "required": ["workspaces"], + "required": [ + "workspaces" + ], "title": "ListWorkspacesResponse" }, "ModelByProvider": { @@ -1238,27 +1416,38 @@ "type": "string", "title": "Name" }, - "provider": { + "provider_id": { "type": "string", - "title": "Provider" + "title": "Provider Id" + }, + "provider_name": { + "type": "string", + "title": "Provider Name" } }, "type": "object", - "required": ["name", "provider"], + "required": [ + "name", + "provider_id", + "provider_name" + ], "title": "ModelByProvider", "description": "Represents a model supported by a provider.\n\nNote that these are auto-discovered by the provider." }, "MuxMatcherType": { "type": "string", - "enum": ["file_regex", "catch_all"], + "enum": [ + "file_regex", + "catch_all" + ], "title": "MuxMatcherType", "description": "Represents the different types of matchers we support." }, "MuxRule": { "properties": { - "provider": { + "provider_id": { "type": "string", - "title": "Provider" + "title": "Provider Id" }, "model": { "type": "string", @@ -1280,21 +1469,37 @@ } }, "type": "object", - "required": ["provider", "model", "matcher_type", "matcher"], + "required": [ + "provider_id", + "model", + "matcher_type" + ], "title": "MuxRule", "description": "Represents a mux rule for a provider." }, "ProviderAuthType": { "type": "string", - "enum": ["none", "passthrough", "api_key"], + "enum": [ + "none", + "passthrough", + "api_key" + ], "title": "ProviderAuthType", "description": "Represents the different types of auth we support for providers." }, "ProviderEndpoint": { "properties": { "id": { - "type": "integer", - "title": "Id" + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Id", + "default": "" }, "name": { "type": "string", @@ -1313,17 +1518,36 @@ "title": "Endpoint" }, "auth_type": { - "$ref": "#/components/schemas/ProviderAuthType" + "anyOf": [ + { + "$ref": "#/components/schemas/ProviderAuthType" + }, + { + "type": "null" + } + ], + "default": "none" } }, "type": "object", - "required": ["id", "name", "provider_type", "endpoint", "auth_type"], + "required": [ + "name", + "provider_type", + "endpoint" + ], "title": "ProviderEndpoint", "description": "Represents a provider's endpoint configuration. This\nallows us to persist the configuration for each provider,\nso we can use this for muxing messages." }, "ProviderType": { "type": "string", - "enum": ["openai", "anthropic", "vllm", "llamacpp", "ollama"], + "enum": [ + "openai", + "anthropic", + "vllm", + "ollama", + "lm_studio", + "llamacpp" + ], "title": "ProviderType", "description": "Represents the different types of providers we support." }, @@ -1344,13 +1568,19 @@ } }, "type": "object", - "required": ["question", "answer"], + "required": [ + "question", + "answer" + ], "title": "QuestionAnswer", "description": "Represents a question and answer pair." }, "QuestionType": { "type": "string", - "enum": ["chat", "fim"], + "enum": [ + "chat", + "fim" + ], "title": "QuestionType" }, "TokenUsage": { @@ -1394,7 +1624,10 @@ } }, "type": "object", - "required": ["tokens_by_model", "token_usage"], + "required": [ + "tokens_by_model", + "token_usage" + ], "title": "TokenUsageAggregate", "description": "Represents the tokens used. Includes the information of the tokens used by model.\n`used_tokens` are the total tokens used in the `tokens_by_model` list." }, @@ -1412,7 +1645,11 @@ } }, "type": "object", - "required": ["provider_type", "model", "token_usage"], + "required": [ + "provider_type", + "model", + "token_usage" + ], "title": "TokenUsageByModel", "description": "Represents the tokens used by a model." }, @@ -1442,7 +1679,11 @@ } }, "type": "object", - "required": ["loc", "msg", "type"], + "required": [ + "loc", + "msg", + "type" + ], "title": "ValidationError" }, "Workspace": { @@ -1457,7 +1698,10 @@ } }, "type": "object", - "required": ["name", "is_active"], + "required": [ + "name", + "is_active" + ], "title": "Workspace" } } diff --git a/src/features/workspace/components/__tests__/workspace-preferred-model.test.tsx b/src/features/workspace/components/__tests__/workspace-preferred-model.test.tsx new file mode 100644 index 00000000..e7ac8314 --- /dev/null +++ b/src/features/workspace/components/__tests__/workspace-preferred-model.test.tsx @@ -0,0 +1,48 @@ +import { render } from "@/lib/test-utils"; +import { screen, waitFor } from "@testing-library/react"; +import { WorkspacePreferredModel } from "../workspace-preferred-model"; +import userEvent from "@testing-library/user-event"; + +test("render model overrides", () => { + render( + , + ); + expect(screen.getByText(/preferred model/i)).toBeVisible(); + expect( + screen.getByText( + /select the model you would like to use in this workspace./i, + ), + ).toBeVisible(); + expect( + screen.getByRole("button", { name: /select the model/i }), + ).toBeVisible(); + expect(screen.getByRole("button", { name: /save/i })).toBeVisible(); +}); + +test("submit preferred model", async () => { + render( + , + ); + + await userEvent.click( + screen.getByRole("button", { name: /select the model/i }), + ); + + await userEvent.click( + screen.getByRole("option", { + name: "anthropic/claude-3.5", + }), + ); + + await userEvent.click(screen.getByRole("button", { name: /save/i })); + + await waitFor(() => { + expect(screen.getByText(/preferred model for fake-workspace updated/i)); + }); +}); diff --git a/src/features/workspace/components/workspace-custom-instructions.tsx b/src/features/workspace/components/workspace-custom-instructions.tsx index c6f462df..bd2dda72 100644 --- a/src/features/workspace/components/workspace-custom-instructions.tsx +++ b/src/features/workspace/components/workspace-custom-instructions.tsx @@ -345,7 +345,6 @@ export function WorkspaceCustomInstructions({ isPending={isMutationPending} isDisabled={Boolean(isArchived ?? isCustomInstructionsPending)} onPress={() => handleSubmit(value)} - variant="secondary" > Save diff --git a/src/features/workspace/components/workspace-name.tsx b/src/features/workspace/components/workspace-name.tsx index 68cf8c7c..513ac47b 100644 --- a/src/features/workspace/components/workspace-name.tsx +++ b/src/features/workspace/components/workspace-name.tsx @@ -71,7 +71,6 @@ export function WorkspaceName({ isDisabled={isArchived || name === ""} isPending={isPending} type="submit" - variant="secondary" > Save diff --git a/src/features/workspace/components/workspace-preferred-model.tsx b/src/features/workspace/components/workspace-preferred-model.tsx new file mode 100644 index 00000000..f3a1c75b --- /dev/null +++ b/src/features/workspace/components/workspace-preferred-model.tsx @@ -0,0 +1,95 @@ +import { + Button, + Card, + CardBody, + CardFooter, + Form, + Text, +} from "@stacklok/ui-kit"; +import { twMerge } from "tailwind-merge"; +import { useMutationPreferredModelWorkspace } from "../hooks/use-mutation-preferred-model-workspace"; +import { MuxMatcherType } from "@/api/generated"; +import { FormEvent } from "react"; +import { usePreferredModelWorkspace } from "../hooks/use-preferred-preferred-model"; +import { Select, SelectButton } from "@stacklok/ui-kit"; +import { useModelsData } from "@/hooks/useModelsData"; + +export function WorkspacePreferredModel({ + className, + workspaceName, + isArchived, +}: { + className?: string; + workspaceName: string; + isArchived: boolean | undefined; +}) { + const { preferredModel, setPreferredModel } = usePreferredModelWorkspace(); + const { mutateAsync } = useMutationPreferredModelWorkspace(); + const { data: providerModels = [] } = useModelsData(); + const { model, provider_id } = preferredModel; + + const handleSubmit = (event: FormEvent) => { + event.preventDefault(); + mutateAsync({ + path: { workspace_name: workspaceName }, + body: [ + { + matcher: "", + provider_id, + model, + matcher_type: MuxMatcherType.CATCH_ALL, + }, + ], + }); + }; + + return ( +
+ + +
+ Preferred Model + + Select the model you would like to use in this workspace. + +
+
+
+ +
+
+
+ + + +
+
+ ); +} diff --git a/src/features/workspace/hooks/use-mutation-preferred-model-workspace.ts b/src/features/workspace/hooks/use-mutation-preferred-model-workspace.ts new file mode 100644 index 00000000..3558b37b --- /dev/null +++ b/src/features/workspace/hooks/use-mutation-preferred-model-workspace.ts @@ -0,0 +1,15 @@ +import { useToastMutation } from "@/hooks/use-toast-mutation"; +import { useInvalidateWorkspaceQueries } from "./use-invalidate-workspace-queries"; +import { v1SetWorkspaceMuxesMutation } from "@/api/generated/@tanstack/react-query.gen"; + +export function useMutationPreferredModelWorkspace() { + const invalidate = useInvalidateWorkspaceQueries(); + return useToastMutation({ + ...v1SetWorkspaceMuxesMutation(), + onSuccess: async () => { + await invalidate(); + }, + successMsg: (variables) => + `Preferred model for ${variables.path.workspace_name} updated`, + }); +} diff --git a/src/features/workspace/hooks/use-preferred-preferred-model.ts b/src/features/workspace/hooks/use-preferred-preferred-model.ts new file mode 100644 index 00000000..555a2a23 --- /dev/null +++ b/src/features/workspace/hooks/use-preferred-preferred-model.ts @@ -0,0 +1,19 @@ +import { MuxRule } from "@/api/generated"; +import { create } from "zustand"; + +export type ModelRule = Omit & {}; + +type State = { + setPreferredModel: (model: ModelRule) => void; + preferredModel: ModelRule; +}; + +export const usePreferredModelWorkspace = create((set) => ({ + preferredModel: { + provider_id: "", + model: "", + }, + setPreferredModel: ({ model, provider_id }: ModelRule) => { + set({ preferredModel: { provider_id, model } }); + }, +})); diff --git a/src/hooks/useModelsData.ts b/src/hooks/useModelsData.ts new file mode 100644 index 00000000..b58fda4e --- /dev/null +++ b/src/hooks/useModelsData.ts @@ -0,0 +1,32 @@ +import { useQuery } from "@tanstack/react-query"; +import { v1ListAllModelsForAllProvidersOptions } from "@/api/generated/@tanstack/react-query.gen"; +import { V1ListAllModelsForAllProvidersResponse } from "@/api/generated"; + +export const useModelsData = () => { + return useQuery({ + ...v1ListAllModelsForAllProvidersOptions(), + queryFn: async () => { + const response: V1ListAllModelsForAllProvidersResponse = [ + { + name: "claude-3.5", + provider_name: "anthropic", + provider_id: "anthropic", + }, + { + name: "claude-3.6", + provider_name: "anthropic", + provider_id: "anthropic", + }, + { + name: "claude-3.7", + provider_name: "anthropic", + provider_id: "anthropic", + }, + { name: "chatgpt-4o", provider_name: "openai", provider_id: "openai" }, + { name: "chatgpt-4p", provider_name: "openai", provider_id: "openai" }, + ]; + + return response; + }, + }); +}; diff --git a/src/mocks/msw/handlers.ts b/src/mocks/msw/handlers.ts index df9d64b8..be2bd963 100644 --- a/src/mocks/msw/handlers.ts +++ b/src/mocks/msw/handlers.ts @@ -96,13 +96,13 @@ export const handlers = [ http.get("*/api/v1/workspaces/:workspace_name/muxes", () => HttpResponse.json([ { - provider: "openai", + provider_id: "openai", model: "gpt-3.5-turbo", matcher_type: "file_regex", matcher: ".*\\.txt", }, { - provider: "anthropic", + provider_id: "anthropic", model: "davinci", matcher_type: "catch_all", }, @@ -131,9 +131,21 @@ export const handlers = [ () => new HttpResponse(null, { status: 204 }), ), http.get("*/api/v1/provider-endpoints/:provider_name/models", () => - HttpResponse.json({ name: "dummy", provider: "dummy" }), + HttpResponse.json([ + { name: "claude-3.5", provider: "anthropic" }, + { name: "claude-3.6", provider: "anthropic" }, + { name: "claude-3.7", provider: "anthropic" }, + { name: "chatgpt-4o", provider: "openai" }, + { name: "chatgpt-4p", provider: "openai" }, + ]), ), http.get("*/api/v1/provider-endpoints/models", () => - HttpResponse.json({ name: "dummy", provider: "dummy" }), + HttpResponse.json([ + { name: "claude-3.5", provider: "anthropic" }, + { name: "claude-3.6", provider: "anthropic" }, + { name: "claude-3.7", provider: "anthropic" }, + { name: "chatgpt-4o", provider: "openai" }, + { name: "chatgpt-4p", provider: "openai" }, + ]), ), ]; diff --git a/src/routes/route-workspace.tsx b/src/routes/route-workspace.tsx index 968a3609..ba265320 100644 --- a/src/routes/route-workspace.tsx +++ b/src/routes/route-workspace.tsx @@ -8,6 +8,7 @@ import { useParams } from "react-router-dom"; import { useArchivedWorkspaces } from "@/features/workspace/hooks/use-archived-workspaces"; import { useRestoreWorkspaceButton } from "@/features/workspace/hooks/use-restore-workspace-button"; import { WorkspaceCustomInstructions } from "@/features/workspace/components/workspace-custom-instructions"; +import { WorkspacePreferredModel } from "@/features/workspace/components/workspace-preferred-model"; function WorkspaceArchivedBanner({ name }: { name: string }) { const restoreButtonProps = useRestoreWorkspaceButton({ workspaceName: name }); @@ -52,6 +53,11 @@ export function RouteWorkspace() { className="mb-4" workspaceName={name} /> +