Skip to content

Commit ae306ef

Browse files
committed
dreambooth: guard against passing keep_fp32_wrapper arg to older versions of accelerate. part of fix for #1566
1 parent c231c62 commit ae306ef

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -680,10 +680,16 @@ def main(args):
680680

681681
if global_step % args.save_steps == 0:
682682
if accelerator.is_main_process:
683+
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
684+
# it, the models will be unwrapped, and when they are then used for further training,
685+
# we will crash. pass this, but only to newer versions of accelerate. fixes
686+
# https://github.com/huggingface/diffusers/issues/1566
687+
accepts_keep_fp32_wrapper = 'keep_fp32_wrapper' in set(inspect.signature(accelerator.unwrap_model).parameters.keys())
688+
extra_args = {'keep_fp32_wrapper': True} if accepts_keep_fp32_wrapper else {}
683689
pipeline = DiffusionPipeline.from_pretrained(
684690
args.pretrained_model_name_or_path,
685-
unet=accelerator.unwrap_model(unet, True),
686-
text_encoder=accelerator.unwrap_model(text_encoder, True),
691+
unet=accelerator.unwrap_model(unet, **extra_args),
692+
text_encoder=accelerator.unwrap_model(text_encoder, **extra_args),
687693
revision=args.revision,
688694
)
689695
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")

0 commit comments

Comments
 (0)