diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 08d6b23d6deb..ce3e7f624843 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -1358,7 +1358,7 @@ def compute_embeddings( # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. with torch.no_grad(): - if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: + if torch.backends.mps.is_available() or "playground" in args.pretrained_teacher_model: autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(accelerator.device.type)