From eb1ee61407b67ddb077da76606e17d709592750f Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sat, 2 Mar 2024 11:35:36 +0100 Subject: [PATCH 01/10] add edm style training --- .../train_dreambooth_lora_sdxl_advanced.py | 105 ++++++++++++++++-- 1 file changed, 93 insertions(+), 12 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 94a32bcc07f8..7b4508ffe4a3 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and import argparse +import contextlib import gc import hashlib import itertools +import json import logging import math import os @@ -37,7 +39,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed -from huggingface_hub import create_repo, upload_folder +from huggingface_hub import create_repo, hf_hub_download, upload_folder from packaging import version from peft import LoraConfig, set_peft_model_state_dict from peft.utils import get_peft_model_state_dict @@ -55,6 +57,8 @@ AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, + EDMEulerScheduler, + EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, ) @@ -78,6 +82,18 @@ logger = get_logger(__name__) +def determine_scheduler_type(pretrained_model_name_or_path, revision): + model_index_filename = "model_index.json" + if os.path.isdir(pretrained_model_name_or_path): + model_index = os.path.join(pretrained_model_name_or_path, model_index_filename) + else: + model_index = hf_hub_download( + repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision + ) + + with open(model_index, "r") as f: + scheduler_type = json.load(f)["scheduler"][1] + return scheduler_type def save_model_card( repo_id: str, @@ -1776,6 +1792,18 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): disable=not accelerator.is_local_main_process, ) + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + if args.train_text_encoder: num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs) elif args.train_text_encoder_ti: # args.train_text_encoder_ti @@ -1827,9 +1855,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = model_input * vae_scaling_factor - if args.pretrained_vae_model_name_or_path is None: - model_input = model_input.to(weight_dtype) + if latents_mean is None and latents_std is None: + model_input = model_input * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + model_input = model_input.to(weight_dtype) + else: + latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype) + latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype) + model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std + model_input = model_input.to(dtype=weight_dtype) + # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) @@ -1840,15 +1875,33 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) bsz = model_input.shape[0] + # Sample a random timestep for each image - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device - ) - timesteps = timesteps.long() + if not args.do_edm_style_training: + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + else: + # in EDM formulation, the model is conditioned on the pre-conditioned noise levels + # instead of discrete timesteps, so here we sample indices to get the noise levels + # from `scheduler.timesteps` + indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,)) + timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device) + # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + # For EDM-style training, we first obtain the sigmas based on the continuous timesteps. + # We then precondition the final model inputs based on these sigmas instead of the timesteps. + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + if args.do_edm_style_training: + sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype) + if "EDM" in scheduler_type: + inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas) + else: + inp_noisy_latents = noisy_model_input / ((sigmas ** 2 + 1) ** 0.5) # time ids add_time_ids = torch.cat( @@ -1874,7 +1927,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): } prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( - noisy_model_input, + inp_noisy_latents if args.do_edm_style_training else noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, @@ -1892,14 +1945,42 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( - noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions + inp_noisy_latents if args.do_edm_style_training else noisy_model_input, + timesteps, + prompt_embeds_input, + added_cond_kwargs=unet_added_conditions ).sample + weighting = None + if args.do_edm_style_training: + # Similar to the input preconditioning, the model predictions are also preconditioned + # on noised model inputs (before preconditioning) and the sigmas. + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + if "EDM" in scheduler_type: + model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas) + else: + if noise_scheduler.config.prediction_type == "epsilon": + model_pred = model_pred * (-sigmas) + noisy_model_input + elif noise_scheduler.config.prediction_type == "v_prediction": + model_pred = model_pred * (-sigmas / (sigmas ** 2 + 1) ** 0.5) + ( + noisy_model_input / (sigmas ** 2 + 1) + ) + # We are not doing weighting here because it tends result in numerical problems. + # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 + # There might be other alternatives for weighting as well: + # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686 + if "EDM" not in scheduler_type: + weighting = (sigmas ** -2.0).float() + # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": - target = noise + target = model_input if args.do_edm_style_training else noise elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(model_input, noise, timesteps) + target = ( + model_input + if args.do_edm_style_training + else noise_scheduler.get_velocity(model_input, noise, timesteps) + ) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") From 60a35c92c5ea509833c0fab81d7e8d4e36c7e335 Mon Sep 17 00:00:00 2001 From: Linoy Date: Sat, 2 Mar 2024 10:42:22 +0000 Subject: [PATCH 02/10] style --- .../train_dreambooth_lora_sdxl_advanced.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 7b4508ffe4a3..ee495f58d395 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and import argparse -import contextlib import gc import hashlib import itertools @@ -57,8 +56,6 @@ AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, - EDMEulerScheduler, - EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, ) @@ -82,6 +79,7 @@ logger = get_logger(__name__) + def determine_scheduler_type(pretrained_model_name_or_path, revision): model_index_filename = "model_index.json" if os.path.isdir(pretrained_model_name_or_path): @@ -95,6 +93,7 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision): scheduler_type = json.load(f)["scheduler"][1] return scheduler_type + def save_model_card( repo_id: str, images=None, @@ -1865,7 +1864,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std model_input = model_input.to(dtype=weight_dtype) - # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) if args.noise_offset: @@ -1889,7 +1887,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,)) timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device) - # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) @@ -1901,7 +1898,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if "EDM" in scheduler_type: inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas) else: - inp_noisy_latents = noisy_model_input / ((sigmas ** 2 + 1) ** 0.5) + inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5) # time ids add_time_ids = torch.cat( @@ -1948,7 +1945,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): inp_noisy_latents if args.do_edm_style_training else noisy_model_input, timesteps, prompt_embeds_input, - added_cond_kwargs=unet_added_conditions + added_cond_kwargs=unet_added_conditions, ).sample weighting = None @@ -1962,15 +1959,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if noise_scheduler.config.prediction_type == "epsilon": model_pred = model_pred * (-sigmas) + noisy_model_input elif noise_scheduler.config.prediction_type == "v_prediction": - model_pred = model_pred * (-sigmas / (sigmas ** 2 + 1) ** 0.5) + ( - noisy_model_input / (sigmas ** 2 + 1) + model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + ( + noisy_model_input / (sigmas**2 + 1) ) # We are not doing weighting here because it tends result in numerical problems. # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 # There might be other alternatives for weighting as well: # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686 if "EDM" not in scheduler_type: - weighting = (sigmas ** -2.0).float() + weighting = (sigmas**-2.0).float() # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": From 60838f8d0571daedf363f2037ff222c86fdf74a8 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 7 Mar 2024 14:58:26 +0200 Subject: [PATCH 03/10] finish adding edm training feature --- .../train_dreambooth_lora_sdxl_advanced.py | 44 +++++++++++++++++-- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 469dd72f1129..01331c3ced7e 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -385,6 +385,12 @@ def parse_args(input_args=None): " `args.validation_prompt` multiple times: `args.num_validation_images`." ), ) + parser.add_argument( + "--do_edm_style_training", + default=False, + action="store_true", + help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.", + ) parser.add_argument( "--with_prior_preservation", default=False, @@ -1132,6 +1138,8 @@ def main(args): "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." " Please use `huggingface-cli login` to authenticate with the Hub." ) + if args.do_edm_style_training and args.snr_gamma is not None: + raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.") logging_dir = Path(args.output_dir, args.logging_dir) @@ -1249,7 +1257,19 @@ def main(args): ) # Load scheduler and models - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision) + if "EDM" in scheduler_type: + args.do_edm_style_training = True + noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + logger.info("Performing EDM-style training!") + elif args.do_edm_style_training: + noise_scheduler = EulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + logger.info("Performing EDM-style training!") + else: + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_one = text_encoder_cls_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) @@ -2001,10 +2021,28 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): target, target_prior = torch.chunk(target, 2, dim=0) # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + if weighting is not None: + prior_loss = torch.mean( + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + else: + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") if args.snr_gamma is None: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + if weighting is not None: + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape( + target.shape[0], -1 + ), + 1, + ) + loss = loss.mean() + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") else: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. From 7ec00642b0c099b1ffa6e63b3bc43a5a6dda5afb Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 7 Mar 2024 15:10:09 +0200 Subject: [PATCH 04/10] import fix --- .../train_dreambooth_lora_sdxl_advanced.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 01331c3ced7e..af18ac5b83cd 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -56,6 +56,8 @@ AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, + EDMEulerScheduler, + EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, ) From 1b0d2e269bdad3d718570d5422f03569eb0d22cd Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 7 Mar 2024 22:43:51 +0200 Subject: [PATCH 05/10] fix latents mean --- .../train_dreambooth_lora_sdxl_advanced.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index af18ac5b83cd..0116d3f1c6dd 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1289,7 +1289,12 @@ def main(args): revision=args.revision, variant=args.variant, ) - vae_scaling_factor = vae.config.scaling_factor + latents_mean = latents_std = None + if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: + latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: + latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) + unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) From 98fb9bc7527e3dae3ee021b0c8482f5d8e2508ed Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 11 Mar 2024 17:08:03 +0200 Subject: [PATCH 06/10] minor adjustments --- .../train_dreambooth_lora_sdxl_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 0116d3f1c6dd..37c42471145b 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -389,7 +389,6 @@ def parse_args(input_args=None): ) parser.add_argument( "--do_edm_style_training", - default=False, action="store_true", help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.", ) @@ -1833,6 +1832,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + #TODO: revisit other sampling algorithms sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) timesteps = timesteps.to(accelerator.device) From dc3cc978679bd03d418b61d91b9af8ac2fe5e600 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 11 Mar 2024 17:16:21 +0200 Subject: [PATCH 07/10] add edm to readme --- .../advanced_diffusion_training/README.md | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/examples/advanced_diffusion_training/README.md b/examples/advanced_diffusion_training/README.md index d1c2ff71e639..b77e625b41d1 100644 --- a/examples/advanced_diffusion_training/README.md +++ b/examples/advanced_diffusion_training/README.md @@ -259,6 +259,50 @@ pip install git+https://github.com/huggingface/peft.git **Inference** The inference is the same as if you train a regular LoRA 🤗 +## Conducting EDM-style training + +It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364). + +simply set: + +```diff ++ --do_edm_style_training \ +``` + +Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command: + +```bash +accelerate launch train_dreambooth_lora_sdxl_advanced.py \ + --pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \ + --dataset_name="linoyts/3d_icon" \ + --instance_prompt="3d icon in the style of TOK" \ + --validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \ + --output_dir="3d-icon-SDXL-LoRA" \ + --do_edm_style_training \ + --caption_column="prompt" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=3 \ + --repeats=1 \ + --report_to="wandb"\ + --gradient_accumulation_steps=1 \ + --gradient_checkpointing \ + --learning_rate=1.0 \ + --text_encoder_lr=1.0 \ + --optimizer="prodigy"\ + --train_text_encoder_ti\ + --train_text_encoder_ti_frac=0.5\ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --rank=8 \ + --max_train_steps=1000 \ + --checkpointing_steps=2000 \ + --seed="0" \ + --push_to_hub +``` + +> [!CAUTION] +> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant". ### Tips and Tricks Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices) From 246fa369f7619d9cb911f029850c74fa68d6b3cf Mon Sep 17 00:00:00 2001 From: Linoy Date: Mon, 11 Mar 2024 15:17:53 +0000 Subject: [PATCH 08/10] style --- .../train_dreambooth_lora_sdxl_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 37c42471145b..1febddb6398d 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1832,7 +1832,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): - #TODO: revisit other sampling algorithms + # TODO: revisit other sampling algorithms sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) timesteps = timesteps.to(accelerator.device) From 3cf936a5329646ff75f0cc414ec21246d683636d Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 14 Mar 2024 11:02:27 +0200 Subject: [PATCH 09/10] fix autocast and scheduler config issues when using edm --- .../train_dreambooth_lora_sdxl_advanced.py | 42 ++++++++++--------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index b7a051f85182..880e3d6a06a2 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -22,6 +22,7 @@ import math import os import random +import contextlib import re import shutil import warnings @@ -2172,17 +2173,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it scheduler_args = {} - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type + if not args.do_edm_style_training: + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" - scheduler_args["variance_type"] = variance_type + scheduler_args["variance_type"] = variance_type - pipeline.scheduler = DPMSolverMultistepScheduler.from_config( - pipeline.scheduler.config, **scheduler_args - ) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, **scheduler_args + ) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) @@ -2190,12 +2192,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} + inference_ctx = ( + contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast() + ) - with torch.cuda.amp.autocast(): - images = [ - pipeline(**pipeline_args, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] + with inference_ctx: + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in + range(args.num_validation_images)] for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -2267,15 +2270,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it scheduler_args = {} - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type + if not args.do_edm_style_training: + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" - scheduler_args["variance_type"] = variance_type + scheduler_args["variance_type"] = variance_type - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) # load attention processors pipeline.load_lora_weights(args.output_dir) From b365510e3c1dc598c79e0b65e457f82b0e7ae66b Mon Sep 17 00:00:00 2001 From: Linoy Date: Thu, 14 Mar 2024 09:09:47 +0000 Subject: [PATCH 10/10] style --- .../train_dreambooth_lora_sdxl_advanced.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 880e3d6a06a2..c99824072331 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and import argparse +import contextlib import gc import hashlib import itertools @@ -22,7 +23,6 @@ import math import os import random -import contextlib import re import shutil import warnings @@ -2193,12 +2193,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} inference_ctx = ( - contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast() + contextlib.nullcontext() + if "playground" in args.pretrained_model_name_or_path + else torch.cuda.amp.autocast() ) with inference_ctx: - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in - range(args.num_validation_images)] + images = [ + pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -2279,7 +2283,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): scheduler_args["variance_type"] = variance_type - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, **scheduler_args + ) # load attention processors pipeline.load_lora_weights(args.output_dir)