Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import argparse
import copy
import gc
import itertools
import logging
import math
Expand Down Expand Up @@ -56,6 +55,7 @@
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
)
Expand Down Expand Up @@ -210,9 +210,7 @@ def log_validation(
}
)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(objs=[pipeline])

return images

Expand Down Expand Up @@ -1107,9 +1105,7 @@ def main(args):
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(objs=[pipeline])

# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -1455,12 +1451,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):

# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
)

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
Expand Down Expand Up @@ -1795,11 +1789,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
pipeline_args=pipeline_args,
epoch=epoch,
)
objs = []
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])

torch.cuda.empty_cache()
gc.collect()
clear_objs_and_retain_memory(objs=objs)

# Save the lora layers
accelerator.wait_for_everyone()
Expand Down
17 changes: 17 additions & 0 deletions src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import copy
import gc
import math
import random
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -259,6 +260,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting


def clear_objs_and_retain_memory(objs: List[Any]):
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
if len(objs) >= 1:
for obj in objs:
del obj

gc.collect()

if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
elif is_torch_npu_available():
torch_npu.empty_cache()


# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
Expand Down