Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 42 additions & 53 deletions examples/community/pipeline_stable_diffusion_pag.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Implementation of StableDiffusionPAGPipeline
# Implementation of StableDiffusionPipeline with PAG
# https://ku-cvlab.github.io/Perturbed-Attention-Guidance

import inspect
from typing import Any, Callable, Dict, List, Optional, Union
Expand Down Expand Up @@ -134,8 +135,8 @@ def __call__(

value = attn.to_v(hidden_states_ptb)

hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
# hidden_states_ptb = value
# hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
hidden_states_ptb = value

hidden_states_ptb = hidden_states_ptb.to(query.dtype)

Expand Down Expand Up @@ -1045,7 +1046,7 @@ def pag_scale(self):
return self._pag_scale

@property
def do_adversarial_guidance(self):
def do_perturbed_attention_guidance(self):
return self._pag_scale > 0

@property
Expand All @@ -1056,14 +1057,6 @@ def pag_adaptive_scaling(self):
def do_pag_adaptive_scaling(self):
return self._pag_adaptive_scaling > 0

@property
def pag_drop_rate(self):
return self._pag_drop_rate

@property
def pag_applied_layers(self):
return self._pag_applied_layers

@property
def pag_applied_layers_index(self):
return self._pag_applied_layers_index
Expand All @@ -1080,8 +1073,6 @@ def __call__(
guidance_scale: float = 7.5,
pag_scale: float = 0.0,
pag_adaptive_scaling: float = 0.0,
pag_drop_rate: float = 0.5,
pag_applied_layers: List[str] = ["down"], # ['down', 'mid', 'up']
pag_applied_layers_index: List[str] = ["d4"], # ['d4', 'd5', 'm0']
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
Expand Down Expand Up @@ -1221,8 +1212,6 @@ def __call__(

self._pag_scale = pag_scale
self._pag_adaptive_scaling = pag_adaptive_scaling
self._pag_drop_rate = pag_drop_rate
self._pag_applied_layers = pag_applied_layers
self._pag_applied_layers_index = pag_applied_layers_index

# 2. Define call parameters
Expand Down Expand Up @@ -1257,13 +1246,13 @@ def __call__(
# to avoid doing two forward passes

# cfg
if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# pag
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
# both
elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])

if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
Expand Down Expand Up @@ -1306,7 +1295,7 @@ def __call__(
).to(device=device, dtype=latents.dtype)

# 7. Denoising loop
if self.do_adversarial_guidance:
if self.do_perturbed_attention_guidance:
down_layers = []
mid_layers = []
up_layers = []
Expand All @@ -1322,6 +1311,29 @@ def __call__(
else:
raise ValueError(f"Invalid layer type: {layer_type}")

# change attention layer in UNet if use PAG
if self.do_perturbed_attention_guidance:
if self.do_classifier_free_guidance:
replace_processor = PAGCFGIdentitySelfAttnProcessor()
else:
replace_processor = PAGIdentitySelfAttnProcessor()

drop_layers = self.pag_applied_layers_index
for drop_layer in drop_layers:
try:
if drop_layer[0] == "d":
down_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "m":
mid_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "u":
up_layers[int(drop_layer[1])].processor = replace_processor
else:
raise ValueError(f"Invalid layer type: {drop_layer[0]}")
except IndexError:
raise ValueError(
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
)

num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand All @@ -1330,41 +1342,18 @@ def __call__(
continue

# cfg
if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
latent_model_input = torch.cat([latents] * 2)
# pag
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
latent_model_input = torch.cat([latents] * 2)
# both
elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
latent_model_input = torch.cat([latents] * 3)
# no
else:
latent_model_input = latents

# change attention layer in UNet if use PAG
if self.do_adversarial_guidance:
if self.do_classifier_free_guidance:
replace_processor = PAGCFGIdentitySelfAttnProcessor()
else:
replace_processor = PAGIdentitySelfAttnProcessor()

drop_layers = self.pag_applied_layers_index
for drop_layer in drop_layers:
try:
if drop_layer[0] == "d":
down_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "m":
mid_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "u":
up_layers[int(drop_layer[1])].processor = replace_processor
else:
raise ValueError(f"Invalid layer type: {drop_layer[0]}")
except IndexError:
raise ValueError(
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
)

latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
Expand All @@ -1381,14 +1370,14 @@ def __call__(
# perform guidance

# cfg
if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

delta = noise_pred_text - noise_pred_uncond
noise_pred = noise_pred_uncond + self.guidance_scale * delta

# pag
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)

signal_scale = self.pag_scale
Expand All @@ -1400,7 +1389,7 @@ def __call__(
noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)

# both
elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)

signal_scale = self.pag_scale
Expand Down Expand Up @@ -1458,11 +1447,8 @@ def __call__(
# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (image, has_nsfw_concept)

# change attention layer in UNet if use PAG
if self.do_adversarial_guidance:
if self.do_perturbed_attention_guidance:
drop_layers = self.pag_applied_layers_index
for drop_layer in drop_layers:
try:
Expand All @@ -1479,4 +1465,7 @@ def __call__(
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
)

if not return_dict:
return (image, has_nsfw_concept)

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)