Skip to content

Commit 47b1a85

Browse files
Fix/long prompts (#3806)
2 parents ded5ebc + ccf093b commit 47b1a85

File tree

2 files changed

+91
-52
lines changed

2 files changed

+91
-52
lines changed

invokeai/backend/stable_diffusion/diffusers_pipeline.py

Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,6 @@ def image_from_embeddings(
422422
noise: torch.Tensor,
423423
callback: Callable[[PipelineIntermediateState], None] = None,
424424
run_id=None,
425-
**kwargs,
426425
) -> InvokeAIStableDiffusionPipelineOutput:
427426
r"""
428427
Function invoked when calling the pipeline for generation.
@@ -443,7 +442,6 @@ def image_from_embeddings(
443442
noise=noise,
444443
run_id=run_id,
445444
callback=callback,
446-
**kwargs,
447445
)
448446
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
449447
torch.cuda.empty_cache()
@@ -469,7 +467,6 @@ def latents_from_embeddings(
469467
run_id=None,
470468
callback: Callable[[PipelineIntermediateState], None] = None,
471469
control_data: List[ControlNetData] = None,
472-
**kwargs,
473470
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
474471
if self.scheduler.config.get("cpu_only", False):
475472
scheduler_device = torch.device('cpu')
@@ -487,11 +484,11 @@ def latents_from_embeddings(
487484
timesteps,
488485
conditioning_data,
489486
noise=noise,
490-
additional_guidance=additional_guidance,
491487
run_id=run_id,
492-
callback=callback,
488+
additional_guidance=additional_guidance,
493489
control_data=control_data,
494-
**kwargs,
490+
491+
callback=callback,
495492
)
496493
return result.latents, result.attention_map_saver
497494

@@ -505,42 +502,7 @@ def generate_latents_from_embeddings(
505502
run_id: str = None,
506503
additional_guidance: List[Callable] = None,
507504
control_data: List[ControlNetData] = None,
508-
**kwargs,
509505
):
510-
def _pad_conditioning(cond, target_len, encoder_attention_mask):
511-
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
512-
513-
if cond.shape[1] < max_len:
514-
conditioning_attention_mask = torch.cat([
515-
conditioning_attention_mask,
516-
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
517-
], dim=1)
518-
519-
cond = torch.cat([
520-
cond,
521-
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
522-
], dim=1)
523-
524-
if encoder_attention_mask is None:
525-
encoder_attention_mask = conditioning_attention_mask
526-
else:
527-
encoder_attention_mask = torch.cat([
528-
encoder_attention_mask,
529-
conditioning_attention_mask,
530-
])
531-
532-
return cond, encoder_attention_mask
533-
534-
encoder_attention_mask = None
535-
if conditioning_data.unconditioned_embeddings.shape[1] != conditioning_data.text_embeddings.shape[1]:
536-
max_len = max(conditioning_data.unconditioned_embeddings.shape[1], conditioning_data.text_embeddings.shape[1])
537-
conditioning_data.unconditioned_embeddings, encoder_attention_mask = _pad_conditioning(
538-
conditioning_data.unconditioned_embeddings, max_len, encoder_attention_mask
539-
)
540-
conditioning_data.text_embeddings, encoder_attention_mask = _pad_conditioning(
541-
conditioning_data.text_embeddings, max_len, encoder_attention_mask
542-
)
543-
544506
self._adjust_memory_efficient_attention(latents)
545507
if run_id is None:
546508
run_id = secrets.token_urlsafe(self.ID_LENGTH)
@@ -580,8 +542,6 @@ def _pad_conditioning(cond, target_len, encoder_attention_mask):
580542
total_step_count=len(timesteps),
581543
additional_guidance=additional_guidance,
582544
control_data=control_data,
583-
encoder_attention_mask=encoder_attention_mask,
584-
**kwargs,
585545
)
586546
latents = step_output.prev_sample
587547

@@ -623,7 +583,6 @@ def step(
623583
total_step_count: int,
624584
additional_guidance: List[Callable] = None,
625585
control_data: List[ControlNetData] = None,
626-
**kwargs,
627586
):
628587
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
629588
timestep = t[0]
@@ -638,8 +597,6 @@ def step(
638597
down_block_res_samples, mid_block_res_sample = None, None
639598

640599
if control_data is not None:
641-
# TODO: rewrite to pass with conditionings
642-
encoder_attention_mask = kwargs.get("encoder_attention_mask", None)
643600
# control_data should be type List[ControlNetData]
644601
# this loop covers both ControlNet (one ControlNetData in list)
645602
# and MultiControlNet (multiple ControlNetData in list)
@@ -669,9 +626,12 @@ def step(
669626

670627
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
671628
encoder_hidden_states = conditioning_data.text_embeddings
629+
encoder_attention_mask = None
672630
else:
673-
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
674-
conditioning_data.text_embeddings])
631+
encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch(
632+
conditioning_data.unconditioned_embeddings,
633+
conditioning_data.text_embeddings,
634+
)
675635
if isinstance(control_datum.weight, list):
676636
# if controlnet has multiple weights, use the weight for the current step
677637
controlnet_weight = control_datum.weight[step_index]

invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,16 +237,53 @@ def do_latent_postprocessing(
237237
)
238238
return latents
239239

240+
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
241+
def _pad_conditioning(cond, target_len, encoder_attention_mask):
242+
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
243+
244+
if cond.shape[1] < max_len:
245+
conditioning_attention_mask = torch.cat([
246+
conditioning_attention_mask,
247+
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
248+
], dim=1)
249+
250+
cond = torch.cat([
251+
cond,
252+
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
253+
], dim=1)
254+
255+
if encoder_attention_mask is None:
256+
encoder_attention_mask = conditioning_attention_mask
257+
else:
258+
encoder_attention_mask = torch.cat([
259+
encoder_attention_mask,
260+
conditioning_attention_mask,
261+
])
262+
263+
return cond, encoder_attention_mask
264+
265+
encoder_attention_mask = None
266+
if unconditioning.shape[1] != conditioning.shape[1]:
267+
max_len = max(unconditioning.shape[1], conditioning.shape[1])
268+
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
269+
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
270+
271+
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
272+
240273
# methods below are called from do_diffusion_step and should be considered private to this class.
241274

242275
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
243276
# fast batched path
244277
x_twice = torch.cat([x] * 2)
245278
sigma_twice = torch.cat([sigma] * 2)
246279

247-
both_conditionings = torch.cat([unconditioning, conditioning])
280+
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
281+
unconditioning, conditioning
282+
)
248283
both_results = self.model_forward_callback(
249-
x_twice, sigma_twice, both_conditionings, **kwargs,
284+
x_twice, sigma_twice, both_conditionings,
285+
encoder_attention_mask=encoder_attention_mask,
286+
**kwargs,
250287
)
251288
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
252289
return unconditioned_next_x, conditioned_next_x
@@ -260,8 +297,32 @@ def _apply_standard_conditioning_sequentially(
260297
**kwargs,
261298
):
262299
# low-memory sequential path
263-
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
264-
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
300+
uncond_down_block, cond_down_block = None, None
301+
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
302+
if down_block_additional_residuals is not None:
303+
uncond_down_block, cond_down_block = [], []
304+
for down_block in down_block_additional_residuals:
305+
_uncond_down, _cond_down = down_block.chunk(2)
306+
uncond_down_block.append(_uncond_down)
307+
cond_down_block.append(_cond_down)
308+
309+
uncond_mid_block, cond_mid_block = None, None
310+
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
311+
if mid_block_additional_residual is not None:
312+
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
313+
314+
unconditioned_next_x = self.model_forward_callback(
315+
x, sigma, unconditioning,
316+
down_block_additional_residuals=uncond_down_block,
317+
mid_block_additional_residual=uncond_mid_block,
318+
**kwargs,
319+
)
320+
conditioned_next_x = self.model_forward_callback(
321+
x, sigma, conditioning,
322+
down_block_additional_residuals=cond_down_block,
323+
mid_block_additional_residual=cond_mid_block,
324+
**kwargs,
325+
)
265326
return unconditioned_next_x, conditioned_next_x
266327

267328
# TODO: looks unused
@@ -295,6 +356,20 @@ def _apply_cross_attention_controlled_conditioning(
295356
):
296357
context: Context = self.cross_attention_control_context
297358

359+
uncond_down_block, cond_down_block = None, None
360+
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
361+
if down_block_additional_residuals is not None:
362+
uncond_down_block, cond_down_block = [], []
363+
for down_block in down_block_additional_residuals:
364+
_uncond_down, _cond_down = down_block.chunk(2)
365+
uncond_down_block.append(_uncond_down)
366+
cond_down_block.append(_cond_down)
367+
368+
uncond_mid_block, cond_mid_block = None, None
369+
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
370+
if mid_block_additional_residual is not None:
371+
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
372+
298373
cross_attn_processor_context = SwapCrossAttnContext(
299374
modified_text_embeddings=context.arguments.edited_conditioning,
300375
index_map=context.cross_attention_index_map,
@@ -307,6 +382,8 @@ def _apply_cross_attention_controlled_conditioning(
307382
sigma,
308383
unconditioning,
309384
{"swap_cross_attn_context": cross_attn_processor_context},
385+
down_block_additional_residuals=uncond_down_block,
386+
mid_block_additional_residual=uncond_mid_block,
310387
**kwargs,
311388
)
312389

@@ -319,6 +396,8 @@ def _apply_cross_attention_controlled_conditioning(
319396
sigma,
320397
conditioning,
321398
{"swap_cross_attn_context": cross_attn_processor_context},
399+
down_block_additional_residuals=cond_down_block,
400+
mid_block_additional_residual=cond_mid_block,
322401
**kwargs,
323402
)
324403
return unconditioned_next_x, conditioned_next_x

0 commit comments

Comments
 (0)