Skip to content

add PAG support for Stable Diffusion 3 #8861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
43e6c20
original sd3 pipelien
HyoungwonCho Jul 8, 2024
d89945b
Merge branch 'main' of https://github.com/huggingface/diffusers into …
HyoungwonCho Jul 8, 2024
24394c5
Add PAG support sd3
sunovivid Jul 13, 2024
f09f09c
Merge commit 'bbd2f9d4e9ae70b04fedf65903fd1fb035437db4' into pag-sd3
sunovivid Jul 13, 2024
34a7814
Update src/diffusers/models/attention_processor.py
sunovivid Jul 17, 2024
ac1ff12
Update src/diffusers/models/attention_processor.py
sunovivid Jul 17, 2024
97b6697
Fix org, ptb, seperate SD3PAGMixion from PAGMixin
sunovivid Jul 21, 2024
12a2e28
Merge branch 'pag-sd3' of https://github.com/sunovivid/diffusers into…
sunovivid Jul 21, 2024
5001510
Merge branch 'main' into pag-sd3
sunovivid Jul 21, 2024
efb3ddf
Update EXAMPLE_DOC_STRING
sunovivid Jul 21, 2024
6c5df1d
Merge branch 'pag-sd3' of https://github.com/sunovivid/diffusers into…
sunovivid Jul 21, 2024
67fd57f
fix missed comma
sunovivid Jul 21, 2024
ab8f642
Remove unused sd3 method in PAGMixin
sunovivid Jul 21, 2024
6e1fd0f
Merge remote-tracking branch 'origin/main' into pag-sd3
crepejung00 Aug 5, 2024
8a8fcb7
Merge branch 'huggingface:main' into pag-sd3
sunovivid Aug 5, 2024
795ddaf
remove SD3PAGMixin and adopt refactored PAGMixin
sunovivid Aug 5, 2024
73df50f
add StableDiffusion3PAGPipeline doc
sunovivid Aug 5, 2024
e622328
add parameter description
sunovivid Aug 5, 2024
35349d9
remove *args, **kwargs
sunovivid Aug 5, 2024
27f1100
add import pag pipeline
sunovivid Aug 5, 2024
083360c
Merge branch 'pag-sd3' of https://github.com/sunovivid/diffusers into…
sunovivid Aug 5, 2024
512026f
make style
sunovivid Aug 5, 2024
29d0c7b
Merge branch 'main' into pag-sd3
sayakpaul Aug 6, 2024
c0bcc82
style
sayakpaul Aug 6, 2024
12c49db
apply comments (copied from)
sunovivid Aug 6, 2024
1c32060
apply comments (fix the comment location)
sunovivid Aug 6, 2024
fa90a92
make fix-copies, sync SD3 pipeline update
sunovivid Aug 6, 2024
986bcef
add test
sunovivid Aug 6, 2024
ee863fa
fix test following HunyuanDiT test -> HunyuanDiTPAG test
sunovivid Aug 6, 2024
e243d5d
remove slow test
sunovivid Aug 6, 2024
ca33100
make style
sunovivid Aug 6, 2024
12ec17c
Update src/diffusers/pipelines/__init__.py
sunovivid Aug 6, 2024
78867bb
set num_layers=2 for sd3 models in SD3 PAG test
sunovivid Aug 6, 2024
25a8388
Merge branch 'pag-sd3' of https://github.com/sunovivid/diffusers into…
sunovivid Aug 6, 2024
dc25826
Merge branch 'main' into pag-sd3
a-r-r-o-w Aug 6, 2024
c6affa7
fix attn1 to attn
sunovivid Aug 6, 2024
d80885c
Merge branch 'main' into pag-sd3
sayakpaul Aug 6, 2024
93b0349
remove assert case for nan
sunovivid Aug 6, 2024
792398d
style
yiyixuxu Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/api/pipelines/pag.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
- __call__


## StableDiffusion3PAGPipeline
[[autodoc]] StableDiffusion3PAGPipeline
- all
- __call__


## PixArtSigmaPAGPipeline
[[autodoc]] PixArtSigmaPAGPipeline
- all
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@
"StableDiffusion3ControlNetPipeline",
"StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusion3Pipeline",
"StableDiffusionAdapterPipeline",
"StableDiffusionAttendAndExcitePipeline",
Expand Down Expand Up @@ -741,6 +742,7 @@
StableDiffusion3ControlNetPipeline,
StableDiffusion3Img2ImgPipeline,
StableDiffusion3InpaintPipeline,
StableDiffusion3PAGPipeline,
StableDiffusion3Pipeline,
StableDiffusionAdapterPipeline,
StableDiffusionAttendAndExcitePipeline,
Expand Down
320 changes: 320 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,326 @@ def __call__(
return hidden_states, encoder_hidden_states


class PAGJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
) -> torch.FloatTensor:
residual = hidden_states

input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

# store the length of image patch sequences to create a mask that prevents interaction between patches
# similar to making the self-attention map an identity matrix
identity_block_size = hidden_states.shape[1]

# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2)

################## original path ##################
batch_size = encoder_hidden_states_org.shape[0]

# `sample` projections.
query_org = attn.to_q(hidden_states_org)
key_org = attn.to_k(hidden_states_org)
value_org = attn.to_v(hidden_states_org)

# `context` projections.
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)

# attention
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)

inner_dim = key_org.shape[-1]
head_dim = inner_dim // attn.heads
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

hidden_states_org = F.scaled_dot_product_attention(
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query_org.dtype)

# Split the attention outputs.
hidden_states_org, encoder_hidden_states_org = (
hidden_states_org[:, : residual.shape[1]],
hidden_states_org[:, residual.shape[1] :],
)

# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if not attn.context_pre_only:
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)

if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
batch_size, channel, height, width
)

################## perturbed path ##################

batch_size = encoder_hidden_states_ptb.shape[0]

# `sample` projections.
query_ptb = attn.to_q(hidden_states_ptb)
key_ptb = attn.to_k(hidden_states_ptb)
value_ptb = attn.to_v(hidden_states_ptb)

# `context` projections.
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)

# attention
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)

inner_dim = key_ptb.shape[-1]
head_dim = inner_dim // attn.heads
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# create a full mask with all entries set to 0
seq_len = query_ptb.size(2)
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)

# set the attention value between image patches to -inf
full_mask[:identity_block_size, :identity_block_size] = float("-inf")

# set the diagonal of the attention value between image patches to 0
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)

# expand the mask to match the attention weights shape
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions

hidden_states_ptb = F.scaled_dot_product_attention(
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
)
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)

# split the attention outputs.
hidden_states_ptb, encoder_hidden_states_ptb = (
hidden_states_ptb[:, : residual.shape[1]],
hidden_states_ptb[:, residual.shape[1] :],
)

# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if not attn.context_pre_only:
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)

if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
batch_size, channel, height, width
)

################ concat ###############
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])

return hidden_states, encoder_hidden_states


class PAGCFGJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states

input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

identity_block_size = hidden_states.shape[
1
] # patch embeddings width * height (correspond to self-attention map width or height)

# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])

(
encoder_hidden_states_uncond,
encoder_hidden_states_org,
encoder_hidden_states_ptb,
) = encoder_hidden_states.chunk(3)
encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org])

################## original path ##################
batch_size = encoder_hidden_states_org.shape[0]

# `sample` projections.
query_org = attn.to_q(hidden_states_org)
key_org = attn.to_k(hidden_states_org)
value_org = attn.to_v(hidden_states_org)

# `context` projections.
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)

# attention
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)

inner_dim = key_org.shape[-1]
head_dim = inner_dim // attn.heads
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

hidden_states_org = F.scaled_dot_product_attention(
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query_org.dtype)

# Split the attention outputs.
hidden_states_org, encoder_hidden_states_org = (
hidden_states_org[:, : residual.shape[1]],
hidden_states_org[:, residual.shape[1] :],
)

# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if not attn.context_pre_only:
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)

if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
batch_size, channel, height, width
)

################## perturbed path ##################

batch_size = encoder_hidden_states_ptb.shape[0]

# `sample` projections.
query_ptb = attn.to_q(hidden_states_ptb)
key_ptb = attn.to_k(hidden_states_ptb)
value_ptb = attn.to_v(hidden_states_ptb)

# `context` projections.
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)

# attention
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)

inner_dim = key_ptb.shape[-1]
head_dim = inner_dim // attn.heads
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# create a full mask with all entries set to 0
seq_len = query_ptb.size(2)
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)

# set the attention value between image patches to -inf
full_mask[:identity_block_size, :identity_block_size] = float("-inf")

# set the diagonal of the attention value between image patches to 0
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)

# expand the mask to match the attention weights shape
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions

hidden_states_ptb = F.scaled_dot_product_attention(
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
)
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)

# split the attention outputs.
hidden_states_ptb, encoder_hidden_states_ptb = (
hidden_states_ptb[:, : residual.shape[1]],
hidden_states_ptb[:, residual.shape[1] :],
)

# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if not attn.context_pre_only:
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)

if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
batch_size, channel, height, width
)

################ concat ###############
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])

return hidden_states, encoder_hidden_states


class FusedJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
[
"AnimateDiffPAGPipeline",
"HunyuanDiTPAGPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline",
Expand Down Expand Up @@ -540,6 +541,7 @@
AnimateDiffPAGPipeline,
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline,
Expand Down Expand Up @@ -84,6 +85,7 @@
("stable-diffusion", StableDiffusionPipeline),
("stable-diffusion-xl", StableDiffusionXLPipeline),
("stable-diffusion-3", StableDiffusion3Pipeline),
("stable-diffusion-3-pag", StableDiffusion3PAGPipeline),
("if", IFPipeline),
("hunyuan", HunyuanDiTPipeline),
("hunyuan-pag", HunyuanDiTPAGPipeline),
Expand Down
Loading
Loading