-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[Flux] Dreambooth LoRA training scripts #9086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for starting this @linoytsaban!
I think this could stay as a minimal educational resource for the community. Could we see some good first results as well? We should probably mention from the readme that for a more advanced and memory-friendly setup, please refer to @bghira's guide here: #9057 (comment).
LMK what you think.
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." | ||
) | ||
|
||
vae.to(accelerator.device, dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The VAE seems to be better. So, let's first try with weight_dtype
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bf16 is good yes but fp16 is not, there are some huge activation values in the attn layer. but you can also just disable attn or upcast it. same issue roughly as SD 2.x Unet, but not quite
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe fp16 is okay after #9097
u = compute_density_for_timestep_sampling( | ||
weighting_scheme=args.weighting_scheme, | ||
batch_size=bsz, | ||
logit_mean=args.logit_mean, | ||
logit_std=args.logit_std, | ||
mode_scale=args.mode_scale, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @bghira. Should we default to "log_normal" weighting here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i used 'none' at first which is an option i added to mine that just uses uniform sampling and uniform loss
but adding 'cosmap' as an option seems to be ideal, maybe approximates some min-snr-gamma like behaviour
the loss seems to follow a bathtub curve in some tests where it's high at both ends and low in the middle. maybe that's why the cosmap
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We support:
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], |
In your experiments, wha's been the best scheme so far?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cosmap or none
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have a reference for none
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes just adding "none" to the above choices list allows it to use uniform sampling and no weighting without any add'l changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah thanks. Good to know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
# handle guidance | ||
if transformer.config.guidance_embeds: | ||
guidance = torch.tensor([args.guidance_scale], device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think an option for min and max values and random range it might be more effective after experimentation
if accelerator.is_main_process: | ||
transformer = unwrap_model(transformer) | ||
transformer = transformer.to(torch.float32) | ||
transformer_lora_layers = get_peft_model_state_dict(transformer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul @bghira
I wonder if upcasting here is necessary or we can support an option to save in lower precision. For me, even when training a rank 4 with 8 bit Adam it still crashes here (training on A100). I think @bghira doesnt upcast when saving - what's your impression?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i am using quanto so i don't touch the dtype before export - using bf16 exports made the files half the size, so i kept it for some time now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alternatively you can move it to cpu but then you're consuming system memory and users report SIGKILL when it OOMs and Linux gets angry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great work, cant wait to see what people do with it
Thank you! |
requirements_flux.txt not found |
how much vram does the default script use? I'm ooming an a100 80gb with fp16 |
Amazing work! FYI, A100 80GB out of memory on Replicate.com |
ah, erwold is right: # noise
z1 = torch.randn_like(x)
# noisy_latents = ( 1.0 - sigmas ) * latents.float() + sigmas * noise.float()
zt = (1 - texp) * x + texp * z1
# model_pred
vtheta = self.model(zt, t, cond)
# target = z1 - x
# mse = target - model_pred
batchwise_mse = ((z1 - x - vtheta) ** 2).mean(dim=list(range(1, len(x.shape)))) |
You're totally right thanks! it was left from the pre-conditioning option that we decided eventually to remove |
all fixes are here - #9139 |
* initial commit - dreambooth for flux * update transformer to be FluxTransformer2DModel * update training loop and validation inference * fix sd3->flux docs * add guidance handling, not sure if it makes sense(?) * inital dreambooth lora commit * fix text_ids in compute_text_embeddings * fix imports of static methods * fix pipeline loading in readme, remove auto1111 docs for now * fix pipeline loading in readme, remove auto1111 docs for now, remove some irrelevant text_encoder_3 refs * Update examples/dreambooth/train_dreambooth_flux.py Co-authored-by: Bagheera <[email protected]> * fix te2 loading and remove te2 refs from text encoder training * fix tokenizer_2 initialization * remove text_encoder training refs from lora script (for now) * try with vae in bfloat16, fix model hook save * fix tokenization * fix static imports * fix CLIP import * remove text_encoder training refs (for now) from lora script * fix minor bug in encode_prompt, add guidance def in lora script, ... * fix unpack_latents args * fix license in readme * add "none" to weighting_scheme options for uniform sampling * style * adapt model saving - remove text encoder refs * adapt model loading - remove text encoder refs * initial commit for readme * Update examples/dreambooth/train_dreambooth_lora_flux.py Co-authored-by: Sayak Paul <[email protected]> * Update examples/dreambooth/train_dreambooth_lora_flux.py Co-authored-by: Sayak Paul <[email protected]> * fix vae casting * remove precondition_outputs * readme * readme * style * readme * readme * update weighting scheme default & docs * style * add text_encoder training to lora script, change vae_scale_factor value in both * style * text encoder training fixes * style * update readme * minor fixes * fix te params * fix te params --------- Co-authored-by: Bagheera <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
add dreambooth & dreambooth lora training scripts for FLUX [dev]