Skip to content

[refactor] Flux/Chroma single file implementation + Attention Dispatcher #11916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
[
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
"AttentionBackendName",
"AuraFlowTransformer2DModel",
"AutoencoderDC",
"AutoencoderKL",
Expand Down Expand Up @@ -238,6 +239,7 @@
"VQModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"attention_backend",
]
)
_import_structure["modular_pipelines"].extend(
Expand Down Expand Up @@ -815,6 +817,7 @@
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
AttentionBackendName,
AuraFlowTransformer2DModel,
AutoencoderDC,
AutoencoderKL,
Expand Down Expand Up @@ -889,6 +892,7 @@
VQModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
attention_backend,
)
from .modular_pipelines import (
ComponentsManager,
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/hooks/faster_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch

from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..models.modeling_outputs import Transformer2DModelOutput
from ..utils import logging
Expand Down Expand Up @@ -567,7 +568,7 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
_apply_faster_cache_on_denoiser(module, config)

for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES):
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
continue
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
_apply_faster_cache_on_attention_class(name, submodule, config)
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/hooks/pyramid_attention_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch

from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from .hooks import HookRegistry, ModelHook
Expand Down Expand Up @@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
config.spatial_attention_block_skip_range = 2

for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES):
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking as a TODO for myself. We will no longer need _ATTENTION_CLASSES once everything is AttentionModuleMixin

# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
Expand Down
9 changes: 5 additions & 4 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
FluxAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
Expand Down Expand Up @@ -867,6 +865,9 @@ def unload_ip_adapter(self):
>>> ...
```
"""
# TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor

# remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
Expand All @@ -886,9 +887,9 @@ def unload_ip_adapter(self):
# restore original Transformer attention processors layers
attn_procs = {}
for name, value in self.transformer.attn_processors.items():
attn_processor_class = FluxAttnProcessor2_0()
attn_processor_class = FluxAttnProcessor()
attn_procs[name] = (
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
)
self.transformer.set_attn_processor(attn_procs)

Expand Down
6 changes: 2 additions & 4 deletions src/diffusers/loaders/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
return image_projection

def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import (
FluxIPAdapterJointAttnProcessor2_0,
)
from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor

if low_cpu_mem_usage:
if is_accelerate_available():
Expand Down Expand Up @@ -120,7 +118,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
else:
cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
attn_processor_class = FluxIPAdapterAttnProcessor
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
Expand Down Expand Up @@ -112,6 +113,7 @@
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
from .attention_dispatch import AttentionBackendName, attention_backend
from .auto_model import AutoModel
from .autoencoders import (
AsymmetricAutoencoderKL,
Expand Down
Loading
Loading