Skip to content

Support views batch for panorama #3632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 6, 2023
8 changes: 8 additions & 0 deletions docs/source/en/api/pipelines/panorama.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ image = pipe(prompt).images[0]
image.save("dolomites.png")
```

<Tip>

While calling this pipeline, it's possible to specify the `view_batch_size` to have a >1 value.
For some GPUs with high performance, higher a `view_batch_size`, can speedup the generation
and increase the VRAM usage.

</Tip>

## StableDiffusionPanoramaPipeline
[[autodoc]] StableDiffusionPanoramaPipeline
- __call__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,10 +451,11 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype

def get_views(self, panorama_height, panorama_width, window_size=64, stride=8):
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
# if panorama's height/width < window_size, num_blocks of height/width should return 1
panorama_height /= 8
panorama_width /= 8
num_blocks_height = (panorama_height - window_size) // stride + 1
num_blocks_width = (panorama_width - window_size) // stride + 1
num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1
num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_height > window_size else 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
views = []
for i in range(total_num_blocks):
Expand All @@ -474,6 +475,7 @@ def __call__(
width: Optional[int] = 2048,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
view_batch_size: int = 1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also add an entry for this arg in the docstrings?

negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
Expand Down Expand Up @@ -508,6 +510,9 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
view_batch_size (`int`, *optional*, defaults to 1):
The batch size to denoise splited views. For some GPUs with high performance, higher view batch size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The batch size to denoise splited views. For some GPUs with high performance, higher view batch size
The batch size to denoise split views. For some GPUs with high performance, higher view batch size

can speedup the generation and increase the VRAM usage.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
Expand Down Expand Up @@ -609,8 +614,11 @@ def __call__(
)

# 6. Define panorama grid and initialize views for synthesis.
# prepare batch grid
views = self.get_views(height, width)
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views)
views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch)

count = torch.zeros_like(latents)
value = torch.zeros_like(latents)

Expand All @@ -631,42 +639,55 @@ def __call__(
# denoised (latent) crops are then averaged to produce the final latent
# for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
# MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
for j, (h_start, h_end, w_start, w_end) in enumerate(views):
# Batch views denoise
for j, batch_view in enumerate(views_batch):
vb_size = len(batch_view)
# get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
latents_for_view = torch.cat(
[latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view]
)

# rematch block's scheduler status
self.scheduler.__dict__.update(views_scheduler_status[j])

# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
latents_for_view.repeat_interleave(2, dim=0)
if do_classifier_free_guidance
else latents_for_view
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# repeat prompt_embeds for batch
prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states=prompt_embeds_input,
cross_attention_kwargs=cross_attention_kwargs,
).sample

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents_view_denoised = self.scheduler.step(
latents_denoised_batch = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs
).prev_sample

# save views scheduler status after sample
views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__)

value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
# extract value from batch
for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
latents_denoised_batch.chunk(vb_size), batch_view
):
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1

# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = torch.where(count > 0, value / count, value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_inference_batch_consistent(self):

# override to speed the overall test timing up.
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=3e-3)
super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=3.25e-3)

def test_stable_diffusion_panorama_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
Expand All @@ -152,6 +152,24 @@ def test_stable_diffusion_panorama_negative_prompt(self):

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_panorama_views_batch(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPanoramaPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs, view_batch_size=2)
image = output.images
image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 64, 64, 3)

expected_slice = np.array([0.6187, 0.5375, 0.4915, 0.4136, 0.4114, 0.4563, 0.5128, 0.4976, 0.4757])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_panorama_euler(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
Expand Down