-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Textual inv make save log both steps #2178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
82b445f
67bb426
eba7a32
fb42636
561f646
6bfae4f
8f060e0
c03a7f6
78257c3
6cd8ad9
6448d4b
3f7a20f
dd60d79
5a94681
95345b5
fffffce
8a9832b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
import math | ||
import os | ||
import random | ||
import warnings | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
|
@@ -54,6 +55,9 @@ | |
from diffusers.utils.import_utils import is_xformers_available | ||
|
||
|
||
if is_wandb_available(): | ||
import wandb | ||
|
||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): | ||
PIL_INTERPOLATION = { | ||
"linear": PIL.Image.Resampling.BILINEAR, | ||
|
@@ -79,6 +83,50 @@ | |
logger = get_logger(__name__) | ||
|
||
|
||
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): | ||
logger.info( | ||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" | ||
f" {args.validation_prompt}." | ||
) | ||
# create pipeline (note: unet and vae are loaded again in float32) | ||
pipeline = DiffusionPipeline.from_pretrained( | ||
args.pretrained_model_name_or_path, | ||
text_encoder=accelerator.unwrap_model(text_encoder), | ||
tokenizer=tokenizer, | ||
unet=unet, | ||
vae=vae, | ||
revision=args.revision, | ||
torch_dtype=weight_dtype, | ||
) | ||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) | ||
pipeline = pipeline.to(accelerator.device) | ||
pipeline.set_progress_bar_config(disable=True) | ||
|
||
# run inference | ||
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) | ||
images = [] | ||
for _ in range(args.num_validation_images): | ||
with torch.autocast("cuda"): | ||
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] | ||
images.append(image) | ||
|
||
for tracker in accelerator.trackers: | ||
if tracker.name == "tensorboard": | ||
np_images = np.stack([np.asarray(img) for img in images]) | ||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") | ||
if tracker.name == "wandb": | ||
tracker.log( | ||
{ | ||
"validation": [ | ||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) | ||
] | ||
} | ||
) | ||
|
||
del pipeline | ||
torch.cuda.empty_cache() | ||
|
||
|
||
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path): | ||
logger.info("Saving embeddings") | ||
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] | ||
|
@@ -268,12 +316,22 @@ def parse_args(): | |
default=4, | ||
help="Number of images that should be generated during validation with `validation_prompt`.", | ||
) | ||
parser.add_argument( | ||
"--validation_steps", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using Instead of changing the arg, maybe we could introduce an additional arg called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patil-suraj good point! One point that I have is to do this we might want to make the logging a function(because calling in 2 places) But if we do that we need to pass wandb as an argument as we aren't doing the import in the global scope. I was thinking of doing this before but it felt a bit hacky. Happy to show what I mean! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Preferring config via steps over epochs makes sense to me 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Steps also makes sense to me but let's deprecate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patrickvonplaten yup sounds good! Then will move the code to a function so it's cleaner There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patrickvonplaten Should be done! |
||
type=int, | ||
default=100, | ||
help=( | ||
"Run validation every X steps. Validation consists of running the prompt" | ||
" `args.validation_prompt` multiple times: `args.num_validation_images`" | ||
" and logging the images." | ||
), | ||
) | ||
parser.add_argument( | ||
"--validation_epochs", | ||
type=int, | ||
default=50, | ||
default=None, | ||
help=( | ||
"Run validation every X epochs. Validation consists of running the prompt" | ||
"Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add the deprecation statement right below the argparsing like @williamberman mentioned below and do something like: if args.validation_epochs is not None:
warnings.warn(f"Deprecate ..... Please make sure to use `validation_steps` instead in the future. Setting `args.validation_steps` to {args.validation_epochs * num_samples_per_epoch}.")
args.validation_steps = There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. once this is done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patrickvonplaten Great point! For the length of the validation steps, I put it after the dataset creation since that is where we count the number of images in the folder/find out the length of each epoch |
||
" `args.validation_prompt` multiple times: `args.num_validation_images`" | ||
" and logging the images." | ||
), | ||
|
@@ -475,7 +533,6 @@ def main(): | |
if args.report_to == "wandb": | ||
if not is_wandb_available(): | ||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") | ||
import wandb | ||
|
||
# Make one log on every process with the configuration for debugging. | ||
logging.basicConfig( | ||
|
@@ -607,6 +664,15 @@ def main(): | |
train_dataloader = torch.utils.data.DataLoader( | ||
train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers | ||
) | ||
if args.validation_epochs is not None: | ||
warnings.warn( | ||
f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}." | ||
" Deprecated validation_epochs in favor of `validation_steps`" | ||
f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}", | ||
FutureWarning, | ||
stacklevel=2, | ||
) | ||
args.validation_steps = args.validation_epochs * len(train_dataset) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. super! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patrickvonplaten thank you! |
||
|
||
# Scheduler and math around the number of training steps. | ||
overrode_max_train_steps = False | ||
|
@@ -663,7 +729,6 @@ def main(): | |
logger.info(f" Total optimization steps = {args.max_train_steps}") | ||
global_step = 0 | ||
first_epoch = 0 | ||
|
||
# Potentially load in the weights and states from a previous save | ||
if args.resume_from_checkpoint: | ||
if args.resume_from_checkpoint != "latest": | ||
|
@@ -763,60 +828,15 @@ def main(): | |
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") | ||
accelerator.save_state(save_path) | ||
logger.info(f"Saved state to {save_path}") | ||
if args.validation_prompt is not None and global_step % args.validation_steps == 0: | ||
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) | ||
|
||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} | ||
progress_bar.set_postfix(**logs) | ||
accelerator.log(logs, step=global_step) | ||
|
||
if global_step >= args.max_train_steps: | ||
break | ||
|
||
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: | ||
logger.info( | ||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" | ||
f" {args.validation_prompt}." | ||
) | ||
# create pipeline (note: unet and vae are loaded again in float32) | ||
pipeline = DiffusionPipeline.from_pretrained( | ||
args.pretrained_model_name_or_path, | ||
text_encoder=accelerator.unwrap_model(text_encoder), | ||
tokenizer=tokenizer, | ||
unet=unet, | ||
vae=vae, | ||
revision=args.revision, | ||
torch_dtype=weight_dtype, | ||
) | ||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) | ||
pipeline = pipeline.to(accelerator.device) | ||
pipeline.set_progress_bar_config(disable=True) | ||
|
||
# run inference | ||
generator = ( | ||
None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) | ||
) | ||
images = [] | ||
for _ in range(args.num_validation_images): | ||
with torch.autocast("cuda"): | ||
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] | ||
images.append(image) | ||
|
||
for tracker in accelerator.trackers: | ||
if tracker.name == "tensorboard": | ||
np_images = np.stack([np.asarray(img) for img in images]) | ||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") | ||
if tracker.name == "wandb": | ||
tracker.log( | ||
{ | ||
"validation": [ | ||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") | ||
for i, image in enumerate(images) | ||
] | ||
} | ||
) | ||
|
||
del pipeline | ||
torch.cuda.empty_cache() | ||
|
||
# Create the pipeline using using the trained modules and save it. | ||
accelerator.wait_for_everyone() | ||
if accelerator.is_main_process: | ||
|
Uh oh!
There was an error while loading. Please reload this page.