From 1b8972db4a004d5cec0b4a8bfc72ee0e8ee25121 Mon Sep 17 00:00:00 2001 From: Andrew Ishutin Date: Thu, 25 Jan 2024 15:51:17 +0300 Subject: [PATCH] fix custom diffusion training with concept list --- examples/custom_diffusion/train_custom_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index e48bd2586299..559430eba177 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -753,7 +753,7 @@ def main(args): num_new_images = args.num_class_images - cur_class_images logger.info(f"Number of class images to sample: {num_new_images}.") - sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataset = PromptDataset(concept["class_prompt"], num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader)