Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module


if is_wandb_available():
Expand Down Expand Up @@ -129,15 +130,12 @@ def log_validation(
if vae is not None:
pipeline_args["vae"] = vae

if text_encoder is not None:
text_encoder = accelerator.unwrap_model(text_encoder)

# create pipeline (note: unet and vae are loaded again in float32)
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
tokenizer=tokenizer,
text_encoder=text_encoder,
unet=accelerator.unwrap_model(unet),
unet=unet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
Expand Down Expand Up @@ -794,6 +792,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
prompt_embeds = text_encoder(
text_input_ids,
attention_mask=attention_mask,
return_dict=False,
)
prompt_embeds = prompt_embeds[0]

Expand Down Expand Up @@ -931,11 +930,16 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)

def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models:
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir))

# make sure to pop weight so that corresponding model is not saved again
Expand All @@ -946,7 +950,7 @@ def load_model_hook(models, input_dir):
# pop models so that they are not loaded again
model = models.pop()

if isinstance(model, type(accelerator.unwrap_model(text_encoder))):
if isinstance(model, type(unwrap_model(text_encoder))):
# load transformers style into model
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
model.config = load_model.config
Expand Down Expand Up @@ -991,15 +995,12 @@ def load_model_hook(models, input_dir):
" doing mixed precision training. copy of the weights should still be float32."
)

if accelerator.unwrap_model(unet).dtype != torch.float32:
raise ValueError(
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
)
if unwrap_model(unet).dtype != torch.float32:
raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}")

if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
f" {low_precision_error_string}"
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
)

# Enable TF32 for faster training on Ampere GPUs,
Expand Down Expand Up @@ -1246,7 +1247,7 @@ def compute_text_embeddings(prompt):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)

if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
if unwrap_model(unet).config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)

if args.class_labels_conditioning == "timesteps":
Expand All @@ -1256,8 +1257,8 @@ def compute_text_embeddings(prompt):

# Predict the noise residual
model_pred = unet(
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
).sample
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False
)[0]

if model_pred.shape[1] == 6:
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
Expand Down Expand Up @@ -1350,9 +1351,9 @@ def compute_text_embeddings(prompt):

if args.validation_prompt is not None and global_step % args.validation_steps == 0:
images = log_validation(
text_encoder,
unwrap_model(text_encoder) if text_encoder is not None else text_encoder,
tokenizer,
unet,
unwrap_model(unet),
vae,
args,
accelerator,
Expand All @@ -1375,14 +1376,14 @@ def compute_text_embeddings(prompt):
pipeline_args = {}

if text_encoder is not None:
pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
pipeline_args["text_encoder"] = unwrap_model(text_encoder)

if args.skip_save_text_encoder:
pipeline_args["text_encoder"] = None

pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
**pipeline_args,
Expand Down