diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 4d921900c059..9838ba8f7d4e 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -995,13 +995,14 @@ def load_model_hook(models, input_dir): progress_bar.update(1) global_step += 1 - if global_step % args.checkpointing_steps == 0: - if accelerator.is_main_process: + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") - if args.validation_prompt is not None and global_step % args.validation_steps == 0: - log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) diff --git a/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py b/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py index 459d9e65c5b6..05f714715fc9 100644 --- a/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py +++ b/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py @@ -864,7 +864,7 @@ def main(): if global_step >= args.max_train_steps: break - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0: logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py index 9f3c7a42707b..8d2c4c3c0bd4 100644 --- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py +++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py @@ -790,7 +790,7 @@ def main(): if global_step >= args.max_train_steps: break - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0: logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 22dee5d1d313..ba4450f7d82c 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -800,20 +800,19 @@ def collate_fn(examples): pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] ) - if accelerator.is_main_process: - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "validation": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) del pipeline torch.cuda.empty_cache() diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 9bfbf782acd4..92f3d27d4905 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -843,13 +843,14 @@ def main(): save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin") save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) - if global_step % args.checkpointing_steps == 0: - if accelerator.is_main_process: + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") - if args.validation_prompt is not None and global_step % args.validation_steps == 0: - log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs)