diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index ca39aeff236b..c9d0a5c452c5 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F +import torch.nn.Module as Module import torch.utils.checkpoint from torch.utils.data import Dataset @@ -185,6 +186,7 @@ def parse_args(): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--train_text_encoder", action="store_true", help="Enable text encoder training.") args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -406,8 +408,16 @@ def main(): else: optimizer_class = torch.optim.AdamW + class WrapperModel(Module): + def __init__(self, _unet, _text_encoder): + super().__init__() + self.unet = _unet + if args.train_text_encoder: + self.text_encoder = _text_encoder + + model = WrapperModel(unet, text_encoder) optimizer = optimizer_class( - unet.parameters(), # only optimize unet + model.parameters(), lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -467,8 +477,8 @@ def collate_fn(examples): num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler ) weight_dtype = torch.float32 @@ -512,9 +522,9 @@ def collate_fn(examples): global_step = 0 for epoch in range(args.num_train_epochs): - unet.train() + model.train() for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(unet): + with accelerator.accumulate(model): # Convert images to latent space with torch.no_grad(): latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() @@ -556,7 +566,7 @@ def collate_fn(examples): accelerator.backward(loss) if accelerator.sync_gradients: - accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() @@ -577,8 +587,11 @@ def collate_fn(examples): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: + unwrapped = accelerator.unwrap_model(model) pipeline = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet) + args.pretrained_model_name_or_path, + text_encoder=unwrapped.text_encoder if args.train_text_encoder else text_encoder, + unet=unwrapped.unet, ) pipeline.save_pretrained(args.output_dir)