Skip to content

Commit 7b98c4c

Browse files
authored
[Core] Add PAG support for PixArtSigma (#8921)
* feat: add pixart sigma pag. * inits. * fixes * fix * remove print. * copy paste methods to the pixart pag mixin * fix-copies * add documentation. * add tests. * remove correction file. * remove pag_applied_layers * empty
1 parent 27637a5 commit 7b98c4c

File tree

11 files changed

+1569
-2
lines changed

11 files changed

+1569
-2
lines changed

docs/source/en/api/pipelines/pag.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,9 @@ The abstract from the paper is:
5454
[[autodoc]] StableDiffusionXLControlNetPAGPipeline
5555
- all
5656
- __call__
57+
58+
59+
## PixArtSigmaPAGPipeline
60+
[[autodoc]] PixArtSigmaPAGPipeline
61+
- all
62+
- __call__

docs/source/en/using-diffusers/pag.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ This guide will show you how to use PAG for various tasks and use cases.
2222
You can apply PAG to the [`StableDiffusionXLPipeline`] for tasks such as text-to-image, image-to-image, and inpainting. To enable PAG for a specific task, load the pipeline using the [AutoPipeline](../api/pipelines/auto_pipeline) API with the `enable_pag=True` flag and the `pag_applied_layers` argument.
2323

2424
> [!TIP]
25-
> 🤗 Diffusers currently only supports using PAG with selected SDXL pipelines, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to add PAG support to a new pipeline!
25+
> 🤗 Diffusers currently only supports using PAG with selected SDXL pipelines and [`PixArtSigmaPAGPipeline`]. But feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to add PAG support to a new pipeline!
2626
2727
<hfoptions id="tasks">
2828
<hfoption id="Text-to-image">

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@
295295
"PaintByExamplePipeline",
296296
"PIAPipeline",
297297
"PixArtAlphaPipeline",
298+
"PixArtSigmaPAGPipeline",
298299
"PixArtSigmaPipeline",
299300
"SemanticStableDiffusionPipeline",
300301
"ShapEImg2ImgPipeline",
@@ -717,6 +718,7 @@
717718
PaintByExamplePipeline,
718719
PIAPipeline,
719720
PixArtAlphaPipeline,
721+
PixArtSigmaPAGPipeline,
720722
PixArtSigmaPipeline,
721723
SemanticStableDiffusionPipeline,
722724
ShapEImg2ImgPipeline,

src/diffusers/models/transformers/pixart_transformer_2d.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Dict, Optional
14+
from typing import Any, Dict, Optional, Union
1515

1616
import torch
1717
from torch import nn
1818

1919
from ...configuration_utils import ConfigMixin, register_to_config
2020
from ...utils import is_torch_version, logging
2121
from ..attention import BasicTransformerBlock
22+
from ..attention_processor import AttentionProcessor
2223
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
2324
from ..modeling_outputs import Transformer2DModelOutput
2425
from ..modeling_utils import ModelMixin
@@ -186,6 +187,66 @@ def _set_gradient_checkpointing(self, module, value=False):
186187
if hasattr(module, "gradient_checkpointing"):
187188
module.gradient_checkpointing = value
188189

190+
@property
191+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
192+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
193+
r"""
194+
Returns:
195+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
196+
indexed by its weight name.
197+
"""
198+
# set recursively
199+
processors = {}
200+
201+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
202+
if hasattr(module, "get_processor"):
203+
processors[f"{name}.processor"] = module.get_processor()
204+
205+
for sub_name, child in module.named_children():
206+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
207+
208+
return processors
209+
210+
for name, module in self.named_children():
211+
fn_recursive_add_processors(name, module, processors)
212+
213+
return processors
214+
215+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
216+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
217+
r"""
218+
Sets the attention processor to use to compute attention.
219+
220+
Parameters:
221+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
222+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
223+
for **all** `Attention` layers.
224+
225+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
226+
processor. This is strongly recommended when setting trainable attention processors.
227+
228+
"""
229+
count = len(self.attn_processors.keys())
230+
231+
if isinstance(processor, dict) and len(processor) != count:
232+
raise ValueError(
233+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
234+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
235+
)
236+
237+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
238+
if hasattr(module, "set_processor"):
239+
if not isinstance(processor, dict):
240+
module.set_processor(processor)
241+
else:
242+
module.set_processor(processor.pop(f"{name}.processor"))
243+
244+
for sub_name, child in module.named_children():
245+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
246+
247+
for name, module in self.named_children():
248+
fn_recursive_attn_processor(name, module, processor)
249+
189250
def forward(
190251
self,
191252
hidden_states: torch.Tensor,

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@
151151
"StableDiffusionXLPAGInpaintPipeline",
152152
"StableDiffusionXLControlNetPAGPipeline",
153153
"StableDiffusionXLPAGImg2ImgPipeline",
154+
"PixArtSigmaPAGPipeline",
154155
]
155156
)
156157
_import_structure["controlnet_xs"].extend(
@@ -531,6 +532,7 @@
531532
from .musicldm import MusicLDMPipeline
532533
from .pag import (
533534
AnimateDiffPAGPipeline,
535+
PixArtSigmaPAGPipeline,
534536
StableDiffusionControlNetPAGPipeline,
535537
StableDiffusionPAGPipeline,
536538
StableDiffusionXLControlNetPAGPipeline,

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
5151
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
5252
from .pag import (
53+
PixArtSigmaPAGPipeline,
5354
StableDiffusionControlNetPAGPipeline,
5455
StableDiffusionPAGPipeline,
5556
StableDiffusionXLControlNetPAGPipeline,
@@ -98,6 +99,7 @@
9899
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGPipeline),
99100
("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline),
100101
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline),
102+
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
101103
("auraflow", AuraFlowPipeline),
102104
("kolors", KolorsPipeline),
103105
("flux", FluxPipeline),

src/diffusers/pipelines/pag/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
else:
2525
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
2626
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
27+
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
2728
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
2829
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
2930
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
@@ -40,6 +41,7 @@
4041
else:
4142
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
4243
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
44+
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
4345
from .pipeline_pag_sd import StableDiffusionPAGPipeline
4446
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
4547
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline

src/diffusers/pipelines/pag/pag_utils.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,185 @@ def pag_attn_processors(self):
275275
if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0):
276276
processors[name] = proc
277277
return processors
278+
279+
280+
class PixArtPAGMixin:
281+
@staticmethod
282+
def _check_input_pag_applied_layer(layer):
283+
r"""
284+
Check if each layer input in `applied_pag_layers` is valid. It should be the block index: {block_index}.
285+
"""
286+
287+
# Check if the layer index is valid (should be int or str of int)
288+
if isinstance(layer, int):
289+
return # Valid layer index
290+
291+
if isinstance(layer, str):
292+
if layer.isdigit():
293+
return # Valid layer index
294+
295+
# If it is not a valid layer index, raise a ValueError
296+
raise ValueError(f"Pag layer should only contain block index. Accept number string like '3', got {layer}.")
297+
298+
def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
299+
r"""
300+
Set the attention processor for the PAG layers.
301+
"""
302+
if do_classifier_free_guidance:
303+
pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0()
304+
else:
305+
pag_attn_proc = PAGIdentitySelfAttnProcessor2_0()
306+
307+
def is_self_attn(module_name):
308+
r"""
309+
Check if the module is self-attention module based on its name.
310+
"""
311+
return (
312+
"attn1" in module_name and len(module_name.split(".")) == 3
313+
) # include transformer_blocks.1.attn1, exclude transformer_blocks.18.attn1.to_q, transformer_blocks.1.attn1.add_q_proj, ...
314+
315+
def get_block_index(module_name):
316+
r"""
317+
Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g.
318+
mid_block) and index is ommited from the name, it will be "block_0".
319+
"""
320+
# transformer_blocks.23.attn -> "23"
321+
return module_name.split(".")[1]
322+
323+
for pag_layer_input in pag_applied_layers:
324+
# for each PAG layer input, we find corresponding self-attention layers in the transformer model
325+
target_modules = []
326+
327+
block_index = str(pag_layer_input)
328+
329+
for name, module in self.transformer.named_modules():
330+
if is_self_attn(name) and get_block_index(name) == block_index:
331+
target_modules.append(module)
332+
333+
if len(target_modules) == 0:
334+
raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}")
335+
336+
for module in target_modules:
337+
module.processor = pag_attn_proc
338+
339+
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.set_pag_applied_layers
340+
def set_pag_applied_layers(self, pag_applied_layers):
341+
r"""
342+
set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
343+
"""
344+
345+
if not isinstance(pag_applied_layers, list):
346+
pag_applied_layers = [pag_applied_layers]
347+
348+
for pag_layer in pag_applied_layers:
349+
self._check_input_pag_applied_layer(pag_layer)
350+
351+
self.pag_applied_layers = pag_applied_layers
352+
353+
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._get_pag_scale
354+
def _get_pag_scale(self, t):
355+
r"""
356+
Get the scale factor for the perturbed attention guidance at timestep `t`.
357+
"""
358+
359+
if self.do_pag_adaptive_scaling:
360+
signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t)
361+
if signal_scale < 0:
362+
signal_scale = 0
363+
return signal_scale
364+
else:
365+
return self.pag_scale
366+
367+
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._apply_perturbed_attention_guidance
368+
def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
369+
r"""
370+
Apply perturbed attention guidance to the noise prediction.
371+
372+
Args:
373+
noise_pred (torch.Tensor): The noise prediction tensor.
374+
do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
375+
guidance_scale (float): The scale factor for the guidance term.
376+
t (int): The current time step.
377+
378+
Returns:
379+
torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
380+
"""
381+
pag_scale = self._get_pag_scale(t)
382+
if do_classifier_free_guidance:
383+
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
384+
noise_pred = (
385+
noise_pred_uncond
386+
+ guidance_scale * (noise_pred_text - noise_pred_uncond)
387+
+ pag_scale * (noise_pred_text - noise_pred_perturb)
388+
)
389+
else:
390+
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
391+
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
392+
return noise_pred
393+
394+
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._prepare_perturbed_attention_guidance
395+
def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
396+
"""
397+
Prepares the perturbed attention guidance for the PAG model.
398+
399+
Args:
400+
cond (torch.Tensor): The conditional input tensor.
401+
uncond (torch.Tensor): The unconditional input tensor.
402+
do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance.
403+
404+
Returns:
405+
torch.Tensor: The prepared perturbed attention guidance tensor.
406+
"""
407+
408+
cond = torch.cat([cond] * 2, dim=0)
409+
410+
if do_classifier_free_guidance:
411+
cond = torch.cat([uncond, cond], dim=0)
412+
return cond
413+
414+
@property
415+
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_scale
416+
def pag_scale(self):
417+
"""
418+
Get the scale factor for the perturbed attention guidance.
419+
"""
420+
return self._pag_scale
421+
422+
@property
423+
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_adaptive_scale
424+
def pag_adaptive_scale(self):
425+
"""
426+
Get the adaptive scale factor for the perturbed attention guidance.
427+
"""
428+
return self._pag_adaptive_scale
429+
430+
@property
431+
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_pag_adaptive_scaling
432+
def do_pag_adaptive_scaling(self):
433+
"""
434+
Check if the adaptive scaling is enabled for the perturbed attention guidance.
435+
"""
436+
return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0
437+
438+
@property
439+
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_perturbed_attention_guidance
440+
def do_perturbed_attention_guidance(self):
441+
"""
442+
Check if the perturbed attention guidance is enabled.
443+
"""
444+
return self._pag_scale > 0 and len(self.pag_applied_layers) > 0
445+
446+
@property
447+
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_attn_processors with unet->transformer
448+
def pag_attn_processors(self):
449+
r"""
450+
Returns:
451+
`dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model
452+
with the key as the name of the layer.
453+
"""
454+
455+
processors = {}
456+
for name, proc in self.transformer.attn_processors.items():
457+
if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0):
458+
processors[name] = proc
459+
return processors

0 commit comments

Comments
 (0)