From 1f19f7d1871870b601a2d5782ab2d4cd4a03041f Mon Sep 17 00:00:00 2001 From: erkams Date: Sat, 4 Feb 2023 02:14:03 +0100 Subject: [PATCH 1/3] [LoRA] Freezing the model weights Freeze the model weights since we don't need to calculate grads for them. --- examples/text_to_image/train_text_to_image_lora.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 4f73af8e79cc..e0d614d68a5f 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -21,6 +21,7 @@ import random from pathlib import Path from typing import Optional +import itertools import numpy as np import torch @@ -415,7 +416,18 @@ def main(): unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) - + unet.requires_grad_(False) + vae.requires_grad_(False) + + params = itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + text_encoder.text_model.embeddings.position_embedding.parameters(), + ) + + for param in params: + param.requires_grad = False + # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 From 056fd25d23b82f15915b203dbcbb58a5b0b13677 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 8 Feb 2023 15:58:14 +0100 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Patrick von Platen --- examples/text_to_image/train_text_to_image_lora.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 067668eeef3a..2363745a923a 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -416,17 +416,11 @@ def main(): unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) + # freeze parameters of models to save more memory unet.requires_grad_(False) vae.requires_grad_(False) - params = itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), - ) - - for param in params: - param.requires_grad = False + text_encoder.requires_grad_(False) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. From 5a9f709de528627a7e8244e0a5fd2dc4e8bb2518 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 9 Feb 2023 10:42:48 +0100 Subject: [PATCH 3/3] Apply suggestions from code review --- examples/text_to_image/train_text_to_image_lora.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 2363745a923a..a3c5bef73a95 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -21,7 +21,6 @@ import random from pathlib import Path from typing import Optional -import itertools import datasets import numpy as np