Skip to content

Commit 1be7df0

Browse files
erkamspatrickvonplatenpatil-suraj
authored
[LoRA] Freezing the model weights (#2245)
* [LoRA] Freezing the model weights Freeze the model weights since we don't need to calculate grads for them. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
1 parent 62a15ce commit 1be7df0

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,12 @@ def main():
415415
unet = UNet2DConditionModel.from_pretrained(
416416
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
417417
)
418-
418+
# freeze parameters of models to save more memory
419+
unet.requires_grad_(False)
420+
vae.requires_grad_(False)
421+
422+
text_encoder.requires_grad_(False)
423+
419424
# For mixed precision training we cast the text_encoder and vae weights to half-precision
420425
# as these models are only used for inference, keeping weights in full precision is not required.
421426
weight_dtype = torch.float32

0 commit comments

Comments
 (0)