-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Fix StableDiffusionXLPAGInpaintPipeline #9128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
a5b6529
fix: class name in `AutoPipelineForInpainting`
gumgood 12bf313
fix: batch size of `mask` and `masked_image_latents`
gumgood ab714f6
Merge branch 'huggingface:main' into fix-pag-inpaint-pipeline
gumgood e8a437a
fix: handling loading of DiffusionPipeline in from_pretrained
gumgood 958e184
fix: init_mask in 4-channel UNet
gumgood ff5fd9d
style: apply formatter
gumgood 2d0429e
fix: logic to find pag pipeline
gumgood 6846f30
revert: `prepare_mask_latents`
gumgood cbb7070
fix: batch size of `mask` and `masked_image_latents`
gumgood File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1471,6 +1471,14 @@ def denoising_value_valid(dnv): | |
| generator, | ||
| self.do_classifier_free_guidance, | ||
| ) | ||
| if self.do_perturbed_attention_guidance: | ||
| if self.do_classifier_free_guidance: | ||
| mask, _ = mask.chunk(2) | ||
| masked_image_latents, _ = masked_image_latents.chunk(2) | ||
| mask = self._prepare_perturbed_attention_guidance(mask, mask, self.do_classifier_free_guidance) | ||
| masked_image_latents = self._prepare_perturbed_attention_guidance( | ||
| masked_image_latents, masked_image_latents, self.do_classifier_free_guidance | ||
| ) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to handle batch size without modifying the |
||
|
|
||
| # 8. Check that sizes of mask, masked image and latents match | ||
| if num_channels_unet == 9: | ||
|
|
@@ -1659,10 +1667,10 @@ def denoising_value_valid(dnv): | |
|
|
||
| if num_channels_unet == 4: | ||
| init_latents_proper = image_latents | ||
| if self.do_classifier_free_guidance: | ||
| init_mask, _ = mask.chunk(2) | ||
| if self.do_perturbed_attention_guidance: | ||
| init_mask, *_ = mask.chunk(3) if self.do_classifier_free_guidance else mask.chunk(2) | ||
| else: | ||
| init_mask = mask | ||
| init_mask, *_ = mask.chunk(2) if self.do_classifier_free_guidance else mask | ||
|
|
||
| if i < len(timesteps) - 1: | ||
| noise_timestep = timesteps[i + 1] | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice code. I have appied it and verified that it works well