Skip to content

Commit 405054d

Browse files
feat: Add Embedding Picker to Linear UI (#3654)
2 parents 52498cc + a901a37 commit 405054d

File tree

14 files changed

+429
-75
lines changed

14 files changed

+429
-75
lines changed

invokeai/backend/model_management/lora.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@
33
import copy
44
from contextlib import contextmanager
55
from pathlib import Path
6-
from typing import Any, Dict, Optional, Tuple
6+
from typing import Any, Dict, Optional, Tuple, Union, List
77

88
import torch
99
from compel.embeddings_provider import BaseTextualInversionManager
1010
from diffusers.models import UNet2DConditionModel
1111
from safetensors.torch import load_file
12-
from torch.utils.hooks import RemovableHandle
13-
from transformers import CLIPTextModel
14-
12+
from transformers import CLIPTextModel, CLIPTokenizer
1513

1614
class LoRALayerBase:
1715
#rank: Optional[int]
@@ -123,8 +121,8 @@ def __init__(
123121

124122
def get_weight(self):
125123
if self.mid is not None:
126-
up = self.up.reshape(up.shape[0], up.shape[1])
127-
down = self.down.reshape(up.shape[0], up.shape[1])
124+
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
125+
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
128126
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
129127
else:
130128
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
@@ -410,7 +408,7 @@ def from_checkpoint(
410408
else:
411409
# TODO: diff/ia3/... format
412410
print(
413-
f">> Encountered unknown lora layer module in {self.name}: {layer_key}"
411+
f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
414412
)
415413
return
416414

invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
22
import { MultiSelect, MultiSelectProps } from '@mantine/core';
33
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
4-
import { memo } from 'react';
4+
import { RefObject, memo } from 'react';
55
import { mode } from 'theme/util/mode';
66

77
type IAIMultiSelectProps = MultiSelectProps & {
88
tooltip?: string;
9+
inputRef?: RefObject<HTMLInputElement>;
910
};
1011

1112
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
12-
const { searchable = true, tooltip, ...rest } = props;
13+
const { searchable = true, tooltip, inputRef, ...rest } = props;
1314
const {
1415
base50,
1516
base100,
@@ -33,6 +34,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
3334
return (
3435
<Tooltip label={tooltip} placement="top" hasArrow>
3536
<MultiSelect
37+
ref={inputRef}
3638
searchable={searchable}
3739
styles={() => ({
3840
label: {
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import IAIIconButton from 'common/components/IAIIconButton';
2+
import { memo } from 'react';
3+
import { BiCode } from 'react-icons/bi';
4+
5+
type Props = {
6+
onClick: () => void;
7+
};
8+
9+
const AddEmbeddingButton = (props: Props) => {
10+
const { onClick } = props;
11+
return (
12+
<IAIIconButton
13+
size="sm"
14+
aria-label="Add Embedding"
15+
tooltip="Add Embedding"
16+
icon={<BiCode />}
17+
sx={{
18+
p: 2,
19+
color: 'base.700',
20+
_hover: {
21+
color: 'base.550',
22+
},
23+
_active: {
24+
color: 'base.500',
25+
},
26+
}}
27+
variant="link"
28+
onClick={onClick}
29+
/>
30+
);
31+
};
32+
33+
export default memo(AddEmbeddingButton);
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import {
2+
Flex,
3+
Popover,
4+
PopoverBody,
5+
PopoverContent,
6+
PopoverTrigger,
7+
Text,
8+
} from '@chakra-ui/react';
9+
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
10+
import { forEach } from 'lodash-es';
11+
import {
12+
PropsWithChildren,
13+
forwardRef,
14+
useCallback,
15+
useMemo,
16+
useRef,
17+
} from 'react';
18+
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
19+
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
20+
21+
type EmbeddingSelectItem = {
22+
label: string;
23+
value: string;
24+
description?: string;
25+
};
26+
27+
type Props = PropsWithChildren & {
28+
onSelect: (v: string) => void;
29+
isOpen: boolean;
30+
onClose: () => void;
31+
};
32+
33+
const ParamEmbeddingPopover = (props: Props) => {
34+
const { onSelect, isOpen, onClose, children } = props;
35+
const { data: embeddingQueryData } = useGetTextualInversionModelsQuery();
36+
const inputRef = useRef<HTMLInputElement>(null);
37+
38+
const data = useMemo(() => {
39+
if (!embeddingQueryData) {
40+
return [];
41+
}
42+
43+
const data: EmbeddingSelectItem[] = [];
44+
45+
forEach(embeddingQueryData.entities, (embedding, _) => {
46+
if (!embedding) return;
47+
48+
data.push({
49+
value: embedding.name,
50+
label: embedding.name,
51+
description: embedding.description,
52+
});
53+
});
54+
55+
return data;
56+
}, [embeddingQueryData]);
57+
58+
const handleChange = useCallback(
59+
(v: string[]) => {
60+
if (v.length === 0) {
61+
return;
62+
}
63+
64+
onSelect(v[0]);
65+
},
66+
[onSelect]
67+
);
68+
69+
return (
70+
<Popover
71+
initialFocusRef={inputRef}
72+
isOpen={isOpen}
73+
onClose={onClose}
74+
placement="bottom"
75+
openDelay={0}
76+
closeDelay={0}
77+
closeOnBlur={true}
78+
returnFocusOnClose={true}
79+
>
80+
<PopoverTrigger>{children}</PopoverTrigger>
81+
<PopoverContent
82+
sx={{
83+
p: 0,
84+
top: -1,
85+
shadow: 'dark-lg',
86+
borderColor: 'accent.300',
87+
borderWidth: '2px',
88+
borderStyle: 'solid',
89+
_dark: { borderColor: 'accent.400' },
90+
}}
91+
>
92+
<PopoverBody
93+
sx={{ p: 0, w: `calc(${PARAMETERS_PANEL_WIDTH} - 2rem )` }}
94+
>
95+
{data.length === 0 ? (
96+
<Flex sx={{ justifyContent: 'center', p: 2 }}>
97+
<Text
98+
sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}
99+
>
100+
No Embeddings Loaded
101+
</Text>
102+
</Flex>
103+
) : (
104+
<IAIMantineMultiSelect
105+
inputRef={inputRef}
106+
placeholder={'Add Embedding'}
107+
value={[]}
108+
data={data}
109+
maxDropdownHeight={400}
110+
nothingFound="No Matching Embeddings"
111+
itemComponent={SelectItem}
112+
disabled={data.length === 0}
113+
filter={(value, selected, item: EmbeddingSelectItem) =>
114+
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
115+
item.value.toLowerCase().includes(value.toLowerCase().trim())
116+
}
117+
onChange={handleChange}
118+
/>
119+
)}
120+
</PopoverBody>
121+
</PopoverContent>
122+
</Popover>
123+
);
124+
};
125+
126+
export default ParamEmbeddingPopover;
127+
128+
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
129+
value: string;
130+
label: string;
131+
description?: string;
132+
}
133+
134+
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
135+
({ label, description, ...others }: ItemProps, ref) => {
136+
return (
137+
<div ref={ref} {...others}>
138+
<div>
139+
<Text>{label}</Text>
140+
{description && (
141+
<Text size="xs" color="base.600">
142+
{description}
143+
</Text>
144+
)}
145+
</div>
146+
</div>
147+
);
148+
}
149+
);
150+
151+
SelectItem.displayName = 'SelectItem';

invokeai/frontend/web/src/features/embedding/store/embeddingSlice.ts

Whitespace-only changes.

invokeai/frontend/web/src/features/lora/components/ParamLora.tsx

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@ import IAIIconButton from 'common/components/IAIIconButton';
44
import IAISlider from 'common/components/IAISlider';
55
import { memo, useCallback } from 'react';
66
import { FaTrash } from 'react-icons/fa';
7-
import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice';
7+
import {
8+
Lora,
9+
loraRemoved,
10+
loraWeightChanged,
11+
loraWeightReset,
12+
} from '../store/loraSlice';
813

914
type Props = {
1015
lora: Lora;
@@ -22,7 +27,7 @@ const ParamLora = (props: Props) => {
2227
);
2328

2429
const handleReset = useCallback(() => {
25-
dispatch(loraWeightChanged({ id: lora.id, weight: 1 }));
30+
dispatch(loraWeightReset(lora.id));
2631
}, [dispatch, lora.id]);
2732

2833
const handleRemoveLora = useCallback(() => {

invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { Text } from '@chakra-ui/react';
1+
import { Flex, Text } from '@chakra-ui/react';
22
import { createSelector } from '@reduxjs/toolkit';
33
import { stateSelector } from 'app/store/store';
44
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -61,6 +61,16 @@ const ParamLoraSelect = () => {
6161
[dispatch, lorasQueryData?.entities]
6262
);
6363

64+
if (lorasQueryData?.ids.length === 0) {
65+
return (
66+
<Flex sx={{ justifyContent: 'center', p: 2 }}>
67+
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
68+
No LoRAs Loaded
69+
</Text>
70+
</Flex>
71+
);
72+
}
73+
6474
return (
6575
<IAIMantineMultiSelect
6676
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}

invokeai/frontend/web/src/features/lora/store/loraSlice.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ export type Lora = {
88
};
99

1010
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
11-
weight: 1,
11+
weight: 0.75,
1212
};
1313

1414
export type LoraState = {
@@ -38,9 +38,14 @@ export const loraSlice = createSlice({
3838
const { id, weight } = action.payload;
3939
state.loras[id].weight = weight;
4040
},
41+
loraWeightReset: (state, action: PayloadAction<string>) => {
42+
const id = action.payload;
43+
state.loras[id].weight = defaultLoRAConfig.weight;
44+
},
4145
},
4246
});
4347

44-
export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions;
48+
export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } =
49+
loraSlice.actions;
4550

4651
export default loraSlice.reducer;

0 commit comments

Comments
 (0)