We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent bde4880 commit c62b3a2Copy full SHA for c62b3a2
examples/dreambooth/train_dreambooth_flax.py
@@ -361,7 +361,8 @@ def main():
361
logger.info(f"Number of class images to sample: {num_new_images}.")
362
363
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
364
- sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+ total_sample_batch_size = args.sample_batch_size * jax.local_device_count()
365
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size)
366
367
for example in tqdm(
368
sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0
0 commit comments