From 8ab999e501ce56fc333a7b85cd08a4f70e437a88 Mon Sep 17 00:00:00 2001 From: Vincent Date: Thu, 11 Jan 2024 23:33:53 +0700 Subject: [PATCH 1/7] support compile --- examples/dreambooth/train_dreambooth.py | 32 +++++++++++++++---------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index f652b1e79bcc..8b02db9ba95b 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -55,6 +55,7 @@ from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module if is_wandb_available(): @@ -106,6 +107,10 @@ def save_model_card( with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) +def unwrap_model(accelerator, model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model def log_validation( text_encoder, @@ -130,14 +135,14 @@ def log_validation( pipeline_args["vae"] = vae if text_encoder is not None: - text_encoder = accelerator.unwrap_model(text_encoder) + text_encoder = unwrap_model(accelerator, text_encoder) # create pipeline (note: unet and vae are loaded again in float32) pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, tokenizer=tokenizer, text_encoder=text_encoder, - unet=accelerator.unwrap_model(unet), + unet=unwrap_model(accelerator, unet), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -794,6 +799,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte prompt_embeds = text_encoder( text_input_ids, attention_mask=attention_mask, + return_dict=False, ) prompt_embeds = prompt_embeds[0] @@ -935,7 +941,7 @@ def main(args): def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: for model in models: - sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder" + sub_dir = "unet" if isinstance(model, type(unwrap_model(accelerator, unet))) else "text_encoder" model.save_pretrained(os.path.join(output_dir, sub_dir)) # make sure to pop weight so that corresponding model is not saved again @@ -946,7 +952,7 @@ def load_model_hook(models, input_dir): # pop models so that they are not loaded again model = models.pop() - if isinstance(model, type(accelerator.unwrap_model(text_encoder))): + if isinstance(model, type(unwrap_model(accelerator, text_encoder))): # load transformers style into model load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") model.config = load_model.config @@ -991,14 +997,14 @@ def load_model_hook(models, input_dir): " doing mixed precision training. copy of the weights should still be float32." ) - if accelerator.unwrap_model(unet).dtype != torch.float32: + if unwrap_model(accelerator, unet).dtype != torch.float32: raise ValueError( - f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + f"Unet loaded as datatype {unwrap_model(accelerator, unet).dtype}. {low_precision_error_string}" ) - if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32: + if args.train_text_encoder and unwrap_model(accelerator, text_encoder).dtype != torch.float32: raise ValueError( - f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." + f"Text encoder loaded as datatype {unwrap_model(accelerator, text_encoder).dtype}." f" {low_precision_error_string}" ) @@ -1246,7 +1252,7 @@ def compute_text_embeddings(prompt): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) - if accelerator.unwrap_model(unet).config.in_channels == channels * 2: + if unwrap_model(accelerator, unet).config.in_channels == channels * 2: noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) if args.class_labels_conditioning == "timesteps": @@ -1256,8 +1262,8 @@ def compute_text_embeddings(prompt): # Predict the noise residual model_pred = unet( - noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels - ).sample + noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False + )[0] if model_pred.shape[1] == 6: model_pred, _ = torch.chunk(model_pred, 2, dim=1) @@ -1375,14 +1381,14 @@ def compute_text_embeddings(prompt): pipeline_args = {} if text_encoder is not None: - pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) + pipeline_args["text_encoder"] = unwrap_model(accelerator, text_encoder) if args.skip_save_text_encoder: pipeline_args["text_encoder"] = None pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), + unet=unwrap_model(accelerator, unet), revision=args.revision, variant=args.variant, **pipeline_args, From c4468f5f015bd15a155ce6b3985df9e4a48693f0 Mon Sep 17 00:00:00 2001 From: Vincent Date: Thu, 11 Jan 2024 23:36:21 +0700 Subject: [PATCH 2/7] make style --- examples/dreambooth/train_dreambooth.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 8b02db9ba95b..a43bc6f9775e 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -107,11 +107,13 @@ def save_model_card( with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) + def unwrap_model(accelerator, model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model return model + def log_validation( text_encoder, tokenizer, From 40b219135e6ecba7480b9cc14203e8a3b256ffc0 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Fri, 12 Jan 2024 12:26:59 +0700 Subject: [PATCH 3/7] move unwrap_model inside function --- examples/dreambooth/train_dreambooth.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index a43bc6f9775e..00088767a4cb 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -108,12 +108,6 @@ def save_model_card( f.write(yaml + model_card) -def unwrap_model(accelerator, model): - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model - return model - - def log_validation( text_encoder, tokenizer, @@ -136,15 +130,12 @@ def log_validation( if vae is not None: pipeline_args["vae"] = vae - if text_encoder is not None: - text_encoder = unwrap_model(accelerator, text_encoder) - # create pipeline (note: unet and vae are loaded again in float32) pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, tokenizer=tokenizer, text_encoder=text_encoder, - unet=unwrap_model(accelerator, unet), + unet=unet, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -939,6 +930,11 @@ def main(args): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: @@ -1358,9 +1354,9 @@ def compute_text_embeddings(prompt): if args.validation_prompt is not None and global_step % args.validation_steps == 0: images = log_validation( - text_encoder, + unwrap_model(text_encoder) if text_encoder is not None else text_encoder, tokenizer, - unet, + unwrap_model(unet), vae, args, accelerator, From 4191fc30f4d0c133ed34cc3664636ca9426bd9f8 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Fri, 12 Jan 2024 12:32:33 +0700 Subject: [PATCH 4/7] change unwrap call --- examples/dreambooth/train_dreambooth.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 00088767a4cb..17b13db22a66 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -939,7 +939,7 @@ def unwrap_model(model): def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: for model in models: - sub_dir = "unet" if isinstance(model, type(unwrap_model(accelerator, unet))) else "text_encoder" + sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder" model.save_pretrained(os.path.join(output_dir, sub_dir)) # make sure to pop weight so that corresponding model is not saved again @@ -950,7 +950,7 @@ def load_model_hook(models, input_dir): # pop models so that they are not loaded again model = models.pop() - if isinstance(model, type(unwrap_model(accelerator, text_encoder))): + if isinstance(model, type(unwrap_model(text_encoder))): # load transformers style into model load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") model.config = load_model.config @@ -995,14 +995,14 @@ def load_model_hook(models, input_dir): " doing mixed precision training. copy of the weights should still be float32." ) - if unwrap_model(accelerator, unet).dtype != torch.float32: + if unwrap_model(unet).dtype != torch.float32: raise ValueError( - f"Unet loaded as datatype {unwrap_model(accelerator, unet).dtype}. {low_precision_error_string}" + f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}" ) - if args.train_text_encoder and unwrap_model(accelerator, text_encoder).dtype != torch.float32: + if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: raise ValueError( - f"Text encoder loaded as datatype {unwrap_model(accelerator, text_encoder).dtype}." + f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" ) @@ -1250,7 +1250,7 @@ def compute_text_embeddings(prompt): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) - if unwrap_model(accelerator, unet).config.in_channels == channels * 2: + if unwrap_model(unet).config.in_channels == channels * 2: noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) if args.class_labels_conditioning == "timesteps": @@ -1379,14 +1379,14 @@ def compute_text_embeddings(prompt): pipeline_args = {} if text_encoder is not None: - pipeline_args["text_encoder"] = unwrap_model(accelerator, text_encoder) + pipeline_args["text_encoder"] = unwrap_model(text_encoder) if args.skip_save_text_encoder: pipeline_args["text_encoder"] = None pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=unwrap_model(accelerator, unet), + unet=unwrap_model(unet), revision=args.revision, variant=args.variant, **pipeline_args, From 362461ba04d66998c50ab334d170f795a97341c4 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Fri, 12 Jan 2024 12:33:53 +0700 Subject: [PATCH 5/7] run make style --- examples/dreambooth/train_dreambooth.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 17b13db22a66..532e134a6153 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -996,14 +996,11 @@ def load_model_hook(models, input_dir): ) if unwrap_model(unet).dtype != torch.float32: - raise ValueError( - f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}" - ) + raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}") if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: raise ValueError( - f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." - f" {low_precision_error_string}" + f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" ) # Enable TF32 for faster training on Ampere GPUs, From 70ab09732e7cfec0b19c497f823ddd1c8259dad0 Mon Sep 17 00:00:00 2001 From: "Vinh H. Pham" Date: Fri, 12 Jan 2024 13:30:55 +0700 Subject: [PATCH 6/7] Update examples/dreambooth/train_dreambooth.py Co-authored-by: Sayak Paul --- examples/dreambooth/train_dreambooth.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 532e134a6153..d1e9c28e05c0 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -789,10 +789,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte else: attention_mask = None - prompt_embeds = text_encoder( - text_input_ids, - attention_mask=attention_mask, - return_dict=False, + prompt_embeds = text_encoder(text_input_ids, attention_mask=attention_mask, return_dict=False) ) prompt_embeds = prompt_embeds[0] From 6e8b573db99eb19314cd225edfeaf29fb2e93703 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Fri, 12 Jan 2024 13:45:17 +0700 Subject: [PATCH 7/7] Revert "Update examples/dreambooth/train_dreambooth.py" This reverts commit 70ab09732e7cfec0b19c497f823ddd1c8259dad0. --- examples/dreambooth/train_dreambooth.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index d1e9c28e05c0..532e134a6153 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -789,7 +789,10 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte else: attention_mask = None - prompt_embeds = text_encoder(text_input_ids, attention_mask=attention_mask, return_dict=False) + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + return_dict=False, ) prompt_embeds = prompt_embeds[0]