We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6b68afd commit 80486c1Copy full SHA for 80486c1
examples/dreambooth/train_dreambooth.py
@@ -692,8 +692,8 @@ def main(args):
692
if accelerator.is_main_process:
693
pipeline = DiffusionPipeline.from_pretrained(
694
args.pretrained_model_name_or_path,
695
- unet=accelerator.unwrap_model(unet),
696
- text_encoder=accelerator.unwrap_model(text_encoder),
+ unet=accelerator.unwrap_model(unet, True),
+ text_encoder=accelerator.unwrap_model(text_encoder, True),
697
revision=args.revision,
698
)
699
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
0 commit comments