From 6ebc21fcddca2ad43be1cbf6190cdaaf971e91ef Mon Sep 17 00:00:00 2001 From: jainalphin Date: Sun, 12 May 2024 15:01:31 +0530 Subject: [PATCH] Fix conditional teacher model check in train_lcm_distill_lora_sdxl_wds.py --- .../consistency_distillation/train_lcm_distill_lora_sdxl_wds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 08d6b23d6deb..ce3e7f624843 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -1358,7 +1358,7 @@ def compute_embeddings( # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. with torch.no_grad(): - if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: + if torch.backends.mps.is_available() or "playground" in args.pretrained_teacher_model: autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(accelerator.device.type)