Skip to content

Conversation

@jiqing-feng
Copy link
Contributor

@jiqing-feng jiqing-feng commented Jan 2, 2024

Hi @sayakpaul @patrickvonplaten . Since Stable-Diffusion-XL is getting increasingly popular, users may want to see how it performs in textual inversion.

I enabled Stable-Diffusion-XL textual inversion and training 2000 steps with bfloat16 on Intel SPR node, the result is as follows:
Text: A cat-toy backpack

cat-backpack

Would you please help to review my changes? Thx!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contributions. However, it might be better to have it in a separate script and add proper test for it.

See how we did it for SDXL LoRA DreamBooth here: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py

Tests:

class DreamBoothLoRASDXL(ExamplesTestsAccelerate):

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @jiqing-feng,

Such a new textual inversion script would be very helpful I think. Can we maybe add it as a new _sdxl script like we've done for lora: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py ?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@jiqing-feng
Copy link
Contributor Author

Hi @sayakpaul @patrickvonplaten . Thanks for your review. I have added textual inversion for SDXL in a new script, and also a new test script. Would you please review it? Thx!

BTW, from my experiments, only fine-tuning the text_encoder_1 will get a better result than find-tuning both text encoders.

return args


imagenet_templates_small = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this coming from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is copied from textual_inversion

Comment on lines 565 to 574
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use CenterCrop from torchvision for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is copied from textual_inversion

# Move vae and unet and text_encoder_2 to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder_2 = text_encoder_2.to(accelerator.device, dtype=weight_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not move the text_encoder too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

text_encoder_1 will move to the device in accelerate.prepare function

# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
unet.train()
text_encoder_1.gradient_checkpointing_enable()
text_encoder_2.gradient_checkpointing_enable()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If text_encoder_2 is not trained then why enable gradient checkpointing here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will remove it if we don't want to train text_encoder_2

unet.train()
text_encoder_1.gradient_checkpointing_enable()
text_encoder_2.gradient_checkpointing_enable()
unet.enable_gradient_checkpointing()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be kept as the comment in line 698 explained it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is weird. None of the other training scripts enable gradient checkpointing on the models that are not being trained.

sample_size = unet.config.sample_size * (2 ** (len(vae.config.block_out_channels) - 1))
original_size = (sample_size, sample_size)
add_time_ids = torch.tensor(
[list(original_size + (0, 0) + original_size)], dtype=weight_dtype, device=accelerator.device
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample_size calculation seems to be wrong to me as it should be the original size of the input images. Also, we're not supplementing the crop coordinates here.

Could you refer to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora_sdxl.py and incorporate the changes here w.r.t how these micro-conditions are implemented?

accelerator,
args,
save_path,
safe_serialization=not args.no_safe_serialization,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's default to safetensors and not make it configurable IMO.

Copy link
Contributor Author

@jiqing-feng jiqing-feng Jan 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is also copied from textual_inversion, and it defaults to safetensors in this script. I can make it unconfigurable if you want to.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safetensors is the default format in diffusers. So, makes sense to not make it configurable here.

Also, the script that keep referring to is a bit old. So, we don't have to follow it note to note :)

commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also run validation here? We can use log_validation() here and log the images under "test" key, instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log_validation() is in the line 961. I think it would be better to do log_validation() after training and before saving, instead of after pushing to hub. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, what I mean is inside log_validation we're using "validation" key to log the medias. If we use the same key for covering both intermediate logging and the logging done after training, it might be inconvenient. To give you a better idea, refer to this:

tracker_key = "test" if is_final_validation else "validation"

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great start! Thank you.

Left some initial comments related to the implementation. I'd prefer having the ability to train two text encoders, though.

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Jan 8, 2024

Hi @sayakpaul. I think I have fixed all your comments except training 2 text decoders.

This is the result that I only trained 1 text encoder for 500 steps on A100, and it seems great.
image

Unfortunately, I didn't get an acceptable result if training 2 text encoders. Could we merge this example first? I will check what's wrong with 2 text encoders training. WDYT @patrickvonplaten

Thanks!

BTW, I run the following commands and it didn't change any code styles in my script.

  ruff check examples tests src utils scripts
  ruff format examples tests src utils scripts --check

@sayakpaul
Copy link
Member

You need to run make style && make quality to get the code styling issues fixed.

f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline (note: unet and vae are loaded again in float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this the case?

Comment on lines 748 to 756
if args.validation_epochs is not None:
warnings.warn(
f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}."
" Deprecated validation_epochs in favor of `validation_steps`"
f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}",
FutureWarning,
stacklevel=2,
)
args.validation_steps = args.validation_epochs * len(train_dataset)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's simplify this. Let's expose only one argument from the command-line to control the interval of running validation.

Comment on lines 694 to 696
text_encoder_2.text_model.encoder.requires_grad_(False)
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just call text_encoder_2.requires_grad_(False) here no since we're not training it?

# Move vae and unet and text_encoder_2 to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder_2 = text_encoder_2.to(accelerator.device, dtype=weight_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no need to assign to text_encoder_2 variable after device placement.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes, however, the open comments aren't resolved still:

Apart from these, I think we'd need to add a separate README_sdxl.md for this example like other SDXL scripts, so that users know what training commands to use and that training text encoder 2 isn't supported for specific reasons.

@jiqing-feng
Copy link
Contributor Author

Hi @sayakpaul
Sorry for that I changed the training script to the test script. I have fixed it, would you please help to review it again? Thx!

@jiqing-feng
Copy link
Contributor Author

make style && make quality

Unfortunately, it doesn't work

@@ -0,0 +1,27 @@
## Textual Inversion fine-tuning example for SDXL

The `textual_inversion.py` do not support training stable-diffusion-XL as it has two text encoders, you can training SDXL by the following command:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The `textual_inversion.py` do not support training stable-diffusion-XL as it has two text encoders, you can training SDXL by the following command:

I don't think we need to mention this. We can just add a note about the SDXL variant in the README.md file.

--output_dir="./textual_inversion_cat_sdxl"
```

We only enabled training the first text encoder because of the precision issue, we will enable training the second text encoder once we fixed the problem. No newline at end of file
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We only enabled training the first text encoder because of the precision issue, we will enable training the second text encoder once we fixed the problem.
For now, only training of the first text encoder is supported.

Comment on lines 710 to 716
optimizer_1 = torch.optim.AdamW(
text_encoder_1.get_input_embeddings().parameters(), # only optimize the embeddings
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually also support the 8Bit Adam too:

Since SDXL is quite heavier than SD, could we add it too?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, let's keep to optimizer for now. No need to use optimizer_1.

tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
size=args.resolution,
placeholder_token=(" ".join(tokenizer_1.convert_ids_to_tokens(placeholder_token_ids))),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's assign this " ".join(tokenizer_1.convert_ids_to_tokens(placeholder_token_ids))) in a separate variable.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the changes. Just some final set of comments.

@jiqing-feng
Copy link
Contributor Author

Hi @sayakpaul . Thanks for your review. I have fixed all the comments, would you please take a look? Thx!


**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**

**___Note: Please follow the README_sdxl.md if you are using the [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).___**
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**___Note: Please follow the README_sdxl.md if you are using the [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).___**
**___Note: Please follow the [README_sdxl.md](./README_sdxl.md) if you are using the [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).___**

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's this open comment still: https://github.com/huggingface/diffusers/pull/6421/files#r1442509162.

After everything is resolved, I will fix the quality issues.

@jiqing-feng
Copy link
Contributor Author

I think there's this open comment still: https://github.com/huggingface/diffusers/pull/6421/files#r1442509162.

After everything is resolved, I will fix the quality issues.

Hi @sayakpaul . Thanks for your clarify, I have fixed it now, please take a look. Thx!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for bearing with my requests!

@sayakpaul
Copy link
Member

Will merge after the CI is green :)

@jiqing-feng
Copy link
Contributor Author

Thank you so much for bearing with my requests!

Also thanks for your patience :)

@sayakpaul sayakpaul merged commit aa1797e into huggingface:main Jan 9, 2024
@sayakpaul
Copy link
Member

Thanks for your great contribution!

AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* enable stable-xl textual inversion

* check if optimizer_2 exists

* check text_encoder_2 before using

* add textual inversion for sdxl in a single file

* fix style

* fix example style

* reset for error changes

* add readme for sdxl

* fix style

* disable autocast as it will cause cast error when weight_dtype=bf16

* fix spelling error

* fix style and readme and 8bit optimizer

* add README_sdxl.md link

* add tracker key on log_validation

* run style

* rm the second center crop

---------

Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants