diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index c61c2ae44c8a..5dc6a15bfe01 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -18,6 +18,7 @@ import math import os import random +import warnings from pathlib import Path from typing import Optional @@ -54,6 +55,9 @@ from diffusers.utils.import_utils import is_xformers_available +if is_wandb_available(): + import wandb + if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): PIL_INTERPOLATION = { "linear": PIL.Image.Resampling.BILINEAR, @@ -79,6 +83,50 @@ logger = get_logger(__name__) +def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline (note: unet and vae are loaded again in float32) + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + unet=unet, + vae=vae, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path): logger.info("Saving embeddings") learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] @@ -268,12 +316,22 @@ def parse_args(): default=4, help="Number of images that should be generated during validation with `validation_prompt`.", ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) parser.add_argument( "--validation_epochs", type=int, - default=50, + default=None, help=( - "Run validation every X epochs. Validation consists of running the prompt" + "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`" " and logging the images." ), @@ -475,7 +533,6 @@ def main(): if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - import wandb # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -607,6 +664,15 @@ def main(): train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers ) + if args.validation_epochs is not None: + warnings.warn( + f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}." + " Deprecated validation_epochs in favor of `validation_steps`" + f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}", + FutureWarning, + stacklevel=2, + ) + args.validation_steps = args.validation_epochs * len(train_dataset) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -663,7 +729,6 @@ def main(): logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 - # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": @@ -763,6 +828,8 @@ def main(): save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -770,53 +837,6 @@ def main(): if global_step >= args.max_train_steps: break - - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) - # create pipeline (note: unet and vae are loaded again in float32) - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - text_encoder=accelerator.unwrap_model(text_encoder), - tokenizer=tokenizer, - unet=unet, - vae=vae, - revision=args.revision, - torch_dtype=weight_dtype, - ) - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - # run inference - generator = ( - None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) - ) - images = [] - for _ in range(args.num_validation_images): - with torch.autocast("cuda"): - image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] - images.append(image) - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "validation": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - - del pipeline - torch.cuda.empty_cache() - # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: