diff --git a/package-lock.json b/package-lock.json index fe141cf9..842d5c03 100644 --- a/package-lock.json +++ b/package-lock.json @@ -18,7 +18,7 @@ "@radix-ui/react-dialog": "^1.1.4", "@radix-ui/react-separator": "^1.1.0", "@radix-ui/react-slot": "^1.1.0", - "@stacklok/ui-kit": "^1.0.1-4", + "@stacklok/ui-kit": "^1.0.1-9", "@tanstack/react-query": "^5.64.1", "@tanstack/react-query-devtools": "^5.66.0", "@types/lodash": "^4.17.15", @@ -1185,6 +1185,18 @@ "node": ">=18" } }, + "node_modules/@hookform/resolvers": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/@hookform/resolvers/-/resolvers-4.1.0.tgz", + "integrity": "sha512-fX/uHKb+OOCpACLc6enuTQsf0ZpRrKbeBBPETg5PCPLCIYV6osP2Bw6ezuclM61lH+wBF9eXcuC0+BFh9XOEnQ==", + "license": "MIT", + "dependencies": { + "caniuse-lite": "^1.0.30001698" + }, + "peerDependencies": { + "react-hook-form": "^7.0.0" + } + }, "node_modules/@humanfs/core": { "version": "0.19.1", "resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz", @@ -4076,18 +4088,20 @@ } }, "node_modules/@stacklok/ui-kit": { - "version": "1.0.1-4", - "resolved": "https://registry.npmjs.org/@stacklok/ui-kit/-/ui-kit-1.0.1-4.tgz", - "integrity": "sha512-Az5mQmb+0P7pMUKVyhJLEpDCzGGtSFi0H0C2Us32gM4Fsz78FL92MHvwdWMZgcepdwbdQoR/C6JASwSHteAbmA==", + "version": "1.0.1-9", + "resolved": "https://registry.npmjs.org/@stacklok/ui-kit/-/ui-kit-1.0.1-9.tgz", + "integrity": "sha512-TM7ajXb43bCx/b/SKuvkFS2MKBo6xThrt6mmBMMYpbgM1aD74nwSaSXuNdYW7/PaTU2hFnzdC7T7wNEHMAJQNw==", "license": "ISC", "dependencies": { "@fontsource-variable/figtree": "^5.1.1", "@fontsource-variable/inter": "^5.1.0", "@fontsource-variable/source-code-pro": "^5.1.0", + "@hookform/resolvers": "4.1.0", "@untitled-ui/icons-react": "^0.1.4", "postcss": "^8.4.47", "react-aria": "3.36.0", "react-aria-components": "1.5.0", + "react-hook-form": "7.47.0", "react-stately": "3.34.0", "sonner": "^1.7.1", "tailwind-merge": "^2.5.2", @@ -5846,10 +5860,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001690", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001690.tgz", - "integrity": "sha512-5ExiE3qQN6oF8Clf8ifIDcMRCRE/dMGcETG/XGMD8/XiXm6HXQgQTh1yZYLXXpSOsEUlJm1Xr7kGULZTuGtP/w==", - "dev": true, + "version": "1.0.30001700", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001700.tgz", + "integrity": "sha512-2S6XIXwaE7K7erT8dY+kLQcpa5ms63XlRkMkReXjle+kf6c5g38vyMl+Z5y8dSxOFDhcFe+nxnn261PLxBSQsQ==", "funding": [ { "type": "opencollective", @@ -11886,6 +11899,22 @@ "react": "^19.0.0" } }, + "node_modules/react-hook-form": { + "version": "7.47.0", + "resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.47.0.tgz", + "integrity": "sha512-F/TroLjTICipmHeFlMrLtNLceO2xr1jU3CyiNla5zdwsGUGu2UOxxR4UyJgLlhMwLW/Wzp4cpJ7CPfgJIeKdSg==", + "license": "MIT", + "engines": { + "node": ">=12.22.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/react-hook-form" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17 || ^18" + } + }, "node_modules/react-markdown": { "version": "9.0.1", "resolved": "https://registry.npmjs.org/react-markdown/-/react-markdown-9.0.1.tgz", diff --git a/package.json b/package.json index 4478ea30..65129636 100644 --- a/package.json +++ b/package.json @@ -31,7 +31,7 @@ "@radix-ui/react-dialog": "^1.1.4", "@radix-ui/react-separator": "^1.1.0", "@radix-ui/react-slot": "^1.1.0", - "@stacklok/ui-kit": "^1.0.1-4", + "@stacklok/ui-kit": "^1.0.1-9", "@tanstack/react-query": "^5.64.1", "@tanstack/react-query-devtools": "^5.66.0", "@types/lodash": "^4.17.15", diff --git a/src/api/generated/types.gen.ts b/src/api/generated/types.gen.ts index 50634d20..f10667f5 100644 --- a/src/api/generated/types.gen.ts +++ b/src/api/generated/types.gen.ts @@ -144,11 +144,23 @@ export type ModelByProvider = { /** * Represents the different types of matchers we support. + * + * The 3 rules present match filenames and request types. They're used in conjunction with the + * matcher field in the MuxRule model. + * E.g. + * - catch_all-> Always match + * - filename_match and match: requests.py -> Match the request if the filename is requests.py + * - fim_filename and match: main.py -> Match the request if the request type is fim + * and the filename is main.py + * + * NOTE: Removing or updating fields from this enum will require a migration. + * Adding new fields is safe. */ export enum MuxMatcherType { CATCH_ALL = 'catch_all', FILENAME_MATCH = 'filename_match', - REQUEST_TYPE_MATCH = 'request_type_match', + FIM_FILENAME = 'fim_filename', + CHAT_FILENAME = 'chat_filename', } /** diff --git a/src/api/openapi.json b/src/api/openapi.json index a6d16753..deb5c2de 100644 --- a/src/api/openapi.json +++ b/src/api/openapi.json @@ -1642,10 +1642,11 @@ "enum": [ "catch_all", "filename_match", - "request_type_match" + "fim_filename", + "chat_filename" ], "title": "MuxMatcherType", - "description": "Represents the different types of matchers we support." + "description": "Represents the different types of matchers we support.\n\nThe 3 rules present match filenames and request types. They're used in conjunction with the\nmatcher field in the MuxRule model.\nE.g.\n- catch_all-> Always match\n- filename_match and match: requests.py -> Match the request if the filename is requests.py\n- fim_filename and match: main.py -> Match the request if the request type is fim\nand the filename is main.py\n\nNOTE: Removing or updating fields from this enum will require a migration.\nAdding new fields is safe." }, "MuxRule": { "properties": { diff --git a/src/features/providers/components/workspaces-by-provider.tsx b/src/features/providers/components/workspaces-by-provider.tsx index 63a35df0..07570ca9 100644 --- a/src/features/providers/components/workspaces-by-provider.tsx +++ b/src/features/providers/components/workspaces-by-provider.tsx @@ -10,7 +10,10 @@ export function WorkspacesByProvider({ if (workspaces.length === 0) return null return (
-

The following workspaces will be impacted by this action

+

+ The following workspaces are currently using this provider and will need + to be updated: +

{uniqBy(workspaces, 'name').map((item, index) => { return ( diff --git a/src/features/workspace/components/__tests__/workspace-muxing-model.test.tsx b/src/features/workspace/components/__tests__/workspace-muxing-model.test.tsx index b564dd12..e36c44d0 100644 --- a/src/features/workspace/components/__tests__/workspace-muxing-model.test.tsx +++ b/src/features/workspace/components/__tests__/workspace-muxing-model.test.tsx @@ -7,6 +7,12 @@ test('renders muxing model', async () => { render( ) + + expect( + screen.getByRole('button', { + name: /all types/i, + }) + ).toBeVisible() expect(screen.getByText(/model muxing/i)).toBeVisible() expect( screen.getByText( @@ -46,6 +52,11 @@ test('disabled muxing fields and buttons for archived workspace', async () => { expect(await screen.findByRole('button', { name: /save/i })).toBeDisabled() expect(screen.getByTestId(/workspace-models-dropdown/i)).toBeDisabled() + expect( + screen.getByRole('button', { + name: /all types/i, + }) + ).toBeDisabled() expect( await screen.findByRole('button', { name: /add filter/i }) ).toBeDisabled() @@ -75,12 +86,27 @@ test('submit additional model overrides', async () => { name: /filter by/i, }) expect(textFields.length).toEqual(2) + + const requestTypeSelect = screen.getAllByRole('button', { + name: /fim & chat/i, + })[0] + await userEvent.click(requestTypeSelect as HTMLFormElement) + await userEvent.click( + screen.getByRole('option', { + name: 'FIM', + }) + ) + expect( + screen.getByRole('button', { + name: 'FIM', + }) + ).toBeVisible() const modelsButton = await screen.findAllByTestId( /workspace-models-dropdown/i ) expect(modelsButton.length).toEqual(2) - await userEvent.type(textFields[1] as HTMLFormElement, '.ts') + await userEvent.type(textFields[0] as HTMLFormElement, '.tsx') await userEvent.click( (await screen.findByRole('button', { @@ -94,6 +120,37 @@ test('submit additional model overrides', async () => { }) ) + await userEvent.click(screen.getByRole('button', { name: /add filter/i })) + await userEvent.click( + screen.getAllByRole('button', { + name: /chat/i, + })[1] as HTMLFormElement + ) + + await userEvent.click( + screen.getByRole('option', { + name: 'Chat', + }) + ) + + await userEvent.type( + screen.getAllByRole('textbox', { + name: /filter by/i, + })[1] as HTMLFormElement, + '.ts' + ) + + await userEvent.click( + (await screen.findByRole('button', { + name: /select a model/i, + })) as HTMLFormElement + ) + + await userEvent.click( + screen.getByRole('option', { + name: /chatgpt-4o/i, + }) + ) await userEvent.click(screen.getByRole('button', { name: /save/i })) await waitFor(() => { diff --git a/src/features/workspace/components/workspace-models-dropdown.tsx b/src/features/workspace/components/workspace-models-dropdown.tsx index 059064c5..d924cbed 100644 --- a/src/features/workspace/components/workspace-models-dropdown.tsx +++ b/src/features/workspace/components/workspace-models-dropdown.tsx @@ -37,7 +37,7 @@ function groupModelsByProviderName( id: providerName, textValue: providerName, items: items.map((item) => ({ - id: `${item.provider_id}/${item.name}`, + id: `${item.provider_id}:${item.name}`, textValue: item.name, })), })) @@ -116,7 +116,7 @@ export function WorkspaceModelsDropdown({ const selectedValue = v.values().next().value if (!selectedValue && typeof selectedValue !== 'string') return if (typeof selectedValue === 'string') { - const [provider_id, modelName] = selectedValue.split('/') + const [provider_id, modelName] = selectedValue.split(':') if (!provider_id || !modelName) return onChange({ model: modelName, diff --git a/src/features/workspace/components/workspace-muxing-model.tsx b/src/features/workspace/components/workspace-muxing-model.tsx index 44e8fe6b..feb800cf 100644 --- a/src/features/workspace/components/workspace-muxing-model.tsx +++ b/src/features/workspace/components/workspace-muxing-model.tsx @@ -9,6 +9,8 @@ import { Label, Link, LinkButton, + Select, + SelectButton, Text, TextField, Tooltip, @@ -17,10 +19,7 @@ import { } from '@stacklok/ui-kit' import { twMerge } from 'tailwind-merge' import { useMutationPreferredModelWorkspace } from '../hooks/use-mutation-preferred-model-workspace' -import { - MuxMatcherType, - V1ListAllModelsForAllProvidersResponse, -} from '@/api/generated' +import { V1ListAllModelsForAllProvidersResponse } from '@/api/generated' import { FormEvent } from 'react' import { LayersThree01, @@ -37,6 +36,7 @@ import { useMuxingRulesFormState, } from '../hooks/use-muxing-rules-form-workspace' import { FormButtons } from '@/components/FormButtons' +import { getRuleData, isRequestType } from '../lib/utils' function MissingProviderBanner() { return ( @@ -77,9 +77,31 @@ function SortableItem({ isArchived, isDefaultRule, }: SortableItemProps) { - const placeholder = isDefaultRule ? 'Catch-all' : 'e.g. file type, file name' + const { selectedKey, placeholder, items } = getRuleData({ + isDefaultRule, + matcher_type: rule.matcher_type, + }) + return (
+
+ +
{ void id - return rest.matcher - ? { ...rest, matcher_type: MuxMatcherType.FILENAME_MATCH } - : { ...rest } + return rest }), }, { @@ -200,6 +220,7 @@ export function WorkspaceMuxingModel({
 
+
Request Type
- +
{ const formState = useFormState({ rules: [{ ...DEFAULT_STATE, id: uuidv4() }], - }); - const { values, updateFormValues, setInitialValues } = formState; - const lastValuesRef = useRef(values.rules); + }) + const { values, updateFormValues, setInitialValues } = formState + const lastValuesRef = useRef(values.rules) useEffect(() => { const newValues = initialValues.length === 0 ? [DEFAULT_STATE] - : initialValues.map((item) => ({ ...item, id: uuidv4() })); + : initialValues.map((item) => ({ ...item, id: uuidv4() })) if (!isEqual(lastValuesRef.current, newValues)) { - lastValuesRef.current = newValues; - setInitialValues({ rules: newValues }); + lastValuesRef.current = newValues + setInitialValues({ rules: newValues }) } - }, [initialValues, setInitialValues]); + }, [initialValues, setInitialValues]) const addRule = useCallback(() => { const newRules = [ ...values.rules.slice(0, values.rules.length - 1), - { ...DEFAULT_STATE, id: uuidv4() }, + { + ...DEFAULT_STATE, + matcher_type: MuxMatcherType.FILENAME_MATCH, + id: uuidv4(), + }, ...values.rules.slice(values.rules.length - 1), - ]; + ] updateFormValues({ rules: newRules, - }); - }, [updateFormValues, values.rules]); + }) + }, [updateFormValues, values.rules]) const setRules = useCallback( (rules: PreferredMuxRule[]) => { - updateFormValues({ rules }); + updateFormValues({ rules }) }, - [updateFormValues], - ); + [updateFormValues] + ) const setRuleItem = useCallback( (rule: PreferredMuxRule) => { updateFormValues({ rules: values.rules.map((item) => (item.id === rule.id ? rule : item)), - }); + }) }, - [updateFormValues, values.rules], - ); + [updateFormValues, values.rules] + ) const removeRule = useCallback( (ruleIndex: number) => { updateFormValues({ rules: values.rules.filter((_, index) => index !== ruleIndex), - }); + }) }, - [updateFormValues, values.rules], - ); + [updateFormValues, values.rules] + ) - return { addRule, setRules, setRuleItem, removeRule, values, formState }; -}; + return { addRule, setRules, setRuleItem, removeRule, values, formState } +} diff --git a/src/features/workspace/lib/utils.ts b/src/features/workspace/lib/utils.ts new file mode 100644 index 00000000..06cd6a5a --- /dev/null +++ b/src/features/workspace/lib/utils.ts @@ -0,0 +1,53 @@ +import { MuxMatcherType } from '@/api/generated' + +export const MUX_MATCHER_TYPE_MAP = { + [MuxMatcherType.CHAT_FILENAME]: 'Chat', + [MuxMatcherType.FIM_FILENAME]: 'FIM', + [MuxMatcherType.FILENAME_MATCH]: 'FIM & Chat', + [MuxMatcherType.CATCH_ALL]: 'All types', +} + +export function getRequestType() { + return Object.values(MuxMatcherType) + .filter((item) => item !== MuxMatcherType.CATCH_ALL) + .map((textValue) => ({ + id: textValue, + textValue: MUX_MATCHER_TYPE_MAP[textValue], + })) +} + +export function isRequestType(value: unknown): value is MuxMatcherType { + return Object.values(MuxMatcherType).includes(value as MuxMatcherType) +} + +const DEFAULT_RULE = { + placeholder: 'Catch-all', + selectedKey: MuxMatcherType.CATCH_ALL, + items: [ + { + id: MuxMatcherType.CATCH_ALL, + textValue: MUX_MATCHER_TYPE_MAP[MuxMatcherType.CATCH_ALL], + }, + ], +} + +const CUSTOM_RULE = { + placeholder: 'e.g. file type, file name', + selectedKey: '', + items: getRequestType(), +} + +export function getRuleData({ + isDefaultRule, + matcher_type, +}: { + isDefaultRule: boolean + matcher_type: MuxMatcherType +}) { + return isDefaultRule + ? DEFAULT_RULE + : { + ...CUSTOM_RULE, + selectedKey: matcher_type, + } +}