diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 5ea077db04ba..b0cf84986c82 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -953,7 +953,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: - orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline") + to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" + orig_class_name = config["_class_name"].replace(to_replace, "PAG" + to_replace) inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py index 64aff497a594..09c3a7029ceb 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -1471,6 +1471,14 @@ def denoising_value_valid(dnv): generator, self.do_classifier_free_guidance, ) + if self.do_perturbed_attention_guidance: + if self.do_classifier_free_guidance: + mask, _ = mask.chunk(2) + masked_image_latents, _ = masked_image_latents.chunk(2) + mask = self._prepare_perturbed_attention_guidance(mask, mask, self.do_classifier_free_guidance) + masked_image_latents = self._prepare_perturbed_attention_guidance( + masked_image_latents, masked_image_latents, self.do_classifier_free_guidance + ) # 8. Check that sizes of mask, masked image and latents match if num_channels_unet == 9: @@ -1659,10 +1667,10 @@ def denoising_value_valid(dnv): if num_channels_unet == 4: init_latents_proper = image_latents - if self.do_classifier_free_guidance: - init_mask, _ = mask.chunk(2) + if self.do_perturbed_attention_guidance: + init_mask, *_ = mask.chunk(3) if self.do_classifier_free_guidance else mask.chunk(2) else: - init_mask = mask + init_mask, *_ = mask.chunk(2) if self.do_classifier_free_guidance else mask if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1]