diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index d2652606b551..84493b1d9484 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -361,7 +361,8 @@ def main(): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + total_sample_batch_size = args.sample_batch_size * jax.local_device_count() + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size) for example in tqdm( sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0