From 76d736dc64bb8aa519bbb9abcacfb8a141e957b5 Mon Sep 17 00:00:00 2001 From: HyoungwonCho Date: Tue, 7 May 2024 16:12:17 +0900 Subject: [PATCH 1/2] edited_pag_implementation --- .../pipeline_stable_diffusion_pag.py | 97 ++++++++----------- 1 file changed, 43 insertions(+), 54 deletions(-) diff --git a/examples/community/pipeline_stable_diffusion_pag.py b/examples/community/pipeline_stable_diffusion_pag.py index 04f38a888460..2d627512d4c4 100644 --- a/examples/community/pipeline_stable_diffusion_pag.py +++ b/examples/community/pipeline_stable_diffusion_pag.py @@ -1,4 +1,5 @@ -# Implementation of StableDiffusionPAGPipeline +# Implementation of StableDiffusionPipeline with PAG +# https://ku-cvlab.github.io/Perturbed-Attention-Guidance import inspect from typing import Any, Callable, Dict, List, Optional, Union @@ -134,8 +135,8 @@ def __call__( value = attn.to_v(hidden_states_ptb) - hidden_states_ptb = torch.zeros(value.shape).to(value.get_device()) - # hidden_states_ptb = value + # hidden_states_ptb = torch.zeros(value.shape).to(value.get_device()) + hidden_states_ptb = value hidden_states_ptb = hidden_states_ptb.to(query.dtype) @@ -1045,7 +1046,7 @@ def pag_scale(self): return self._pag_scale @property - def do_adversarial_guidance(self): + def do_perturbed_attention_guidance(self): return self._pag_scale > 0 @property @@ -1056,14 +1057,6 @@ def pag_adaptive_scaling(self): def do_pag_adaptive_scaling(self): return self._pag_adaptive_scaling > 0 - @property - def pag_drop_rate(self): - return self._pag_drop_rate - - @property - def pag_applied_layers(self): - return self._pag_applied_layers - @property def pag_applied_layers_index(self): return self._pag_applied_layers_index @@ -1080,8 +1073,6 @@ def __call__( guidance_scale: float = 7.5, pag_scale: float = 0.0, pag_adaptive_scaling: float = 0.0, - pag_drop_rate: float = 0.5, - pag_applied_layers: List[str] = ["down"], # ['down', 'mid', 'up'] pag_applied_layers_index: List[str] = ["d4"], # ['d4', 'd5', 'm0'] negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -1221,8 +1212,6 @@ def __call__( self._pag_scale = pag_scale self._pag_adaptive_scaling = pag_adaptive_scaling - self._pag_drop_rate = pag_drop_rate - self._pag_applied_layers = pag_applied_layers self._pag_applied_layers_index = pag_applied_layers_index # 2. Define call parameters @@ -1257,13 +1246,13 @@ def __call__( # to avoid doing two forward passes # cfg - if self.do_classifier_free_guidance and not self.do_adversarial_guidance: + if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # pag - elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance: prompt_embeds = torch.cat([prompt_embeds, prompt_embeds]) # both - elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: @@ -1306,7 +1295,7 @@ def __call__( ).to(device=device, dtype=latents.dtype) # 7. Denoising loop - if self.do_adversarial_guidance: + if self.do_perturbed_attention_guidance: down_layers = [] mid_layers = [] up_layers = [] @@ -1322,6 +1311,29 @@ def __call__( else: raise ValueError(f"Invalid layer type: {layer_type}") + # change attention layer in UNet if use PAG + if self.do_perturbed_attention_guidance: + if self.do_classifier_free_guidance: + replace_processor = PAGCFGIdentitySelfAttnProcessor() + else: + replace_processor = PAGIdentitySelfAttnProcessor() + + drop_layers = self.pag_applied_layers_index + for drop_layer in drop_layers: + try: + if drop_layer[0] == "d": + down_layers[int(drop_layer[1])].processor = replace_processor + elif drop_layer[0] == "m": + mid_layers[int(drop_layer[1])].processor = replace_processor + elif drop_layer[0] == "u": + up_layers[int(drop_layer[1])].processor = replace_processor + else: + raise ValueError(f"Invalid layer type: {drop_layer[0]}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." + ) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -1330,41 +1342,18 @@ def __call__( continue # cfg - if self.do_classifier_free_guidance and not self.do_adversarial_guidance: + if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance: latent_model_input = torch.cat([latents] * 2) # pag - elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance: latent_model_input = torch.cat([latents] * 2) # both - elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance: latent_model_input = torch.cat([latents] * 3) # no else: latent_model_input = latents - # change attention layer in UNet if use PAG - if self.do_adversarial_guidance: - if self.do_classifier_free_guidance: - replace_processor = PAGCFGIdentitySelfAttnProcessor() - else: - replace_processor = PAGIdentitySelfAttnProcessor() - - drop_layers = self.pag_applied_layers_index - for drop_layer in drop_layers: - try: - if drop_layer[0] == "d": - down_layers[int(drop_layer[1])].processor = replace_processor - elif drop_layer[0] == "m": - mid_layers[int(drop_layer[1])].processor = replace_processor - elif drop_layer[0] == "u": - up_layers[int(drop_layer[1])].processor = replace_processor - else: - raise ValueError(f"Invalid layer type: {drop_layer[0]}") - except IndexError: - raise ValueError( - f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." - ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual @@ -1381,14 +1370,14 @@ def __call__( # perform guidance # cfg - if self.do_classifier_free_guidance and not self.do_adversarial_guidance: + if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) delta = noise_pred_text - noise_pred_uncond noise_pred = noise_pred_uncond + self.guidance_scale * delta # pag - elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance: noise_pred_original, noise_pred_perturb = noise_pred.chunk(2) signal_scale = self.pag_scale @@ -1400,7 +1389,7 @@ def __call__( noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb) # both - elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance: noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3) signal_scale = self.pag_scale @@ -1458,11 +1447,8 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - if not return_dict: - return (image, has_nsfw_concept) - # change attention layer in UNet if use PAG - if self.do_adversarial_guidance: + if self.do_perturbed_attention_guidance: drop_layers = self.pag_applied_layers_index for drop_layer in drop_layers: try: @@ -1479,4 +1465,7 @@ def __call__( f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." ) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file From 128cb6900f90946a31dce6ebe537f09cbffbf2cd Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 8 May 2024 04:12:43 +0200 Subject: [PATCH 2/2] update --- examples/community/pipeline_stable_diffusion_pag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/pipeline_stable_diffusion_pag.py b/examples/community/pipeline_stable_diffusion_pag.py index 2d627512d4c4..cdb7bd99cb29 100644 --- a/examples/community/pipeline_stable_diffusion_pag.py +++ b/examples/community/pipeline_stable_diffusion_pag.py @@ -1468,4 +1468,4 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)