11# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
22
3+ from contextlib import ExitStack
34from typing import List , Literal , Optional , Union
45
56import einops
1112
1213from invokeai .app .invocations .metadata import CoreMetadata
1314from invokeai .app .util .step_callback import stable_diffusion_step_callback
15+ from invokeai .backend .model_management .models .base import ModelType
1416
1517from ...backend .model_management .lora import ModelPatcher
1618from ...backend .stable_diffusion import PipelineIntermediateState
@@ -71,16 +73,21 @@ def get_scheduler(
7173 scheduler_name : str ,
7274) -> Scheduler :
7375 scheduler_class , scheduler_extra_config = SCHEDULER_MAP .get (
74- scheduler_name , SCHEDULER_MAP ['ddim' ])
76+ scheduler_name , SCHEDULER_MAP ['ddim' ]
77+ )
7578 orig_scheduler_info = context .services .model_manager .get_model (
76- ** scheduler_info .dict ())
79+ ** scheduler_info .dict ()
80+ )
7781 with orig_scheduler_info as orig_scheduler :
7882 scheduler_config = orig_scheduler .config
7983
8084 if "_backup" in scheduler_config :
8185 scheduler_config = scheduler_config ["_backup" ]
82- scheduler_config = {** scheduler_config , **
83- scheduler_extra_config , "_backup" : scheduler_config }
86+ scheduler_config = {
87+ ** scheduler_config ,
88+ ** scheduler_extra_config ,
89+ "_backup" : scheduler_config ,
90+ }
8491 scheduler = scheduler_class .from_config (scheduler_config )
8592
8693 # hack copied over from generate.py
@@ -137,8 +144,11 @@ class Config(InvocationConfig):
137144
138145 # TODO: pass this an emitter method or something? or a session for dispatching?
139146 def dispatch_progress (
140- self , context : InvocationContext , source_node_id : str ,
141- intermediate_state : PipelineIntermediateState ) -> None :
147+ self ,
148+ context : InvocationContext ,
149+ source_node_id : str ,
150+ intermediate_state : PipelineIntermediateState ,
151+ ) -> None :
142152 stable_diffusion_step_callback (
143153 context = context ,
144154 intermediate_state = intermediate_state ,
@@ -147,11 +157,16 @@ def dispatch_progress(
147157 )
148158
149159 def get_conditioning_data (
150- self , context : InvocationContext , scheduler ) -> ConditioningData :
160+ self ,
161+ context : InvocationContext ,
162+ scheduler ,
163+ ) -> ConditioningData :
151164 c , extra_conditioning_info = context .services .latents .get (
152- self .positive_conditioning .conditioning_name )
165+ self .positive_conditioning .conditioning_name
166+ )
153167 uc , _ = context .services .latents .get (
154- self .negative_conditioning .conditioning_name )
168+ self .negative_conditioning .conditioning_name
169+ )
155170
156171 conditioning_data = ConditioningData (
157172 unconditioned_embeddings = uc ,
@@ -178,7 +193,10 @@ def get_conditioning_data(
178193 return conditioning_data
179194
180195 def create_pipeline (
181- self , unet , scheduler ) -> StableDiffusionGeneratorPipeline :
196+ self ,
197+ unet ,
198+ scheduler ,
199+ ) -> StableDiffusionGeneratorPipeline :
182200 # TODO:
183201 # configure_model_padding(
184202 # unet,
@@ -213,6 +231,7 @@ def prep_control_data(
213231 model : StableDiffusionGeneratorPipeline ,
214232 control_input : List [ControlField ],
215233 latents_shape : List [int ],
234+ exit_stack : ExitStack ,
216235 do_classifier_free_guidance : bool = True ,
217236 ) -> List [ControlNetData ]:
218237
@@ -238,25 +257,19 @@ def prep_control_data(
238257 control_data = []
239258 control_models = []
240259 for control_info in control_list :
241- # handle control models
242- if ("," in control_info .control_model ):
243- control_model_split = control_info .control_model .split ("," )
244- control_name = control_model_split [0 ]
245- control_subfolder = control_model_split [1 ]
246- print ("Using HF model subfolders" )
247- print (" control_name: " , control_name )
248- print (" control_subfolder: " , control_subfolder )
249- control_model = ControlNetModel .from_pretrained (
250- control_name , subfolder = control_subfolder ,
251- torch_dtype = model .unet .dtype ).to (
252- model .device )
253- else :
254- control_model = ControlNetModel .from_pretrained (
255- control_info .control_model , torch_dtype = model .unet .dtype ).to (model .device )
260+ control_model = exit_stack .enter_context (
261+ context .services .model_manager .get_model (
262+ model_name = control_info .control_model .model_name ,
263+ model_type = ModelType .ControlNet ,
264+ base_model = control_info .control_model .base_model ,
265+ )
266+ )
267+
256268 control_models .append (control_model )
257269 control_image_field = control_info .image
258270 input_image = context .services .images .get_pil_image (
259- control_image_field .image_name )
271+ control_image_field .image_name
272+ )
260273 # self.image.image_type, self.image.image_name
261274 # FIXME: still need to test with different widths, heights, devices, dtypes
262275 # and add in batch_size, num_images_per_prompt?
@@ -278,7 +291,8 @@ def prep_control_data(
278291 weight = control_info .control_weight ,
279292 begin_step_percent = control_info .begin_step_percent ,
280293 end_step_percent = control_info .end_step_percent ,
281- control_mode = control_info .control_mode ,)
294+ control_mode = control_info .control_mode ,
295+ )
282296 control_data .append (control_item )
283297 # MultiControlNetModel has been refactored out, just need list[ControlNetData]
284298 return control_data
@@ -289,7 +303,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
289303
290304 # Get the source node id (we are invoking the prepared node)
291305 graph_execution_state = context .services .graph_execution_manager .get (
292- context .graph_execution_state_id )
306+ context .graph_execution_state_id
307+ )
293308 source_node_id = graph_execution_state .prepared_source_mapping [self .id ]
294309
295310 def step_callback (state : PipelineIntermediateState ):
@@ -298,14 +313,17 @@ def step_callback(state: PipelineIntermediateState):
298313 def _lora_loader ():
299314 for lora in self .unet .loras :
300315 lora_info = context .services .model_manager .get_model (
301- ** lora .dict (exclude = {"weight" }))
316+ ** lora .dict (exclude = {"weight" })
317+ )
302318 yield (lora_info .context .model , lora .weight )
303319 del lora_info
304320 return
305321
306322 unet_info = context .services .model_manager .get_model (
307- ** self .unet .unet .dict ())
308- with ModelPatcher .apply_lora_unet (unet_info .context .model , _lora_loader ()),\
323+ ** self .unet .unet .dict ()
324+ )
325+ with ExitStack () as exit_stack ,\
326+ ModelPatcher .apply_lora_unet (unet_info .context .model , _lora_loader ()),\
309327 unet_info as unet :
310328
311329 scheduler = get_scheduler (
@@ -322,6 +340,7 @@ def _lora_loader():
322340 latents_shape = noise .shape ,
323341 # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
324342 do_classifier_free_guidance = True ,
343+ exit_stack = exit_stack ,
325344 )
326345
327346 # TODO: Verify the noise is the right size
@@ -374,7 +393,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
374393
375394 # Get the source node id (we are invoking the prepared node)
376395 graph_execution_state = context .services .graph_execution_manager .get (
377- context .graph_execution_state_id )
396+ context .graph_execution_state_id
397+ )
378398 source_node_id = graph_execution_state .prepared_source_mapping [self .id ]
379399
380400 def step_callback (state : PipelineIntermediateState ):
@@ -383,14 +403,17 @@ def step_callback(state: PipelineIntermediateState):
383403 def _lora_loader ():
384404 for lora in self .unet .loras :
385405 lora_info = context .services .model_manager .get_model (
386- ** lora .dict (exclude = {"weight" }))
406+ ** lora .dict (exclude = {"weight" })
407+ )
387408 yield (lora_info .context .model , lora .weight )
388409 del lora_info
389410 return
390411
391412 unet_info = context .services .model_manager .get_model (
392- ** self .unet .unet .dict ())
393- with ModelPatcher .apply_lora_unet (unet_info .context .model , _lora_loader ()),\
413+ ** self .unet .unet .dict ()
414+ )
415+ with ExitStack () as exit_stack ,\
416+ ModelPatcher .apply_lora_unet (unet_info .context .model , _lora_loader ()),\
394417 unet_info as unet :
395418
396419 scheduler = get_scheduler (
@@ -407,11 +430,13 @@ def _lora_loader():
407430 latents_shape = noise .shape ,
408431 # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
409432 do_classifier_free_guidance = True ,
433+ exit_stack = exit_stack ,
410434 )
411435
412436 # TODO: Verify the noise is the right size
413437 initial_latents = latent if self .strength < 1.0 else torch .zeros_like (
414- latent , device = unet .device , dtype = latent .dtype )
438+ latent , device = unet .device , dtype = latent .dtype
439+ )
415440
416441 timesteps , _ = pipeline .get_img2img_timesteps (
417442 self .steps ,
@@ -535,7 +560,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
535560 resized_latents = torch .nn .functional .interpolate (
536561 latents , size = (self .height // 8 , self .width // 8 ),
537562 mode = self .mode , antialias = self .antialias
538- if self .mode in ["bilinear" , "bicubic" ] else False ,)
563+ if self .mode in ["bilinear" , "bicubic" ] else False ,
564+ )
539565
540566 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
541567 torch .cuda .empty_cache ()
@@ -569,7 +595,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
569595 resized_latents = torch .nn .functional .interpolate (
570596 latents , scale_factor = self .scale_factor , mode = self .mode ,
571597 antialias = self .antialias
572- if self .mode in ["bilinear" , "bicubic" ] else False ,)
598+ if self .mode in ["bilinear" , "bicubic" ] else False ,
599+ )
573600
574601 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
575602 torch .cuda .empty_cache ()
0 commit comments