-
Notifications
You must be signed in to change notification settings - Fork 6.7k
enable stable-xl textual inversion #6421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contributions. However, it might be better to have it in a separate script and add proper test for it.
See how we did it for SDXL LoRA DreamBooth here: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
Tests:
| class DreamBoothLoRASDXL(ExamplesTestsAccelerate): |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @jiqing-feng,
Such a new textual inversion script would be very helpful I think. Can we maybe add it as a new _sdxl script like we've done for lora: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py ?
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Hi @sayakpaul @patrickvonplaten . Thanks for your review. I have added textual inversion for SDXL in a new script, and also a new test script. Would you please review it? Thx! BTW, from my experiments, only fine-tuning the text_encoder_1 will get a better result than find-tuning both text encoders. |
| return args | ||
|
|
||
|
|
||
| imagenet_templates_small = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this coming from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is copied from textual_inversion
| if self.center_crop: | ||
| crop = min(img.shape[0], img.shape[1]) | ||
| ( | ||
| h, | ||
| w, | ||
| ) = ( | ||
| img.shape[0], | ||
| img.shape[1], | ||
| ) | ||
| img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we use CenterCrop from torchvision for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is copied from textual_inversion
| # Move vae and unet and text_encoder_2 to device and cast to weight_dtype | ||
| unet.to(accelerator.device, dtype=weight_dtype) | ||
| vae.to(accelerator.device, dtype=weight_dtype) | ||
| text_encoder_2 = text_encoder_2.to(accelerator.device, dtype=weight_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not move the text_encoder too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
text_encoder_1 will move to the device in accelerate.prepare function
| # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode. | ||
| unet.train() | ||
| text_encoder_1.gradient_checkpointing_enable() | ||
| text_encoder_2.gradient_checkpointing_enable() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If text_encoder_2 is not trained then why enable gradient checkpointing here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I will remove it if we don't want to train text_encoder_2
| unet.train() | ||
| text_encoder_1.gradient_checkpointing_enable() | ||
| text_encoder_2.gradient_checkpointing_enable() | ||
| unet.enable_gradient_checkpointing() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be kept as the comment in line 698 explained it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is weird. None of the other training scripts enable gradient checkpointing on the models that are not being trained.
| sample_size = unet.config.sample_size * (2 ** (len(vae.config.block_out_channels) - 1)) | ||
| original_size = (sample_size, sample_size) | ||
| add_time_ids = torch.tensor( | ||
| [list(original_size + (0, 0) + original_size)], dtype=weight_dtype, device=accelerator.device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sample_size calculation seems to be wrong to me as it should be the original size of the input images. Also, we're not supplementing the crop coordinates here.
Could you refer to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora_sdxl.py and incorporate the changes here w.r.t how these micro-conditions are implemented?
| accelerator, | ||
| args, | ||
| save_path, | ||
| safe_serialization=not args.no_safe_serialization, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's default to safetensors and not make it configurable IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is also copied from textual_inversion, and it defaults to safetensors in this script. I can make it unconfigurable if you want to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
safetensors is the default format in diffusers. So, makes sense to not make it configurable here.
Also, the script that keep referring to is a bit old. So, we don't have to follow it note to note :)
| commit_message="End of training", | ||
| ignore_patterns=["step_*", "epoch_*"], | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also run validation here? We can use log_validation() here and log the images under "test" key, instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log_validation() is in the line 961. I think it would be better to do log_validation() after training and before saving, instead of after pushing to hub. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, what I mean is inside log_validation we're using "validation" key to log the medias. If we use the same key for covering both intermediate logging and the logging done after training, it might be inconvenient. To give you a better idea, refer to this:
| tracker_key = "test" if is_final_validation else "validation" |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great start! Thank you.
Left some initial comments related to the implementation. I'd prefer having the ability to train two text encoders, though.
|
Hi @sayakpaul. I think I have fixed all your comments except training 2 text decoders. This is the result that I only trained 1 text encoder for 500 steps on A100, and it seems great. Unfortunately, I didn't get an acceptable result if training 2 text encoders. Could we merge this example first? I will check what's wrong with 2 text encoders training. WDYT @patrickvonplaten Thanks! BTW, I run the following commands and it didn't change any code styles in my script. |
|
You need to run |
| f"Running validation... \n Generating {args.num_validation_images} images with prompt:" | ||
| f" {args.validation_prompt}." | ||
| ) | ||
| # create pipeline (note: unet and vae are loaded again in float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this the case?
| if args.validation_epochs is not None: | ||
| warnings.warn( | ||
| f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}." | ||
| " Deprecated validation_epochs in favor of `validation_steps`" | ||
| f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}", | ||
| FutureWarning, | ||
| stacklevel=2, | ||
| ) | ||
| args.validation_steps = args.validation_epochs * len(train_dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's simplify this. Let's expose only one argument from the command-line to control the interval of running validation.
| text_encoder_2.text_model.encoder.requires_grad_(False) | ||
| text_encoder_2.text_model.final_layer_norm.requires_grad_(False) | ||
| text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can just call text_encoder_2.requires_grad_(False) here no since we're not training it?
| # Move vae and unet and text_encoder_2 to device and cast to weight_dtype | ||
| unet.to(accelerator.device, dtype=weight_dtype) | ||
| vae.to(accelerator.device, dtype=weight_dtype) | ||
| text_encoder_2 = text_encoder_2.to(accelerator.device, dtype=weight_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no need to assign to text_encoder_2 variable after device placement.
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes, however, the open comments aren't resolved still:
- https://github.com/huggingface/diffusers/pull/6421/files#r1442508090
- https://github.com/huggingface/diffusers/pull/6421/files#r1442508672
- https://github.com/huggingface/diffusers/pull/6421/files#r1442509162
Apart from these, I think we'd need to add a separate README_sdxl.md for this example like other SDXL scripts, so that users know what training commands to use and that training text encoder 2 isn't supported for specific reasons.
0082197 to
322ef19
Compare
|
Hi @sayakpaul |
Unfortunately, it doesn't work |
| @@ -0,0 +1,27 @@ | |||
| ## Textual Inversion fine-tuning example for SDXL | |||
|
|
|||
| The `textual_inversion.py` do not support training stable-diffusion-XL as it has two text encoders, you can training SDXL by the following command: | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| The `textual_inversion.py` do not support training stable-diffusion-XL as it has two text encoders, you can training SDXL by the following command: |
I don't think we need to mention this. We can just add a note about the SDXL variant in the README.md file.
| --output_dir="./textual_inversion_cat_sdxl" | ||
| ``` | ||
|
|
||
| We only enabled training the first text encoder because of the precision issue, we will enable training the second text encoder once we fixed the problem. No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| We only enabled training the first text encoder because of the precision issue, we will enable training the second text encoder once we fixed the problem. | |
| For now, only training of the first text encoder is supported. |
| optimizer_1 = torch.optim.AdamW( | ||
| text_encoder_1.get_input_embeddings().parameters(), # only optimize the embeddings | ||
| lr=args.learning_rate, | ||
| betas=(args.adam_beta1, args.adam_beta2), | ||
| weight_decay=args.adam_weight_decay, | ||
| eps=args.adam_epsilon, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We usually also support the 8Bit Adam too:
| if args.use_8bit_adam: |
Since SDXL is quite heavier than SD, could we add it too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, let's keep to optimizer for now. No need to use optimizer_1.
| tokenizer_1=tokenizer_1, | ||
| tokenizer_2=tokenizer_2, | ||
| size=args.resolution, | ||
| placeholder_token=(" ".join(tokenizer_1.convert_ids_to_tokens(placeholder_token_ids))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's assign this " ".join(tokenizer_1.convert_ids_to_tokens(placeholder_token_ids))) in a separate variable.
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the changes. Just some final set of comments.
|
Hi @sayakpaul . Thanks for your review. I have fixed all the comments, would you please take a look? Thx! |
examples/textual_inversion/README.md
Outdated
|
|
||
| **___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** | ||
|
|
||
| **___Note: Please follow the README_sdxl.md if you are using the [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).___** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| **___Note: Please follow the README_sdxl.md if you are using the [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).___** | |
| **___Note: Please follow the [README_sdxl.md](./README_sdxl.md) if you are using the [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).___** |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there's this open comment still: https://github.com/huggingface/diffusers/pull/6421/files#r1442509162.
After everything is resolved, I will fix the quality issues.
Hi @sayakpaul . Thanks for your clarify, I have fixed it now, please take a look. Thx! |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for bearing with my requests!
|
Will merge after the CI is green :) |
Also thanks for your patience :) |
|
Thanks for your great contribution! |
* enable stable-xl textual inversion * check if optimizer_2 exists * check text_encoder_2 before using * add textual inversion for sdxl in a single file * fix style * fix example style * reset for error changes * add readme for sdxl * fix style * disable autocast as it will cause cast error when weight_dtype=bf16 * fix spelling error * fix style and readme and 8bit optimizer * add README_sdxl.md link * add tracker key on log_validation * run style * rm the second center crop --------- Co-authored-by: Sayak Paul <[email protected]>

Hi @sayakpaul @patrickvonplaten . Since Stable-Diffusion-XL is getting increasingly popular, users may want to see how it performs in textual inversion.
I enabled Stable-Diffusion-XL textual inversion and training 2000 steps with bfloat16 on Intel SPR node, the result is as follows:
Text: A cat-toy backpack
Would you please help to review my changes? Thx!