Skip to content

Commit e496dd1

Browse files
feat(ui): persist socket session ids and re-sub on connect
1 parent 5c3b3fe commit e496dd1

File tree

4 files changed

+88
-72
lines changed

4 files changed

+88
-72
lines changed

invokeai/frontend/web/src/app/store.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ export const store = configureStore({
114114
'canvas/setBoundingBoxDimensions',
115115
'canvas/setIsDrawing',
116116
'canvas/addPointToCurrentLine',
117+
'socket/generatorProgress',
117118
],
118119
},
119120
});

invokeai/frontend/web/src/services/events/middleware.ts

Lines changed: 33 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import { receivedModels } from 'services/thunks/model';
3232
import { receivedOpenAPISchema } from 'services/thunks/schema';
3333
import { isImageOutput } from 'services/types/guards';
3434
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
35+
import { setEventListeners } from './util/setEventListeners';
3536

3637
export const socketMiddleware = () => {
3738
let areListenersSet = false;
@@ -66,33 +67,34 @@ export const socketMiddleware = () => {
6667
(store: MiddlewareAPI<AppDispatch, RootState>) => (next) => (action) => {
6768
const { dispatch, getState } = store;
6869

69-
// Nothing dispatches `socketReset` actions yet, so this is a noop, but including anyways
70-
if (socketReset.match(action)) {
71-
const { sessionId } = getState().system;
70+
// Nothing dispatches `socketReset` actions yet
71+
// if (socketReset.match(action)) {
72+
// const { sessionId } = getState().system;
7273

73-
if (sessionId) {
74-
socket.emit('unsubscribe', { session: sessionId });
75-
dispatch(
76-
socketUnsubscribed({ sessionId, timestamp: getTimestamp() })
77-
);
78-
}
74+
// if (sessionId) {
75+
// socket.emit('unsubscribe', { session: sessionId });
76+
// dispatch(
77+
// socketUnsubscribed({ sessionId, timestamp: getTimestamp() })
78+
// );
79+
// }
7980

80-
if (socket.connected) {
81-
socket.disconnect();
82-
dispatch(socketDisconnected({ timestamp: getTimestamp() }));
83-
}
81+
// if (socket.connected) {
82+
// socket.disconnect();
83+
// dispatch(socketDisconnected({ timestamp: getTimestamp() }));
84+
// }
8485

85-
socket.removeAllListeners();
86-
areListenersSet = false;
87-
}
86+
// socket.removeAllListeners();
87+
// areListenersSet = false;
88+
// }
8889

8990
// Set listeners for `connect` and `disconnect` events once
9091
// Must happen in middleware to get access to `dispatch`
9192
if (!areListenersSet) {
9293
socket.on('connect', () => {
9394
dispatch(socketConnected({ timestamp: getTimestamp() }));
9495

95-
const { results, uploads, models, nodes, config } = getState();
96+
const { results, uploads, models, nodes, config, system } =
97+
getState();
9698

9799
const { disabledTabs } = config;
98100

@@ -112,6 +114,18 @@ export const socketMiddleware = () => {
112114
if (!nodes.schema && !disabledTabs.includes('nodes')) {
113115
dispatch(receivedOpenAPISchema());
114116
}
117+
118+
if (system.sessionId) {
119+
console.log(`Re-subscribing to session ${system.sessionId}`);
120+
socket.emit('subscribe', { session: system.sessionId });
121+
dispatch(
122+
socketSubscribed({
123+
sessionId: system.sessionId,
124+
timestamp: getTimestamp(),
125+
})
126+
);
127+
setEventListeners({ socket, store });
128+
}
115129
});
116130

117131
socket.on('disconnect', () => {
@@ -128,9 +142,6 @@ export const socketMiddleware = () => {
128142
if (isFulfilledSessionCreatedAction(action)) {
129143
const oldSessionId = getState().system.sessionId;
130144

131-
// temp disable event subscription
132-
const shouldHandleEvent = (id: string): boolean => true;
133-
134145
// const subscribedNodeIds = getState().system.subscribedNodeIds;
135146
// const shouldHandleEvent = (id: string): boolean => {
136147
// if (subscribedNodeIds.length === 1 && subscribedNodeIds[0] === '*') {
@@ -152,7 +163,6 @@ export const socketMiddleware = () => {
152163
timestamp: getTimestamp(),
153164
})
154165
);
155-
156166
const listenersToRemove: (keyof ServerToClientEvents)[] = [
157167
'invocation_started',
158168
'generator_progress',
@@ -168,57 +178,14 @@ export const socketMiddleware = () => {
168178

169179
const sessionId = action.payload.id;
170180

171-
// After a session is created, we immediately subscribe to events and then invoke the session
172181
socket.emit('subscribe', { session: sessionId });
173-
174-
// Always dispatch the event actions for other consumers who want to know when we subscribed
175182
dispatch(
176183
socketSubscribed({
177-
sessionId,
184+
sessionId: sessionId,
178185
timestamp: getTimestamp(),
179186
})
180187
);
181-
182-
// Set up listeners for the present subscription
183-
socket.on('invocation_started', (data) => {
184-
if (shouldHandleEvent(data.node.id)) {
185-
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
186-
}
187-
});
188-
189-
socket.on('generator_progress', (data) => {
190-
if (shouldHandleEvent(data.node.id)) {
191-
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
192-
}
193-
});
194-
195-
socket.on('invocation_error', (data) => {
196-
if (shouldHandleEvent(data.node.id)) {
197-
dispatch(invocationError({ data, timestamp: getTimestamp() }));
198-
}
199-
});
200-
201-
socket.on('invocation_complete', (data) => {
202-
if (shouldHandleEvent(data.node.id)) {
203-
const sessionId = data.graph_execution_state_id;
204-
205-
const { cancelType, isCancelScheduled } = getState().system;
206-
const { shouldFetchImages } = getState().config;
207-
208-
// Handle scheduled cancelation
209-
if (cancelType === 'scheduled' && isCancelScheduled) {
210-
dispatch(sessionCanceled({ sessionId }));
211-
}
212-
213-
dispatch(
214-
invocationComplete({
215-
data,
216-
timestamp: getTimestamp(),
217-
shouldFetchImages,
218-
})
219-
);
220-
}
221-
});
188+
setEventListeners({ socket, store });
222189

223190
// Finally we actually invoke the session, starting processing
224191
dispatch(sessionInvoked({ sessionId }));

invokeai/frontend/web/src/services/events/types.ts

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@ export type AnyInvocationType = NonNullable<
1515

1616
export type AnyInvocation = NonNullable<Graph['nodes']>[string];
1717

18-
// export type AnyInvocation = {
19-
// id: string;
20-
// type: AnyInvocationType | string;
21-
// [key: string]: any;
22-
// };
23-
2418
export type AnyResult = GraphExecutionState['results'][string];
2519

2620
/**
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import { MiddlewareAPI } from '@reduxjs/toolkit';
2+
import { AppDispatch, RootState } from 'app/store';
3+
import { getTimestamp } from 'common/util/getTimestamp';
4+
import { sessionCanceled } from 'services/thunks/session';
5+
import { Socket } from 'socket.io-client';
6+
import {
7+
generatorProgress,
8+
invocationComplete,
9+
invocationError,
10+
invocationStarted,
11+
} from '../actions';
12+
import { ClientToServerEvents, ServerToClientEvents } from '../types';
13+
14+
type SetEventListenersArg = {
15+
socket: Socket<ServerToClientEvents, ClientToServerEvents>;
16+
store: MiddlewareAPI<AppDispatch, RootState>;
17+
};
18+
19+
export const setEventListeners = (arg: SetEventListenersArg) => {
20+
const { socket, store } = arg;
21+
const { dispatch, getState } = store;
22+
// Set up listeners for the present subscription
23+
socket.on('invocation_started', (data) => {
24+
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
25+
});
26+
27+
socket.on('generator_progress', (data) => {
28+
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
29+
});
30+
31+
socket.on('invocation_error', (data) => {
32+
dispatch(invocationError({ data, timestamp: getTimestamp() }));
33+
});
34+
35+
socket.on('invocation_complete', (data) => {
36+
const sessionId = data.graph_execution_state_id;
37+
38+
const { cancelType, isCancelScheduled } = getState().system;
39+
const { shouldFetchImages } = getState().config;
40+
41+
// Handle scheduled cancelation
42+
if (cancelType === 'scheduled' && isCancelScheduled) {
43+
dispatch(sessionCanceled({ sessionId }));
44+
}
45+
46+
dispatch(
47+
invocationComplete({
48+
data,
49+
timestamp: getTimestamp(),
50+
shouldFetchImages,
51+
})
52+
);
53+
});
54+
};

0 commit comments

Comments
 (0)