diff --git a/docs/source/en/api/pipelines/panorama.mdx b/docs/source/en/api/pipelines/panorama.mdx index e0c7747a0193..044901f24bf3 100644 --- a/docs/source/en/api/pipelines/panorama.mdx +++ b/docs/source/en/api/pipelines/panorama.mdx @@ -52,6 +52,14 @@ image = pipe(prompt).images[0] image.save("dolomites.png") ``` + + +While calling this pipeline, it's possible to specify the `view_batch_size` to have a >1 value. +For some GPUs with high performance, higher a `view_batch_size`, can speedup the generation +and increase the VRAM usage. + + + ## StableDiffusionPanoramaPipeline [[autodoc]] StableDiffusionPanoramaPipeline - __call__ diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 66706c806a81..35d57d048907 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -451,10 +451,11 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype def get_views(self, panorama_height, panorama_width, window_size=64, stride=8): # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) + # if panorama's height/width < window_size, num_blocks of height/width should return 1 panorama_height /= 8 panorama_width /= 8 - num_blocks_height = (panorama_height - window_size) // stride + 1 - num_blocks_width = (panorama_width - window_size) // stride + 1 + num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1 + num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_height > window_size else 1 total_num_blocks = int(num_blocks_height * num_blocks_width) views = [] for i in range(total_num_blocks): @@ -474,6 +475,7 @@ def __call__( width: Optional[int] = 2048, num_inference_steps: int = 50, guidance_scale: float = 7.5, + view_batch_size: int = 1, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, @@ -508,6 +510,9 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + view_batch_size (`int`, *optional*, defaults to 1): + The batch size to denoise splited views. For some GPUs with high performance, higher view batch size + can speedup the generation and increase the VRAM usage. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is @@ -609,8 +614,11 @@ def __call__( ) # 6. Define panorama grid and initialize views for synthesis. + # prepare batch grid views = self.get_views(height, width) - views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views) + views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)] + views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch) + count = torch.zeros_like(latents) value = torch.zeros_like(latents) @@ -631,42 +639,55 @@ def __call__( # denoised (latent) crops are then averaged to produce the final latent # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113 - for j, (h_start, h_end, w_start, w_end) in enumerate(views): + # Batch views denoise + for j, batch_view in enumerate(views_batch): + vb_size = len(batch_view) # get the latents corresponding to the current view coordinates - latents_for_view = latents[:, :, h_start:h_end, w_start:w_end] + latents_for_view = torch.cat( + [latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view] + ) # rematch block's scheduler status self.scheduler.__dict__.update(views_scheduler_status[j]) # expand the latents if we are doing classifier free guidance latent_model_input = ( - torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view + latents_for_view.repeat_interleave(2, dim=0) + if do_classifier_free_guidance + else latents_for_view ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # repeat prompt_embeds for batch + prompt_embeds_input = torch.cat([prompt_embeds] * vb_size) + # predict the noise residual noise_pred = self.unet( latent_model_input, t, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=prompt_embeds_input, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents_view_denoised = self.scheduler.step( + latents_denoised_batch = self.scheduler.step( noise_pred, t, latents_for_view, **extra_step_kwargs ).prev_sample # save views scheduler status after sample views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) - value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised - count[:, :, h_start:h_end, w_start:w_end] += 1 + # extract value from batch + for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip( + latents_denoised_batch.chunk(vb_size), batch_view + ): + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised + count[:, :, h_start:h_end, w_start:w_end] += 1 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = torch.where(count > 0, value / count, value) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py index c8d2bfa8c59d..32541c980a15 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py @@ -131,7 +131,7 @@ def test_inference_batch_consistent(self): # override to speed the overall test timing up. def test_inference_batch_single_identical(self): - super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=3e-3) + super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=3.25e-3) def test_stable_diffusion_panorama_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -152,6 +152,24 @@ def test_stable_diffusion_panorama_negative_prompt(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_panorama_views_batch(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionPanoramaPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs, view_batch_size=2) + image = output.images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array([0.6187, 0.5375, 0.4915, 0.4136, 0.4114, 0.4563, 0.5128, 0.4976, 0.4757]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_panorama_euler(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components()