diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 568279d9be3e..e35630e3e8af 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1279,7 +1279,7 @@ def main(args): for name, param in text_encoder_one.named_parameters(): if "token_embedding" in name: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param = param.to(dtype=torch.float32) + param.data = param.to(dtype=torch.float32) param.requires_grad = True text_lora_parameters_one.append(param) else: @@ -1288,7 +1288,7 @@ def main(args): for name, param in text_encoder_two.named_parameters(): if "token_embedding" in name: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param = param.to(dtype=torch.float32) + param.data = param.to(dtype=torch.float32) param.requires_grad = True text_lora_parameters_two.append(param) else: @@ -1725,19 +1725,19 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs) elif args.train_text_encoder_ti: # args.train_text_encoder_ti num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs) - + # flag used for textual inversion + pivoted = False for epoch in range(first_epoch, args.num_train_epochs): # if performing any kind of optimization of text_encoder params if args.train_text_encoder or args.train_text_encoder_ti: if epoch == num_train_epochs_text_encoder: print("PIVOT HALFWAY", epoch) # stopping optimization of text_encoder params - # re setting the optimizer to optimize only on unet params - optimizer.param_groups[1]["lr"] = 0.0 - optimizer.param_groups[2]["lr"] = 0.0 + # this flag is used to reset the optimizer to optimize only on unet params + pivoted = True else: - # still optimizng the text encoder + # still optimizing the text encoder text_encoder_one.train() text_encoder_two.train() # set top parameter requires_grad = True for gradient checkpointing works @@ -1747,6 +1747,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): unet.train() for step, batch in enumerate(train_dataloader): + if pivoted: + # stopping optimization of text_encoder params + # re setting the optimizer to optimize only on unet params + optimizer.param_groups[1]["lr"] = 0.0 + optimizer.param_groups[2]["lr"] = 0.0 + with accelerator.accumulate(unet): prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - @@ -1885,8 +1891,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # every step, we reset the embeddings to the original embeddings. if args.train_text_encoder_ti: - for idx, text_encoder in enumerate(text_encoders): - embedding_handler.retract_embeddings() + embedding_handler.retract_embeddings() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: