Skip to content

Commit e95ac9d

Browse files
authored
Merge pull request #1 from iddl/chroma-fixes
2 parents 373106c + 104e163 commit e95ac9d

File tree

4 files changed

+100
-83
lines changed

4 files changed

+100
-83
lines changed

src/diffusers/loaders/single_file_utils.py

Lines changed: 79 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2137,9 +2137,18 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
21372137
converted_state_dict = {}
21382138
keys = list(checkpoint.keys())
21392139

2140+
variant = "chroma" if "distilled_guidance_layer.in_proj.weight" in checkpoint else "flux"
2141+
21402142
for k in keys:
21412143
if "model.diffusion_model." in k:
21422144
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2145+
if variant == "chroma" and "distilled_guidance_layer." in k:
2146+
new_key = k
2147+
if k.startswith("distilled_guidance_layer.norms"):
2148+
new_key = k.replace(".scale", ".weight")
2149+
elif k.startswith("distilled_guidance_layer.layer"):
2150+
new_key = k.replace("in_layer", "linear_1").replace("out_layer", "linear_2")
2151+
converted_state_dict[new_key] = checkpoint.pop(k)
21432152

21442153
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
21452154
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
@@ -2153,40 +2162,49 @@ def swap_scale_shift(weight):
21532162
new_weight = torch.cat([scale, shift], dim=0)
21542163
return new_weight
21552164

2156-
## time_text_embed.timestep_embedder <- time_in
2157-
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
2158-
"time_in.in_layer.weight"
2159-
)
2160-
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
2161-
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
2162-
"time_in.out_layer.weight"
2163-
)
2164-
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
2165-
2166-
## time_text_embed.text_embedder <- vector_in
2167-
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
2168-
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
2169-
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
2170-
"vector_in.out_layer.weight"
2171-
)
2172-
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
2173-
2174-
# guidance
2175-
has_guidance = any("guidance" in k for k in checkpoint)
2176-
if has_guidance:
2177-
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
2178-
"guidance_in.in_layer.weight"
2165+
if variant == "flux":
2166+
## time_text_embed.timestep_embedder <- time_in
2167+
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
2168+
"time_in.in_layer.weight"
21792169
)
2180-
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
2181-
"guidance_in.in_layer.bias"
2170+
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop(
2171+
"time_in.in_layer.bias"
21822172
)
2183-
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
2184-
"guidance_in.out_layer.weight"
2173+
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
2174+
"time_in.out_layer.weight"
21852175
)
2186-
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
2187-
"guidance_in.out_layer.bias"
2176+
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop(
2177+
"time_in.out_layer.bias"
21882178
)
21892179

2180+
## time_text_embed.text_embedder <- vector_in
2181+
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop(
2182+
"vector_in.in_layer.weight"
2183+
)
2184+
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
2185+
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
2186+
"vector_in.out_layer.weight"
2187+
)
2188+
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop(
2189+
"vector_in.out_layer.bias"
2190+
)
2191+
2192+
# guidance
2193+
has_guidance = any("guidance" in k for k in checkpoint)
2194+
if has_guidance:
2195+
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
2196+
"guidance_in.in_layer.weight"
2197+
)
2198+
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
2199+
"guidance_in.in_layer.bias"
2200+
)
2201+
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
2202+
"guidance_in.out_layer.weight"
2203+
)
2204+
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
2205+
"guidance_in.out_layer.bias"
2206+
)
2207+
21902208
# context_embedder
21912209
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
21922210
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
@@ -2199,20 +2217,21 @@ def swap_scale_shift(weight):
21992217
for i in range(num_layers):
22002218
block_prefix = f"transformer_blocks.{i}."
22012219
# norms.
2202-
## norm1
2203-
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
2204-
f"double_blocks.{i}.img_mod.lin.weight"
2205-
)
2206-
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
2207-
f"double_blocks.{i}.img_mod.lin.bias"
2208-
)
2209-
## norm1_context
2210-
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
2211-
f"double_blocks.{i}.txt_mod.lin.weight"
2212-
)
2213-
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
2214-
f"double_blocks.{i}.txt_mod.lin.bias"
2215-
)
2220+
if variant == "flux":
2221+
## norm1
2222+
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
2223+
f"double_blocks.{i}.img_mod.lin.weight"
2224+
)
2225+
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
2226+
f"double_blocks.{i}.img_mod.lin.bias"
2227+
)
2228+
## norm1_context
2229+
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
2230+
f"double_blocks.{i}.txt_mod.lin.weight"
2231+
)
2232+
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
2233+
f"double_blocks.{i}.txt_mod.lin.bias"
2234+
)
22162235
# Q, K, V
22172236
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
22182237
context_q, context_k, context_v = torch.chunk(
@@ -2285,13 +2304,15 @@ def swap_scale_shift(weight):
22852304
# single transformer blocks
22862305
for i in range(num_single_layers):
22872306
block_prefix = f"single_transformer_blocks.{i}."
2288-
# norm.linear <- single_blocks.0.modulation.lin
2289-
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
2290-
f"single_blocks.{i}.modulation.lin.weight"
2291-
)
2292-
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
2293-
f"single_blocks.{i}.modulation.lin.bias"
2294-
)
2307+
2308+
if variant == "flux":
2309+
# norm.linear <- single_blocks.0.modulation.lin
2310+
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
2311+
f"single_blocks.{i}.modulation.lin.weight"
2312+
)
2313+
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
2314+
f"single_blocks.{i}.modulation.lin.bias"
2315+
)
22952316
# Q, K, V, mlp
22962317
mlp_hidden_dim = int(inner_dim * mlp_ratio)
22972318
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
@@ -2320,12 +2341,14 @@ def swap_scale_shift(weight):
23202341

23212342
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
23222343
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2323-
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
2324-
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
2325-
)
2326-
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
2327-
checkpoint.pop("final_layer.adaLN_modulation.1.bias")
2328-
)
2344+
2345+
if variant == "flux":
2346+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
2347+
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
2348+
)
2349+
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
2350+
checkpoint.pop("final_layer.adaLN_modulation.1.bias")
2351+
)
23292352

23302353
return converted_state_dict
23312354

src/diffusers/models/embeddings.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,42 +1643,27 @@ def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, em
16431643

16441644
self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
16451645
self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
1646-
self.embedder = ChromaApproximator(
1647-
in_dim=factor * 4,
1648-
out_dim=out_dim,
1649-
hidden_dim=hidden_dim,
1650-
n_layers=n_layers,
1651-
)
1652-
self.embedding_dim = embedding_dim
16531646

16541647
self.register_buffer(
16551648
"mod_proj",
1656-
get_timestep_embedding(torch.arange(out_dim), 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0),
1649+
get_timestep_embedding(torch.arange(out_dim)*1000, 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0, ),
16571650
persistent=False,
16581651
)
16591652

16601653
def forward(
16611654
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor
16621655
) -> torch.Tensor:
16631656
mod_index_length = self.mod_proj.shape[0]
1664-
timesteps_proj = self.time_proj(timestep) + self.time_proj(pooled_projections)
1665-
if guidance is not None:
1666-
guidance_proj = self.guidance_proj(guidance)
1667-
else:
1668-
guidance_proj = torch.zeros(
1669-
(self.embedding_dim, self.guidance_proj.num_channels),
1670-
dtype=timesteps_proj.dtype,
1671-
device=timesteps_proj.device,
1672-
)
1657+
timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
1658+
guidance_proj = self.guidance_proj(torch.tensor([0])).to(dtype=timestep.dtype, device=timestep.device)
16731659

16741660
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
16751661
timestep_guidance = (
16761662
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
16771663
)
1678-
input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
1679-
conditioning = self.embedder(input_vec)
1664+
input_vec = torch.cat([timestep_guidance, mod_proj.unsqueeze(0)], dim=-1)
16801665

1681-
return conditioning
1666+
return input_vec
16821667

16831668

16841669
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):

src/diffusers/models/normalization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def forward(
206206
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
207207
if self.emb is not None:
208208
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
209-
scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp = emb.chunk(6, dim=1)
209+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.squeeze(0).chunk(6, dim=0)
210210
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
211211
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
212212

@@ -267,7 +267,7 @@ def forward(
267267
x: torch.Tensor,
268268
emb: Optional[torch.Tensor] = None,
269269
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
270-
scale_msa, shift_msa, gate_msa = emb.chunk(3, dim=1)
270+
shift_msa, scale_msa, gate_msa = emb.squeeze(0).chunk(3, dim=0)
271271
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
272272
return x, gate_msa
273273

@@ -413,7 +413,7 @@ def __init__(
413413

414414
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
415415
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
416-
shift, scale = torch.chunk(emb.to(x.dtype), 2, dim=1)
416+
shift, scale = torch.chunk(emb.squeeze(0).to(x.dtype), 2, dim=0)
417417
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
418418
return x
419419

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
CombinedTimestepGuidanceTextProjEmbeddings,
3838
CombinedTimestepTextProjChromaEmbeddings,
3939
CombinedTimestepTextProjEmbeddings,
40+
ChromaApproximator,
4041
FluxPosEmbed,
4142
)
4243
from ..modeling_outputs import Transformer2DModelOutput
@@ -308,6 +309,7 @@ def __init__(
308309
embedding_dim=self.inner_dim,
309310
n_layers=approximator_layers,
310311
)
312+
self.distilled_guidance_layer = ChromaApproximator(in_dim=64, out_dim=3072, hidden_dim=5120, n_layers=5)
311313
else:
312314
raise ValueError(INVALID_VARIANT_ERRMSG)
313315

@@ -518,7 +520,8 @@ def forward(
518520
else self.time_text_embed(timestep, guidance, pooled_projections)
519521
)
520522
else:
521-
pooled_temb = self.time_text_embed(timestep, guidance, pooled_projections)
523+
input_vec = self.time_text_embed(timestep, guidance, pooled_projections)
524+
pooled_temb = self.distilled_guidance_layer(input_vec)
522525

523526
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
524527

@@ -545,10 +548,16 @@ def forward(
545548

546549
for index_block, block in enumerate(self.transformer_blocks):
547550
if is_chroma:
548-
start_idx1 = 3 * len(self.single_transformer_blocks) + 6 * index_block
549-
start_idx2 = start_idx1 + 6 * len(self.transformer_blocks)
551+
img_offset = 3 * len(self.single_transformer_blocks)
552+
txt_offset = img_offset + 6 * len(self.transformer_blocks)
553+
img_modulation = img_offset + 6 * index_block
554+
text_modulation = txt_offset + 6 * index_block
550555
temb = torch.cat(
551-
(pooled_temb[:, start_idx1 : start_idx1 + 6], pooled_temb[:, start_idx2 : start_idx2 + 6]), dim=1
556+
(
557+
pooled_temb[:, img_modulation : img_modulation + 6],
558+
pooled_temb[:, text_modulation : text_modulation + 6],
559+
),
560+
dim=1,
552561
)
553562
if torch.is_grad_enabled() and self.gradient_checkpointing:
554563
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(

0 commit comments

Comments
 (0)