Skip to content

Commit ebfd8bf

Browse files
committed
dreambooth: guard against passing keep_fp32_wrapper arg to older versions of accelerate. part of fix for #1566
1 parent 80486c1 commit ebfd8bf

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import hashlib
3+
import inspect
34
import itertools
45
import math
56
import os
@@ -690,10 +691,18 @@ def main(args):
690691

691692
if global_step % args.save_steps == 0:
692693
if accelerator.is_main_process:
694+
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
695+
# it, the models will be unwrapped, and when they are then used for further training,
696+
# we will crash. pass this, but only to newer versions of accelerate. fixes
697+
# https://github.com/huggingface/diffusers/issues/1566
698+
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
699+
inspect.signature(accelerator.unwrap_model).parameters.keys()
700+
)
701+
extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
693702
pipeline = DiffusionPipeline.from_pretrained(
694703
args.pretrained_model_name_or_path,
695-
unet=accelerator.unwrap_model(unet, True),
696-
text_encoder=accelerator.unwrap_model(text_encoder, True),
704+
unet=accelerator.unwrap_model(unet, **extra_args),
705+
text_encoder=accelerator.unwrap_model(text_encoder, **extra_args),
697706
revision=args.revision,
698707
)
699708
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")

0 commit comments

Comments
 (0)