diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 7db7bfeda600..018b57dd07eb 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -928,9 +928,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F hidden_size = self.config.block_out_channels[block_id] if cross_attention_dim is None or "motion_modules" in name: - attn_processor_class = ( - AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor - ) + attn_processor_class = self.attn_processors[name].__class__ attn_procs[name] = attn_processor_class() else: diff --git a/src/diffusers/pipelines/pag_utils.py b/src/diffusers/pipelines/pag_utils.py index bb1b1083235f..c9da87523359 100644 --- a/src/diffusers/pipelines/pag_utils.py +++ b/src/diffusers/pipelines/pag_utils.py @@ -45,7 +45,7 @@ def enable_pag( self._pag_applied_layers = pag_applied_layers self._pag_applied_layers_index = pag_applied_layers_index self._pag_cfg = pag_cfg - + self._is_pag_enabled = True self._set_pag_attn_processor() def _get_self_attn_layers(self): @@ -180,6 +180,7 @@ def disable_pag(self): self._pag_applied_layers = None self._pag_applied_layers_index = None self._pag_cfg = None + self._is_pag_enabled = False @property def pag_adaptive_scaling(self): @@ -191,4 +192,4 @@ def do_pag_adaptive_scaling(self): @property def do_perturbed_attention_guidance(self): - return hasattr(self, "_pag_scale") and self._pag_scale is not None and self._pag_scale > 0 + return self._is_pag_enabled diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index b28ebc334e44..aebf3f0d941e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -536,7 +536,7 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance, do_perturbed_attention_guidance ): if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): @@ -560,6 +560,10 @@ def prepare_ip_adapter_image_embeds( [single_negative_image_embeds] * num_images_per_prompt, dim=0 ) + if do_perturbed_attention_guidance: + single_image_embeds = torch.cat([single_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device) + if do_classifier_free_guidance: single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = single_image_embeds.to(device) @@ -577,11 +581,16 @@ def prepare_ip_adapter_image_embeds( single_negative_image_embeds = single_negative_image_embeds.repeat( num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + if do_perturbed_attention_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds, single_image_embeds], dim=0) + else: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: single_image_embeds = single_image_embeds.repeat( num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) ) + if do_perturbed_attention_guidance: + single_image_embeds = torch.cat([single_image_embeds, single_image_embeds], dim=0) image_embeds.append(single_image_embeds) return image_embeds @@ -1170,6 +1179,7 @@ def __call__( device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, + self.do_perturbed_attention_guidance, ) # 8. Denoising loop @@ -1205,7 +1215,7 @@ def __call__( if self.interrupt: continue - # expand the latents if we are doing classifier free guidance + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)