Skip to content

Commit 121567b

Browse files
committed
Revert "[Wuerstchen] fix fp16 training and correct lora args (#6245)"
This reverts commit 0bb9cf0.
1 parent fd64acf commit 121567b

File tree

1 file changed

+1
-9
lines changed

1 file changed

+1
-9
lines changed

examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -527,17 +527,9 @@ def deepspeed_zero_init_disabled_context_manager():
527527

528528
# lora attn processor
529529
prior_lora_config = LoraConfig(
530-
r=args.rank,
531-
lora_alpha=args.rank,
532-
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
530+
r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"]
533531
)
534-
# Add adapter and make sure the trainable params are in float32.
535532
prior.add_adapter(prior_lora_config)
536-
if args.mixed_precision == "fp16":
537-
for param in prior.parameters():
538-
# only upcast trainable parameters (LoRA) into fp32
539-
if param.requires_grad:
540-
param.data = param.to(torch.float32)
541533

542534
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
543535
def save_model_hook(models, weights, output_dir):

0 commit comments

Comments
 (0)