-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[T2I LoRA training] fix: unscale fp16 gradient problem #6119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| pipeline = pipeline.to(accelerator.device) | ||
| # Final inference | ||
| # Load previous pipeline | ||
| if args.validation_prompt is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not validation_prompt was passed we must not run this step.
| # load attention processors | ||
| pipeline.unet.load_attn_procs(args.output_dir) | ||
| # load attention processors | ||
| pipeline.load_lora_weights(args.output_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure to use load_lora_weights() instead of load_attn_procs().
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@sayakpaul I am getting this error for regular LoRA fine-tune: This is in a free GPU enabled Google Colab |
|
I am not sure what script you're using here. |
younesbelkada
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense thanks! in the future we could also expose a method on PEFT to upcast trainable params in fp32 ! cc @BenjaminBossan @pacman100 similarly as prepare_model_for_kbit_training
Yes, for sure, this isn't the first time this came up. Do we know exactly when this condition appears? Is it only when the user explicitly loads a model in float16? If yes, we may want to add a corresponding check to this PR. |
@sayakpaul can confirm but I think that's the case right ? |
Indeed that's the case. Only reduced precisions. |
|
@patil-suraj @williamberman can you please also take a look here? |
I am using the |
Does this also apply to bf16? If not, I think the dtype conversion should be conditional, i.e. |
|
@BenjaminBossan done in 8ac462b. |
|
Hmm, but now we're just silently disabling fp16 training - didn't this work before (e.g. that the whole UNet is kept in fp16 when LoRA is trained). Why doesn't it work anymore? |
|
The problem here is the following IMO:
|
|
the changes proposed only upcasts the LoRA in fp32 (with the check diffusers/src/diffusers/models/lora.py Line 204 in a0c5482
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
) |
|
one cleaner check could be to check if the module is an instance of |
I see that makes sense! Thanks for the explanation |
Co-authored-by: Patrick von Platen <[email protected]>
|
I pulled in the changes from this PR and added to #6225. text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [unet]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one, text_encoder_two])
+ for model in models:
+ for param in model.parameters():
+ # only upcast trainable parameters (LoRA) into fp32
+ if param.requires_grad:
+ param.data = param.to(torch.float32)
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:I can confirm that things are working well: https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl/runs/ow1vrez8. See the "test" media pictures. Command I ran: CUDA_VISIBLE_DEVICES=1 accelerate launch train_with_fixes.py \
--pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
--pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
--instance_data_dir="dog" \
--output_dir="corgy_dog_LoRA" \
--mixed_precision="fp16" \
--instance_prompt="a photo of TOK dog" \
--resolution=1024 \
--train_batch_size=4 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--snr_gamma=5.0 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--mixed_precision="fp16" \
--use_8bit_adam \
--max_train_steps=500 \
--checkpointing_steps=100 \
--push_to_hub \
--validation_prompt="a photo of TOK dog in a bucket at the beach" \
--report_to="wandb" \
--seed="0"Trained model: https://huggingface.co/sayakpaul/corgy_dog_LoRA. I am gonna try to run using Colab free tier too and report back here. |
|
To the ones wondering if this stuff would run on free-tier Colab Notebook, https://colab.research.google.com/gist/sayakpaul/9615b89369f3ef23cc29d0dac58253dd/scratchpad.ipynb should clear all the doubts once and for all 💪 |
This seems to be working, but when added the |
|
That is an unrelated problem and you should instead file this in the |
Sure will do, but is there a way apart from this to train an SDXL model with captions for each image, alternative to this? |
|
This you will have to debug your way through, cause it's not exactly the same code that you're using. |
) * fix: unscale fp16 gradient problem * fix for dreambooth lora sdxl * make the type-casting conditional. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
) * fix: unscale fp16 gradient problem * fix for dreambooth lora sdxl * make the type-casting conditional. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>

What does this PR do?
Fixes: #6086