Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/dreambooth/train_dreambooth_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's provide the author courtesy here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@linoytsaban did we resolve this?

pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand Down Expand Up @@ -1579,7 +1579,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)

# handle guidance
if transformer.config.guidance_embeds:
if accelerator.unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
Expand Down Expand Up @@ -1693,6 +1693,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# create pipeline
if not args.train_text_encoder:
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
else: # even when training the text encoder we're only training text encoder one
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path,
Expand Down
33 changes: 30 additions & 3 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,15 @@ def parse_args(input_args=None):
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)

parser.add_argument(
"--lora_layers",
type=str,
default=None,
help=(
'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
),
)

parser.add_argument(
"--adam_epsilon",
type=float,
Expand Down Expand Up @@ -1186,12 +1195,30 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable()

# now we will add new LoRA weights to the attention layers
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = [
"attn.to_k",
"attn.to_q",
"attn.to_v",
"attn.to_out.0",
"attn.add_k_proj",
"attn.add_q_proj",
"attn.add_v_proj",
"attn.to_add_out",
"ff.net.0.proj",
"ff.net.2",
"ff_context.net.0.proj",
"ff_context.net.2",
]
Comment on lines +1201 to +1214
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a bit breaking no? Better to not do it and instead make a note from the README?

WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Breaking or just changing default behavior? I think it's geared more towards the latter, but I think it's in line with the other trainers & makes sense for Transformer based models, so maybe a Warning note and a guide on how to train it the old way for e.g.?

Copy link
Member

@sayakpaul sayakpaul Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah maybe a warning note at the beginning of the README should cut it.

With this change, we're likely also increasing the total training wall-clock time in the default setting, so, that is worth noting.


# now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder:
Expand Down Expand Up @@ -1367,7 +1394,7 @@ def load_model_hook(models, input_dir):
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# changes the learning rate of text_encoder_parameters_one to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate

Expand Down
Loading