@@ -275,3 +275,185 @@ def pag_attn_processors(self):
275
275
if proc .__class__ in (PAGCFGIdentitySelfAttnProcessor2_0 , PAGIdentitySelfAttnProcessor2_0 ):
276
276
processors [name ] = proc
277
277
return processors
278
+
279
+
280
+ class PixArtPAGMixin :
281
+ @staticmethod
282
+ def _check_input_pag_applied_layer (layer ):
283
+ r"""
284
+ Check if each layer input in `applied_pag_layers` is valid. It should be the block index: {block_index}.
285
+ """
286
+
287
+ # Check if the layer index is valid (should be int or str of int)
288
+ if isinstance (layer , int ):
289
+ return # Valid layer index
290
+
291
+ if isinstance (layer , str ):
292
+ if layer .isdigit ():
293
+ return # Valid layer index
294
+
295
+ # If it is not a valid layer index, raise a ValueError
296
+ raise ValueError (f"Pag layer should only contain block index. Accept number string like '3', got { layer } ." )
297
+
298
+ def _set_pag_attn_processor (self , pag_applied_layers , do_classifier_free_guidance ):
299
+ r"""
300
+ Set the attention processor for the PAG layers.
301
+ """
302
+ if do_classifier_free_guidance :
303
+ pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0 ()
304
+ else :
305
+ pag_attn_proc = PAGIdentitySelfAttnProcessor2_0 ()
306
+
307
+ def is_self_attn (module_name ):
308
+ r"""
309
+ Check if the module is self-attention module based on its name.
310
+ """
311
+ return (
312
+ "attn1" in module_name and len (module_name .split ("." )) == 3
313
+ ) # include transformer_blocks.1.attn1, exclude transformer_blocks.18.attn1.to_q, transformer_blocks.1.attn1.add_q_proj, ...
314
+
315
+ def get_block_index (module_name ):
316
+ r"""
317
+ Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g.
318
+ mid_block) and index is ommited from the name, it will be "block_0".
319
+ """
320
+ # transformer_blocks.23.attn -> "23"
321
+ return module_name .split ("." )[1 ]
322
+
323
+ for pag_layer_input in pag_applied_layers :
324
+ # for each PAG layer input, we find corresponding self-attention layers in the transformer model
325
+ target_modules = []
326
+
327
+ block_index = str (pag_layer_input )
328
+
329
+ for name , module in self .transformer .named_modules ():
330
+ if is_self_attn (name ) and get_block_index (name ) == block_index :
331
+ target_modules .append (module )
332
+
333
+ if len (target_modules ) == 0 :
334
+ raise ValueError (f"Cannot find pag layer to set attention processor for: { pag_layer_input } " )
335
+
336
+ for module in target_modules :
337
+ module .processor = pag_attn_proc
338
+
339
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.set_pag_applied_layers
340
+ def set_pag_applied_layers (self , pag_applied_layers ):
341
+ r"""
342
+ set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
343
+ """
344
+
345
+ if not isinstance (pag_applied_layers , list ):
346
+ pag_applied_layers = [pag_applied_layers ]
347
+
348
+ for pag_layer in pag_applied_layers :
349
+ self ._check_input_pag_applied_layer (pag_layer )
350
+
351
+ self .pag_applied_layers = pag_applied_layers
352
+
353
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._get_pag_scale
354
+ def _get_pag_scale (self , t ):
355
+ r"""
356
+ Get the scale factor for the perturbed attention guidance at timestep `t`.
357
+ """
358
+
359
+ if self .do_pag_adaptive_scaling :
360
+ signal_scale = self .pag_scale - self .pag_adaptive_scale * (1000 - t )
361
+ if signal_scale < 0 :
362
+ signal_scale = 0
363
+ return signal_scale
364
+ else :
365
+ return self .pag_scale
366
+
367
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._apply_perturbed_attention_guidance
368
+ def _apply_perturbed_attention_guidance (self , noise_pred , do_classifier_free_guidance , guidance_scale , t ):
369
+ r"""
370
+ Apply perturbed attention guidance to the noise prediction.
371
+
372
+ Args:
373
+ noise_pred (torch.Tensor): The noise prediction tensor.
374
+ do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
375
+ guidance_scale (float): The scale factor for the guidance term.
376
+ t (int): The current time step.
377
+
378
+ Returns:
379
+ torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
380
+ """
381
+ pag_scale = self ._get_pag_scale (t )
382
+ if do_classifier_free_guidance :
383
+ noise_pred_uncond , noise_pred_text , noise_pred_perturb = noise_pred .chunk (3 )
384
+ noise_pred = (
385
+ noise_pred_uncond
386
+ + guidance_scale * (noise_pred_text - noise_pred_uncond )
387
+ + pag_scale * (noise_pred_text - noise_pred_perturb )
388
+ )
389
+ else :
390
+ noise_pred_text , noise_pred_perturb = noise_pred .chunk (2 )
391
+ noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb )
392
+ return noise_pred
393
+
394
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._prepare_perturbed_attention_guidance
395
+ def _prepare_perturbed_attention_guidance (self , cond , uncond , do_classifier_free_guidance ):
396
+ """
397
+ Prepares the perturbed attention guidance for the PAG model.
398
+
399
+ Args:
400
+ cond (torch.Tensor): The conditional input tensor.
401
+ uncond (torch.Tensor): The unconditional input tensor.
402
+ do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance.
403
+
404
+ Returns:
405
+ torch.Tensor: The prepared perturbed attention guidance tensor.
406
+ """
407
+
408
+ cond = torch .cat ([cond ] * 2 , dim = 0 )
409
+
410
+ if do_classifier_free_guidance :
411
+ cond = torch .cat ([uncond , cond ], dim = 0 )
412
+ return cond
413
+
414
+ @property
415
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_scale
416
+ def pag_scale (self ):
417
+ """
418
+ Get the scale factor for the perturbed attention guidance.
419
+ """
420
+ return self ._pag_scale
421
+
422
+ @property
423
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_adaptive_scale
424
+ def pag_adaptive_scale (self ):
425
+ """
426
+ Get the adaptive scale factor for the perturbed attention guidance.
427
+ """
428
+ return self ._pag_adaptive_scale
429
+
430
+ @property
431
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_pag_adaptive_scaling
432
+ def do_pag_adaptive_scaling (self ):
433
+ """
434
+ Check if the adaptive scaling is enabled for the perturbed attention guidance.
435
+ """
436
+ return self ._pag_adaptive_scale > 0 and self ._pag_scale > 0 and len (self .pag_applied_layers ) > 0
437
+
438
+ @property
439
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_perturbed_attention_guidance
440
+ def do_perturbed_attention_guidance (self ):
441
+ """
442
+ Check if the perturbed attention guidance is enabled.
443
+ """
444
+ return self ._pag_scale > 0 and len (self .pag_applied_layers ) > 0
445
+
446
+ @property
447
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_attn_processors with unet->transformer
448
+ def pag_attn_processors (self ):
449
+ r"""
450
+ Returns:
451
+ `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model
452
+ with the key as the name of the layer.
453
+ """
454
+
455
+ processors = {}
456
+ for name , proc in self .transformer .attn_processors .items ():
457
+ if proc .__class__ in (PAGCFGIdentitySelfAttnProcessor2_0 , PAGIdentitySelfAttnProcessor2_0 ):
458
+ processors [name ] = proc
459
+ return processors
0 commit comments