@@ -339,41 +339,46 @@ def create_ldm_bert_config(original_config):
339
339
return config
340
340
341
341
342
- def convert_ldm_unet_checkpoint (checkpoint , config , path = None , extract_ema = False , controlnet = False ):
342
+ def convert_ldm_unet_checkpoint (
343
+ checkpoint , config , path = None , extract_ema = False , controlnet = False , skip_extract_state_dict = False
344
+ ):
343
345
"""
344
346
Takes a state dict and a config, and returns a converted checkpoint.
345
347
"""
346
348
347
- # extract state_dict for UNet
348
- unet_state_dict = {}
349
- keys = list (checkpoint .keys ())
350
-
351
- if controlnet :
352
- unet_key = "control_model."
349
+ if skip_extract_state_dict :
350
+ unet_state_dict = checkpoint
353
351
else :
354
- unet_key = "model.diffusion_model."
355
-
356
- # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
357
- if sum (k .startswith ("model_ema" ) for k in keys ) > 100 and extract_ema :
358
- print (f"Checkpoint { path } has both EMA and non-EMA weights." )
359
- print (
360
- "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
361
- " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
362
- )
363
- for key in keys :
364
- if key .startswith ("model.diffusion_model" ):
365
- flat_ema_key = "model_ema." + "" .join (key .split ("." )[1 :])
366
- unet_state_dict [key .replace (unet_key , "" )] = checkpoint .pop (flat_ema_key )
367
- else :
368
- if sum (k .startswith ("model_ema" ) for k in keys ) > 100 :
352
+ # extract state_dict for UNet
353
+ unet_state_dict = {}
354
+ keys = list (checkpoint .keys ())
355
+
356
+ if controlnet :
357
+ unet_key = "control_model."
358
+ else :
359
+ unet_key = "model.diffusion_model."
360
+
361
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
362
+ if sum (k .startswith ("model_ema" ) for k in keys ) > 100 and extract_ema :
363
+ print (f"Checkpoint { path } has both EMA and non-EMA weights." )
369
364
print (
370
- "In this conversion only the non- EMA weights are extracted. If you want to instead extract the EMA"
371
- " weights (usually better for inference ), please make sure to add the `--extract_ema` flag."
365
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non- EMA"
366
+ " weights (useful to continue fine-tuning ), please make sure to remove the `--extract_ema` flag."
372
367
)
368
+ for key in keys :
369
+ if key .startswith ("model.diffusion_model" ):
370
+ flat_ema_key = "model_ema." + "" .join (key .split ("." )[1 :])
371
+ unet_state_dict [key .replace (unet_key , "" )] = checkpoint .pop (flat_ema_key )
372
+ else :
373
+ if sum (k .startswith ("model_ema" ) for k in keys ) > 100 :
374
+ print (
375
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
376
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
377
+ )
373
378
374
- for key in keys :
375
- if key .startswith (unet_key ):
376
- unet_state_dict [key .replace (unet_key , "" )] = checkpoint .pop (key )
379
+ for key in keys :
380
+ if key .startswith (unet_key ):
381
+ unet_state_dict [key .replace (unet_key , "" )] = checkpoint .pop (key )
377
382
378
383
new_checkpoint = {}
379
384
@@ -956,17 +961,42 @@ def stable_unclip_image_noising_components(
956
961
957
962
958
963
def convert_controlnet_checkpoint (
959
- checkpoint , original_config , checkpoint_path , image_size , upcast_attention , extract_ema
964
+ checkpoint ,
965
+ original_config ,
966
+ checkpoint_path ,
967
+ image_size ,
968
+ upcast_attention ,
969
+ extract_ema ,
970
+ use_linear_projection = None ,
971
+ cross_attention_dim = None ,
960
972
):
961
973
ctrlnet_config = create_unet_diffusers_config (original_config , image_size = image_size , controlnet = True )
962
974
ctrlnet_config ["upcast_attention" ] = upcast_attention
963
975
964
976
ctrlnet_config .pop ("sample_size" )
965
977
978
+ if use_linear_projection is not None :
979
+ ctrlnet_config ["use_linear_projection" ] = use_linear_projection
980
+
981
+ if cross_attention_dim is not None :
982
+ ctrlnet_config ["cross_attention_dim" ] = cross_attention_dim
983
+
966
984
controlnet_model = ControlNetModel (** ctrlnet_config )
967
985
986
+ # Some controlnet ckpt files are distributed independently from the rest of the
987
+ # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
988
+ if "time_embed.0.weight" in checkpoint :
989
+ skip_extract_state_dict = True
990
+ else :
991
+ skip_extract_state_dict = False
992
+
968
993
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint (
969
- checkpoint , ctrlnet_config , path = checkpoint_path , extract_ema = extract_ema , controlnet = True
994
+ checkpoint ,
995
+ ctrlnet_config ,
996
+ path = checkpoint_path ,
997
+ extract_ema = extract_ema ,
998
+ controlnet = True ,
999
+ skip_extract_state_dict = skip_extract_state_dict ,
970
1000
)
971
1001
972
1002
controlnet_model .load_state_dict (converted_ctrl_checkpoint )
@@ -1344,6 +1374,8 @@ def download_controlnet_from_original_ckpt(
1344
1374
upcast_attention : Optional [bool ] = None ,
1345
1375
device : str = None ,
1346
1376
from_safetensors : bool = False ,
1377
+ use_linear_projection : Optional [bool ] = None ,
1378
+ cross_attention_dim : Optional [bool ] = None ,
1347
1379
) -> DiffusionPipeline :
1348
1380
if not is_omegaconf_available ():
1349
1381
raise ValueError (BACKENDS_MAPPING ["omegaconf" ][1 ])
@@ -1381,7 +1413,14 @@ def download_controlnet_from_original_ckpt(
1381
1413
raise ValueError ("`control_stage_config` not present in original config" )
1382
1414
1383
1415
controlnet_model = convert_controlnet_checkpoint (
1384
- checkpoint , original_config , checkpoint_path , image_size , upcast_attention , extract_ema
1416
+ checkpoint ,
1417
+ original_config ,
1418
+ checkpoint_path ,
1419
+ image_size ,
1420
+ upcast_attention ,
1421
+ extract_ema ,
1422
+ use_linear_projection = use_linear_projection ,
1423
+ cross_attention_dim = cross_attention_dim ,
1385
1424
)
1386
1425
1387
1426
return controlnet_model
0 commit comments