diff --git a/examples/advanced_diffusion_training/README.md b/examples/advanced_diffusion_training/README.md new file mode 100644 index 000000000000..0a49284543d2 --- /dev/null +++ b/examples/advanced_diffusion_training/README.md @@ -0,0 +1,244 @@ +# Advanced diffusion training examples + +## Train Dreambooth LoRA with Stable Diffusion XL +> [!TIP] +> 💡 This example follows the techniques and recommended practices covered in the blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). Make sure to check it out before starting 🤗 + +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. + +LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen* +In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: +- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114) +- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable. +- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter. +[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in +the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. + +The `train_dreambooth_lora_sdxl_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_sdxl.py`, with +advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl), +[Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️ + +> [!NOTE] +> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳 +> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora) + +📚 Read more about the advanced features and best practices in this community derived blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script) + + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/advanced_diffusion_training` folder and run +```bash +pip install -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell e.g. a notebook + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. + +### Pivotal Tuning +**Training with text encoder(s)** + +Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization +available with `train_dreambooth_lora_sdxl_advanced.py`, in the advanced script **pivotal tuning** is also supported. +[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning - +we insert new tokens into the text encoders of the model, instead of reusing existing ones. +We then optimize the newly-inserted token embeddings to represent the new concept. + +To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`). +Please keep the following points in mind: + +* SDXL has two text encoders. So, we fine-tune both using LoRA. +* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memoםהקרry. + + +### 3D icon example + +Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon. + +Let's first download it locally: + +```python +from huggingface_hub import snapshot_download + +local_dir = "./3d_icon" +snapshot_download( + "LinoyTsaban/3d_icon", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +Let's review some of the advanced features we're going to be using for this example: +- **custom captions**: +To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by +```bash +pip install datasets +``` + +Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt") + +``` +--dataset_name=./3d_icon +--caption_column=prompt +``` + +You can also load a dataset straight from by specifying it's name in `dataset_name`. +Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset. + +- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer +- **pivotal tuning** +- **min SNR gamma** + +**Now, we can launch training:** + +```bash +export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" +export DATASET_NAME="./3d_icon" +export OUTPUT_DIR="3d-icon-SDXL-LoRA" +export VAE_PATH="madebyollin/sdxl-vae-fp16-fix" + +accelerate launch train_dreambooth_lora_sdxl_advanced.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --pretrained_vae_model_name_or_path=$VAE_PATH \ + --dataset_name=$DATASET_NAME \ + --instance_prompt="3d icon in the style of TOK" \ + --validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \ + --output_dir=$OUTPUT_DIR \ + --caption_column="prompt" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=3 \ + --repeats=1 \ + --report_to="wandb"\ + --gradient_accumulation_steps=1 \ + --gradient_checkpointing \ + --learning_rate=1.0 \ + --text_encoder_lr=1.0 \ + --optimizer="prodigy"\ + --train_text_encoder_ti\ + --train_text_encoder_ti_frac=0.5\ + --snr_gamma=5.0 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --rank=8 \ + --max_train_steps=1000 \ + --checkpointing_steps=2000 \ + --seed="0" \ + --push_to_hub +``` + +To better track our training experiments, we're using the following flags in the command above: + +* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. +* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +Our experiments were conducted on a single 40GB A100 GPU. + + +### Inference + +Once training is done, we can perform inference like so: +1. starting with loading the unet lora weights +```python +import torch +from huggingface_hub import hf_hub_download, upload_file +from diffusers import DiffusionPipeline +from diffusers.models import AutoencoderKL +from safetensors.torch import load_file + +username = "linoyts" +repo_id = f"{username}/3d-icon-SDXL-LoRA" + +pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16, + variant="fp16", +).to("cuda") + + +pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors") +``` +2. now we load the pivotal tuning embeddings + +```python +text_encoders = [pipe.text_encoder, pipe.text_encoder_2] +tokenizers = [pipe.tokenizer, pipe.tokenizer_2] + +embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-SDXL-LoRA_emb.safetensors", repo_type="model") + +state_dict = load_file(embedding_path) +# load embeddings of text_encoder 1 (CLIP ViT-L/14) +pipe.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) +# load embeddings of text_encoder 2 (CLIP ViT-G/14) +pipe.load_textual_inversion(state_dict["clip_g"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) +``` + +3. let's generate images + +```python +instance_token = "" +prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}" + +image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0] +image.save("llama.png") +``` + +### Comfy UI / AUTOMATIC1111 Inference +The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats! + +**AUTOMATIC1111 / SD.Next** \ +In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time. +- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. +- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory. + +You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls `. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`. + +**ComfyUI** \ +In ComfyUI we will load a LoRA and a textual embedding at the same time. +- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/) +- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/). +- +### Specifying a better VAE + +SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). + + +### Tips and Tricks +Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices) + +## Running on Colab Notebook +Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_advanced_example.ipynb). +to train using the advanced features (including pivotal tuning), and [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb) to train on a free colab, using some of the advanced features (excluding pivotal tuning) + diff --git a/examples/advanced_diffusion_training/requirements.txt b/examples/advanced_diffusion_training/requirements.txt new file mode 100644 index 000000000000..3f86855e1d1e --- /dev/null +++ b/examples/advanced_diffusion_training/requirements.txt @@ -0,0 +1,7 @@ +accelerate>=0.16.0 +torchvision +transformers>=4.25.1 +ftfy +tensorboard +Jinja2 +peft==0.7.0 \ No newline at end of file diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 385144b133a6..3f660c5a3f4f 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -119,10 +119,9 @@ def save_model_card( diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download from safetensors.torch import load_file """ - diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model") + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model") state_dict = load_file(embedding_path) pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) -pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) """ webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**. - Place it on it on your `embeddings` folder @@ -389,7 +388,7 @@ def parse_args(input_args=None): parser.add_argument( "--resolution", type=int, - default=1024, + default=512, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" @@ -645,6 +644,7 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") parser.add_argument( "--rank", type=int, @@ -745,10 +745,11 @@ def initialize_new_tokens(self, inserting_toks: List[str]): idx += 1 + # copied from train_dreambooth_lora_sdxl_advanced.py def save_embeddings(self, file_path: str): assert self.train_ids is not None, "Initialize new tokens before saving embeddings." tensors = {} - # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 + # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"} for idx, text_encoder in enumerate(self.text_encoders): assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( @@ -1634,6 +1635,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device + ) bsz = model_input.shape[0] # Sample a random timestep for each image timesteps = torch.randint( @@ -1788,6 +1794,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, + tokenizer=tokenizer_one, text_encoder=accelerator.unwrap_model(text_encoder_one), unet=accelerator.unwrap_model(unet), revision=args.revision, @@ -1860,6 +1867,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): unet_lora_layers=unet_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers, ) + + if args.train_text_encoder_ti: + embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors" + embedding_handler.save_embeddings(embeddings_path) + images = [] if args.validation_prompt and args.num_validation_images > 0: # Final inference @@ -1895,6 +1907,18 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # load attention processors pipeline.load_lora_weights(args.output_dir) + # load new tokens + if args.train_text_encoder_ti: + state_dict = load_file(embeddings_path) + all_new_tokens = [] + for key, value in token_abstraction_dict.items(): + all_new_tokens.extend(value) + pipeline.load_textual_inversion( + state_dict["clip_l"], + token=all_new_tokens, + text_encoder=pipeline.text_encoder, + tokenizer=pipeline.tokenizer, + ) # run inference pipeline = pipeline.to(accelerator.device) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None @@ -1917,11 +1941,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): } ) - if args.train_text_encoder_ti: - embedding_handler.save_embeddings( - f"{args.output_dir}/{args.output_dir}_emb.safetensors", - ) - # Conver to WebUI format lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index e35630e3e8af..6ae3d315f8ff 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -20,6 +20,7 @@ import logging import math import os +import random import re import shutil import warnings @@ -45,6 +46,7 @@ from safetensors.torch import load_file, save_file from torch.utils.data import Dataset from torchvision import transforms +from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig @@ -121,7 +123,7 @@ def save_model_card( diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download from safetensors.torch import load_file """ - diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model") + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model") state_dict = load_file(embedding_path) pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) @@ -397,18 +399,6 @@ def parse_args(input_args=None): " resolution" ), ) - parser.add_argument( - "--crops_coords_top_left_h", - type=int, - default=0, - help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), - ) - parser.add_argument( - "--crops_coords_top_left_w", - type=int, - default=0, - help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), - ) parser.add_argument( "--center_crop", default=False, @@ -418,6 +408,11 @@ def parse_args(input_args=None): " cropped. The images will be resized to the resolution first before cropping." ), ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) parser.add_argument( "--train_text_encoder", action="store_true", @@ -659,6 +654,7 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") parser.add_argument( "--rank", type=int, @@ -901,6 +897,41 @@ def __init__( self.instance_images = [] for img in instance_images: self.instance_images.extend(itertools.repeat(img, repeats)) + + # image processing to prepare for using SD-XL micro-conditioning + self.original_sizes = [] + self.crop_top_lefts = [] + self.pixel_values = [] + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + for image in self.instance_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + self.original_sizes.append((image.height, image.width)) + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + crop_top_left = (y1, x1) + self.crop_top_lefts.append(crop_top_left) + image = train_transforms(image) + self.pixel_values.append(image) + self.num_instance_images = len(self.instance_images) self._length = self.num_instance_images @@ -930,12 +961,12 @@ def __len__(self): def __getitem__(self, index): example = {} - instance_image = self.instance_images[index % self.num_instance_images] - instance_image = exif_transpose(instance_image) - - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") - example["instance_images"] = self.image_transforms(instance_image) + instance_image = self.pixel_values[index % self.num_instance_images] + original_size = self.original_sizes[index % self.num_instance_images] + crop_top_left = self.crop_top_lefts[index % self.num_instance_images] + example["instance_images"] = instance_image + example["original_size"] = original_size + example["crop_top_left"] = crop_top_left if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] @@ -966,6 +997,8 @@ def __getitem__(self, index): def collate_fn(examples, with_prior_preservation=False): pixel_values = [example["instance_images"] for example in examples] prompts = [example["instance_prompt"] for example in examples] + original_sizes = [example["original_size"] for example in examples] + crop_top_lefts = [example["crop_top_left"] for example in examples] # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. @@ -976,7 +1009,12 @@ def collate_fn(examples, with_prior_preservation=False): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - batch = {"pixel_values": pixel_values, "prompts": prompts} + batch = { + "pixel_values": pixel_values, + "prompts": prompts, + "original_sizes": original_sizes, + "crop_top_lefts": crop_top_lefts, + } return batch @@ -1198,7 +1236,9 @@ def main(args): args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement)) if args.with_prior_preservation: args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement)) - + if args.validation_prompt: + args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) + print("validation prompt:", args.validation_prompt) # initialize the new tokens for textual inversion embedding_handler = TokenEmbeddingsHandler( [text_encoder_one, text_encoder_two], [tokenizer_one, tokenizer_two] @@ -1539,11 +1579,11 @@ def load_model_hook(models, input_dir): # pooled text embeddings # time ids - def compute_time_ids(): + def compute_time_ids(crops_coords_top_left, original_size=None): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids - original_size = (args.resolution, args.resolution) + if original_size is None: + original_size = (args.resolution, args.resolution) target_size = (args.resolution, args.resolution) - crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids]) add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) @@ -1560,9 +1600,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) return prompt_embeds, pooled_prompt_embeds - # Handle instance prompt. - instance_time_ids = compute_time_ids() - # If no type of tuning is done on the text_encoder and custom instance prompts are NOT # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. @@ -1573,7 +1610,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Handle class prompt for prior-preservation. if args.with_prior_preservation: - class_time_ids = compute_time_ids() if freeze_text_encoder: class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( args.class_prompt, text_encoders, tokenizers @@ -1588,9 +1624,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. - add_time_ids = instance_time_ids - if args.with_prior_preservation: - add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0) # if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion add_special_tokens = True if args.train_text_encoder_ti else False @@ -1613,12 +1646,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) - if args.train_text_encoder_ti and args.validation_prompt: - # replace instances of --token_abstraction in validation prompt with the new tokens: "" etc. - for token_abs, token_replacement in train_dataset.token_abstraction_dict.items(): - args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) - print("validation prompt:", args.validation_prompt) - if args.cache_latents: latents_cache = [] for batch in tqdm(train_dataloader, desc="Caching latents"): @@ -1778,6 +1805,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device + ) + bsz = model_input.shape[0] # Sample a random timestep for each image timesteps = torch.randint( @@ -1789,19 +1822,26 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # (this is the forward diffusion process) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + # time ids + add_time_ids = torch.cat( + [ + compute_time_ids(original_size=s, crops_coords_top_left=c) + for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"]) + ] + ) + # Calculate the elements to repeat depending on the use of prior-preservation and custom captions. if not train_dataset.custom_instance_prompts: elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz - elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz else: elems_to_repeat_text_embeds = 1 - elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz # Predict the noise residual if freeze_text_encoder: unet_added_conditions = { - "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1), + "time_ids": add_time_ids, + # "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1), "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), } prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) @@ -1812,7 +1852,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): added_cond_kwargs=unet_added_conditions, ).sample else: - unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} + unet_added_conditions = {"time_ids": add_time_ids} prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=None, @@ -1954,6 +1994,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, + tokenizer=tokenizer_one, + tokenizer_2=tokenizer_two, text_encoder=accelerator.unwrap_model(text_encoder_one), text_encoder_2=accelerator.unwrap_model(text_encoder_two), unet=accelerator.unwrap_model(unet), @@ -2033,6 +2075,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, ) + + if args.train_text_encoder_ti: + embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors" + embedding_handler.save_embeddings(embeddings_path) + images = [] if args.validation_prompt and args.num_validation_images > 0: # Final inference @@ -2068,6 +2115,25 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # load attention processors pipeline.load_lora_weights(args.output_dir) + # load new tokens + if args.train_text_encoder_ti: + state_dict = load_file(embeddings_path) + all_new_tokens = [] + for key, value in token_abstraction_dict.items(): + all_new_tokens.extend(value) + pipeline.load_textual_inversion( + state_dict["clip_l"], + token=all_new_tokens, + text_encoder=pipeline.text_encoder, + tokenizer=pipeline.tokenizer, + ) + pipeline.load_textual_inversion( + state_dict["clip_g"], + token=all_new_tokens, + text_encoder=pipeline.text_encoder_2, + tokenizer=pipeline.tokenizer_2, + ) + # run inference pipeline = pipeline.to(accelerator.device) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None @@ -2090,11 +2156,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): } ) - if args.train_text_encoder_ti: - embedding_handler.save_embeddings( - f"{args.output_dir}/{args.output_dir}_emb.safetensors", - ) - # Conver to WebUI format lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)