Skip to content

Commit f258237

Browse files
sunovividHyoungwonChosayakpaulcrepejung00yiyixuxu
committed
add PAG support for Stable Diffusion 3 (#8861)
add pag sd3 --------- Co-authored-by: HyoungwonCho <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: crepejung00 <[email protected]> Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Aryan <[email protected]> Co-authored-by: Aryan <[email protected]>
1 parent 4a91ee8 commit f258237

File tree

9 files changed

+1629
-0
lines changed

9 files changed

+1629
-0
lines changed

docs/source/en/api/pipelines/pag.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
7474
- __call__
7575

7676

77+
## StableDiffusion3PAGPipeline
78+
[[autodoc]] StableDiffusion3PAGPipeline
79+
- all
80+
- __call__
81+
82+
7783
## PixArtSigmaPAGPipeline
7884
[[autodoc]] PixArtSigmaPAGPipeline
7985
- all

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@
308308
"StableDiffusion3ControlNetPipeline",
309309
"StableDiffusion3Img2ImgPipeline",
310310
"StableDiffusion3InpaintPipeline",
311+
"StableDiffusion3PAGPipeline",
311312
"StableDiffusion3Pipeline",
312313
"StableDiffusionAdapterPipeline",
313314
"StableDiffusionAttendAndExcitePipeline",
@@ -741,6 +742,7 @@
741742
StableDiffusion3ControlNetPipeline,
742743
StableDiffusion3Img2ImgPipeline,
743744
StableDiffusion3InpaintPipeline,
745+
StableDiffusion3PAGPipeline,
744746
StableDiffusion3Pipeline,
745747
StableDiffusionAdapterPipeline,
746748
StableDiffusionAttendAndExcitePipeline,

src/diffusers/models/attention_processor.py

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,6 +1106,326 @@ def __call__(
11061106
return hidden_states, encoder_hidden_states
11071107

11081108

1109+
class PAGJointAttnProcessor2_0:
1110+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
1111+
1112+
def __init__(self):
1113+
if not hasattr(F, "scaled_dot_product_attention"):
1114+
raise ImportError(
1115+
"PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1116+
)
1117+
1118+
def __call__(
1119+
self,
1120+
attn: Attention,
1121+
hidden_states: torch.FloatTensor,
1122+
encoder_hidden_states: torch.FloatTensor = None,
1123+
) -> torch.FloatTensor:
1124+
residual = hidden_states
1125+
1126+
input_ndim = hidden_states.ndim
1127+
if input_ndim == 4:
1128+
batch_size, channel, height, width = hidden_states.shape
1129+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1130+
context_input_ndim = encoder_hidden_states.ndim
1131+
if context_input_ndim == 4:
1132+
batch_size, channel, height, width = encoder_hidden_states.shape
1133+
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1134+
1135+
# store the length of image patch sequences to create a mask that prevents interaction between patches
1136+
# similar to making the self-attention map an identity matrix
1137+
identity_block_size = hidden_states.shape[1]
1138+
1139+
# chunk
1140+
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
1141+
encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2)
1142+
1143+
################## original path ##################
1144+
batch_size = encoder_hidden_states_org.shape[0]
1145+
1146+
# `sample` projections.
1147+
query_org = attn.to_q(hidden_states_org)
1148+
key_org = attn.to_k(hidden_states_org)
1149+
value_org = attn.to_v(hidden_states_org)
1150+
1151+
# `context` projections.
1152+
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
1153+
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
1154+
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
1155+
1156+
# attention
1157+
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
1158+
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
1159+
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
1160+
1161+
inner_dim = key_org.shape[-1]
1162+
head_dim = inner_dim // attn.heads
1163+
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1164+
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1165+
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1166+
1167+
hidden_states_org = F.scaled_dot_product_attention(
1168+
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
1169+
)
1170+
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1171+
hidden_states_org = hidden_states_org.to(query_org.dtype)
1172+
1173+
# Split the attention outputs.
1174+
hidden_states_org, encoder_hidden_states_org = (
1175+
hidden_states_org[:, : residual.shape[1]],
1176+
hidden_states_org[:, residual.shape[1] :],
1177+
)
1178+
1179+
# linear proj
1180+
hidden_states_org = attn.to_out[0](hidden_states_org)
1181+
# dropout
1182+
hidden_states_org = attn.to_out[1](hidden_states_org)
1183+
if not attn.context_pre_only:
1184+
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
1185+
1186+
if input_ndim == 4:
1187+
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
1188+
if context_input_ndim == 4:
1189+
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
1190+
batch_size, channel, height, width
1191+
)
1192+
1193+
################## perturbed path ##################
1194+
1195+
batch_size = encoder_hidden_states_ptb.shape[0]
1196+
1197+
# `sample` projections.
1198+
query_ptb = attn.to_q(hidden_states_ptb)
1199+
key_ptb = attn.to_k(hidden_states_ptb)
1200+
value_ptb = attn.to_v(hidden_states_ptb)
1201+
1202+
# `context` projections.
1203+
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
1204+
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
1205+
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
1206+
1207+
# attention
1208+
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
1209+
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
1210+
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
1211+
1212+
inner_dim = key_ptb.shape[-1]
1213+
head_dim = inner_dim // attn.heads
1214+
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1215+
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1216+
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1217+
1218+
# create a full mask with all entries set to 0
1219+
seq_len = query_ptb.size(2)
1220+
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
1221+
1222+
# set the attention value between image patches to -inf
1223+
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
1224+
1225+
# set the diagonal of the attention value between image patches to 0
1226+
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
1227+
1228+
# expand the mask to match the attention weights shape
1229+
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
1230+
1231+
hidden_states_ptb = F.scaled_dot_product_attention(
1232+
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
1233+
)
1234+
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1235+
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
1236+
1237+
# split the attention outputs.
1238+
hidden_states_ptb, encoder_hidden_states_ptb = (
1239+
hidden_states_ptb[:, : residual.shape[1]],
1240+
hidden_states_ptb[:, residual.shape[1] :],
1241+
)
1242+
1243+
# linear proj
1244+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
1245+
# dropout
1246+
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
1247+
if not attn.context_pre_only:
1248+
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
1249+
1250+
if input_ndim == 4:
1251+
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
1252+
if context_input_ndim == 4:
1253+
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
1254+
batch_size, channel, height, width
1255+
)
1256+
1257+
################ concat ###############
1258+
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
1259+
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
1260+
1261+
return hidden_states, encoder_hidden_states
1262+
1263+
1264+
class PAGCFGJointAttnProcessor2_0:
1265+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
1266+
1267+
def __init__(self):
1268+
if not hasattr(F, "scaled_dot_product_attention"):
1269+
raise ImportError(
1270+
"PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1271+
)
1272+
1273+
def __call__(
1274+
self,
1275+
attn: Attention,
1276+
hidden_states: torch.FloatTensor,
1277+
encoder_hidden_states: torch.FloatTensor = None,
1278+
attention_mask: Optional[torch.FloatTensor] = None,
1279+
*args,
1280+
**kwargs,
1281+
) -> torch.FloatTensor:
1282+
residual = hidden_states
1283+
1284+
input_ndim = hidden_states.ndim
1285+
if input_ndim == 4:
1286+
batch_size, channel, height, width = hidden_states.shape
1287+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1288+
context_input_ndim = encoder_hidden_states.ndim
1289+
if context_input_ndim == 4:
1290+
batch_size, channel, height, width = encoder_hidden_states.shape
1291+
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1292+
1293+
identity_block_size = hidden_states.shape[
1294+
1
1295+
] # patch embeddings width * height (correspond to self-attention map width or height)
1296+
1297+
# chunk
1298+
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
1299+
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
1300+
1301+
(
1302+
encoder_hidden_states_uncond,
1303+
encoder_hidden_states_org,
1304+
encoder_hidden_states_ptb,
1305+
) = encoder_hidden_states.chunk(3)
1306+
encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org])
1307+
1308+
################## original path ##################
1309+
batch_size = encoder_hidden_states_org.shape[0]
1310+
1311+
# `sample` projections.
1312+
query_org = attn.to_q(hidden_states_org)
1313+
key_org = attn.to_k(hidden_states_org)
1314+
value_org = attn.to_v(hidden_states_org)
1315+
1316+
# `context` projections.
1317+
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
1318+
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
1319+
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
1320+
1321+
# attention
1322+
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
1323+
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
1324+
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
1325+
1326+
inner_dim = key_org.shape[-1]
1327+
head_dim = inner_dim // attn.heads
1328+
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1329+
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1330+
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1331+
1332+
hidden_states_org = F.scaled_dot_product_attention(
1333+
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
1334+
)
1335+
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1336+
hidden_states_org = hidden_states_org.to(query_org.dtype)
1337+
1338+
# Split the attention outputs.
1339+
hidden_states_org, encoder_hidden_states_org = (
1340+
hidden_states_org[:, : residual.shape[1]],
1341+
hidden_states_org[:, residual.shape[1] :],
1342+
)
1343+
1344+
# linear proj
1345+
hidden_states_org = attn.to_out[0](hidden_states_org)
1346+
# dropout
1347+
hidden_states_org = attn.to_out[1](hidden_states_org)
1348+
if not attn.context_pre_only:
1349+
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
1350+
1351+
if input_ndim == 4:
1352+
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
1353+
if context_input_ndim == 4:
1354+
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
1355+
batch_size, channel, height, width
1356+
)
1357+
1358+
################## perturbed path ##################
1359+
1360+
batch_size = encoder_hidden_states_ptb.shape[0]
1361+
1362+
# `sample` projections.
1363+
query_ptb = attn.to_q(hidden_states_ptb)
1364+
key_ptb = attn.to_k(hidden_states_ptb)
1365+
value_ptb = attn.to_v(hidden_states_ptb)
1366+
1367+
# `context` projections.
1368+
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
1369+
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
1370+
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
1371+
1372+
# attention
1373+
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
1374+
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
1375+
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
1376+
1377+
inner_dim = key_ptb.shape[-1]
1378+
head_dim = inner_dim // attn.heads
1379+
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1380+
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1381+
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1382+
1383+
# create a full mask with all entries set to 0
1384+
seq_len = query_ptb.size(2)
1385+
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
1386+
1387+
# set the attention value between image patches to -inf
1388+
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
1389+
1390+
# set the diagonal of the attention value between image patches to 0
1391+
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
1392+
1393+
# expand the mask to match the attention weights shape
1394+
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
1395+
1396+
hidden_states_ptb = F.scaled_dot_product_attention(
1397+
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
1398+
)
1399+
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1400+
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
1401+
1402+
# split the attention outputs.
1403+
hidden_states_ptb, encoder_hidden_states_ptb = (
1404+
hidden_states_ptb[:, : residual.shape[1]],
1405+
hidden_states_ptb[:, residual.shape[1] :],
1406+
)
1407+
1408+
# linear proj
1409+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
1410+
# dropout
1411+
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
1412+
if not attn.context_pre_only:
1413+
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
1414+
1415+
if input_ndim == 4:
1416+
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
1417+
if context_input_ndim == 4:
1418+
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
1419+
batch_size, channel, height, width
1420+
)
1421+
1422+
################ concat ###############
1423+
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
1424+
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
1425+
1426+
return hidden_states, encoder_hidden_states
1427+
1428+
11091429
class FusedJointAttnProcessor2_0:
11101430
"""Attention processor used typically in processing the SD3-like self-attention projections."""
11111431

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@
147147
[
148148
"AnimateDiffPAGPipeline",
149149
"HunyuanDiTPAGPipeline",
150+
"StableDiffusion3PAGPipeline",
150151
"StableDiffusionPAGPipeline",
151152
"StableDiffusionControlNetPAGPipeline",
152153
"StableDiffusionXLPAGPipeline",
@@ -540,6 +541,7 @@
540541
AnimateDiffPAGPipeline,
541542
HunyuanDiTPAGPipeline,
542543
PixArtSigmaPAGPipeline,
544+
StableDiffusion3PAGPipeline,
543545
StableDiffusionControlNetPAGPipeline,
544546
StableDiffusionPAGPipeline,
545547
StableDiffusionXLControlNetPAGPipeline,

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .pag import (
5353
HunyuanDiTPAGPipeline,
5454
PixArtSigmaPAGPipeline,
55+
StableDiffusion3PAGPipeline,
5556
StableDiffusionControlNetPAGPipeline,
5657
StableDiffusionPAGPipeline,
5758
StableDiffusionXLControlNetPAGPipeline,
@@ -84,6 +85,7 @@
8485
("stable-diffusion", StableDiffusionPipeline),
8586
("stable-diffusion-xl", StableDiffusionXLPipeline),
8687
("stable-diffusion-3", StableDiffusion3Pipeline),
88+
("stable-diffusion-3-pag", StableDiffusion3PAGPipeline),
8789
("if", IFPipeline),
8890
("hunyuan", HunyuanDiTPipeline),
8991
("hunyuan-pag", HunyuanDiTPAGPipeline),

0 commit comments

Comments
 (0)