@@ -241,11 +241,45 @@ def do_latent_postprocessing(
241241
242242 def _apply_standard_conditioning (self , x , sigma , unconditioning , conditioning , ** kwargs ):
243243 # fast batched path
244+
245+ def _pad_conditioning (cond , target_len , encoder_attention_mask ):
246+ conditioning_attention_mask = torch .ones ((cond .shape [0 ], cond .shape [1 ]), device = cond .device , dtype = cond .dtype )
247+
248+ if cond .shape [1 ] < max_len :
249+ conditioning_attention_mask = torch .cat ([
250+ conditioning_attention_mask ,
251+ torch .zeros ((cond .shape [0 ], max_len - cond .shape [1 ]), device = cond .device , dtype = cond .dtype ),
252+ ], dim = 1 )
253+
254+ cond = torch .cat ([
255+ cond ,
256+ torch .zeros ((cond .shape [0 ], max_len - cond .shape [1 ], cond .shape [2 ]), device = cond .device , dtype = cond .dtype ),
257+ ], dim = 1 )
258+
259+ if encoder_attention_mask is None :
260+ encoder_attention_mask = conditioning_attention_mask
261+ else :
262+ encoder_attention_mask = torch .cat ([
263+ encoder_attention_mask ,
264+ conditioning_attention_mask ,
265+ ])
266+
267+ return cond , encoder_attention_mask
268+
244269 x_twice = torch .cat ([x ] * 2 )
245270 sigma_twice = torch .cat ([sigma ] * 2 )
271+
272+ encoder_attention_mask = None
273+ if unconditioning .shape [1 ] != conditioning .shape [1 ]:
274+ max_len = max (unconditioning .shape [1 ], conditioning .shape [1 ])
275+ unconditioning , encoder_attention_mask = _pad_conditioning (unconditioning , max_len , encoder_attention_mask )
276+ conditioning , encoder_attention_mask = _pad_conditioning (conditioning , max_len , encoder_attention_mask )
277+
246278 both_conditionings = torch .cat ([unconditioning , conditioning ])
247279 both_results = self .model_forward_callback (
248- x_twice , sigma_twice , both_conditionings , ** kwargs ,
280+ x_twice , sigma_twice , both_conditionings ,
281+ encoder_attention_mask = encoder_attention_mask ,
282+ ** kwargs ,
249283 )
250284 unconditioned_next_x , conditioned_next_x = both_results .chunk (2 )
251285 return unconditioned_next_x , conditioned_next_x
0 commit comments