diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index c96733f0425e..05689b71fa04 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -156,7 +156,7 @@ def __call__(self, x): return False -class Text2ImageDataset: +class SDText2ImageDataset: def __init__( self, train_shards_path_or_url: Union[str, List[str]], @@ -359,19 +359,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= # Compare LCMScheduler.step, Step 4 -def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): +def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": - sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) - alphas = extract_into_tensor(alphas, timesteps, sample.shape) pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "sample": + pred_x_0 = model_output elif prediction_type == "v_prediction": - pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output + pred_x_0 = alphas * sample - sigmas * model_output else: - raise ValueError(f"Prediction type {prediction_type} currently not supported.") + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) return pred_x_0 +# Based on step 4 in DDIMScheduler.step +def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_epsilon = model_output + elif prediction_type == "sample": + pred_epsilon = (sample - alphas * model_output) / sigmas + elif prediction_type == "v_prediction": + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_epsilon + + def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) @@ -835,34 +859,35 @@ def main(args): args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - # The scheduler calculates the alpha and sigma schedule for us + # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + # Initialize the DDIM ODE solver for distillation. solver = DDIMSolver( noise_scheduler.alphas_cumprod.numpy(), timesteps=noise_scheduler.config.num_train_timesteps, ddim_timesteps=args.num_ddim_timesteps, ) - # 2. Load tokenizers from SD-XL checkpoint. + # 2. Load tokenizers from SD 1.X/2.X checkpoint. tokenizer = AutoTokenizer.from_pretrained( args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False ) - # 3. Load text encoders from SD-1.5 checkpoint. + # 3. Load text encoders from SD 1.X/2.X checkpoint. # import correct text encoder classes text_encoder = CLIPTextModel.from_pretrained( args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision ) - # 4. Load VAE from SD-XL checkpoint (or more stable VAE) + # 4. Load VAE from SD 1.X/2.X checkpoint vae = AutoencoderKL.from_pretrained( args.pretrained_teacher_model, subfolder="vae", revision=args.teacher_revision, ) - # 5. Load teacher U-Net from SD-XL checkpoint + # 5. Load teacher U-Net from SD 1.X/2.X checkpoint teacher_unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) @@ -872,7 +897,7 @@ def main(args): text_encoder.requires_grad_(False) teacher_unet.requires_grad_(False) - # 7. Create online (`unet`) student U-Nets. + # 7. Create online student U-Net. unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) @@ -935,6 +960,7 @@ def main(args): # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device) + # Move the ODE solver to accelerator.device. solver = solver.to(accelerator.device) # 10. Handle saving and loading of checkpoints @@ -1011,13 +1037,14 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) + # 13. Dataset creation and data processing # Here, we compute not just the text embeddings but also the additional embeddings # needed for the SD XL UNet to operate. def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True): prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train) return {"prompt_embeds": prompt_embeds} - dataset = Text2ImageDataset( + dataset = SDText2ImageDataset( train_shards_path_or_url=args.train_shards_path_or_url, num_train_examples=args.max_train_samples, per_gpu_batch_size=args.train_batch_size, @@ -1037,6 +1064,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok tokenizer=tokenizer, ) + # 14. LR Scheduler creation # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) @@ -1051,6 +1079,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok num_training_steps=args.max_train_steps, ) + # 15. Prepare for training # Prepare everything with our `accelerator`. unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler) @@ -1072,7 +1101,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok ).input_ids.to(accelerator.device) uncond_prompt_embeds = text_encoder(uncond_input_ids)[0] - # Train! + # 16. Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") @@ -1123,6 +1152,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): + # 1. Load and process the image and text conditioning image, text = batch image = image.to(accelerator.device, non_blocking=True) @@ -1140,37 +1170,37 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok latents = latents * vae.config.scaling_factor latents = latents.to(weight_dtype) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. + # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() start_timesteps = solver.ddim_timesteps[index] timesteps = start_timesteps - topk timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) - # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps. + # 3. Get boundary scalings for start_timesteps and (end) timesteps. c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] - # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each + # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noise = torch.randn_like(latents) noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) - # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it + # 5. Sample a random guidance scale w from U[w_min, w_max] + # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w = w.reshape(bsz, 1, 1, 1) w = w.to(device=latents.device, dtype=latents.dtype) - # 20.4.8. Prepare prompt embeds and unet_added_conditions + # 6. Prepare prompt embeds and unet_added_conditions prompt_embeds = encoded_text.pop("prompt_embeds") - # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) noise_pred = unet( noisy_model_input, start_timesteps, @@ -1179,7 +1209,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok added_cond_kwargs=encoded_text, ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( noise_pred, start_timesteps, noisy_model_input, @@ -1190,17 +1220,27 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 - # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after - # noisy_latents with both the conditioning embedding c and unconditional embedding 0 - # Get teacher model prediction on noisy_latents and conditional embedding + # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the + # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these + # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE + # solver timestep. with torch.no_grad(): with torch.autocast("cuda"): + # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, encoder_hidden_states=prompt_embeds.to(weight_dtype), ).sample - cond_pred_x0 = predicted_origin( + cond_pred_x0 = get_predicted_original_sample( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + cond_pred_noise = get_predicted_noise( cond_teacher_output, start_timesteps, noisy_model_input, @@ -1209,13 +1249,21 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok sigma_schedule, ) - # Get teacher model prediction on noisy_latents and unconditional embedding + # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0 uncond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), ).sample - uncond_pred_x0 = predicted_origin( + uncond_pred_x0 = get_predicted_original_sample( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + uncond_pred_noise = get_predicted_noise( uncond_teacher_output, start_timesteps, noisy_model_input, @@ -1224,12 +1272,17 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok sigma_schedule, ) - # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise) + # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) - pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise) + # 4. Run one step of the ODE solver to estimate the next point x_prev on the + # augmented PF-ODE trajectory (solving backward in time) + # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. x_prev = solver.ddim_step(pred_x0, pred_noise, index) - # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n + # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) + # Note that we do not use a separate target network for LCM-LoRA distillation. with torch.no_grad(): with torch.autocast("cuda", dtype=weight_dtype): target_noise_pred = unet( @@ -1238,7 +1291,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok timestep_cond=None, encoder_hidden_states=prompt_embeds.float(), ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( target_noise_pred, timesteps, x_prev, @@ -1248,7 +1301,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok ) target = c_skip * x_prev + c_out * pred_x_0 - # 20.4.13. Calculate loss + # 10. Calculate loss if args.loss_type == "l2": loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") elif args.loss_type == "huber": @@ -1256,7 +1309,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c ) - # 20.4.14. Backpropagate on the online student model (`unet`) + # 11. Backpropagate on the online student model (`unet`) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 2ecd6f43dcde..014a770fa0ba 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -162,7 +162,7 @@ def __call__(self, x): return False -class Text2ImageDataset: +class SDXLText2ImageDataset: def __init__( self, train_shards_path_or_url: Union[str, List[str]], @@ -346,19 +346,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= # Compare LCMScheduler.step, Step 4 -def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): +def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": - sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) - alphas = extract_into_tensor(alphas, timesteps, sample.shape) pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "sample": + pred_x_0 = model_output elif prediction_type == "v_prediction": - pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output + pred_x_0 = alphas * sample - sigmas * model_output else: - raise ValueError(f"Prediction type {prediction_type} currently not supported.") + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) return pred_x_0 +# Based on step 4 in DDIMScheduler.step +def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_epsilon = model_output + elif prediction_type == "sample": + pred_epsilon = (sample - alphas * model_output) / sigmas + elif prediction_type == "v_prediction": + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_epsilon + + def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) @@ -830,9 +854,10 @@ def main(args): args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - # The scheduler calculates the alpha and sigma schedule for us + # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + # Initialize the DDIM ODE solver for distillation. solver = DDIMSolver( noise_scheduler.alphas_cumprod.numpy(), timesteps=noise_scheduler.config.num_train_timesteps, @@ -886,7 +911,7 @@ def main(args): text_encoder_two.requires_grad_(False) teacher_unet.requires_grad_(False) - # 7. Create online (`unet`) student U-Nets. + # 7. Create online student U-Net. unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) @@ -950,6 +975,7 @@ def main(args): # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device) + # Move the ODE solver to accelerator.device. solver = solver.to(accelerator.device) # 10. Handle saving and loading of checkpoints @@ -1057,7 +1083,7 @@ def compute_embeddings( return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} - dataset = Text2ImageDataset( + dataset = SDXLText2ImageDataset( train_shards_path_or_url=args.train_shards_path_or_url, num_train_examples=args.max_train_samples, per_gpu_batch_size=args.train_batch_size, @@ -1175,6 +1201,7 @@ def compute_embeddings( for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): + # 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates) image, text, orig_size, crop_coords = batch image = image.to(accelerator.device, non_blocking=True) @@ -1196,37 +1223,37 @@ def compute_embeddings( latents = latents * vae.config.scaling_factor if args.pretrained_vae_model_name_or_path is None: latents = latents.to(weight_dtype) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. + # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() start_timesteps = solver.ddim_timesteps[index] timesteps = start_timesteps - topk timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) - # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps. + # 3. Get boundary scalings for start_timesteps and (end) timesteps. c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] - # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each + # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noise = torch.randn_like(latents) noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) - # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it + # 5. Sample a random guidance scale w from U[w_min, w_max] + # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w = w.reshape(bsz, 1, 1, 1) w = w.to(device=latents.device, dtype=latents.dtype) - # 20.4.8. Prepare prompt embeds and unet_added_conditions + # 6. Prepare prompt embeds and unet_added_conditions prompt_embeds = encoded_text.pop("prompt_embeds") - # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) noise_pred = unet( noisy_model_input, start_timesteps, @@ -1235,7 +1262,7 @@ def compute_embeddings( added_cond_kwargs=encoded_text, ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( noise_pred, start_timesteps, noisy_model_input, @@ -1246,18 +1273,28 @@ def compute_embeddings( model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 - # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after - # noisy_latents with both the conditioning embedding c and unconditional embedding 0 - # Get teacher model prediction on noisy_latents and conditional embedding + # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the + # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these + # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE + # solver timestep. with torch.no_grad(): with torch.autocast("cuda"): + # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, encoder_hidden_states=prompt_embeds.to(weight_dtype), added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, ).sample - cond_pred_x0 = predicted_origin( + cond_pred_x0 = get_predicted_original_sample( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + cond_pred_noise = get_predicted_noise( cond_teacher_output, start_timesteps, noisy_model_input, @@ -1266,7 +1303,7 @@ def compute_embeddings( sigma_schedule, ) - # Get teacher model prediction on noisy_latents and unconditional embedding + # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0 uncond_added_conditions = copy.deepcopy(encoded_text) uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds uncond_teacher_output = teacher_unet( @@ -1275,7 +1312,15 @@ def compute_embeddings( encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, ).sample - uncond_pred_x0 = predicted_origin( + uncond_pred_x0 = get_predicted_original_sample( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + uncond_pred_noise = get_predicted_noise( uncond_teacher_output, start_timesteps, noisy_model_input, @@ -1284,12 +1329,17 @@ def compute_embeddings( sigma_schedule, ) - # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise) + # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) - pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise) + # 4. Run one step of the ODE solver to estimate the next point x_prev on the + # augmented PF-ODE trajectory (solving backward in time) + # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. x_prev = solver.ddim_step(pred_x0, pred_noise, index) - # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n + # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) + # Note that we do not use a separate target network for LCM-LoRA distillation. with torch.no_grad(): with torch.autocast("cuda", enabled=True, dtype=weight_dtype): target_noise_pred = unet( @@ -1299,7 +1349,7 @@ def compute_embeddings( encoder_hidden_states=prompt_embeds.float(), added_cond_kwargs=encoded_text, ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( target_noise_pred, timesteps, x_prev, @@ -1309,7 +1359,7 @@ def compute_embeddings( ) target = c_skip * x_prev + c_out * pred_x_0 - # 20.4.13. Calculate loss + # 10. Calculate loss if args.loss_type == "l2": loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") elif args.loss_type == "huber": @@ -1317,7 +1367,7 @@ def compute_embeddings( torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c ) - # 20.4.14. Backpropagate on the online student model (`unet`) + # 11. Backpropagate on the online student model (`unet`) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 1dfac0464271..54d05bb5ea26 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -138,7 +138,7 @@ def __call__(self, x): return False -class Text2ImageDataset: +class SDText2ImageDataset: def __init__( self, train_shards_path_or_url: Union[str, List[str]], @@ -336,19 +336,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= # Compare LCMScheduler.step, Step 4 -def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): +def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": - sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) - alphas = extract_into_tensor(alphas, timesteps, sample.shape) pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "sample": + pred_x_0 = model_output elif prediction_type == "v_prediction": - pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output + pred_x_0 = alphas * sample - sigmas * model_output else: - raise ValueError(f"Prediction type {prediction_type} currently not supported.") + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) return pred_x_0 +# Based on step 4 in DDIMScheduler.step +def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_epsilon = model_output + elif prediction_type == "sample": + pred_epsilon = (sample - alphas * model_output) / sigmas + elif prediction_type == "v_prediction": + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_epsilon + + def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) @@ -823,34 +847,35 @@ def main(args): args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - # The scheduler calculates the alpha and sigma schedule for us + # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + # Initialize the DDIM ODE solver for distillation. solver = DDIMSolver( noise_scheduler.alphas_cumprod.numpy(), timesteps=noise_scheduler.config.num_train_timesteps, ddim_timesteps=args.num_ddim_timesteps, ) - # 2. Load tokenizers from SD-XL checkpoint. + # 2. Load tokenizers from SD 1.X/2.X checkpoint. tokenizer = AutoTokenizer.from_pretrained( args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False ) - # 3. Load text encoders from SD-1.5 checkpoint. + # 3. Load text encoders from SD 1.X/2.X checkpoint. # import correct text encoder classes text_encoder = CLIPTextModel.from_pretrained( args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision ) - # 4. Load VAE from SD-XL checkpoint (or more stable VAE) + # 4. Load VAE from SD 1.X/2.X checkpoint vae = AutoencoderKL.from_pretrained( args.pretrained_teacher_model, subfolder="vae", revision=args.teacher_revision, ) - # 5. Load teacher U-Net from SD-XL checkpoint + # 5. Load teacher U-Net from SD 1.X/2.X checkpoint teacher_unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) @@ -860,7 +885,7 @@ def main(args): text_encoder.requires_grad_(False) teacher_unet.requires_grad_(False) - # 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.) + # 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.) # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None if teacher_unet.config.time_cond_proj_dim is None: teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim @@ -869,8 +894,8 @@ def main(args): unet.load_state_dict(teacher_unet.state_dict(), strict=False) unet.train() - # 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging). - # Initialize from unet + # 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging). + # Initialize from (online) unet target_unet = UNet2DConditionModel(**teacher_unet.config) target_unet.load_state_dict(unet.state_dict()) target_unet.train() @@ -887,7 +912,7 @@ def main(args): f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" ) - # 10. Handle mixed precision and device placement + # 9. Handle mixed precision and device placement # For mixed precision training we cast all non-trainable weigths to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 @@ -914,7 +939,7 @@ def main(args): sigma_schedule = sigma_schedule.to(accelerator.device) solver = solver.to(accelerator.device) - # 11. Handle saving and loading of checkpoints + # 10. Handle saving and loading of checkpoints # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -948,7 +973,7 @@ def load_model_hook(models, input_dir): accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) - # 12. Enable optimizations + # 11. Enable optimizations if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers @@ -994,13 +1019,14 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) + # 13. Dataset creation and data processing # Here, we compute not just the text embeddings but also the additional embeddings # needed for the SD XL UNet to operate. def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True): prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train) return {"prompt_embeds": prompt_embeds} - dataset = Text2ImageDataset( + dataset = SDText2ImageDataset( train_shards_path_or_url=args.train_shards_path_or_url, num_train_examples=args.max_train_samples, per_gpu_batch_size=args.train_batch_size, @@ -1020,6 +1046,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok tokenizer=tokenizer, ) + # 14. LR Scheduler creation # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) @@ -1034,6 +1061,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok num_training_steps=args.max_train_steps, ) + # 15. Prepare for training # Prepare everything with our `accelerator`. unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler) @@ -1055,7 +1083,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok ).input_ids.to(accelerator.device) uncond_prompt_embeds = text_encoder(uncond_input_ids)[0] - # Train! + # 16. Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") @@ -1106,6 +1134,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): + # 1. Load and process the image and text conditioning image, text = batch image = image.to(accelerator.device, non_blocking=True) @@ -1123,29 +1152,28 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok latents = latents * vae.config.scaling_factor latents = latents.to(weight_dtype) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. + # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() start_timesteps = solver.ddim_timesteps[index] timesteps = start_timesteps - topk timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) - # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps. + # 3. Get boundary scalings for start_timesteps and (end) timesteps. c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] - # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each + # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noise = torch.randn_like(latents) noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) - # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it + # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) @@ -1153,10 +1181,10 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok w = w.to(device=latents.device, dtype=latents.dtype) w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype) - # 20.4.8. Prepare prompt embeds and unet_added_conditions + # 6. Prepare prompt embeds and unet_added_conditions prompt_embeds = encoded_text.pop("prompt_embeds") - # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) noise_pred = unet( noisy_model_input, start_timesteps, @@ -1165,7 +1193,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok added_cond_kwargs=encoded_text, ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( noise_pred, start_timesteps, noisy_model_input, @@ -1176,17 +1204,27 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 - # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after - # noisy_latents with both the conditioning embedding c and unconditional embedding 0 - # Get teacher model prediction on noisy_latents and conditional embedding + # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the + # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these + # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE + # solver timestep. with torch.no_grad(): with torch.autocast("cuda"): + # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, encoder_hidden_states=prompt_embeds.to(weight_dtype), ).sample - cond_pred_x0 = predicted_origin( + cond_pred_x0 = get_predicted_original_sample( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + cond_pred_noise = get_predicted_noise( cond_teacher_output, start_timesteps, noisy_model_input, @@ -1195,13 +1233,21 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok sigma_schedule, ) - # Get teacher model prediction on noisy_latents and unconditional embedding + # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0 uncond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), ).sample - uncond_pred_x0 = predicted_origin( + uncond_pred_x0 = get_predicted_original_sample( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + uncond_pred_noise = get_predicted_noise( uncond_teacher_output, start_timesteps, noisy_model_input, @@ -1210,12 +1256,16 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok sigma_schedule, ) - # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise) + # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) - pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise) + # 4. Run one step of the ODE solver to estimate the next point x_prev on the + # augmented PF-ODE trajectory (solving backward in time) + # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. x_prev = solver.ddim_step(pred_x0, pred_noise, index) - # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n + # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) with torch.no_grad(): with torch.autocast("cuda", dtype=weight_dtype): target_noise_pred = target_unet( @@ -1224,7 +1274,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok timestep_cond=w_embedding, encoder_hidden_states=prompt_embeds.float(), ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( target_noise_pred, timesteps, x_prev, @@ -1234,7 +1284,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok ) target = c_skip * x_prev + c_out * pred_x_0 - # 20.4.13. Calculate loss + # 10. Calculate loss if args.loss_type == "l2": loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") elif args.loss_type == "huber": @@ -1242,7 +1292,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c ) - # 20.4.14. Backpropagate on the online student model (`unet`) + # 11. Backpropagate on the online student model (`unet`) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) @@ -1252,7 +1302,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - # 20.4.15. Make EMA update to target student model parameters + # 12. Make EMA update to target student model parameters (`target_unet`) update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay) progress_bar.update(1) global_step += 1 diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 952bec67d148..e58db46c9811 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -144,7 +144,7 @@ def __call__(self, x): return False -class Text2ImageDataset: +class SDXLText2ImageDataset: def __init__( self, train_shards_path_or_url: Union[str, List[str]], @@ -324,19 +324,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= # Compare LCMScheduler.step, Step 4 -def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): +def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": - sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) - alphas = extract_into_tensor(alphas, timesteps, sample.shape) pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "sample": + pred_x_0 = model_output elif prediction_type == "v_prediction": - pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output + pred_x_0 = alphas * sample - sigmas * model_output else: - raise ValueError(f"Prediction type {prediction_type} currently not supported.") + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) return pred_x_0 +# Based on step 4 in DDIMScheduler.step +def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_epsilon = model_output + elif prediction_type == "sample": + pred_epsilon = (sample - alphas * model_output) / sigmas + elif prediction_type == "v_prediction": + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_epsilon + + def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) @@ -863,9 +887,10 @@ def main(args): args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - # The scheduler calculates the alpha and sigma schedule for us + # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + # Initialize the DDIM ODE solver for distillation. solver = DDIMSolver( noise_scheduler.alphas_cumprod.numpy(), timesteps=noise_scheduler.config.num_train_timesteps, @@ -919,7 +944,7 @@ def main(args): text_encoder_two.requires_grad_(False) teacher_unet.requires_grad_(False) - # 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.) + # 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.) # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None if teacher_unet.config.time_cond_proj_dim is None: teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim @@ -928,8 +953,8 @@ def main(args): unet.load_state_dict(teacher_unet.state_dict(), strict=False) unet.train() - # 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging). - # Initialize from unet + # 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging). + # Initialize from (online) unet target_unet = UNet2DConditionModel(**teacher_unet.config) target_unet.load_state_dict(unet.state_dict()) target_unet.train() @@ -971,6 +996,7 @@ def main(args): # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device) + # Move the ODE solver to accelerator.device. solver = solver.to(accelerator.device) # 10. Handle saving and loading of checkpoints @@ -1084,7 +1110,7 @@ def compute_embeddings( return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} - dataset = Text2ImageDataset( + dataset = SDXLText2ImageDataset( train_shards_path_or_url=args.train_shards_path_or_url, num_train_examples=args.max_train_samples, per_gpu_batch_size=args.train_batch_size, @@ -1202,6 +1228,7 @@ def compute_embeddings( for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): + # 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates) image, text, orig_size, crop_coords = batch image = image.to(accelerator.device, non_blocking=True) @@ -1223,38 +1250,39 @@ def compute_embeddings( latents = latents * vae.config.scaling_factor if args.pretrained_vae_model_name_or_path is None: latents = latents.to(weight_dtype) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. + # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() start_timesteps = solver.ddim_timesteps[index] timesteps = start_timesteps - topk timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) - # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps. + # 3. Get boundary scalings for start_timesteps and (end) timesteps. c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] - # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each + # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noise = torch.randn_like(latents) noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) - # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it + # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) + # Move to U-Net device and dtype w = w.to(device=latents.device, dtype=latents.dtype) + w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype) - # 20.4.8. Prepare prompt embeds and unet_added_conditions + # 6. Prepare prompt embeds and unet_added_conditions prompt_embeds = encoded_text.pop("prompt_embeds") - # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) noise_pred = unet( noisy_model_input, start_timesteps, @@ -1263,7 +1291,7 @@ def compute_embeddings( added_cond_kwargs=encoded_text, ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( noise_pred, start_timesteps, noisy_model_input, @@ -1274,18 +1302,28 @@ def compute_embeddings( model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 - # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after - # noisy_latents with both the conditioning embedding c and unconditional embedding 0 - # Get teacher model prediction on noisy_latents and conditional embedding + # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the + # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these + # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE + # solver timestep. with torch.no_grad(): with torch.autocast("cuda"): + # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, encoder_hidden_states=prompt_embeds.to(weight_dtype), added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, ).sample - cond_pred_x0 = predicted_origin( + cond_pred_x0 = get_predicted_original_sample( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + cond_pred_noise = get_predicted_noise( cond_teacher_output, start_timesteps, noisy_model_input, @@ -1294,7 +1332,7 @@ def compute_embeddings( sigma_schedule, ) - # Get teacher model prediction on noisy_latents and unconditional embedding + # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0 uncond_added_conditions = copy.deepcopy(encoded_text) uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds uncond_teacher_output = teacher_unet( @@ -1303,7 +1341,15 @@ def compute_embeddings( encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, ).sample - uncond_pred_x0 = predicted_origin( + uncond_pred_x0 = get_predicted_original_sample( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + uncond_pred_noise = get_predicted_noise( uncond_teacher_output, start_timesteps, noisy_model_input, @@ -1312,12 +1358,16 @@ def compute_embeddings( sigma_schedule, ) - # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise) + # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) - pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise) + # 4. Run one step of the ODE solver to estimate the next point x_prev on the + # augmented PF-ODE trajectory (solving backward in time) + # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. x_prev = solver.ddim_step(pred_x0, pred_noise, index) - # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n + # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) with torch.no_grad(): with torch.autocast("cuda", dtype=weight_dtype): target_noise_pred = target_unet( @@ -1327,7 +1377,7 @@ def compute_embeddings( encoder_hidden_states=prompt_embeds.float(), added_cond_kwargs=encoded_text, ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( target_noise_pred, timesteps, x_prev, @@ -1337,7 +1387,7 @@ def compute_embeddings( ) target = c_skip * x_prev + c_out * pred_x_0 - # 20.4.13. Calculate loss + # 10. Calculate loss if args.loss_type == "l2": loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") elif args.loss_type == "huber": @@ -1345,7 +1395,7 @@ def compute_embeddings( torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c ) - # 20.4.14. Backpropagate on the online student model (`unet`) + # 11. Backpropagate on the online student model (`unet`) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) @@ -1355,7 +1405,7 @@ def compute_embeddings( # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - # 20.4.15. Make EMA update to target student model parameters + # 12. Make EMA update to target student model parameters (`target_unet`) update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay) progress_bar.update(1) global_step += 1