Skip to content

Commit 462956b

Browse files
small tweaks for parsing thibaudz controlnet checkpoints (#3657)
1 parent 5990014 commit 462956b

File tree

2 files changed

+87
-30
lines changed

2 files changed

+87
-30
lines changed

scripts/convert_original_controlnet_to_diffusers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,22 @@
7575
)
7676
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
7777
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
78+
79+
# small workaround to get argparser to parse a boolean input as either true _or_ false
80+
def parse_bool(string):
81+
if string == "True":
82+
return True
83+
elif string == "False":
84+
return False
85+
else:
86+
raise ValueError(f"could not parse string as bool {string}")
87+
88+
parser.add_argument(
89+
"--use_linear_projection", help="Override for use linear projection", required=False, type=parse_bool
90+
)
91+
92+
parser.add_argument("--cross_attention_dim", help="Override for cross attention_dim", required=False, type=int)
93+
7894
args = parser.parse_args()
7995

8096
controlnet = download_controlnet_from_original_ckpt(
@@ -86,6 +102,8 @@
86102
upcast_attention=args.upcast_attention,
87103
from_safetensors=args.from_safetensors,
88104
device=args.device,
105+
use_linear_projection=args.use_linear_projection,
106+
cross_attention_dim=args.cross_attention_dim,
89107
)
90108

91109
controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -339,41 +339,46 @@ def create_ldm_bert_config(original_config):
339339
return config
340340

341341

342-
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
342+
def convert_ldm_unet_checkpoint(
343+
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
344+
):
343345
"""
344346
Takes a state dict and a config, and returns a converted checkpoint.
345347
"""
346348

347-
# extract state_dict for UNet
348-
unet_state_dict = {}
349-
keys = list(checkpoint.keys())
350-
351-
if controlnet:
352-
unet_key = "control_model."
349+
if skip_extract_state_dict:
350+
unet_state_dict = checkpoint
353351
else:
354-
unet_key = "model.diffusion_model."
355-
356-
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
357-
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
358-
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
359-
print(
360-
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
361-
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
362-
)
363-
for key in keys:
364-
if key.startswith("model.diffusion_model"):
365-
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
366-
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
367-
else:
368-
if sum(k.startswith("model_ema") for k in keys) > 100:
352+
# extract state_dict for UNet
353+
unet_state_dict = {}
354+
keys = list(checkpoint.keys())
355+
356+
if controlnet:
357+
unet_key = "control_model."
358+
else:
359+
unet_key = "model.diffusion_model."
360+
361+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
362+
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
363+
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
369364
print(
370-
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
371-
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
365+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
366+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
372367
)
368+
for key in keys:
369+
if key.startswith("model.diffusion_model"):
370+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
371+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
372+
else:
373+
if sum(k.startswith("model_ema") for k in keys) > 100:
374+
print(
375+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
376+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
377+
)
373378

374-
for key in keys:
375-
if key.startswith(unet_key):
376-
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
379+
for key in keys:
380+
if key.startswith(unet_key):
381+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
377382

378383
new_checkpoint = {}
379384

@@ -956,17 +961,42 @@ def stable_unclip_image_noising_components(
956961

957962

958963
def convert_controlnet_checkpoint(
959-
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
964+
checkpoint,
965+
original_config,
966+
checkpoint_path,
967+
image_size,
968+
upcast_attention,
969+
extract_ema,
970+
use_linear_projection=None,
971+
cross_attention_dim=None,
960972
):
961973
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
962974
ctrlnet_config["upcast_attention"] = upcast_attention
963975

964976
ctrlnet_config.pop("sample_size")
965977

978+
if use_linear_projection is not None:
979+
ctrlnet_config["use_linear_projection"] = use_linear_projection
980+
981+
if cross_attention_dim is not None:
982+
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
983+
966984
controlnet_model = ControlNetModel(**ctrlnet_config)
967985

986+
# Some controlnet ckpt files are distributed independently from the rest of the
987+
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
988+
if "time_embed.0.weight" in checkpoint:
989+
skip_extract_state_dict = True
990+
else:
991+
skip_extract_state_dict = False
992+
968993
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
969-
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
994+
checkpoint,
995+
ctrlnet_config,
996+
path=checkpoint_path,
997+
extract_ema=extract_ema,
998+
controlnet=True,
999+
skip_extract_state_dict=skip_extract_state_dict,
9701000
)
9711001

9721002
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
@@ -1344,6 +1374,8 @@ def download_controlnet_from_original_ckpt(
13441374
upcast_attention: Optional[bool] = None,
13451375
device: str = None,
13461376
from_safetensors: bool = False,
1377+
use_linear_projection: Optional[bool] = None,
1378+
cross_attention_dim: Optional[bool] = None,
13471379
) -> DiffusionPipeline:
13481380
if not is_omegaconf_available():
13491381
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
@@ -1381,7 +1413,14 @@ def download_controlnet_from_original_ckpt(
13811413
raise ValueError("`control_stage_config` not present in original config")
13821414

13831415
controlnet_model = convert_controlnet_checkpoint(
1384-
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
1416+
checkpoint,
1417+
original_config,
1418+
checkpoint_path,
1419+
image_size,
1420+
upcast_attention,
1421+
extract_ema,
1422+
use_linear_projection=use_linear_projection,
1423+
cross_attention_dim=cross_attention_dim,
13851424
)
13861425

13871426
return controlnet_model

0 commit comments

Comments
 (0)