diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 2e77cb946f92..17e6e107b079 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -15,7 +15,6 @@ import argparse import copy -import gc import itertools import logging import math @@ -56,6 +55,7 @@ from diffusers.training_utils import ( _set_state_dict_into_text_encoder, cast_training_params, + clear_objs_and_retain_memory, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, ) @@ -210,9 +210,7 @@ def log_validation( } ) - del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clear_objs_and_retain_memory(objs=[pipeline]) return images @@ -1107,9 +1105,7 @@ def main(args): image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) - del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clear_objs_and_retain_memory(objs=[pipeline]) # Handle the repository creation if accelerator.is_main_process: @@ -1455,12 +1451,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - del tokenizers, text_encoders # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection - del text_encoder_one, text_encoder_two, text_encoder_three - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clear_objs_and_retain_memory( + objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three] + ) # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1795,11 +1789,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline_args=pipeline_args, epoch=epoch, ) + objs = [] if not args.train_text_encoder: - del text_encoder_one, text_encoder_two, text_encoder_three + objs.extend([text_encoder_one, text_encoder_two, text_encoder_three]) - torch.cuda.empty_cache() - gc.collect() + clear_objs_and_retain_memory(objs=objs) # Save the lora layers accelerator.wait_for_everyone() diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index f497fcc6131c..26d4a2a504c6 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,5 +1,6 @@ import contextlib import copy +import gc import math import random from typing import Any, Dict, Iterable, List, Optional, Tuple, Union @@ -259,6 +260,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting +def clear_objs_and_retain_memory(objs: List[Any]): + """Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator.""" + if len(objs) >= 1: + for obj in objs: + del obj + + gc.collect() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() + elif is_torch_npu_available(): + torch_npu.empty_cache() + + # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """