Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 46 additions & 40 deletions examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import argparse
import copy
import logging
import math
import os
import random
from pathlib import Path
from typing import Optional
from typing import Iterable, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -234,25 +233,17 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
}


# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
Exponential Moving Average of models weights
"""

def __init__(
self,
model,
decay=0.9999,
device=None,
):
self.averaged_model = copy.deepcopy(model).eval()
self.averaged_model.requires_grad_(False)
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
parameters = list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never heard them being called shadow, but pretty creative :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


self.decay = decay

if device is not None:
self.averaged_model = self.averaged_model.to(device=device)

self.optimization_step = 0

def get_decay(self, optimization_step):
Expand All @@ -263,34 +254,47 @@ def get_decay(self, optimization_step):
return 1 - min(self.decay, value)

@torch.no_grad()
def step(self, new_model):
ema_state_dict = self.averaged_model.state_dict()
def step(self, parameters):
parameters = list(parameters)

self.optimization_step += 1
self.decay = self.get_decay(self.optimization_step)

for key, param in new_model.named_parameters():
if isinstance(param, dict):
continue
try:
ema_param = ema_state_dict[key]
except KeyError:
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
ema_state_dict[key] = ema_param

param = param.clone().detach().to(ema_param.dtype).to(ema_param.device)

for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
ema_state_dict[key].sub_(self.decay * (ema_param - param))
tmp = self.decay * (s_param - param)
s_param.sub_(tmp)
else:
ema_state_dict[key].copy_(param)

for key, param in new_model.named_buffers():
ema_state_dict[key] = param
s_param.copy_(param)

self.averaged_model.load_state_dict(ema_state_dict, strict=False)
torch.cuda.empty_cache()

def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Copy current averaged parameters into given collection of parameters.

Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = list(parameters)
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data)

def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.

Args:
device: like `device` argument to `torch.Tensor.to`
"""
# .to() on the tensors handles None correctly
self.shadow_params = [
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
for p in self.shadow_params
]


def main():
args = parse_args()
Expand Down Expand Up @@ -336,9 +340,6 @@ def main():
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")

if args.use_ema:
ema_unet = EMAModel(unet)

# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
Expand Down Expand Up @@ -510,8 +511,9 @@ def collate_fn(examples):
text_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)

# Move the ema_unet to gpu.
ema_unet.averaged_model.to(accelerator.device)
# Create EMA for the unet.
if args.use_ema:
ema_unet = EMAModel(unet.parameters())

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down Expand Up @@ -583,7 +585,7 @@ def collate_fn(examples):
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.use_ema:
ema_unet.step(unet)
ema_unet.step(unet.parameters())
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
Expand All @@ -598,10 +600,14 @@ def collate_fn(examples):
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())

pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=accelerator.unwrap_model(ema_unet.averaged_model if args.use_ema else unet),
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
Expand Down