Skip to content

Commit 0edb31f

Browse files
feat: model events (#3786)
[feat(nodes): emit model loading events](7b6159f) - remove dependency on having access to a `node` during emits, would need a bit of additional args passed through the system and I don't think its necessary at this point. this also allowed us to drop an extraneous fetching/parsing of the session from db. - provide the invocation context to all `get_model()` calls, so the events are able to be emitted - test all model loading events in the app and confirm socket events are received [feat(ui): add listeners for model load events](c487166) - currently only exposed as DEBUG-level logs --- One change I missed in the commit messages is the `ModelInfo` class is not serializable, so I split out the pieces of information we didn't already have (hash, location, precision) and added them to the event payload directly.
2 parents 179455e + a137f7f commit 0edb31f

File tree

13 files changed

+174
-39
lines changed

13 files changed

+174
-39
lines changed

invokeai/app/invocations/compel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ class Config(InvocationConfig):
5757
@torch.no_grad()
5858
def invoke(self, context: InvocationContext) -> CompelOutput:
5959
tokenizer_info = context.services.model_manager.get_model(
60-
**self.clip.tokenizer.dict(),
60+
**self.clip.tokenizer.dict(), context=context,
6161
)
6262
text_encoder_info = context.services.model_manager.get_model(
63-
**self.clip.text_encoder.dict(),
63+
**self.clip.text_encoder.dict(), context=context,
6464
)
6565

6666
def _lora_loader():
@@ -82,6 +82,7 @@ def _lora_loader():
8282
model_name=name,
8383
base_model=self.clip.text_encoder.base_model,
8484
model_type=ModelType.TextualInversion,
85+
context=context,
8586
).context.model
8687
)
8788
except ModelNotFoundException:

invokeai/app/invocations/generate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,13 @@ def load_model_old_way(self, context, scheduler):
157157
def _lora_loader():
158158
for lora in self.unet.loras:
159159
lora_info = context.services.model_manager.get_model(
160-
**lora.dict(exclude={"weight"}))
160+
**lora.dict(exclude={"weight"}), context=context,)
161161
yield (lora_info.context.model, lora.weight)
162162
del lora_info
163163
return
164164

165-
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
166-
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
165+
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,)
166+
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,)
167167

168168
with vae_info as vae,\
169169
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\

invokeai/app/invocations/latent.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def get_scheduler(
7676
scheduler_name, SCHEDULER_MAP['ddim']
7777
)
7878
orig_scheduler_info = context.services.model_manager.get_model(
79-
**scheduler_info.dict()
79+
**scheduler_info.dict(), context=context,
8080
)
8181
with orig_scheduler_info as orig_scheduler:
8282
scheduler_config = orig_scheduler.config
@@ -262,6 +262,7 @@ def prep_control_data(
262262
model_name=control_info.control_model.model_name,
263263
model_type=ModelType.ControlNet,
264264
base_model=control_info.control_model.base_model,
265+
context=context,
265266
)
266267
)
267268

@@ -313,14 +314,14 @@ def step_callback(state: PipelineIntermediateState):
313314
def _lora_loader():
314315
for lora in self.unet.loras:
315316
lora_info = context.services.model_manager.get_model(
316-
**lora.dict(exclude={"weight"})
317+
**lora.dict(exclude={"weight"}), context=context,
317318
)
318319
yield (lora_info.context.model, lora.weight)
319320
del lora_info
320321
return
321322

322323
unet_info = context.services.model_manager.get_model(
323-
**self.unet.unet.dict()
324+
**self.unet.unet.dict(), context=context,
324325
)
325326
with ExitStack() as exit_stack,\
326327
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
@@ -403,14 +404,14 @@ def step_callback(state: PipelineIntermediateState):
403404
def _lora_loader():
404405
for lora in self.unet.loras:
405406
lora_info = context.services.model_manager.get_model(
406-
**lora.dict(exclude={"weight"})
407+
**lora.dict(exclude={"weight"}), context=context,
407408
)
408409
yield (lora_info.context.model, lora.weight)
409410
del lora_info
410411
return
411412

412413
unet_info = context.services.model_manager.get_model(
413-
**self.unet.unet.dict()
414+
**self.unet.unet.dict(), context=context,
414415
)
415416
with ExitStack() as exit_stack,\
416417
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
@@ -491,7 +492,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
491492
latents = context.services.latents.get(self.latents.latents_name)
492493

493494
vae_info = context.services.model_manager.get_model(
494-
**self.vae.vae.dict(),
495+
**self.vae.vae.dict(), context=context,
495496
)
496497

497498
with vae_info as vae:
@@ -636,7 +637,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
636637

637638
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
638639
vae_info = context.services.model_manager.get_model(
639-
**self.vae.vae.dict(),
640+
**self.vae.vae.dict(), context=context,
640641
)
641642

642643
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))

invokeai/app/services/events.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
105105
def emit_model_load_started (
106106
self,
107107
graph_execution_state_id: str,
108-
node: dict,
109-
source_node_id: str,
110108
model_name: str,
111109
base_model: BaseModelType,
112110
model_type: ModelType,
@@ -117,8 +115,6 @@ def emit_model_load_started (
117115
event_name="model_load_started",
118116
payload=dict(
119117
graph_execution_state_id=graph_execution_state_id,
120-
node=node,
121-
source_node_id=source_node_id,
122118
model_name=model_name,
123119
base_model=base_model,
124120
model_type=model_type,
@@ -129,8 +125,6 @@ def emit_model_load_started (
129125
def emit_model_load_completed(
130126
self,
131127
graph_execution_state_id: str,
132-
node: dict,
133-
source_node_id: str,
134128
model_name: str,
135129
base_model: BaseModelType,
136130
model_type: ModelType,
@@ -142,12 +136,12 @@ def emit_model_load_completed(
142136
event_name="model_load_completed",
143137
payload=dict(
144138
graph_execution_state_id=graph_execution_state_id,
145-
node=node,
146-
source_node_id=source_node_id,
147139
model_name=model_name,
148140
base_model=base_model,
149141
model_type=model_type,
150142
submodel=submodel,
151-
model_info=model_info,
143+
hash=model_info.hash,
144+
location=model_info.location,
145+
precision=str(model_info.precision),
152146
),
153147
)

invokeai/app/services/model_manager_service.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -339,19 +339,16 @@ def get_model(
339339
base_model: BaseModelType,
340340
model_type: ModelType,
341341
submodel: Optional[SubModelType] = None,
342-
node: Optional[BaseInvocation] = None,
343342
context: Optional[InvocationContext] = None,
344343
) -> ModelInfo:
345344
"""
346345
Retrieve the indicated model. submodel can be used to get a
347346
part (such as the vae) of a diffusers mode.
348347
"""
349348

350-
# if we are called from within a node, then we get to emit
351-
# load start and complete events
352-
if node and context:
349+
# we can emit model loading events if we are executing with access to the invocation context
350+
if context:
353351
self._emit_load_event(
354-
node=node,
355352
context=context,
356353
model_name=model_name,
357354
base_model=base_model,
@@ -366,9 +363,8 @@ def get_model(
366363
submodel,
367364
)
368365

369-
if node and context:
366+
if context:
370367
self._emit_load_event(
371-
node=node,
372368
context=context,
373369
model_name=model_name,
374370
base_model=base_model,
@@ -510,23 +506,19 @@ def commit(self, conf_file: Optional[Path]=None):
510506

511507
def _emit_load_event(
512508
self,
513-
node,
514509
context,
515510
model_name: str,
516511
base_model: BaseModelType,
517512
model_type: ModelType,
518-
submodel: SubModelType,
513+
submodel: Optional[SubModelType] = None,
519514
model_info: Optional[ModelInfo] = None,
520515
):
521516
if context.services.queue.is_canceled(context.graph_execution_state_id):
522517
raise CanceledException()
523-
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
524-
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
518+
525519
if model_info:
526520
context.services.events.emit_model_load_completed(
527521
graph_execution_state_id=context.graph_execution_state_id,
528-
node=node.dict(),
529-
source_node_id=source_node_id,
530522
model_name=model_name,
531523
base_model=base_model,
532524
model_type=model_type,
@@ -536,8 +528,6 @@ def _emit_load_event(
536528
else:
537529
context.services.events.emit_model_load_started(
538530
graph_execution_state_id=context.graph_execution_state_id,
539-
node=node.dict(),
540-
source_node_id=source_node_id,
541531
model_name=model_name,
542532
base_model=base_model,
543533
model_type=model_type,

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
8888
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
8989
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
9090
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
91+
import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted';
92+
import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted';
9193

9294
export const listenerMiddleware = createListenerMiddleware();
9395

@@ -177,6 +179,8 @@ addSocketConnectedListener();
177179
addSocketDisconnectedListener();
178180
addSocketSubscribedListener();
179181
addSocketUnsubscribedListener();
182+
addModelLoadStartedEventListener();
183+
addModelLoadCompletedEventListener();
180184

181185
// Session Created
182186
addSessionCreatedPendingListener();
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import { log } from 'app/logging/useLogger';
2+
import {
3+
appSocketModelLoadCompleted,
4+
socketModelLoadCompleted,
5+
} from 'services/events/actions';
6+
import { startAppListening } from '../..';
7+
8+
const moduleLog = log.child({ namespace: 'socketio' });
9+
10+
export const addModelLoadCompletedEventListener = () => {
11+
startAppListening({
12+
actionCreator: socketModelLoadCompleted,
13+
effect: (action, { dispatch, getState }) => {
14+
const { model_name, model_type, submodel } = action.payload.data;
15+
16+
let modelString = `${model_type} model: ${model_name}`;
17+
18+
if (submodel) {
19+
modelString = modelString.concat(`, submodel: ${submodel}`);
20+
}
21+
22+
moduleLog.debug(action.payload, `Model load completed (${modelString})`);
23+
24+
// pass along the socket event as an application action
25+
dispatch(appSocketModelLoadCompleted(action.payload));
26+
},
27+
});
28+
};
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import { log } from 'app/logging/useLogger';
2+
import {
3+
appSocketModelLoadStarted,
4+
socketModelLoadStarted,
5+
} from 'services/events/actions';
6+
import { startAppListening } from '../..';
7+
8+
const moduleLog = log.child({ namespace: 'socketio' });
9+
10+
export const addModelLoadStartedEventListener = () => {
11+
startAppListening({
12+
actionCreator: socketModelLoadStarted,
13+
effect: (action, { dispatch, getState }) => {
14+
const { model_name, model_type, submodel } = action.payload.data;
15+
16+
let modelString = `${model_type} model: ${model_name}`;
17+
18+
if (submodel) {
19+
modelString = modelString.concat(`, submodel: ${submodel}`);
20+
}
21+
22+
moduleLog.debug(action.payload, `Model load started (${modelString})`);
23+
24+
// pass along the socket event as an application action
25+
dispatch(appSocketModelLoadStarted(action.payload));
26+
},
27+
});
28+
};

invokeai/frontend/web/src/common/util/getTimestamp.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ import dateFormat from 'dateformat';
33
/**
44
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
55
*/
6-
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');
6+
export const getTimestamp = () =>
7+
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);

invokeai/frontend/web/src/services/api/types.d.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ export type OffsetPaginatedResults_ImageDTO_ =
2828

2929
// Models
3030
export type ModelType = components['schemas']['ModelType'];
31+
export type SubModelType = components['schemas']['SubModelType'];
3132
export type BaseModelType = components['schemas']['BaseModelType'];
3233
export type MainModelField = components['schemas']['MainModelField'];
3334
export type VAEModelField = components['schemas']['VAEModelField'];

0 commit comments

Comments
 (0)