Skip to content

Commit 41442eb

Browse files
feat(ui): convert canvas txt2img & img2img to latents
- Add graph builders for canvas txt2img & img2img - they are mostly copy and paste from the linear graph builders but different in a few ways that are very tricky to work around. Just made totally new functions for them. - Canvas txt2img and img2img support ControlNet (not inpaint/outpaint). There's no way to determine in real-time which mode the canvas is in just yet, so we cannot disable the ControlNet UI when the mode will be inpaint/outpaint - it will always display. It's possible to determine this in near-real-time, will add this at some point. - Canvas inpaint/outpaint migrated to use model loader, though inpaint/outpaint are still using the non-latents nodes.
1 parent 223a679 commit 41442eb

File tree

11 files changed

+889
-184
lines changed

11 files changed

+889
-184
lines changed

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts

Lines changed: 55 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import { startAppListening } from '..';
22
import { sessionCreated } from 'services/thunks/session';
3-
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
3+
import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
44
import { log } from 'app/logging/useLogger';
55
import { canvasGraphBuilt } from 'features/nodes/store/actions';
66
import { imageUpdated, imageUploaded } from 'services/thunks/image';
7-
import { v4 as uuidv4 } from 'uuid';
8-
import { Graph } from 'services/api';
7+
import { ImageDTO } from 'services/api';
98
import {
109
canvasSessionIdChanged,
1110
stagingAreaInitialized,
@@ -67,112 +66,106 @@ export const addUserInvokedCanvasListener = () => {
6766

6867
moduleLog.debug(`Generation mode: ${generationMode}`);
6968

70-
// Build the canvas graph
71-
const graphComponents = await buildCanvasGraphComponents(
72-
state,
73-
generationMode
74-
);
75-
76-
if (!graphComponents) {
77-
moduleLog.error('Problem building graph');
78-
return;
79-
}
80-
81-
const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
82-
83-
// Assemble! Note that this graph *does not have the init or mask image set yet!*
84-
const nodes: Graph['nodes'] = {
85-
[rangeNode.id]: rangeNode,
86-
[iterateNode.id]: iterateNode,
87-
[baseNode.id]: baseNode,
88-
};
89-
90-
const graph = { nodes, edges };
91-
92-
dispatch(canvasGraphBuilt(graph));
93-
94-
moduleLog.debug({ data: graph }, 'Canvas graph built');
69+
// Temp placeholders for the init and mask images
70+
let canvasInitImage: ImageDTO | undefined;
71+
let canvasMaskImage: ImageDTO | undefined;
9572

96-
// If we are generating img2img or inpaint, we need to upload the init images
97-
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
98-
const baseFilename = `${uuidv4()}.png`;
99-
dispatch(
73+
// For img2img and inpaint/outpaint, we need to upload the init images
74+
if (['img2img', 'inpaint', 'outpaint'].includes(generationMode)) {
75+
// upload the image, saving the request id
76+
const { requestId: initImageUploadedRequestId } = dispatch(
10077
imageUploaded({
10178
formData: {
102-
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
79+
file: new File([baseBlob], 'canvasInitImage.png', {
80+
type: 'image/png',
81+
}),
10382
},
10483
imageCategory: 'general',
10584
isIntermediate: true,
10685
})
10786
);
10887

109-
// Wait for the image to be uploaded
110-
const [{ payload: baseImageDTO }] = await take(
88+
// Wait for the image to be uploaded, matching by request id
89+
const [{ payload }] = await take(
11190
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
11291
imageUploaded.fulfilled.match(action) &&
113-
action.meta.arg.formData.file.name === baseFilename
92+
action.meta.requestId === initImageUploadedRequestId
11493
);
11594

116-
// Update the base node with the image name and type
117-
baseNode.image = {
118-
image_name: baseImageDTO.image_name,
119-
};
95+
canvasInitImage = payload;
12096
}
12197

122-
// For inpaint, we also need to upload the mask layer
123-
if (baseNode.type === 'inpaint') {
124-
const maskFilename = `${uuidv4()}.png`;
125-
dispatch(
98+
// For inpaint/outpaint, we also need to upload the mask layer
99+
if (['inpaint', 'outpaint'].includes(generationMode)) {
100+
// upload the image, saving the request id
101+
const { requestId: maskImageUploadedRequestId } = dispatch(
126102
imageUploaded({
127103
formData: {
128-
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
104+
file: new File([maskBlob], 'canvasMaskImage.png', {
105+
type: 'image/png',
106+
}),
129107
},
130108
imageCategory: 'mask',
131109
isIntermediate: true,
132110
})
133111
);
134112

135-
// Wait for the mask to be uploaded
136-
const [{ payload: maskImageDTO }] = await take(
113+
// Wait for the image to be uploaded, matching by request id
114+
const [{ payload }] = await take(
137115
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
138116
imageUploaded.fulfilled.match(action) &&
139-
action.meta.arg.formData.file.name === maskFilename
117+
action.meta.requestId === maskImageUploadedRequestId
140118
);
141119

142-
// Update the base node with the image name and type
143-
baseNode.mask = {
144-
image_name: maskImageDTO.image_name,
145-
};
120+
canvasMaskImage = payload;
146121
}
147122

148-
// Create the session and wait for response
149-
dispatch(sessionCreated({ graph }));
150-
const [sessionCreatedAction] = await take(sessionCreated.fulfilled.match);
123+
const graph = buildCanvasGraph(
124+
state,
125+
generationMode,
126+
canvasInitImage,
127+
canvasMaskImage
128+
);
129+
130+
moduleLog.debug({ graph }, `Canvas graph built`);
131+
132+
// currently this action is just listened to for logging
133+
dispatch(canvasGraphBuilt(graph));
134+
135+
// Create the session, store the request id
136+
const { requestId: sessionCreatedRequestId } = dispatch(
137+
sessionCreated({ graph })
138+
);
139+
140+
// Take the session created action, matching by its request id
141+
const [sessionCreatedAction] = await take(
142+
(action): action is ReturnType<typeof sessionCreated.fulfilled> =>
143+
sessionCreated.fulfilled.match(action) &&
144+
action.meta.requestId === sessionCreatedRequestId
145+
);
151146
const sessionId = sessionCreatedAction.payload.id;
152147

153148
// Associate the init image with the session, now that we have the session ID
154-
if (
155-
(baseNode.type === 'img2img' || baseNode.type === 'inpaint') &&
156-
baseNode.image
157-
) {
149+
if (['img2img', 'inpaint'].includes(generationMode) && canvasInitImage) {
158150
dispatch(
159151
imageUpdated({
160-
imageName: baseNode.image.image_name,
152+
imageName: canvasInitImage.image_name,
161153
requestBody: { session_id: sessionId },
162154
})
163155
);
164156
}
165157

166158
// Associate the mask image with the session, now that we have the session ID
167-
if (baseNode.type === 'inpaint' && baseNode.mask) {
159+
if (['inpaint'].includes(generationMode) && canvasMaskImage) {
168160
dispatch(
169161
imageUpdated({
170-
imageName: baseNode.mask.image_name,
162+
imageName: canvasMaskImage.image_name,
171163
requestBody: { session_id: sessionId },
172164
})
173165
);
174166
}
175167

168+
// Prep the canvas staging area if it is not yet initialized
176169
if (!state.canvas.layerState.stagingArea.boundingBox) {
177170
dispatch(
178171
stagingAreaInitialized({

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { log } from 'app/logging/useLogger';
44
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
55
import { userInvoked } from 'app/store/actions';
66
import { sessionReadyToInvoke } from 'features/system/store/actions';
7-
import { buildImageToImageGraph } from 'features/nodes/util/graphBuilders/buildImageToImageGraph';
7+
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
88

99
const moduleLog = log.child({ namespace: 'invoke' });
1010

@@ -15,7 +15,7 @@ export const addUserInvokedImageToImageListener = () => {
1515
effect: async (action, { getState, dispatch, take }) => {
1616
const state = getState();
1717

18-
const graph = buildImageToImageGraph(state);
18+
const graph = buildLinearImageToImageGraph(state);
1919
dispatch(imageToImageGraphBuilt(graph));
2020
moduleLog.debug({ data: graph }, 'Image to Image graph built');
2121

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { log } from 'app/logging/useLogger';
44
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
55
import { userInvoked } from 'app/store/actions';
66
import { sessionReadyToInvoke } from 'features/system/store/actions';
7-
import { buildTextToImageGraph } from 'features/nodes/util/graphBuilders/buildTextToImageGraph';
7+
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
88

99
const moduleLog = log.child({ namespace: 'invoke' });
1010

@@ -15,7 +15,7 @@ export const addUserInvokedTextToImageListener = () => {
1515
effect: async (action, { getState, dispatch, take }) => {
1616
const state = getState();
1717

18-
const graph = buildTextToImageGraph(state);
18+
const graph = buildLinearTextToImageGraph(state);
1919

2020
dispatch(textToImageGraphBuilt(graph));
2121

Lines changed: 27 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,116 +1,39 @@
11
import { RootState } from 'app/store/store';
2-
import {
3-
Edge,
4-
ImageToImageInvocation,
5-
InpaintInvocation,
6-
IterateInvocation,
7-
RandomRangeInvocation,
8-
RangeInvocation,
9-
TextToImageInvocation,
10-
} from 'services/api';
11-
import { buildImg2ImgNode } from '../nodeBuilders/buildImageToImageNode';
12-
import { buildTxt2ImgNode } from '../nodeBuilders/buildTextToImageNode';
13-
import { buildRangeNode } from '../nodeBuilders/buildRangeNode';
14-
import { buildIterateNode } from '../nodeBuilders/buildIterateNode';
15-
import { buildEdges } from '../edgeBuilders/buildEdges';
2+
import { ImageDTO } from 'services/api';
163
import { log } from 'app/logging/useLogger';
17-
import { buildInpaintNode } from '../nodeBuilders/buildInpaintNode';
4+
import { forEach } from 'lodash-es';
5+
import { buildCanvasInpaintGraph } from './buildCanvasInpaintGraph';
6+
import { NonNullableGraph } from 'features/nodes/types/types';
7+
import { buildCanvasImageToImageGraph } from './buildCanvasImageToImageGraph';
8+
import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph';
189

1910
const moduleLog = log.child({ namespace: 'nodes' });
2011

21-
const buildBaseNode = (
22-
nodeType: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
23-
state: RootState
24-
):
25-
| TextToImageInvocation
26-
| ImageToImageInvocation
27-
| InpaintInvocation
28-
| undefined => {
29-
const overrides = {
30-
...state.canvas.boundingBoxDimensions,
31-
is_intermediate: true,
32-
};
33-
34-
if (nodeType === 'txt2img') {
35-
return buildTxt2ImgNode(state, overrides);
36-
}
37-
38-
if (nodeType === 'img2img') {
39-
return buildImg2ImgNode(state, overrides);
40-
}
41-
42-
if (nodeType === 'inpaint' || nodeType === 'outpaint') {
43-
return buildInpaintNode(state, overrides);
44-
}
45-
};
46-
47-
/**
48-
* Builds the Canvas workflow graph and image blobs.
49-
*/
50-
export const buildCanvasGraphComponents = async (
12+
export const buildCanvasGraph = (
5113
state: RootState,
52-
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint'
53-
): Promise<
54-
| {
55-
rangeNode: RangeInvocation | RandomRangeInvocation;
56-
iterateNode: IterateInvocation;
57-
baseNode:
58-
| TextToImageInvocation
59-
| ImageToImageInvocation
60-
| InpaintInvocation;
61-
edges: Edge[];
14+
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
15+
canvasInitImage: ImageDTO | undefined,
16+
canvasMaskImage: ImageDTO | undefined
17+
) => {
18+
let graph: NonNullableGraph;
19+
20+
if (generationMode === 'txt2img') {
21+
graph = buildCanvasTextToImageGraph(state);
22+
} else if (generationMode === 'img2img') {
23+
if (!canvasInitImage) {
24+
throw new Error('Missing canvas init image');
6225
}
63-
| undefined
64-
> => {
65-
// The base node is a txt2img, img2img or inpaint node
66-
const baseNode = buildBaseNode(generationMode, state);
67-
68-
if (!baseNode) {
69-
moduleLog.error('Problem building base node');
70-
return;
71-
}
72-
73-
if (baseNode.type === 'inpaint') {
74-
const {
75-
seamSize,
76-
seamBlur,
77-
seamSteps,
78-
seamStrength,
79-
tileSize,
80-
infillMethod,
81-
} = state.generation;
82-
83-
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } =
84-
state.canvas;
85-
86-
if (boundingBoxScaleMethod !== 'none') {
87-
baseNode.inpaint_width = scaledBoundingBoxDimensions.width;
88-
baseNode.inpaint_height = scaledBoundingBoxDimensions.height;
89-
}
90-
91-
baseNode.seam_size = seamSize;
92-
baseNode.seam_blur = seamBlur;
93-
baseNode.seam_strength = seamStrength;
94-
baseNode.seam_steps = seamSteps;
95-
baseNode.infill_method = infillMethod as InpaintInvocation['infill_method'];
96-
97-
if (infillMethod === 'tile') {
98-
baseNode.tile_size = tileSize;
26+
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
27+
} else {
28+
if (!canvasInitImage || !canvasMaskImage) {
29+
throw new Error('Missing canvas init and mask images');
9930
}
31+
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
10032
}
10133

102-
// We always range and iterate nodes, no matter the iteration count
103-
// This is required to provide the correct seeds to the backend engine
104-
const rangeNode = buildRangeNode(state);
105-
const iterateNode = buildIterateNode();
106-
107-
// Build the edges for the nodes selected.
108-
const edges = buildEdges(baseNode, rangeNode, iterateNode);
34+
forEach(graph.nodes, (node) => {
35+
graph.nodes[node.id].is_intermediate = true;
36+
});
10937

110-
return {
111-
rangeNode,
112-
iterateNode,
113-
baseNode,
114-
edges,
115-
};
38+
return graph;
11639
};

0 commit comments

Comments
 (0)