From 2ad983734b8d24e858a7e6f99199ab88db53023d Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Sat, 25 Nov 2023 16:23:28 +0200 Subject: [PATCH 01/13] imports and readme bug fixes --- .../train_dreambooth_lora_sdxl_advanced.py | 44 ++++++++++++++++--- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index f032634a11f0..3619596adbd2 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -54,7 +54,7 @@ UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin -from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict +from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr, unet_lora_state_dict from diffusers.utils import check_min_version, is_wandb_available @@ -67,6 +67,39 @@ logger = get_logger(__name__) +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + def save_model_card( repo_id: str, images=None, @@ -83,7 +116,7 @@ def save_model_card( img_str += f""" - text: '{validation_prompt if validation_prompt else ' ' }' output: - url: >- + url: "image_{i}.png" """ @@ -96,9 +129,7 @@ def save_model_card( - diffusers - lora - template:sd-lora -widget: {img_str} ---- base_model: {base_model} instance_prompt: {instance_prompt} license: openrail++ @@ -112,9 +143,12 @@ def save_model_card( ## Model description -These are {repo_id} LoRA adaption weights for {base_model}. +### These are {repo_id} LoRA adaption weights for {base_model}. + The weights were trained using [DreamBooth](https://dreambooth.github.io/). + LoRA for the text encoder was enabled: {train_text_encoder}. + Special VAE used for training: {vae_path}. ## Trigger words From bf783ff751800efb61a0da238bb614063ee79ca5 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Sun, 26 Nov 2023 17:51:37 +0200 Subject: [PATCH 02/13] bug fix - ensures text_encoder params are dtype==float32 (when using pivotal tuning) even if the rest of the model is loaded in fp16 --- .../train_dreambooth_lora_sdxl_advanced.py | 26 +++---------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 3619596adbd2..429c2cc834a5 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -730,19 +730,6 @@ def dtype(self): def device(self): return self.text_encoders[0].device - # def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder): - # # Assuming new tokens are of the format - # self.inserting_toks = [f"" for i in range(loaded_embeddings.shape[0])] - # special_tokens_dict = {"additional_special_tokens": self.inserting_toks} - # tokenizer.add_special_tokens(special_tokens_dict) - # text_encoder.resize_token_embeddings(len(tokenizer)) - # - # self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) - # assert self.train_ids is not None, "New tokens could not be converted to IDs." - # text_encoder.text_model.embeddings.token_embedding.weight.data[ - # self.train_ids - # ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype) - @torch.no_grad() def retract_embeddings(self): for idx, text_encoder in enumerate(self.text_encoders): @@ -764,15 +751,6 @@ def retract_embeddings(self): new_embeddings = new_embeddings * (off_ratio**0.1) text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings - # def load_embeddings(self, file_path: str): - # with safe_open(file_path, framework="pt", device=self.device.type) as f: - # for idx in range(len(self.text_encoders)): - # text_encoder = self.text_encoders[idx] - # tokenizer = self.tokenizers[idx] - # - # loaded_embeddings = f.get_tensor(f"text_encoders_{idx}") - # self._load_embeddings(loaded_embeddings, tokenizer, text_encoder) - class DreamBoothDataset(Dataset): """ @@ -1250,6 +1228,8 @@ def main(args): text_lora_parameters_one = [] for name, param in text_encoder_one.named_parameters(): if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param = param.to(dtype=torch.float32) param.requires_grad = True text_lora_parameters_one.append(param) else: @@ -1257,6 +1237,8 @@ def main(args): text_lora_parameters_two = [] for name, param in text_encoder_two.named_parameters(): if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param = param.to(dtype=torch.float32) param.requires_grad = True text_lora_parameters_two.append(param) else: From cdceaa6a64e85afb4947a9924b0a402305702225 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Sun, 26 Nov 2023 23:02:28 +0200 Subject: [PATCH 03/13] added pivotal tuning to readme --- .../train_dreambooth_lora_sdxl_advanced.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 429c2cc834a5..d203c3cb23f8 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -105,6 +105,7 @@ def save_model_card( images=None, base_model=str, train_text_encoder=False, + train_text_encoder_ti=False, instance_prompt=str, validation_prompt=str, repo_folder=None, @@ -149,6 +150,8 @@ def save_model_card( LoRA for the text encoder was enabled: {train_text_encoder}. +Pivotal tuning was enabled: {train_text_encoder_ti}. + Special VAE used for training: {vae_path}. ## Trigger words @@ -1964,6 +1967,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): images=images, base_model=args.pretrained_model_name_or_path, train_text_encoder=args.train_text_encoder, + train_text_encoder_ti=args.train_text_encoder_ti, instance_prompt=args.instance_prompt, validation_prompt=args.validation_prompt, repo_folder=args.output_dir, From 0bbffa9b4a8d1e5e4ce2e045f8413447b1b01962 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Sun, 26 Nov 2023 23:32:46 +0200 Subject: [PATCH 04/13] mapping token identifier to new inserted token in validation prompt (if used) --- .../train_dreambooth_lora_sdxl_advanced.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index d203c3cb23f8..17b7fd01642a 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1513,6 +1513,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + if args.train_text_encoder_ti and args.validation_prompt: + # replace instances of --token_abstraction in validation prompt with the new tokens: "" etc. + for token_abs, token_replacement in train_dataset.token_abstraction_dict.items(): + args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) + print("validation prompt:", args.validation_prompt) + # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1647,7 +1653,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] - print(prompts) + # print(prompts) # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: if freeze_text_encoder: From b83a846ea52b975b44c2107763c58de6d0191710 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Mon, 27 Nov 2023 01:05:29 +0200 Subject: [PATCH 05/13] correct default value of --train_text_encoder_frac --- .../train_dreambooth_lora_sdxl_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 17b7fd01642a..3baacd326f62 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -492,7 +492,7 @@ def parse_args(input_args=None): parser.add_argument( "--train_text_encoder_frac", type=float, - default=0.5, + default=1.0, help=("The percentage of epochs to perform text encoder tuning"), ) From d53be8ff789ea37e13569729b7f174ad2f20c9bd Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Mon, 27 Nov 2023 08:52:25 +0200 Subject: [PATCH 06/13] change default value of --adam_weight_decay_text_encoder --- .../train_dreambooth_lora_sdxl_advanced.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 3baacd326f62..93000d73f863 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -525,7 +525,7 @@ def parse_args(input_args=None): parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( - "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + "--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder" ) parser.add_argument( @@ -1328,12 +1328,12 @@ def load_model_hook(models, input_dir): # different learning rate for text encoder and unet text_lora_parameters_one_with_lr = { "params": text_lora_parameters_one, - "weight_decay": args.adam_weight_decay_text_encoder, + "weight_decay": args.adam_weight_decay_text_encoder if args.adam_weight_decay_text_encoder else args.adam_weight_decay, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } text_lora_parameters_two_with_lr = { "params": text_lora_parameters_two, - "weight_decay": args.adam_weight_decay_text_encoder, + "weight_decay": args.adam_weight_decay_text_encoder if args.adam_weight_decay_text_encoder else args.adam_weight_decay, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } params_to_optimize = [ From 6edc0f5a044e35ed3a0f6a35d468c9a3c79c5775 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Mon, 27 Nov 2023 10:28:37 +0200 Subject: [PATCH 07/13] validation prompt generations when using pivotal tuning bug fix --- .../train_dreambooth_lora_sdxl_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 93000d73f863..788fec49898c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1826,7 +1826,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): f" {args.validation_prompt}." ) # create pipeline - if not args.train_text_encoder: + if freeze_text_encoder: text_encoder_one = text_encoder_cls_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) From 3c2afc59c711fef72665665c4aa48bd6a7eed4a0 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Mon, 27 Nov 2023 08:38:42 +0000 Subject: [PATCH 08/13] style fix --- .../train_dreambooth_lora_sdxl_advanced.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 788fec49898c..103e72b05db0 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1328,12 +1328,16 @@ def load_model_hook(models, input_dir): # different learning rate for text encoder and unet text_lora_parameters_one_with_lr = { "params": text_lora_parameters_one, - "weight_decay": args.adam_weight_decay_text_encoder if args.adam_weight_decay_text_encoder else args.adam_weight_decay, + "weight_decay": args.adam_weight_decay_text_encoder + if args.adam_weight_decay_text_encoder + else args.adam_weight_decay, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } text_lora_parameters_two_with_lr = { "params": text_lora_parameters_two, - "weight_decay": args.adam_weight_decay_text_encoder if args.adam_weight_decay_text_encoder else args.adam_weight_decay, + "weight_decay": args.adam_weight_decay_text_encoder + if args.adam_weight_decay_text_encoder + else args.adam_weight_decay, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } params_to_optimize = [ From 9265cb09e069694e5dfe3447a51b46daa79066d2 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Mon, 27 Nov 2023 12:44:00 +0200 Subject: [PATCH 09/13] textual inversion embeddings name change --- .../train_dreambooth_lora_sdxl_advanced.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 103e72b05db0..2ad0d00d51b4 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -716,12 +716,19 @@ def initialize_new_tokens(self, inserting_toks: List[str]): def save_embeddings(self, file_path: str): assert self.train_ids is not None, "Initialize new tokens before saving embeddings." tensors = {} + # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 + idx_to_text_encoder_name = {0:"clip_l", 1:"clip_g"} for idx, text_encoder in enumerate(self.text_encoders): assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( self.tokenizers[0] ), "Tokenizers should be the same." new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] - tensors[f"text_encoders_{idx}"] = new_token_embeddings + + # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for + # text_encoder 1) to keep compatible with the ecosystem. + # Note: When loading with diffusers, any name can work - simply specify in inference + tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings + # tensors[f"text_encoders_{idx}"] = new_token_embeddings save_file(tensors, file_path) From 0af8f44755f0ca6e9e835f85f285a2d7133cf579 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Mon, 27 Nov 2023 10:44:42 +0000 Subject: [PATCH 10/13] style fix --- .../train_dreambooth_lora_sdxl_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 2ad0d00d51b4..7d3e1dda0c9c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -717,7 +717,7 @@ def save_embeddings(self, file_path: str): assert self.train_ids is not None, "Initialize new tokens before saving embeddings." tensors = {} # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - idx_to_text_encoder_name = {0:"clip_l", 1:"clip_g"} + idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"} for idx, text_encoder in enumerate(self.text_encoders): assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( self.tokenizers[0] From e53a84ade44b0997b056d5289b0c920a82a235c8 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Tue, 28 Nov 2023 12:25:31 +0200 Subject: [PATCH 11/13] bug fix - stopping text encoder optimization halfway --- .../train_dreambooth_lora_sdxl_advanced.py | 25 +++---------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 7d3e1dda0c9c..d1645d1cb387 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1629,27 +1629,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if epoch == num_train_epochs_text_encoder: print("PIVOT HALFWAY", epoch) # stopping optimization of text_encoder params - params_to_optimize = params_to_optimize[:1] - # reinitializing the optimizer to optimize only on unet params - if args.optimizer.lower() == "prodigy": - optimizer = optimizer_class( - params_to_optimize, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - beta3=args.prodigy_beta3, - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - decouple=args.prodigy_decouple, - use_bias_correction=args.prodigy_use_bias_correction, - safeguard_warmup=args.prodigy_safeguard_warmup, - ) - else: # AdamW or 8-bit-AdamW - optimizer = optimizer_class( - params_to_optimize, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) + # re setting the optimizer to optimize only on unet params + optimizer.param_groups[1]["lr"] = 0.0 + optimizer.param_groups[2]["lr"] = 0.0 + else: # still optimizng the text encoder text_encoder_one.train() From 9092ab28320d252001fbc45c0a27ef5dc9d42655 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Fri, 1 Dec 2023 14:44:32 +0200 Subject: [PATCH 12/13] readme - will include token abstraction and new inserted tokens when using pivotal tuning - added type to --num_new_tokens_per_abstraction --- .../train_dreambooth_lora_sdxl_advanced.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index d1645d1cb387..a65f51309f92 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -106,6 +106,7 @@ def save_model_card( base_model=str, train_text_encoder=False, train_text_encoder_ti=False, + token_abstraction_dict = None, instance_prompt=str, validation_prompt=str, repo_folder=None, @@ -121,6 +122,17 @@ def save_model_card( "image_{i}.png" """ + trigger_str = f"You should use {instance_prompt} to trigger the image generation." + if train_text_encoder_ti: + trigger_str = "To trigger image generation of trained concept(or concepts) replace each concept identifier " \ + "in you prompt with the new inserted tokens:\n" + if token_abstraction_dict: + for key,value in token_abstraction_dict.items(): + tokens = "".join(value) + trigger_str += f""" + to trigger concept {key}-> use {tokens} in your prompt \n + """ + yaml = f""" --- tags: @@ -156,7 +168,7 @@ def save_model_card( ## Trigger words -You should use {instance_prompt} to trigger the image generation. +{trigger_str} ## Download model @@ -281,6 +293,7 @@ def parse_args(input_args=None): parser.add_argument( "--num_new_tokens_per_abstraction", + type=int, default=2, help="number of new tokens inserted to the tokenizers per token_abstraction value when " "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " @@ -1968,6 +1981,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): base_model=args.pretrained_model_name_or_path, train_text_encoder=args.train_text_encoder, train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict = train_dataset.token_abstraction_dict, instance_prompt=args.instance_prompt, validation_prompt=args.validation_prompt, repo_folder=args.output_dir, From 9868e675fa69f2636a9a2969862e39ac3258d18f Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Fri, 1 Dec 2023 12:50:51 +0000 Subject: [PATCH 13/13] style fix --- .../train_dreambooth_lora_sdxl_advanced.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index a65f51309f92..3fccd1786be5 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -106,7 +106,7 @@ def save_model_card( base_model=str, train_text_encoder=False, train_text_encoder_ti=False, - token_abstraction_dict = None, + token_abstraction_dict=None, instance_prompt=str, validation_prompt=str, repo_folder=None, @@ -124,13 +124,15 @@ def save_model_card( trigger_str = f"You should use {instance_prompt} to trigger the image generation." if train_text_encoder_ti: - trigger_str = "To trigger image generation of trained concept(or concepts) replace each concept identifier " \ - "in you prompt with the new inserted tokens:\n" + trigger_str = ( + "To trigger image generation of trained concept(or concepts) replace each concept identifier " + "in you prompt with the new inserted tokens:\n" + ) if token_abstraction_dict: - for key,value in token_abstraction_dict.items(): + for key, value in token_abstraction_dict.items(): tokens = "".join(value) trigger_str += f""" - to trigger concept {key}-> use {tokens} in your prompt \n + to trigger concept {key}-> use {tokens} in your prompt \n """ yaml = f""" @@ -1981,7 +1983,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): base_model=args.pretrained_model_name_or_path, train_text_encoder=args.train_text_encoder, train_text_encoder_ti=args.train_text_encoder_ti, - token_abstraction_dict = train_dataset.token_abstraction_dict, + token_abstraction_dict=train_dataset.token_abstraction_dict, instance_prompt=args.instance_prompt, validation_prompt=args.validation_prompt, repo_folder=args.output_dir,