Skip to content

Commit ebd4495

Browse files
image generation main process checks (#2631)
1 parent e2d9a9b commit ebd4495

File tree

5 files changed

+25
-24
lines changed

5 files changed

+25
-24
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,13 +1000,14 @@ def load_model_hook(models, input_dir):
10001000
progress_bar.update(1)
10011001
global_step += 1
10021002

1003-
if global_step % args.checkpointing_steps == 0:
1004-
if accelerator.is_main_process:
1003+
if accelerator.is_main_process:
1004+
if global_step % args.checkpointing_steps == 0:
10051005
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
10061006
accelerator.save_state(save_path)
10071007
logger.info(f"Saved state to {save_path}")
1008-
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1009-
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
1008+
1009+
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1010+
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
10101011

10111012
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
10121013
progress_bar.set_postfix(**logs)

examples/research_projects/mulit_token_textual_inversion/textual_inversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def main():
864864
if global_step >= args.max_train_steps:
865865
break
866866

867-
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
867+
if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0:
868868
logger.info(
869869
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
870870
f" {args.validation_prompt}."

examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ def main():
790790
if global_step >= args.max_train_steps:
791791
break
792792

793-
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
793+
if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0:
794794
logger.info(
795795
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
796796
f" {args.validation_prompt}."

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -800,20 +800,19 @@ def collate_fn(examples):
800800
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
801801
)
802802

803-
if accelerator.is_main_process:
804-
for tracker in accelerator.trackers:
805-
if tracker.name == "tensorboard":
806-
np_images = np.stack([np.asarray(img) for img in images])
807-
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
808-
if tracker.name == "wandb":
809-
tracker.log(
810-
{
811-
"validation": [
812-
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
813-
for i, image in enumerate(images)
814-
]
815-
}
816-
)
803+
for tracker in accelerator.trackers:
804+
if tracker.name == "tensorboard":
805+
np_images = np.stack([np.asarray(img) for img in images])
806+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
807+
if tracker.name == "wandb":
808+
tracker.log(
809+
{
810+
"validation": [
811+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
812+
for i, image in enumerate(images)
813+
]
814+
}
815+
)
817816

818817
del pipeline
819818
torch.cuda.empty_cache()

examples/textual_inversion/textual_inversion.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -843,13 +843,14 @@ def main():
843843
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
844844
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
845845

846-
if global_step % args.checkpointing_steps == 0:
847-
if accelerator.is_main_process:
846+
if accelerator.is_main_process:
847+
if global_step % args.checkpointing_steps == 0:
848848
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
849849
accelerator.save_state(save_path)
850850
logger.info(f"Saved state to {save_path}")
851-
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
852-
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
851+
852+
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
853+
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
853854

854855
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
855856
progress_bar.set_postfix(**logs)

0 commit comments

Comments
 (0)