Skip to content

Textual inv make save log both steps #2178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 71 additions & 51 deletions examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import math
import os
import random
import warnings
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -54,6 +55,9 @@
from diffusers.utils.import_utils import is_xformers_available


if is_wandb_available():
import wandb

if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
Expand All @@ -79,6 +83,50 @@
logger = get_logger(__name__)


def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline (note: unet and vae are loaded again in float32)
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
unet=unet,
vae=vae,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
images.append(image)

for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
]
}
)

del pipeline
torch.cuda.empty_cache()


def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
Expand Down Expand Up @@ -268,12 +316,22 @@ def parse_args():
default=4,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_steps",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using steps here (and in dreambooth) script makes sense to me and also it aligns well with save_steps. But Removing the existing arg would be a breaking change,

Instead of changing the arg, maybe we could introduce an additional arg called validation_steps and allow only one of them to be set. Inside the script, we always use validation_steps , we can compute that easily using the epochs value. And to align saving and validating we could prefer the validation_steps argument in the docs.

wdyt @patrickvonplaten @williamberman @pcuenca ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj good point! One point that I have is to do this we might want to make the logging a function(because calling in 2 places)

But if we do that we need to pass wandb as an argument as we aren't doing the import in the global scope. I was thinking of doing this before but it felt a bit hacky. Happy to show what I mean!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preferring config via steps over epochs makes sense to me 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Steps also makes sense to me but let's deprecate validation_epochs nicely. Could we keep validation_epochs and just add a new validation_steps . Then will convert validation_epochs to validation_steps if used and throw a deprecation warning. Would that be ok for you @isamu-isozaki ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten yup sounds good! Then will move the code to a function so it's cleaner

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten Should be done!

type=int,
default=100,
help=(
"Run validation every X steps. Validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`"
" and logging the images."
),
)
parser.add_argument(
"--validation_epochs",
type=int,
default=50,
default=None,
help=(
"Run validation every X epochs. Validation consists of running the prompt"
"Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the deprecation statement right below the argparsing like @williamberman mentioned below and do something like:

if args.validation_epochs is not None:
    warnings.warn(f"Deprecate ..... Please make sure to use `validation_steps` instead in the future. Setting `args.validation_steps` to {args.validation_epochs * num_samples_per_epoch}.")
    args.validation_steps = 

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once this is done args.validation_epochs doesn't have to be used anymore and it'll also be much easier to remove in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten Great point! For the length of the validation steps, I put it after the dataset creation since that is where we count the number of images in the folder/find out the length of each epoch

" `args.validation_prompt` multiple times: `args.num_validation_images`"
" and logging the images."
),
Expand Down Expand Up @@ -475,7 +533,6 @@ def main():
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
Expand Down Expand Up @@ -607,6 +664,15 @@ def main():
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
)
if args.validation_epochs is not None:
warnings.warn(
f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}."
" Deprecated validation_epochs in favor of `validation_steps`"
f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}",
FutureWarning,
stacklevel=2,
)
args.validation_steps = args.validation_epochs * len(train_dataset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten thank you!


# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
Expand Down Expand Up @@ -663,7 +729,6 @@ def main():
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
Expand Down Expand Up @@ -763,60 +828,15 @@ def main():
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
break

if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline (note: unet and vae are loaded again in float32)
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
unet=unet,
vae=vae,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = (
None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
)
images = []
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
images.append(image)

for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)

del pipeline
torch.cuda.empty_cache()

# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
Expand Down