@@ -2137,9 +2137,18 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2137
2137
converted_state_dict = {}
2138
2138
keys = list (checkpoint .keys ())
2139
2139
2140
+ variant = "chroma" if "distilled_guidance_layer.in_proj.weight" in checkpoint else "flux"
2141
+
2140
2142
for k in keys :
2141
2143
if "model.diffusion_model." in k :
2142
2144
checkpoint [k .replace ("model.diffusion_model." , "" )] = checkpoint .pop (k )
2145
+ if variant == "chroma" and "distilled_guidance_layer." in k :
2146
+ new_key = k
2147
+ if k .startswith ("distilled_guidance_layer.norms" ):
2148
+ new_key = k .replace (".scale" , ".weight" )
2149
+ elif k .startswith ("distilled_guidance_layer.layer" ):
2150
+ new_key = k .replace ("in_layer" , "linear_1" ).replace ("out_layer" , "linear_2" )
2151
+ converted_state_dict [new_key ] = checkpoint .pop (k )
2143
2152
2144
2153
num_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in checkpoint if "double_blocks." in k ))[- 1 ] + 1 # noqa: C401
2145
2154
num_single_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in checkpoint if "single_blocks." in k ))[- 1 ] + 1 # noqa: C401
@@ -2153,40 +2162,49 @@ def swap_scale_shift(weight):
2153
2162
new_weight = torch .cat ([scale , shift ], dim = 0 )
2154
2163
return new_weight
2155
2164
2156
- ## time_text_embed.timestep_embedder <- time_in
2157
- converted_state_dict ["time_text_embed.timestep_embedder.linear_1.weight" ] = checkpoint .pop (
2158
- "time_in.in_layer.weight"
2159
- )
2160
- converted_state_dict ["time_text_embed.timestep_embedder.linear_1.bias" ] = checkpoint .pop ("time_in.in_layer.bias" )
2161
- converted_state_dict ["time_text_embed.timestep_embedder.linear_2.weight" ] = checkpoint .pop (
2162
- "time_in.out_layer.weight"
2163
- )
2164
- converted_state_dict ["time_text_embed.timestep_embedder.linear_2.bias" ] = checkpoint .pop ("time_in.out_layer.bias" )
2165
-
2166
- ## time_text_embed.text_embedder <- vector_in
2167
- converted_state_dict ["time_text_embed.text_embedder.linear_1.weight" ] = checkpoint .pop ("vector_in.in_layer.weight" )
2168
- converted_state_dict ["time_text_embed.text_embedder.linear_1.bias" ] = checkpoint .pop ("vector_in.in_layer.bias" )
2169
- converted_state_dict ["time_text_embed.text_embedder.linear_2.weight" ] = checkpoint .pop (
2170
- "vector_in.out_layer.weight"
2171
- )
2172
- converted_state_dict ["time_text_embed.text_embedder.linear_2.bias" ] = checkpoint .pop ("vector_in.out_layer.bias" )
2173
-
2174
- # guidance
2175
- has_guidance = any ("guidance" in k for k in checkpoint )
2176
- if has_guidance :
2177
- converted_state_dict ["time_text_embed.guidance_embedder.linear_1.weight" ] = checkpoint .pop (
2178
- "guidance_in.in_layer.weight"
2165
+ if variant == "flux" :
2166
+ ## time_text_embed.timestep_embedder <- time_in
2167
+ converted_state_dict ["time_text_embed.timestep_embedder.linear_1.weight" ] = checkpoint .pop (
2168
+ "time_in.in_layer.weight"
2179
2169
)
2180
- converted_state_dict ["time_text_embed.guidance_embedder .linear_1.bias" ] = checkpoint .pop (
2181
- "guidance_in .in_layer.bias"
2170
+ converted_state_dict ["time_text_embed.timestep_embedder .linear_1.bias" ] = checkpoint .pop (
2171
+ "time_in .in_layer.bias"
2182
2172
)
2183
- converted_state_dict ["time_text_embed.guidance_embedder .linear_2.weight" ] = checkpoint .pop (
2184
- "guidance_in .out_layer.weight"
2173
+ converted_state_dict ["time_text_embed.timestep_embedder .linear_2.weight" ] = checkpoint .pop (
2174
+ "time_in .out_layer.weight"
2185
2175
)
2186
- converted_state_dict ["time_text_embed.guidance_embedder .linear_2.bias" ] = checkpoint .pop (
2187
- "guidance_in .out_layer.bias"
2176
+ converted_state_dict ["time_text_embed.timestep_embedder .linear_2.bias" ] = checkpoint .pop (
2177
+ "time_in .out_layer.bias"
2188
2178
)
2189
2179
2180
+ ## time_text_embed.text_embedder <- vector_in
2181
+ converted_state_dict ["time_text_embed.text_embedder.linear_1.weight" ] = checkpoint .pop (
2182
+ "vector_in.in_layer.weight"
2183
+ )
2184
+ converted_state_dict ["time_text_embed.text_embedder.linear_1.bias" ] = checkpoint .pop ("vector_in.in_layer.bias" )
2185
+ converted_state_dict ["time_text_embed.text_embedder.linear_2.weight" ] = checkpoint .pop (
2186
+ "vector_in.out_layer.weight"
2187
+ )
2188
+ converted_state_dict ["time_text_embed.text_embedder.linear_2.bias" ] = checkpoint .pop (
2189
+ "vector_in.out_layer.bias"
2190
+ )
2191
+
2192
+ # guidance
2193
+ has_guidance = any ("guidance" in k for k in checkpoint )
2194
+ if has_guidance :
2195
+ converted_state_dict ["time_text_embed.guidance_embedder.linear_1.weight" ] = checkpoint .pop (
2196
+ "guidance_in.in_layer.weight"
2197
+ )
2198
+ converted_state_dict ["time_text_embed.guidance_embedder.linear_1.bias" ] = checkpoint .pop (
2199
+ "guidance_in.in_layer.bias"
2200
+ )
2201
+ converted_state_dict ["time_text_embed.guidance_embedder.linear_2.weight" ] = checkpoint .pop (
2202
+ "guidance_in.out_layer.weight"
2203
+ )
2204
+ converted_state_dict ["time_text_embed.guidance_embedder.linear_2.bias" ] = checkpoint .pop (
2205
+ "guidance_in.out_layer.bias"
2206
+ )
2207
+
2190
2208
# context_embedder
2191
2209
converted_state_dict ["context_embedder.weight" ] = checkpoint .pop ("txt_in.weight" )
2192
2210
converted_state_dict ["context_embedder.bias" ] = checkpoint .pop ("txt_in.bias" )
@@ -2199,20 +2217,21 @@ def swap_scale_shift(weight):
2199
2217
for i in range (num_layers ):
2200
2218
block_prefix = f"transformer_blocks.{ i } ."
2201
2219
# norms.
2202
- ## norm1
2203
- converted_state_dict [f"{ block_prefix } norm1.linear.weight" ] = checkpoint .pop (
2204
- f"double_blocks.{ i } .img_mod.lin.weight"
2205
- )
2206
- converted_state_dict [f"{ block_prefix } norm1.linear.bias" ] = checkpoint .pop (
2207
- f"double_blocks.{ i } .img_mod.lin.bias"
2208
- )
2209
- ## norm1_context
2210
- converted_state_dict [f"{ block_prefix } norm1_context.linear.weight" ] = checkpoint .pop (
2211
- f"double_blocks.{ i } .txt_mod.lin.weight"
2212
- )
2213
- converted_state_dict [f"{ block_prefix } norm1_context.linear.bias" ] = checkpoint .pop (
2214
- f"double_blocks.{ i } .txt_mod.lin.bias"
2215
- )
2220
+ if variant == "flux" :
2221
+ ## norm1
2222
+ converted_state_dict [f"{ block_prefix } norm1.linear.weight" ] = checkpoint .pop (
2223
+ f"double_blocks.{ i } .img_mod.lin.weight"
2224
+ )
2225
+ converted_state_dict [f"{ block_prefix } norm1.linear.bias" ] = checkpoint .pop (
2226
+ f"double_blocks.{ i } .img_mod.lin.bias"
2227
+ )
2228
+ ## norm1_context
2229
+ converted_state_dict [f"{ block_prefix } norm1_context.linear.weight" ] = checkpoint .pop (
2230
+ f"double_blocks.{ i } .txt_mod.lin.weight"
2231
+ )
2232
+ converted_state_dict [f"{ block_prefix } norm1_context.linear.bias" ] = checkpoint .pop (
2233
+ f"double_blocks.{ i } .txt_mod.lin.bias"
2234
+ )
2216
2235
# Q, K, V
2217
2236
sample_q , sample_k , sample_v = torch .chunk (checkpoint .pop (f"double_blocks.{ i } .img_attn.qkv.weight" ), 3 , dim = 0 )
2218
2237
context_q , context_k , context_v = torch .chunk (
@@ -2285,13 +2304,15 @@ def swap_scale_shift(weight):
2285
2304
# single transformer blocks
2286
2305
for i in range (num_single_layers ):
2287
2306
block_prefix = f"single_transformer_blocks.{ i } ."
2288
- # norm.linear <- single_blocks.0.modulation.lin
2289
- converted_state_dict [f"{ block_prefix } norm.linear.weight" ] = checkpoint .pop (
2290
- f"single_blocks.{ i } .modulation.lin.weight"
2291
- )
2292
- converted_state_dict [f"{ block_prefix } norm.linear.bias" ] = checkpoint .pop (
2293
- f"single_blocks.{ i } .modulation.lin.bias"
2294
- )
2307
+
2308
+ if variant == "flux" :
2309
+ # norm.linear <- single_blocks.0.modulation.lin
2310
+ converted_state_dict [f"{ block_prefix } norm.linear.weight" ] = checkpoint .pop (
2311
+ f"single_blocks.{ i } .modulation.lin.weight"
2312
+ )
2313
+ converted_state_dict [f"{ block_prefix } norm.linear.bias" ] = checkpoint .pop (
2314
+ f"single_blocks.{ i } .modulation.lin.bias"
2315
+ )
2295
2316
# Q, K, V, mlp
2296
2317
mlp_hidden_dim = int (inner_dim * mlp_ratio )
2297
2318
split_size = (inner_dim , inner_dim , inner_dim , mlp_hidden_dim )
@@ -2320,12 +2341,14 @@ def swap_scale_shift(weight):
2320
2341
2321
2342
converted_state_dict ["proj_out.weight" ] = checkpoint .pop ("final_layer.linear.weight" )
2322
2343
converted_state_dict ["proj_out.bias" ] = checkpoint .pop ("final_layer.linear.bias" )
2323
- converted_state_dict ["norm_out.linear.weight" ] = swap_scale_shift (
2324
- checkpoint .pop ("final_layer.adaLN_modulation.1.weight" )
2325
- )
2326
- converted_state_dict ["norm_out.linear.bias" ] = swap_scale_shift (
2327
- checkpoint .pop ("final_layer.adaLN_modulation.1.bias" )
2328
- )
2344
+
2345
+ if variant == "flux" :
2346
+ converted_state_dict ["norm_out.linear.weight" ] = swap_scale_shift (
2347
+ checkpoint .pop ("final_layer.adaLN_modulation.1.weight" )
2348
+ )
2349
+ converted_state_dict ["norm_out.linear.bias" ] = swap_scale_shift (
2350
+ checkpoint .pop ("final_layer.adaLN_modulation.1.bias" )
2351
+ )
2329
2352
2330
2353
return converted_state_dict
2331
2354
0 commit comments