diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py index a05fb9001c0e..a5d9f06a59b6 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py @@ -100,8 +100,10 @@ def __init__( ) self.register_to_config(latent_dim_scale=latent_dim_scale) - def prepare_latents(self, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler): - batch_size, channels, height, width = image_embeddings.shape + def prepare_latents( + self, batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler + ): + _, channels, height, width = image_embeddings.shape latents_shape = ( batch_size * num_images_per_prompt, 4, @@ -383,7 +385,19 @@ def __call__( ) if isinstance(image_embeddings, list): image_embeddings = torch.cat(image_embeddings, dim=0) - batch_size = image_embeddings.shape[0] + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Compute the effective number of images per prompt + # We must account for the fact that the image embeddings from the prior can be generated with num_images_per_prompt > 1 + # This results in a case where a single prompt is associated with multiple image embeddings + # Divide the number of image embeddings by the batch size to determine if this is the case. + num_images_per_prompt = num_images_per_prompt * (image_embeddings.shape[0] // batch_size) # 2. Encode caption if prompt_embeds is None and negative_prompt_embeds is None: @@ -417,7 +431,7 @@ def __call__( # 5. Prepare latents latents = self.prepare_latents( - image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler + batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler ) # 6. Run denoising loop diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py b/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py index 3ad19dbba1c6..4a8cab77079c 100644 --- a/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py +++ b/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py @@ -33,6 +33,7 @@ slow, torch_device, ) +from diffusers.utils.torch_utils import randn_tensor from ..test_pipelines_common import PipelineTesterMixin @@ -246,6 +247,66 @@ def test_stable_cascade_decoder_prompt_embeds(self): assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5 + def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings(self): + device = "cpu" + components = self.get_dummy_components() + + pipe = StableCascadeDecoderPipeline(**components) + pipe.set_progress_bar_config(disable=None) + + prior_num_images_per_prompt = 2 + decoder_num_images_per_prompt = 2 + prompt = ["a cat"] + batch_size = len(prompt) + + generator = torch.Generator(device) + image_embeddings = randn_tensor( + (batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0) + ) + decoder_output = pipe( + image_embeddings=image_embeddings, + prompt=prompt, + num_inference_steps=1, + output_type="np", + guidance_scale=0.0, + generator=generator.manual_seed(0), + num_images_per_prompt=decoder_num_images_per_prompt, + ) + + assert decoder_output.images.shape[0] == ( + batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt + ) + + def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings_with_guidance(self): + device = "cpu" + components = self.get_dummy_components() + + pipe = StableCascadeDecoderPipeline(**components) + pipe.set_progress_bar_config(disable=None) + + prior_num_images_per_prompt = 2 + decoder_num_images_per_prompt = 2 + prompt = ["a cat"] + batch_size = len(prompt) + + generator = torch.Generator(device) + image_embeddings = randn_tensor( + (batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0) + ) + decoder_output = pipe( + image_embeddings=image_embeddings, + prompt=prompt, + num_inference_steps=1, + output_type="np", + guidance_scale=2.0, + generator=generator.manual_seed(0), + num_images_per_prompt=decoder_num_images_per_prompt, + ) + + assert decoder_output.images.shape[0] == ( + batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt + ) + @slow @require_torch_gpu