-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Support views batch for panorama #3632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
52712e7
support views batch for panorama
Isotr0py 30fadbf
add entry for the new argument
Isotr0py 6c1faab
format entry for the new argument
Isotr0py 2155029
add view_batch_size test
Isotr0py 66bf08e
fix batch test and a boundary condition
Isotr0py 82932a6
add more docstrings
Isotr0py 305ad46
fix a typos
Isotr0py b9f7691
fix typos
Isotr0py a36aeaa
add: entry to the doc about view_batch_size.
sayakpaul f493331
Revert "add: entry to the doc about view_batch_size."
sayakpaul dcd5329
Merge branch 'main' into panorama_batch
sayakpaul 49f7371
add a tip on .
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -451,10 +451,11 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype | |||||
|
||||||
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8): | ||||||
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) | ||||||
# if panorama's height/width < window_size, num_blocks of height/width should return 1 | ||||||
panorama_height /= 8 | ||||||
panorama_width /= 8 | ||||||
num_blocks_height = (panorama_height - window_size) // stride + 1 | ||||||
num_blocks_width = (panorama_width - window_size) // stride + 1 | ||||||
num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1 | ||||||
num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_height > window_size else 1 | ||||||
total_num_blocks = int(num_blocks_height * num_blocks_width) | ||||||
views = [] | ||||||
for i in range(total_num_blocks): | ||||||
|
@@ -474,6 +475,7 @@ def __call__( | |||||
width: Optional[int] = 2048, | ||||||
num_inference_steps: int = 50, | ||||||
guidance_scale: float = 7.5, | ||||||
view_batch_size: int = 1, | ||||||
negative_prompt: Optional[Union[str, List[str]]] = None, | ||||||
num_images_per_prompt: Optional[int] = 1, | ||||||
eta: float = 0.0, | ||||||
|
@@ -508,6 +510,9 @@ def __call__( | |||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | ||||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | ||||||
usually at the expense of lower image quality. | ||||||
view_batch_size (`int`, *optional*, defaults to 1): | ||||||
The batch size to denoise splited views. For some GPUs with high performance, higher view batch size | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
can speedup the generation and increase the VRAM usage. | ||||||
negative_prompt (`str` or `List[str]`, *optional*): | ||||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass | ||||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | ||||||
|
@@ -609,8 +614,11 @@ def __call__( | |||||
) | ||||||
|
||||||
# 6. Define panorama grid and initialize views for synthesis. | ||||||
# prepare batch grid | ||||||
views = self.get_views(height, width) | ||||||
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views) | ||||||
views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)] | ||||||
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch) | ||||||
|
||||||
count = torch.zeros_like(latents) | ||||||
value = torch.zeros_like(latents) | ||||||
|
||||||
|
@@ -631,42 +639,55 @@ def __call__( | |||||
# denoised (latent) crops are then averaged to produce the final latent | ||||||
# for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the | ||||||
# MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113 | ||||||
for j, (h_start, h_end, w_start, w_end) in enumerate(views): | ||||||
# Batch views denoise | ||||||
for j, batch_view in enumerate(views_batch): | ||||||
vb_size = len(batch_view) | ||||||
# get the latents corresponding to the current view coordinates | ||||||
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end] | ||||||
latents_for_view = torch.cat( | ||||||
[latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view] | ||||||
) | ||||||
|
||||||
# rematch block's scheduler status | ||||||
self.scheduler.__dict__.update(views_scheduler_status[j]) | ||||||
|
||||||
# expand the latents if we are doing classifier free guidance | ||||||
latent_model_input = ( | ||||||
torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view | ||||||
latents_for_view.repeat_interleave(2, dim=0) | ||||||
if do_classifier_free_guidance | ||||||
else latents_for_view | ||||||
) | ||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | ||||||
|
||||||
# repeat prompt_embeds for batch | ||||||
prompt_embeds_input = torch.cat([prompt_embeds] * vb_size) | ||||||
|
||||||
# predict the noise residual | ||||||
noise_pred = self.unet( | ||||||
latent_model_input, | ||||||
t, | ||||||
encoder_hidden_states=prompt_embeds, | ||||||
encoder_hidden_states=prompt_embeds_input, | ||||||
cross_attention_kwargs=cross_attention_kwargs, | ||||||
).sample | ||||||
|
||||||
# perform guidance | ||||||
if do_classifier_free_guidance: | ||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||||||
noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] | ||||||
Isotr0py marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||||||
|
||||||
# compute the previous noisy sample x_t -> x_t-1 | ||||||
latents_view_denoised = self.scheduler.step( | ||||||
latents_denoised_batch = self.scheduler.step( | ||||||
noise_pred, t, latents_for_view, **extra_step_kwargs | ||||||
).prev_sample | ||||||
|
||||||
# save views scheduler status after sample | ||||||
views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) | ||||||
|
||||||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised | ||||||
count[:, :, h_start:h_end, w_start:w_end] += 1 | ||||||
# extract value from batch | ||||||
for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip( | ||||||
latents_denoised_batch.chunk(vb_size), batch_view | ||||||
): | ||||||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised | ||||||
count[:, :, h_start:h_end, w_start:w_end] += 1 | ||||||
|
||||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 | ||||||
latents = torch.where(count > 0, value / count, value) | ||||||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also add an entry for this
arg
in the docstrings?