Skip to content

Commit 2ca48e3

Browse files
committed
dreambooth: fix #1566: maintain fp32 wrapper when saving a checkpoint to avoid crash when running fp16
1 parent af04479 commit 2ca48e3

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,8 +676,8 @@ def collate_fn(examples):
676676
if accelerator.is_main_process:
677677
pipeline = DiffusionPipeline.from_pretrained(
678678
args.pretrained_model_name_or_path,
679-
unet=accelerator.unwrap_model(unet),
680-
text_encoder=accelerator.unwrap_model(text_encoder),
679+
unet=accelerator.unwrap_model(unet, True),
680+
text_encoder=accelerator.unwrap_model(text_encoder, True),
681681
revision=args.revision,
682682
)
683683
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")

0 commit comments

Comments
 (0)