Skip to content

Commit c62b3a2

Browse files
authored
[Flax] Fix sample batch size DreamBooth (#1129)
fix sample batch size
1 parent bde4880 commit c62b3a2

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/dreambooth/train_dreambooth_flax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,8 @@ def main():
361361
logger.info(f"Number of class images to sample: {num_new_images}.")
362362

363363
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
364-
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
364+
total_sample_batch_size = args.sample_batch_size * jax.local_device_count()
365+
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size)
365366

366367
for example in tqdm(
367368
sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0

0 commit comments

Comments
 (0)