Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 90 additions & 37 deletions examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __call__(self, x):
return False


class Text2ImageDataset:
class SDText2ImageDataset:
def __init__(
self,
train_shards_path_or_url: Union[str, List[str]],
Expand Down Expand Up @@ -359,19 +359,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=


# Compare LCMScheduler.step, Step 4
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "sample":
pred_x_0 = model_output
elif prediction_type == "v_prediction":
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
pred_x_0 = alphas * sample - sigmas * model_output
else:
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)

return pred_x_0


# Based on step 4 in DDIMScheduler.step
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_epsilon = model_output
elif prediction_type == "sample":
pred_epsilon = (sample - alphas * model_output) / sigmas
elif prediction_type == "v_prediction":
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)

return pred_epsilon


def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
Expand Down Expand Up @@ -835,34 +859,35 @@ def main(args):
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
)

# The scheduler calculates the alpha and sigma schedule for us
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
# Initialize the DDIM ODE solver for distillation.
solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps,
ddim_timesteps=args.num_ddim_timesteps,
)

# 2. Load tokenizers from SD-XL checkpoint.
# 2. Load tokenizers from SD 1.X/2.X checkpoint.
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
)

# 3. Load text encoders from SD-1.5 checkpoint.
# 3. Load text encoders from SD 1.X/2.X checkpoint.
# import correct text encoder classes
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
)

# 4. Load VAE from SD-XL checkpoint (or more stable VAE)
# 4. Load VAE from SD 1.X/2.X checkpoint
vae = AutoencoderKL.from_pretrained(
args.pretrained_teacher_model,
subfolder="vae",
revision=args.teacher_revision,
)

# 5. Load teacher U-Net from SD-XL checkpoint
# 5. Load teacher U-Net from SD 1.X/2.X checkpoint
teacher_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
)
Expand All @@ -872,7 +897,7 @@ def main(args):
text_encoder.requires_grad_(False)
teacher_unet.requires_grad_(False)

# 7. Create online (`unet`) student U-Nets.
# 7. Create online student U-Net.
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
)
Expand Down Expand Up @@ -935,6 +960,7 @@ def main(args):
# Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device)
sigma_schedule = sigma_schedule.to(accelerator.device)
# Move the ODE solver to accelerator.device.
solver = solver.to(accelerator.device)

# 10. Handle saving and loading of checkpoints
Expand Down Expand Up @@ -1011,13 +1037,14 @@ def load_model_hook(models, input_dir):
eps=args.adam_epsilon,
)

# 13. Dataset creation and data processing
# Here, we compute not just the text embeddings but also the additional embeddings
# needed for the SD XL UNet to operate.
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
return {"prompt_embeds": prompt_embeds}

dataset = Text2ImageDataset(
dataset = SDText2ImageDataset(
train_shards_path_or_url=args.train_shards_path_or_url,
num_train_examples=args.max_train_samples,
per_gpu_batch_size=args.train_batch_size,
Expand All @@ -1037,6 +1064,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
tokenizer=tokenizer,
)

# 14. LR Scheduler creation
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
Expand All @@ -1051,6 +1079,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
num_training_steps=args.max_train_steps,
)

# 15. Prepare for training
# Prepare everything with our `accelerator`.
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)

Expand All @@ -1072,7 +1101,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
).input_ids.to(accelerator.device)
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]

# Train!
# 16. Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
Expand Down Expand Up @@ -1123,6 +1152,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# 1. Load and process the image and text conditioning
image, text = batch

image = image.to(accelerator.device, non_blocking=True)
Expand All @@ -1140,37 +1170,37 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok

latents = latents * vae.config.scaling_factor
latents = latents.to(weight_dtype)

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]

# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)

# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]

# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noise = torch.randn_like(latents)
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)

# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
# 5. Sample a random guidance scale w from U[w_min, w_max]
# Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype)

# 20.4.8. Prepare prompt embeds and unet_added_conditions
# 6. Prepare prompt embeds and unet_added_conditions
prompt_embeds = encoded_text.pop("prompt_embeds")

# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
noise_pred = unet(
noisy_model_input,
start_timesteps,
Expand All @@ -1179,7 +1209,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
added_cond_kwargs=encoded_text,
).sample

pred_x_0 = predicted_origin(
pred_x_0 = get_predicted_original_sample(
noise_pred,
start_timesteps,
noisy_model_input,
Expand All @@ -1190,17 +1220,27 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok

model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0

# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
# Get teacher model prediction on noisy_latents and conditional embedding
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
start_timesteps,
encoder_hidden_states=prompt_embeds.to(weight_dtype),
).sample
cond_pred_x0 = predicted_origin(
cond_pred_x0 = get_predicted_original_sample(
cond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
cond_pred_noise = get_predicted_noise(
cond_teacher_output,
start_timesteps,
noisy_model_input,
Expand All @@ -1209,13 +1249,21 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
sigma_schedule,
)

# Get teacher model prediction on noisy_latents and unconditional embedding
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
uncond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
start_timesteps,
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
).sample
uncond_pred_x0 = predicted_origin(
uncond_pred_x0 = get_predicted_original_sample(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
uncond_pred_noise = get_predicted_noise(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
Expand All @@ -1224,12 +1272,17 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
sigma_schedule,
)

# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
# augmented PF-ODE trajectory (solving backward in time)
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
x_prev = solver.ddim_step(pred_x0, pred_noise, index)

# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
target_noise_pred = unet(
Expand All @@ -1238,7 +1291,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
timestep_cond=None,
encoder_hidden_states=prompt_embeds.float(),
).sample
pred_x_0 = predicted_origin(
pred_x_0 = get_predicted_original_sample(
target_noise_pred,
timesteps,
x_prev,
Expand All @@ -1248,15 +1301,15 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
)
target = c_skip * x_prev + c_out * pred_x_0

# 20.4.13. Calculate loss
# 10. Calculate loss
if args.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.loss_type == "huber":
loss = torch.mean(
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
)

# 20.4.14. Backpropagate on the online student model (`unet`)
# 11. Backpropagate on the online student model (`unet`)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
Expand Down
Loading