Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions scripts/convert_original_stable_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,15 +285,34 @@ def create_ldm_bert_config(original_config):
return config


def convert_ldm_unet_checkpoint(checkpoint, config):
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""

# extract state_dict for UNet
unet_state_dict = {}
unet_key = "model.diffusion_model."
keys = list(checkpoint.keys())

unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100:
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
if extract_ema:
print(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
print(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)

for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
Expand Down Expand Up @@ -630,6 +649,15 @@ def convert_ldm_clip_checkpoint(checkpoint):
type=str,
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
)
parser.add_argument(
"--extract_ema",
action="store_true",
help=(
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
),
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")

args = parser.parse_args()
Expand All @@ -641,7 +669,9 @@ def convert_ldm_clip_checkpoint(checkpoint):
args.original_config_file = "./v1-inference.yaml"

original_config = OmegaConf.load(args.original_config_file)
checkpoint = torch.load(args.checkpoint_path)["state_dict"]

checkpoint = torch.load(args.checkpoint_path)
checkpoint = checkpoint["state_dict"]

num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
Expand Down Expand Up @@ -669,7 +699,9 @@ def convert_ldm_clip_checkpoint(checkpoint):

# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
)

unet = UNet2DConditionModel(**unet_config)
unet.load_state_dict(converted_unet_checkpoint)
Expand Down