1212
1313from invokeai .app .invocations .metadata import CoreMetadata
1414from invokeai .app .util .step_callback import stable_diffusion_step_callback
15- from invokeai .backend .model_management .models . base import ModelType
15+ from invokeai .backend .model_management .models import ModelType , SilenceWarnings
1616
1717from ...backend .model_management .lora import ModelPatcher
1818from ...backend .stable_diffusion import PipelineIntermediateState
@@ -311,70 +311,71 @@ def prep_control_data(
311311
312312 @torch .no_grad ()
313313 def invoke (self , context : InvocationContext ) -> LatentsOutput :
314- noise = context .services .latents .get (self .noise .latents_name )
314+ with SilenceWarnings ():
315+ noise = context .services .latents .get (self .noise .latents_name )
315316
316- # Get the source node id (we are invoking the prepared node)
317- graph_execution_state = context .services .graph_execution_manager .get (context .graph_execution_state_id )
318- source_node_id = graph_execution_state .prepared_source_mapping [self .id ]
317+ # Get the source node id (we are invoking the prepared node)
318+ graph_execution_state = context .services .graph_execution_manager .get (context .graph_execution_state_id )
319+ source_node_id = graph_execution_state .prepared_source_mapping [self .id ]
319320
320- def step_callback (state : PipelineIntermediateState ):
321- self .dispatch_progress (context , source_node_id , state )
321+ def step_callback (state : PipelineIntermediateState ):
322+ self .dispatch_progress (context , source_node_id , state )
322323
323- def _lora_loader ():
324- for lora in self .unet .loras :
325- lora_info = context .services .model_manager .get_model (
326- ** lora .dict (exclude = {"weight" }),
327- context = context ,
328- )
329- yield (lora_info .context .model , lora .weight )
330- del lora_info
331- return
332-
333- unet_info = context .services .model_manager .get_model (
334- ** self .unet .unet .dict (),
335- context = context ,
336- )
337- with ExitStack () as exit_stack , ModelPatcher .apply_lora_unet (
338- unet_info .context .model , _lora_loader ()
339- ), unet_info as unet :
340- noise = noise .to (device = unet .device , dtype = unet .dtype )
324+ def _lora_loader ():
325+ for lora in self .unet .loras :
326+ lora_info = context .services .model_manager .get_model (
327+ ** lora .dict (exclude = {"weight" }),
328+ context = context ,
329+ )
330+ yield (lora_info .context .model , lora .weight )
331+ del lora_info
332+ return
341333
342- scheduler = get_scheduler (
334+ unet_info = context .services .model_manager .get_model (
335+ ** self .unet .unet .dict (),
343336 context = context ,
344- scheduler_info = self .unet .scheduler ,
345- scheduler_name = self .scheduler ,
346337 )
338+ with ExitStack () as exit_stack , ModelPatcher .apply_lora_unet (
339+ unet_info .context .model , _lora_loader ()
340+ ), unet_info as unet :
341+ noise = noise .to (device = unet .device , dtype = unet .dtype )
347342
348- pipeline = self .create_pipeline (unet , scheduler )
349- conditioning_data = self .get_conditioning_data (context , scheduler , unet )
343+ scheduler = get_scheduler (
344+ context = context ,
345+ scheduler_info = self .unet .scheduler ,
346+ scheduler_name = self .scheduler ,
347+ )
350348
351- control_data = self .prep_control_data (
352- model = pipeline ,
353- context = context ,
354- control_input = self .control ,
355- latents_shape = noise .shape ,
356- # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
357- do_classifier_free_guidance = True ,
358- exit_stack = exit_stack ,
359- )
349+ pipeline = self .create_pipeline (unet , scheduler )
350+ conditioning_data = self .get_conditioning_data (context , scheduler , unet )
360351
361- # TODO: Verify the noise is the right size
362- result_latents , result_attention_map_saver = pipeline . latents_from_embeddings (
363- latents = torch . zeros_like ( noise , dtype = torch_dtype ( unet . device )) ,
364- noise = noise ,
365- num_inference_steps = self . steps ,
366- conditioning_data = conditioning_data ,
367- control_data = control_data , # list[ControlNetData]
368- callback = step_callback ,
369- )
352+ control_data = self . prep_control_data (
353+ model = pipeline ,
354+ context = context ,
355+ control_input = self . control ,
356+ latents_shape = noise . shape ,
357+ # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
358+ do_classifier_free_guidance = True ,
359+ exit_stack = exit_stack ,
360+ )
370361
371- # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
372- result_latents = result_latents .to ("cpu" )
373- torch .cuda .empty_cache ()
362+ # TODO: Verify the noise is the right size
363+ result_latents , result_attention_map_saver = pipeline .latents_from_embeddings (
364+ latents = torch .zeros_like (noise , dtype = torch_dtype (unet .device )),
365+ noise = noise ,
366+ num_inference_steps = self .steps ,
367+ conditioning_data = conditioning_data ,
368+ control_data = control_data , # list[ControlNetData]
369+ callback = step_callback ,
370+ )
374371
375- name = f"{ context .graph_execution_state_id } __{ self .id } "
376- context .services .latents .save (name , result_latents )
377- return build_latents_output (latents_name = name , latents = result_latents )
372+ # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
373+ result_latents = result_latents .to ("cpu" )
374+ torch .cuda .empty_cache ()
375+
376+ name = f"{ context .graph_execution_state_id } __{ self .id } "
377+ context .services .latents .save (name , result_latents )
378+ return build_latents_output (latents_name = name , latents = result_latents )
378379
379380
380381class LatentsToLatentsInvocation (TextToLatentsInvocation ):
@@ -402,82 +403,83 @@ class Config(InvocationConfig):
402403
403404 @torch .no_grad ()
404405 def invoke (self , context : InvocationContext ) -> LatentsOutput :
405- noise = context .services .latents .get (self .noise .latents_name )
406- latent = context .services .latents .get (self .latents .latents_name )
406+ with SilenceWarnings (): # this quenches NSFW nag from diffusers
407+ noise = context .services .latents .get (self .noise .latents_name )
408+ latent = context .services .latents .get (self .latents .latents_name )
407409
408- # Get the source node id (we are invoking the prepared node)
409- graph_execution_state = context .services .graph_execution_manager .get (context .graph_execution_state_id )
410- source_node_id = graph_execution_state .prepared_source_mapping [self .id ]
410+ # Get the source node id (we are invoking the prepared node)
411+ graph_execution_state = context .services .graph_execution_manager .get (context .graph_execution_state_id )
412+ source_node_id = graph_execution_state .prepared_source_mapping [self .id ]
411413
412- def step_callback (state : PipelineIntermediateState ):
413- self .dispatch_progress (context , source_node_id , state )
414+ def step_callback (state : PipelineIntermediateState ):
415+ self .dispatch_progress (context , source_node_id , state )
414416
415- def _lora_loader ():
416- for lora in self .unet .loras :
417- lora_info = context .services .model_manager .get_model (
418- ** lora .dict (exclude = {"weight" }),
419- context = context ,
420- )
421- yield (lora_info .context .model , lora .weight )
422- del lora_info
423- return
424-
425- unet_info = context .services .model_manager .get_model (
426- ** self .unet .unet .dict (),
427- context = context ,
428- )
429- with ExitStack () as exit_stack , ModelPatcher .apply_lora_unet (
430- unet_info .context .model , _lora_loader ()
431- ), unet_info as unet :
432- noise = noise .to (device = unet .device , dtype = unet .dtype )
433- latent = latent .to (device = unet .device , dtype = unet .dtype )
417+ def _lora_loader ():
418+ for lora in self .unet .loras :
419+ lora_info = context .services .model_manager .get_model (
420+ ** lora .dict (exclude = {"weight" }),
421+ context = context ,
422+ )
423+ yield (lora_info .context .model , lora .weight )
424+ del lora_info
425+ return
434426
435- scheduler = get_scheduler (
427+ unet_info = context .services .model_manager .get_model (
428+ ** self .unet .unet .dict (),
436429 context = context ,
437- scheduler_info = self .unet .scheduler ,
438- scheduler_name = self .scheduler ,
439430 )
431+ with ExitStack () as exit_stack , ModelPatcher .apply_lora_unet (
432+ unet_info .context .model , _lora_loader ()
433+ ), unet_info as unet :
434+ noise = noise .to (device = unet .device , dtype = unet .dtype )
435+ latent = latent .to (device = unet .device , dtype = unet .dtype )
440436
441- pipeline = self .create_pipeline (unet , scheduler )
442- conditioning_data = self .get_conditioning_data (context , scheduler , unet )
437+ scheduler = get_scheduler (
438+ context = context ,
439+ scheduler_info = self .unet .scheduler ,
440+ scheduler_name = self .scheduler ,
441+ )
443442
444- control_data = self .prep_control_data (
445- model = pipeline ,
446- context = context ,
447- control_input = self .control ,
448- latents_shape = noise .shape ,
449- # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
450- do_classifier_free_guidance = True ,
451- exit_stack = exit_stack ,
452- )
443+ pipeline = self .create_pipeline (unet , scheduler )
444+ conditioning_data = self .get_conditioning_data (context , scheduler , unet )
453445
454- # TODO: Verify the noise is the right size
455- initial_latents = (
456- latent if self .strength < 1.0 else torch .zeros_like (latent , device = unet .device , dtype = latent .dtype )
457- )
446+ control_data = self .prep_control_data (
447+ model = pipeline ,
448+ context = context ,
449+ control_input = self .control ,
450+ latents_shape = noise .shape ,
451+ # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
452+ do_classifier_free_guidance = True ,
453+ exit_stack = exit_stack ,
454+ )
458455
459- timesteps , _ = pipeline .get_img2img_timesteps (
460- self .steps ,
461- self .strength ,
462- device = unet .device ,
463- )
456+ # TODO: Verify the noise is the right size
457+ initial_latents = (
458+ latent if self .strength < 1.0 else torch .zeros_like (latent , device = unet .device , dtype = latent .dtype )
459+ )
464460
465- result_latents , result_attention_map_saver = pipeline .latents_from_embeddings (
466- latents = initial_latents ,
467- timesteps = timesteps ,
468- noise = noise ,
469- num_inference_steps = self .steps ,
470- conditioning_data = conditioning_data ,
471- control_data = control_data , # list[ControlNetData]
472- callback = step_callback ,
473- )
461+ timesteps , _ = pipeline .get_img2img_timesteps (
462+ self .steps ,
463+ self .strength ,
464+ device = unet .device ,
465+ )
474466
475- # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
476- result_latents = result_latents .to ("cpu" )
477- torch .cuda .empty_cache ()
467+ result_latents , result_attention_map_saver = pipeline .latents_from_embeddings (
468+ latents = initial_latents ,
469+ timesteps = timesteps ,
470+ noise = noise ,
471+ num_inference_steps = self .steps ,
472+ conditioning_data = conditioning_data ,
473+ control_data = control_data , # list[ControlNetData]
474+ callback = step_callback ,
475+ )
478476
479- name = f"{ context .graph_execution_state_id } __{ self .id } "
480- context .services .latents .save (name , result_latents )
477+ # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
478+ result_latents = result_latents .to ("cpu" )
479+ torch .cuda .empty_cache ()
480+
481+ name = f"{ context .graph_execution_state_id } __{ self .id } "
482+ context .services .latents .save (name , result_latents )
481483 return build_latents_output (latents_name = name , latents = result_latents )
482484
483485
@@ -490,7 +492,7 @@ class LatentsToImageInvocation(BaseInvocation):
490492 # Inputs
491493 latents : Optional [LatentsField ] = Field (description = "The latents to generate an image from" )
492494 vae : VaeField = Field (default = None , description = "Vae submodel" )
493- tiled : bool = Field (default = False , description = "Decode latents by overlapping tiles(less memory consumption)" )
495+ tiled : bool = Field (default = False , description = "Decode latents by overlaping tiles (less memory consumption)" )
494496 fp32 : bool = Field (DEFAULT_PRECISION == "float32" , description = "Decode in full precision" )
495497 metadata : Optional [CoreMetadata ] = Field (
496498 default = None , description = "Optional core metadata to be written to the image"
0 commit comments