Skip to content

Commit 80486c1

Browse files
committed
dreambooth: fix #1566: maintain fp32 wrapper when saving a checkpoint to avoid crash when running fp16
1 parent 6b68afd commit 80486c1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -692,8 +692,8 @@ def main(args):
692692
if accelerator.is_main_process:
693693
pipeline = DiffusionPipeline.from_pretrained(
694694
args.pretrained_model_name_or_path,
695-
unet=accelerator.unwrap_model(unet),
696-
text_encoder=accelerator.unwrap_model(text_encoder),
695+
unet=accelerator.unwrap_model(unet, True),
696+
text_encoder=accelerator.unwrap_model(text_encoder, True),
697697
revision=args.revision,
698698
)
699699
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")

0 commit comments

Comments
 (0)