Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3743,3 +3743,80 @@ onestep_image = pipe(prompt, num_inference_steps=1).images[0]
# Multistep sampling
multistep_image = pipe(prompt, num_inference_steps=4).images[0]
```

# Perturbed-Attention Guidance

[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)

This implementation is based on [Diffusers](https://huggingface.co/docs/diffusers/index). StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).

## Example Usage

```
import os
import torch

from accelerate.utils import set_seed

from diffusers import StableDiffusionPipeline
from diffusers.utils import load_image, make_image_grid
from diffusers.utils.torch_utils import randn_tensor

pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
custom_pipeline="hyoungwoncho/sd_perturbed_attention_guidance",
torch_dtype=torch.float16
)

device="cuda"
pipe = pipe.to(device)

pag_scale = 5.0
pag_applied_layers_index = ['m0']

batch_size = 4
seed=10

base_dir = "./results/"
grid_dir = base_dir + "/pag" + str(pag_scale) + "/"

if not os.path.exists(grid_dir):
os.makedirs(grid_dir)

set_seed(seed)

latent_input = randn_tensor(shape=(batch_size,4,64,64),generator=None, device=device, dtype=torch.float16)

output_baseline = pipe(
"",
width=512,
height=512,
num_inference_steps=50,
guidance_scale=0.0,
pag_scale=0.0,
pag_applied_layers_index=pag_applied_layers_index,
num_images_per_prompt=batch_size,
latents=latent_input
).images

output_pag = pipe(
"",
width=512,
height=512,
num_inference_steps=50,
guidance_scale=0.0,
pag_scale=5.0,
pag_applied_layers_index=pag_applied_layers_index,
num_images_per_prompt=batch_size,
latents=latent_input
).images

grid_image = make_image_grid(output_baseline + output_pag, rows=2, cols=batch_size)
grid_image.save(grid_dir + "sample.png")
```

## PAG Parameters

pag_scale : gudiance scale of PAG (ex: 5.0)

pag_applied_layers_index : index of the layer to apply perturbation (ex: ['m0'])
Loading