@@ -1106,6 +1106,326 @@ def __call__(
1106
1106
return hidden_states , encoder_hidden_states
1107
1107
1108
1108
1109
+ class PAGJointAttnProcessor2_0 :
1110
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1111
+
1112
+ def __init__ (self ):
1113
+ if not hasattr (F , "scaled_dot_product_attention" ):
1114
+ raise ImportError (
1115
+ "PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1116
+ )
1117
+
1118
+ def __call__ (
1119
+ self ,
1120
+ attn : Attention ,
1121
+ hidden_states : torch .FloatTensor ,
1122
+ encoder_hidden_states : torch .FloatTensor = None ,
1123
+ ) -> torch .FloatTensor :
1124
+ residual = hidden_states
1125
+
1126
+ input_ndim = hidden_states .ndim
1127
+ if input_ndim == 4 :
1128
+ batch_size , channel , height , width = hidden_states .shape
1129
+ hidden_states = hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1130
+ context_input_ndim = encoder_hidden_states .ndim
1131
+ if context_input_ndim == 4 :
1132
+ batch_size , channel , height , width = encoder_hidden_states .shape
1133
+ encoder_hidden_states = encoder_hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1134
+
1135
+ # store the length of image patch sequences to create a mask that prevents interaction between patches
1136
+ # similar to making the self-attention map an identity matrix
1137
+ identity_block_size = hidden_states .shape [1 ]
1138
+
1139
+ # chunk
1140
+ hidden_states_org , hidden_states_ptb = hidden_states .chunk (2 )
1141
+ encoder_hidden_states_org , encoder_hidden_states_ptb = encoder_hidden_states .chunk (2 )
1142
+
1143
+ ################## original path ##################
1144
+ batch_size = encoder_hidden_states_org .shape [0 ]
1145
+
1146
+ # `sample` projections.
1147
+ query_org = attn .to_q (hidden_states_org )
1148
+ key_org = attn .to_k (hidden_states_org )
1149
+ value_org = attn .to_v (hidden_states_org )
1150
+
1151
+ # `context` projections.
1152
+ encoder_hidden_states_org_query_proj = attn .add_q_proj (encoder_hidden_states_org )
1153
+ encoder_hidden_states_org_key_proj = attn .add_k_proj (encoder_hidden_states_org )
1154
+ encoder_hidden_states_org_value_proj = attn .add_v_proj (encoder_hidden_states_org )
1155
+
1156
+ # attention
1157
+ query_org = torch .cat ([query_org , encoder_hidden_states_org_query_proj ], dim = 1 )
1158
+ key_org = torch .cat ([key_org , encoder_hidden_states_org_key_proj ], dim = 1 )
1159
+ value_org = torch .cat ([value_org , encoder_hidden_states_org_value_proj ], dim = 1 )
1160
+
1161
+ inner_dim = key_org .shape [- 1 ]
1162
+ head_dim = inner_dim // attn .heads
1163
+ query_org = query_org .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1164
+ key_org = key_org .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1165
+ value_org = value_org .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1166
+
1167
+ hidden_states_org = F .scaled_dot_product_attention (
1168
+ query_org , key_org , value_org , dropout_p = 0.0 , is_causal = False
1169
+ )
1170
+ hidden_states_org = hidden_states_org .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1171
+ hidden_states_org = hidden_states_org .to (query_org .dtype )
1172
+
1173
+ # Split the attention outputs.
1174
+ hidden_states_org , encoder_hidden_states_org = (
1175
+ hidden_states_org [:, : residual .shape [1 ]],
1176
+ hidden_states_org [:, residual .shape [1 ] :],
1177
+ )
1178
+
1179
+ # linear proj
1180
+ hidden_states_org = attn .to_out [0 ](hidden_states_org )
1181
+ # dropout
1182
+ hidden_states_org = attn .to_out [1 ](hidden_states_org )
1183
+ if not attn .context_pre_only :
1184
+ encoder_hidden_states_org = attn .to_add_out (encoder_hidden_states_org )
1185
+
1186
+ if input_ndim == 4 :
1187
+ hidden_states_org = hidden_states_org .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1188
+ if context_input_ndim == 4 :
1189
+ encoder_hidden_states_org = encoder_hidden_states_org .transpose (- 1 , - 2 ).reshape (
1190
+ batch_size , channel , height , width
1191
+ )
1192
+
1193
+ ################## perturbed path ##################
1194
+
1195
+ batch_size = encoder_hidden_states_ptb .shape [0 ]
1196
+
1197
+ # `sample` projections.
1198
+ query_ptb = attn .to_q (hidden_states_ptb )
1199
+ key_ptb = attn .to_k (hidden_states_ptb )
1200
+ value_ptb = attn .to_v (hidden_states_ptb )
1201
+
1202
+ # `context` projections.
1203
+ encoder_hidden_states_ptb_query_proj = attn .add_q_proj (encoder_hidden_states_ptb )
1204
+ encoder_hidden_states_ptb_key_proj = attn .add_k_proj (encoder_hidden_states_ptb )
1205
+ encoder_hidden_states_ptb_value_proj = attn .add_v_proj (encoder_hidden_states_ptb )
1206
+
1207
+ # attention
1208
+ query_ptb = torch .cat ([query_ptb , encoder_hidden_states_ptb_query_proj ], dim = 1 )
1209
+ key_ptb = torch .cat ([key_ptb , encoder_hidden_states_ptb_key_proj ], dim = 1 )
1210
+ value_ptb = torch .cat ([value_ptb , encoder_hidden_states_ptb_value_proj ], dim = 1 )
1211
+
1212
+ inner_dim = key_ptb .shape [- 1 ]
1213
+ head_dim = inner_dim // attn .heads
1214
+ query_ptb = query_ptb .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1215
+ key_ptb = key_ptb .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1216
+ value_ptb = value_ptb .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1217
+
1218
+ # create a full mask with all entries set to 0
1219
+ seq_len = query_ptb .size (2 )
1220
+ full_mask = torch .zeros ((seq_len , seq_len ), device = query_ptb .device , dtype = query_ptb .dtype )
1221
+
1222
+ # set the attention value between image patches to -inf
1223
+ full_mask [:identity_block_size , :identity_block_size ] = float ("-inf" )
1224
+
1225
+ # set the diagonal of the attention value between image patches to 0
1226
+ full_mask [:identity_block_size , :identity_block_size ].fill_diagonal_ (0 )
1227
+
1228
+ # expand the mask to match the attention weights shape
1229
+ full_mask = full_mask .unsqueeze (0 ).unsqueeze (0 ) # Add batch and num_heads dimensions
1230
+
1231
+ hidden_states_ptb = F .scaled_dot_product_attention (
1232
+ query_ptb , key_ptb , value_ptb , attn_mask = full_mask , dropout_p = 0.0 , is_causal = False
1233
+ )
1234
+ hidden_states_ptb = hidden_states_ptb .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1235
+ hidden_states_ptb = hidden_states_ptb .to (query_ptb .dtype )
1236
+
1237
+ # split the attention outputs.
1238
+ hidden_states_ptb , encoder_hidden_states_ptb = (
1239
+ hidden_states_ptb [:, : residual .shape [1 ]],
1240
+ hidden_states_ptb [:, residual .shape [1 ] :],
1241
+ )
1242
+
1243
+ # linear proj
1244
+ hidden_states_ptb = attn .to_out [0 ](hidden_states_ptb )
1245
+ # dropout
1246
+ hidden_states_ptb = attn .to_out [1 ](hidden_states_ptb )
1247
+ if not attn .context_pre_only :
1248
+ encoder_hidden_states_ptb = attn .to_add_out (encoder_hidden_states_ptb )
1249
+
1250
+ if input_ndim == 4 :
1251
+ hidden_states_ptb = hidden_states_ptb .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1252
+ if context_input_ndim == 4 :
1253
+ encoder_hidden_states_ptb = encoder_hidden_states_ptb .transpose (- 1 , - 2 ).reshape (
1254
+ batch_size , channel , height , width
1255
+ )
1256
+
1257
+ ################ concat ###############
1258
+ hidden_states = torch .cat ([hidden_states_org , hidden_states_ptb ])
1259
+ encoder_hidden_states = torch .cat ([encoder_hidden_states_org , encoder_hidden_states_ptb ])
1260
+
1261
+ return hidden_states , encoder_hidden_states
1262
+
1263
+
1264
+ class PAGCFGJointAttnProcessor2_0 :
1265
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1266
+
1267
+ def __init__ (self ):
1268
+ if not hasattr (F , "scaled_dot_product_attention" ):
1269
+ raise ImportError (
1270
+ "PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1271
+ )
1272
+
1273
+ def __call__ (
1274
+ self ,
1275
+ attn : Attention ,
1276
+ hidden_states : torch .FloatTensor ,
1277
+ encoder_hidden_states : torch .FloatTensor = None ,
1278
+ attention_mask : Optional [torch .FloatTensor ] = None ,
1279
+ * args ,
1280
+ ** kwargs ,
1281
+ ) -> torch .FloatTensor :
1282
+ residual = hidden_states
1283
+
1284
+ input_ndim = hidden_states .ndim
1285
+ if input_ndim == 4 :
1286
+ batch_size , channel , height , width = hidden_states .shape
1287
+ hidden_states = hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1288
+ context_input_ndim = encoder_hidden_states .ndim
1289
+ if context_input_ndim == 4 :
1290
+ batch_size , channel , height , width = encoder_hidden_states .shape
1291
+ encoder_hidden_states = encoder_hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1292
+
1293
+ identity_block_size = hidden_states .shape [
1294
+ 1
1295
+ ] # patch embeddings width * height (correspond to self-attention map width or height)
1296
+
1297
+ # chunk
1298
+ hidden_states_uncond , hidden_states_org , hidden_states_ptb = hidden_states .chunk (3 )
1299
+ hidden_states_org = torch .cat ([hidden_states_uncond , hidden_states_org ])
1300
+
1301
+ (
1302
+ encoder_hidden_states_uncond ,
1303
+ encoder_hidden_states_org ,
1304
+ encoder_hidden_states_ptb ,
1305
+ ) = encoder_hidden_states .chunk (3 )
1306
+ encoder_hidden_states_org = torch .cat ([encoder_hidden_states_uncond , encoder_hidden_states_org ])
1307
+
1308
+ ################## original path ##################
1309
+ batch_size = encoder_hidden_states_org .shape [0 ]
1310
+
1311
+ # `sample` projections.
1312
+ query_org = attn .to_q (hidden_states_org )
1313
+ key_org = attn .to_k (hidden_states_org )
1314
+ value_org = attn .to_v (hidden_states_org )
1315
+
1316
+ # `context` projections.
1317
+ encoder_hidden_states_org_query_proj = attn .add_q_proj (encoder_hidden_states_org )
1318
+ encoder_hidden_states_org_key_proj = attn .add_k_proj (encoder_hidden_states_org )
1319
+ encoder_hidden_states_org_value_proj = attn .add_v_proj (encoder_hidden_states_org )
1320
+
1321
+ # attention
1322
+ query_org = torch .cat ([query_org , encoder_hidden_states_org_query_proj ], dim = 1 )
1323
+ key_org = torch .cat ([key_org , encoder_hidden_states_org_key_proj ], dim = 1 )
1324
+ value_org = torch .cat ([value_org , encoder_hidden_states_org_value_proj ], dim = 1 )
1325
+
1326
+ inner_dim = key_org .shape [- 1 ]
1327
+ head_dim = inner_dim // attn .heads
1328
+ query_org = query_org .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1329
+ key_org = key_org .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1330
+ value_org = value_org .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1331
+
1332
+ hidden_states_org = F .scaled_dot_product_attention (
1333
+ query_org , key_org , value_org , dropout_p = 0.0 , is_causal = False
1334
+ )
1335
+ hidden_states_org = hidden_states_org .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1336
+ hidden_states_org = hidden_states_org .to (query_org .dtype )
1337
+
1338
+ # Split the attention outputs.
1339
+ hidden_states_org , encoder_hidden_states_org = (
1340
+ hidden_states_org [:, : residual .shape [1 ]],
1341
+ hidden_states_org [:, residual .shape [1 ] :],
1342
+ )
1343
+
1344
+ # linear proj
1345
+ hidden_states_org = attn .to_out [0 ](hidden_states_org )
1346
+ # dropout
1347
+ hidden_states_org = attn .to_out [1 ](hidden_states_org )
1348
+ if not attn .context_pre_only :
1349
+ encoder_hidden_states_org = attn .to_add_out (encoder_hidden_states_org )
1350
+
1351
+ if input_ndim == 4 :
1352
+ hidden_states_org = hidden_states_org .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1353
+ if context_input_ndim == 4 :
1354
+ encoder_hidden_states_org = encoder_hidden_states_org .transpose (- 1 , - 2 ).reshape (
1355
+ batch_size , channel , height , width
1356
+ )
1357
+
1358
+ ################## perturbed path ##################
1359
+
1360
+ batch_size = encoder_hidden_states_ptb .shape [0 ]
1361
+
1362
+ # `sample` projections.
1363
+ query_ptb = attn .to_q (hidden_states_ptb )
1364
+ key_ptb = attn .to_k (hidden_states_ptb )
1365
+ value_ptb = attn .to_v (hidden_states_ptb )
1366
+
1367
+ # `context` projections.
1368
+ encoder_hidden_states_ptb_query_proj = attn .add_q_proj (encoder_hidden_states_ptb )
1369
+ encoder_hidden_states_ptb_key_proj = attn .add_k_proj (encoder_hidden_states_ptb )
1370
+ encoder_hidden_states_ptb_value_proj = attn .add_v_proj (encoder_hidden_states_ptb )
1371
+
1372
+ # attention
1373
+ query_ptb = torch .cat ([query_ptb , encoder_hidden_states_ptb_query_proj ], dim = 1 )
1374
+ key_ptb = torch .cat ([key_ptb , encoder_hidden_states_ptb_key_proj ], dim = 1 )
1375
+ value_ptb = torch .cat ([value_ptb , encoder_hidden_states_ptb_value_proj ], dim = 1 )
1376
+
1377
+ inner_dim = key_ptb .shape [- 1 ]
1378
+ head_dim = inner_dim // attn .heads
1379
+ query_ptb = query_ptb .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1380
+ key_ptb = key_ptb .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1381
+ value_ptb = value_ptb .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1382
+
1383
+ # create a full mask with all entries set to 0
1384
+ seq_len = query_ptb .size (2 )
1385
+ full_mask = torch .zeros ((seq_len , seq_len ), device = query_ptb .device , dtype = query_ptb .dtype )
1386
+
1387
+ # set the attention value between image patches to -inf
1388
+ full_mask [:identity_block_size , :identity_block_size ] = float ("-inf" )
1389
+
1390
+ # set the diagonal of the attention value between image patches to 0
1391
+ full_mask [:identity_block_size , :identity_block_size ].fill_diagonal_ (0 )
1392
+
1393
+ # expand the mask to match the attention weights shape
1394
+ full_mask = full_mask .unsqueeze (0 ).unsqueeze (0 ) # Add batch and num_heads dimensions
1395
+
1396
+ hidden_states_ptb = F .scaled_dot_product_attention (
1397
+ query_ptb , key_ptb , value_ptb , attn_mask = full_mask , dropout_p = 0.0 , is_causal = False
1398
+ )
1399
+ hidden_states_ptb = hidden_states_ptb .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1400
+ hidden_states_ptb = hidden_states_ptb .to (query_ptb .dtype )
1401
+
1402
+ # split the attention outputs.
1403
+ hidden_states_ptb , encoder_hidden_states_ptb = (
1404
+ hidden_states_ptb [:, : residual .shape [1 ]],
1405
+ hidden_states_ptb [:, residual .shape [1 ] :],
1406
+ )
1407
+
1408
+ # linear proj
1409
+ hidden_states_ptb = attn .to_out [0 ](hidden_states_ptb )
1410
+ # dropout
1411
+ hidden_states_ptb = attn .to_out [1 ](hidden_states_ptb )
1412
+ if not attn .context_pre_only :
1413
+ encoder_hidden_states_ptb = attn .to_add_out (encoder_hidden_states_ptb )
1414
+
1415
+ if input_ndim == 4 :
1416
+ hidden_states_ptb = hidden_states_ptb .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1417
+ if context_input_ndim == 4 :
1418
+ encoder_hidden_states_ptb = encoder_hidden_states_ptb .transpose (- 1 , - 2 ).reshape (
1419
+ batch_size , channel , height , width
1420
+ )
1421
+
1422
+ ################ concat ###############
1423
+ hidden_states = torch .cat ([hidden_states_org , hidden_states_ptb ])
1424
+ encoder_hidden_states = torch .cat ([encoder_hidden_states_org , encoder_hidden_states_ptb ])
1425
+
1426
+ return hidden_states , encoder_hidden_states
1427
+
1428
+
1109
1429
class FusedJointAttnProcessor2_0 :
1110
1430
"""Attention processor used typically in processing the SD3-like self-attention projections."""
1111
1431
0 commit comments