Skip to content

Chroma Follow Up #11725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline",
"CLIPImageProjection",
"CogVideoXFunControlPipeline",
Expand Down Expand Up @@ -945,6 +946,7 @@
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,
CLIPImageProjection,
CogVideoXFunControlPipeline,
Expand Down
8 changes: 6 additions & 2 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2543,7 +2543,9 @@ def __call__(
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
Expand Down Expand Up @@ -2776,7 +2778,9 @@ def __call__(
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down
20 changes: 15 additions & 5 deletions src/diffusers/models/transformers/transformer_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,21 @@ def forward(
hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}

if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]

attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**joint_attention_kwargs,
)

Expand Down Expand Up @@ -312,6 +318,7 @@ def forward(
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
Expand All @@ -321,11 +328,15 @@ def forward(
encoder_hidden_states, emb=temb_txt
)
joint_attention_kwargs = joint_attention_kwargs or {}
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]

# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**joint_attention_kwargs,
)

Expand Down Expand Up @@ -570,6 +581,7 @@ def forward(
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
Expand Down Expand Up @@ -659,11 +671,7 @@ def forward(
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask
)

else:
Expand All @@ -672,6 +680,7 @@ def forward(
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
joint_attention_kwargs=joint_attention_kwargs,
)

Expand Down Expand Up @@ -704,6 +713,7 @@ def forward(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
joint_attention_kwargs=joint_attention_kwargs,
)

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["chroma"] = ["ChromaPipeline"]
_import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"]
_import_structure["cogvideo"] = [
"CogVideoXPipeline",
"CogVideoXImageToVideoPipeline",
Expand Down Expand Up @@ -537,7 +537,7 @@
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .chroma import ChromaPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/chroma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
_import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
Expand All @@ -31,6 +32,7 @@
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_chroma import ChromaPipeline
from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
else:
import sys

Expand Down
Loading