diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index b904920f1cd4..5c35760a9370 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1,5 +1,6 @@ import argparse import hashlib +import inspect import itertools import math import os @@ -680,10 +681,18 @@ def main(args): if global_step % args.save_steps == 0: if accelerator.is_main_process: + # newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing + # it, the models will be unwrapped, and when they are then used for further training, + # we will crash. pass this, but only to newer versions of accelerate. fixes + # https://github.com/huggingface/diffusers/issues/1566 + accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( + inspect.signature(accelerator.unwrap_model).parameters.keys() + ) + extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {} pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), + unet=accelerator.unwrap_model(unet, **extra_args), + text_encoder=accelerator.unwrap_model(text_encoder, **extra_args), revision=args.revision, ) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")