Skip to content

Commit 80ff4ba

Browse files
DN6sayakpaul
andauthored
Fix issue with prompt embeds and latents in SD Cascade Decoder with multiple image embeddings for a single prompt. (#7381)
* fix * update * update --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent b09a2aa commit 80ff4ba

File tree

2 files changed

+79
-4
lines changed

2 files changed

+79
-4
lines changed

src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,10 @@ def __init__(
100100
)
101101
self.register_to_config(latent_dim_scale=latent_dim_scale)
102102

103-
def prepare_latents(self, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler):
104-
batch_size, channels, height, width = image_embeddings.shape
103+
def prepare_latents(
104+
self, batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler
105+
):
106+
_, channels, height, width = image_embeddings.shape
105107
latents_shape = (
106108
batch_size * num_images_per_prompt,
107109
4,
@@ -383,7 +385,19 @@ def __call__(
383385
)
384386
if isinstance(image_embeddings, list):
385387
image_embeddings = torch.cat(image_embeddings, dim=0)
386-
batch_size = image_embeddings.shape[0]
388+
389+
if prompt is not None and isinstance(prompt, str):
390+
batch_size = 1
391+
elif prompt is not None and isinstance(prompt, list):
392+
batch_size = len(prompt)
393+
else:
394+
batch_size = prompt_embeds.shape[0]
395+
396+
# Compute the effective number of images per prompt
397+
# We must account for the fact that the image embeddings from the prior can be generated with num_images_per_prompt > 1
398+
# This results in a case where a single prompt is associated with multiple image embeddings
399+
# Divide the number of image embeddings by the batch size to determine if this is the case.
400+
num_images_per_prompt = num_images_per_prompt * (image_embeddings.shape[0] // batch_size)
387401

388402
# 2. Encode caption
389403
if prompt_embeds is None and negative_prompt_embeds is None:
@@ -417,7 +431,7 @@ def __call__(
417431

418432
# 5. Prepare latents
419433
latents = self.prepare_latents(
420-
image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
434+
batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
421435
)
422436

423437
# 6. Run denoising loop

tests/pipelines/stable_cascade/test_stable_cascade_decoder.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
slow,
3434
torch_device,
3535
)
36+
from diffusers.utils.torch_utils import randn_tensor
3637

3738
from ..test_pipelines_common import PipelineTesterMixin
3839

@@ -246,6 +247,66 @@ def test_stable_cascade_decoder_prompt_embeds(self):
246247

247248
assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
248249

250+
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings(self):
251+
device = "cpu"
252+
components = self.get_dummy_components()
253+
254+
pipe = StableCascadeDecoderPipeline(**components)
255+
pipe.set_progress_bar_config(disable=None)
256+
257+
prior_num_images_per_prompt = 2
258+
decoder_num_images_per_prompt = 2
259+
prompt = ["a cat"]
260+
batch_size = len(prompt)
261+
262+
generator = torch.Generator(device)
263+
image_embeddings = randn_tensor(
264+
(batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0)
265+
)
266+
decoder_output = pipe(
267+
image_embeddings=image_embeddings,
268+
prompt=prompt,
269+
num_inference_steps=1,
270+
output_type="np",
271+
guidance_scale=0.0,
272+
generator=generator.manual_seed(0),
273+
num_images_per_prompt=decoder_num_images_per_prompt,
274+
)
275+
276+
assert decoder_output.images.shape[0] == (
277+
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
278+
)
279+
280+
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings_with_guidance(self):
281+
device = "cpu"
282+
components = self.get_dummy_components()
283+
284+
pipe = StableCascadeDecoderPipeline(**components)
285+
pipe.set_progress_bar_config(disable=None)
286+
287+
prior_num_images_per_prompt = 2
288+
decoder_num_images_per_prompt = 2
289+
prompt = ["a cat"]
290+
batch_size = len(prompt)
291+
292+
generator = torch.Generator(device)
293+
image_embeddings = randn_tensor(
294+
(batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0)
295+
)
296+
decoder_output = pipe(
297+
image_embeddings=image_embeddings,
298+
prompt=prompt,
299+
num_inference_steps=1,
300+
output_type="np",
301+
guidance_scale=2.0,
302+
generator=generator.manual_seed(0),
303+
num_images_per_prompt=decoder_num_images_per_prompt,
304+
)
305+
306+
assert decoder_output.images.shape[0] == (
307+
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
308+
)
309+
249310

250311
@slow
251312
@require_torch_gpu

0 commit comments

Comments
 (0)