@@ -237,16 +237,53 @@ def do_latent_postprocessing(
237237 )
238238 return latents
239239
240+ def _concat_conditionings_for_batch (self , unconditioning , conditioning ):
241+ def _pad_conditioning (cond , target_len , encoder_attention_mask ):
242+ conditioning_attention_mask = torch .ones ((cond .shape [0 ], cond .shape [1 ]), device = cond .device , dtype = cond .dtype )
243+
244+ if cond .shape [1 ] < max_len :
245+ conditioning_attention_mask = torch .cat ([
246+ conditioning_attention_mask ,
247+ torch .zeros ((cond .shape [0 ], max_len - cond .shape [1 ]), device = cond .device , dtype = cond .dtype ),
248+ ], dim = 1 )
249+
250+ cond = torch .cat ([
251+ cond ,
252+ torch .zeros ((cond .shape [0 ], max_len - cond .shape [1 ], cond .shape [2 ]), device = cond .device , dtype = cond .dtype ),
253+ ], dim = 1 )
254+
255+ if encoder_attention_mask is None :
256+ encoder_attention_mask = conditioning_attention_mask
257+ else :
258+ encoder_attention_mask = torch .cat ([
259+ encoder_attention_mask ,
260+ conditioning_attention_mask ,
261+ ])
262+
263+ return cond , encoder_attention_mask
264+
265+ encoder_attention_mask = None
266+ if unconditioning .shape [1 ] != conditioning .shape [1 ]:
267+ max_len = max (unconditioning .shape [1 ], conditioning .shape [1 ])
268+ unconditioning , encoder_attention_mask = _pad_conditioning (unconditioning , max_len , encoder_attention_mask )
269+ conditioning , encoder_attention_mask = _pad_conditioning (conditioning , max_len , encoder_attention_mask )
270+
271+ return torch .cat ([unconditioning , conditioning ]), encoder_attention_mask
272+
240273 # methods below are called from do_diffusion_step and should be considered private to this class.
241274
242275 def _apply_standard_conditioning (self , x , sigma , unconditioning , conditioning , ** kwargs ):
243276 # fast batched path
244277 x_twice = torch .cat ([x ] * 2 )
245278 sigma_twice = torch .cat ([sigma ] * 2 )
246279
247- both_conditionings = torch .cat ([unconditioning , conditioning ])
280+ both_conditionings , encoder_attention_mask = self ._concat_conditionings_for_batch (
281+ unconditioning , conditioning
282+ )
248283 both_results = self .model_forward_callback (
249- x_twice , sigma_twice , both_conditionings , ** kwargs ,
284+ x_twice , sigma_twice , both_conditionings ,
285+ encoder_attention_mask = encoder_attention_mask ,
286+ ** kwargs ,
250287 )
251288 unconditioned_next_x , conditioned_next_x = both_results .chunk (2 )
252289 return unconditioned_next_x , conditioned_next_x
@@ -260,8 +297,32 @@ def _apply_standard_conditioning_sequentially(
260297 ** kwargs ,
261298 ):
262299 # low-memory sequential path
263- unconditioned_next_x = self .model_forward_callback (x , sigma , unconditioning , ** kwargs )
264- conditioned_next_x = self .model_forward_callback (x , sigma , conditioning , ** kwargs )
300+ uncond_down_block , cond_down_block = None , None
301+ down_block_additional_residuals = kwargs .pop ("down_block_additional_residuals" , None )
302+ if down_block_additional_residuals is not None :
303+ uncond_down_block , cond_down_block = [], []
304+ for down_block in down_block_additional_residuals :
305+ _uncond_down , _cond_down = down_block .chunk (2 )
306+ uncond_down_block .append (_uncond_down )
307+ cond_down_block .append (_cond_down )
308+
309+ uncond_mid_block , cond_mid_block = None , None
310+ mid_block_additional_residual = kwargs .pop ("mid_block_additional_residual" , None )
311+ if mid_block_additional_residual is not None :
312+ uncond_mid_block , cond_mid_block = mid_block_additional_residual .chunk (2 )
313+
314+ unconditioned_next_x = self .model_forward_callback (
315+ x , sigma , unconditioning ,
316+ down_block_additional_residuals = uncond_down_block ,
317+ mid_block_additional_residual = uncond_mid_block ,
318+ ** kwargs ,
319+ )
320+ conditioned_next_x = self .model_forward_callback (
321+ x , sigma , conditioning ,
322+ down_block_additional_residuals = cond_down_block ,
323+ mid_block_additional_residual = cond_mid_block ,
324+ ** kwargs ,
325+ )
265326 return unconditioned_next_x , conditioned_next_x
266327
267328 # TODO: looks unused
@@ -295,6 +356,20 @@ def _apply_cross_attention_controlled_conditioning(
295356 ):
296357 context : Context = self .cross_attention_control_context
297358
359+ uncond_down_block , cond_down_block = None , None
360+ down_block_additional_residuals = kwargs .pop ("down_block_additional_residuals" , None )
361+ if down_block_additional_residuals is not None :
362+ uncond_down_block , cond_down_block = [], []
363+ for down_block in down_block_additional_residuals :
364+ _uncond_down , _cond_down = down_block .chunk (2 )
365+ uncond_down_block .append (_uncond_down )
366+ cond_down_block .append (_cond_down )
367+
368+ uncond_mid_block , cond_mid_block = None , None
369+ mid_block_additional_residual = kwargs .pop ("mid_block_additional_residual" , None )
370+ if mid_block_additional_residual is not None :
371+ uncond_mid_block , cond_mid_block = mid_block_additional_residual .chunk (2 )
372+
298373 cross_attn_processor_context = SwapCrossAttnContext (
299374 modified_text_embeddings = context .arguments .edited_conditioning ,
300375 index_map = context .cross_attention_index_map ,
@@ -307,6 +382,8 @@ def _apply_cross_attention_controlled_conditioning(
307382 sigma ,
308383 unconditioning ,
309384 {"swap_cross_attn_context" : cross_attn_processor_context },
385+ down_block_additional_residuals = uncond_down_block ,
386+ mid_block_additional_residual = uncond_mid_block ,
310387 ** kwargs ,
311388 )
312389
@@ -319,6 +396,8 @@ def _apply_cross_attention_controlled_conditioning(
319396 sigma ,
320397 conditioning ,
321398 {"swap_cross_attn_context" : cross_attn_processor_context },
399+ down_block_additional_residuals = cond_down_block ,
400+ mid_block_additional_residual = cond_mid_block ,
322401 ** kwargs ,
323402 )
324403 return unconditioned_next_x , conditioned_next_x
0 commit comments