We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent af04479 commit 2ca48e3Copy full SHA for 2ca48e3
examples/dreambooth/train_dreambooth.py
@@ -676,8 +676,8 @@ def collate_fn(examples):
676
if accelerator.is_main_process:
677
pipeline = DiffusionPipeline.from_pretrained(
678
args.pretrained_model_name_or_path,
679
- unet=accelerator.unwrap_model(unet),
680
- text_encoder=accelerator.unwrap_model(text_encoder),
+ unet=accelerator.unwrap_model(unet, True),
+ text_encoder=accelerator.unwrap_model(text_encoder, True),
681
revision=args.revision,
682
)
683
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
0 commit comments