Skip to content

[Flux] Dreambooth LoRA training scripts #9086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 59 commits into from
Aug 9, 2024

Conversation

linoytsaban
Copy link
Collaborator

@linoytsaban linoytsaban commented Aug 5, 2024

add dreambooth & dreambooth lora training scripts for FLUX [dev]

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@linoytsaban linoytsaban marked this pull request as ready for review August 5, 2024 12:58
@linoytsaban linoytsaban requested a review from sayakpaul August 5, 2024 13:05
@linoytsaban linoytsaban changed the title [Flux] Dreambooth training script [Flux] Dreambooth LoRA training scripts Aug 5, 2024
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for starting this @linoytsaban!

I think this could stay as a minimal educational resource for the community. Could we see some good first results as well? We should probably mention from the readme that for a more advanced and memory-friendly setup, please refer to @bghira's guide here: #9057 (comment).

LMK what you think.

"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)

vae.to(accelerator.device, dtype=torch.float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The VAE seems to be better. So, let's first try with weight_dtype?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bf16 is good yes but fp16 is not, there are some huge activation values in the attn layer. but you can also just disable attn or upcast it. same issue roughly as SD 2.x Unet, but not quite

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe fp16 is okay after #9097

Comment on lines +1567 to +1573
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @bghira. Should we default to "log_normal" weighting here as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i used 'none' at first which is an option i added to mine that just uses uniform sampling and uniform loss

but adding 'cosmap' as an option seems to be ideal, maybe approximates some min-snr-gamma like behaviour

the loss seems to follow a bathtub curve in some tests where it's high at both ends and low in the middle. maybe that's why the cosmap

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support:

choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],

In your experiments, wha's been the best scheme so far?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cosmap or none

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a reference for none?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes just adding "none" to the above choices list allows it to use uniform sampling and no weighting without any add'l changes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah thanks. Good to know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

here's how the different selection options look. i think uniform makes the most sense for low batch sizes, eg none

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

here's how the different selection options look. i think uniform makes the most sense for low batch sizes, eg none


# handle guidance
if transformer.config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think an option for min and max values and random range it might be more effective after experimentation

Comment on lines +1632 to +1635
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
transformer = transformer.to(torch.float32)
transformer_lora_layers = get_peft_model_state_dict(transformer)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul @bghira
I wonder if upcasting here is necessary or we can support an option to save in lower precision. For me, even when training a rank 4 with 8 bit Adam it still crashes here (training on A100). I think @bghira doesnt upcast when saving - what's your impression?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am using quanto so i don't touch the dtype before export - using bf16 exports made the files half the size, so i kept it for some time now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternatively you can move it to cpu but then you're consuming system memory and users report SIGKILL when it OOMs and Linux gets angry

Copy link
Contributor

@bghira bghira left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work, cant wait to see what people do with it

@sayakpaul sayakpaul merged commit 65e3090 into huggingface:main Aug 9, 2024
8 checks passed
@sayakpaul
Copy link
Member

Thank you!

@sevenclay
Copy link

requirements_flux.txt not found

@kohya-ss kohya-ss mentioned this pull request Aug 9, 2024
25 tasks
@rvorias
Copy link

rvorias commented Aug 9, 2024

how much vram does the default script use? I'm ooming an a100 80gb with fp16

@dannypostma
Copy link

dannypostma commented Aug 10, 2024

Amazing work! FYI, A100 80GB out of memory on Replicate.com

@erwold
Copy link

erwold commented Aug 10, 2024

image this seems not correct. from the SimpleTuner code below: image image I'm not sure if I get it, if your target is "noise - model_input", then you shouldn't have "model_pred = model_pred * (-sigmas) + noisy_latents" this line of code.

@bghira
Copy link
Contributor

bghira commented Aug 10, 2024

ah, erwold is right:

        # noise
        z1 = torch.randn_like(x)
        # noisy_latents = ( 1.0 - sigmas  ) * latents.float() + sigmas * noise.float()
        zt = (1 - texp) * x + texp * z1
        # model_pred
        vtheta = self.model(zt, t, cond)
        # target = z1 - x
        # mse = target - model_pred
        batchwise_mse = ((z1 - x - vtheta) ** 2).mean(dim=list(range(1, len(x.shape))))

@linoytsaban
Copy link
Collaborator Author

You're totally right thanks! it was left from the pre-conditioning option that we decided eventually to remove

@linoytsaban
Copy link
Collaborator Author

all fixes are here - #9139

@linoytsaban linoytsaban deleted the flux-fine-tuning branch August 12, 2024 14:08
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* initial commit - dreambooth for flux

* update transformer to be FluxTransformer2DModel

* update training loop and validation inference

* fix sd3->flux docs

* add guidance handling, not sure if it makes sense(?)

* inital dreambooth lora commit

* fix text_ids in compute_text_embeddings

* fix imports of static methods

* fix pipeline loading in readme, remove auto1111 docs for now

* fix pipeline loading in readme, remove auto1111 docs for now, remove some irrelevant text_encoder_3 refs

* Update examples/dreambooth/train_dreambooth_flux.py

Co-authored-by: Bagheera <[email protected]>

* fix te2 loading and remove te2 refs from text encoder training

* fix tokenizer_2 initialization

* remove text_encoder training refs from lora script (for now)

* try with vae in bfloat16, fix model hook save

* fix tokenization

* fix static imports

* fix CLIP import

* remove text_encoder training refs (for now) from lora script

* fix minor bug in encode_prompt, add guidance def in lora script, ...

* fix unpack_latents args

* fix license in readme

* add "none" to weighting_scheme options for uniform sampling

* style

* adapt model saving - remove text encoder refs

* adapt model loading - remove text encoder refs

* initial commit for readme

* Update examples/dreambooth/train_dreambooth_lora_flux.py

Co-authored-by: Sayak Paul <[email protected]>

* Update examples/dreambooth/train_dreambooth_lora_flux.py

Co-authored-by: Sayak Paul <[email protected]>

* fix vae casting

* remove precondition_outputs

* readme

* readme

* style

* readme

* readme

* update weighting scheme default & docs

* style

* add text_encoder training to lora script, change vae_scale_factor value in both

* style

* text encoder training fixes

* style

* update readme

* minor fixes

* fix te params

* fix te params

---------

Co-authored-by: Bagheera <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants