Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import torch.nn.functional as F
import torch.nn.Module as Module
import torch.utils.checkpoint
from torch.utils.data import Dataset

Expand Down Expand Up @@ -185,6 +186,7 @@ def parse_args():
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--train_text_encoder", action="store_true", help="Enable text encoder training.")

args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
Expand Down Expand Up @@ -406,8 +408,16 @@ def main():
else:
optimizer_class = torch.optim.AdamW

class WrapperModel(Module):
def __init__(self, _unet, _text_encoder):
super().__init__()
self.unet = _unet
if args.train_text_encoder:
self.text_encoder = _text_encoder

model = WrapperModel(unet, text_encoder)
optimizer = optimizer_class(
unet.parameters(), # only optimize unet
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
Expand Down Expand Up @@ -467,8 +477,8 @@ def collate_fn(examples):
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)

unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)

weight_dtype = torch.float32
Expand Down Expand Up @@ -512,9 +522,9 @@ def collate_fn(examples):
global_step = 0

for epoch in range(args.num_train_epochs):
unet.train()
model.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
with accelerator.accumulate(model):
# Convert images to latent space
with torch.no_grad():
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
Expand Down Expand Up @@ -556,7 +566,7 @@ def collate_fn(examples):

accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Expand All @@ -577,8 +587,11 @@ def collate_fn(examples):

# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
unwrapped = accelerator.unwrap_model(model)
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet)
args.pretrained_model_name_or_path,
text_encoder=unwrapped.text_encoder if args.train_text_encoder else text_encoder,
unet=unwrapped.unet,
)
pipeline.save_pretrained(args.output_dir)

Expand Down