@@ -123,16 +123,26 @@ def save_model_card(
123123 """
124124
125125 trigger_str = f"You should use { instance_prompt } to trigger the image generation."
126+ diffusers_imports_pivotal = ""
127+ diffusers_example_pivotal = ""
126128 if train_text_encoder_ti :
127129 trigger_str = (
128130 "To trigger image generation of trained concept(or concepts) replace each concept identifier "
129131 "in you prompt with the new inserted tokens:\n "
130132 )
133+ diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
134+ from safetensors.torch import load_file
135+ """
136+ diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{ repo_id } ", filename="embeddings.safetensors", repo_type="model")
137+ state_dict = load_file(embedding_path)
138+ pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
139+ pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
140+ """
131141 if token_abstraction_dict :
132142 for key , value in token_abstraction_dict .items ():
133143 tokens = "" .join (value )
134144 trigger_str += f"""
135- to trigger concept `{ key } ->` use `{ tokens } ` in your prompt \n
145+ to trigger concept `{ key } ` → use `{ tokens } ` in your prompt \n
136146"""
137147
138148 yaml = f"""
@@ -172,7 +182,21 @@ def save_model_card(
172182
173183{ trigger_str }
174184
175- ## Download model
185+ ## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
186+
187+ ```py
188+ from diffusers import AutoPipelineForText2Image
189+ import torch
190+ { diffusers_imports_pivotal }
191+ pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16).to('cuda')
192+ pipeline.load_lora_weights('{ repo_id } ', weight_name='pytorch_lora_weights.safetensors')
193+ { diffusers_example_pivotal }
194+ image = pipeline('{ validation_prompt if validation_prompt else instance_prompt } ').images[0]
195+ ```
196+
197+ For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
198+
199+ ## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke)
176200
177201Weights for this model are available in Safetensors format.
178202
@@ -791,6 +815,12 @@ def __init__(
791815 instance_data_root ,
792816 instance_prompt ,
793817 class_prompt ,
818+ dataset_name ,
819+ dataset_config_name ,
820+ cache_dir ,
821+ image_column ,
822+ caption_column ,
823+ train_text_encoder_ti ,
794824 class_data_root = None ,
795825 class_num = None ,
796826 token_abstraction_dict = None , # token mapping for textual inversion
@@ -805,10 +835,10 @@ def __init__(
805835 self .custom_instance_prompts = None
806836 self .class_prompt = class_prompt
807837 self .token_abstraction_dict = token_abstraction_dict
808-
838+ self . train_text_encoder_ti = train_text_encoder_ti
809839 # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
810840 # we load the training data using load_dataset
811- if args . dataset_name is not None :
841+ if dataset_name is not None :
812842 try :
813843 from datasets import load_dataset
814844 except ImportError :
@@ -821,38 +851,37 @@ def __init__(
821851 # See more about loading custom images at
822852 # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
823853 dataset = load_dataset (
824- args . dataset_name ,
825- args . dataset_config_name ,
826- cache_dir = args . cache_dir ,
854+ dataset_name ,
855+ dataset_config_name ,
856+ cache_dir = cache_dir ,
827857 )
828858 # Preprocessing the datasets.
829859 column_names = dataset ["train" ].column_names
830860
831861 # 6. Get the column names for input/target.
832- if args . image_column is None :
862+ if image_column is None :
833863 image_column = column_names [0 ]
834864 logger .info (f"image column defaulting to { image_column } " )
835865 else :
836- image_column = args .image_column
837866 if image_column not in column_names :
838867 raise ValueError (
839- f"`--image_column` value '{ args . image_column } ' not found in dataset columns. Dataset columns are: { ', ' .join (column_names )} "
868+ f"`--image_column` value '{ image_column } ' not found in dataset columns. Dataset columns are: { ', ' .join (column_names )} "
840869 )
841870 instance_images = dataset ["train" ][image_column ]
842871
843- if args . caption_column is None :
872+ if caption_column is None :
844873 logger .info (
845874 "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
846875 "contains captions/prompts for the images, make sure to specify the "
847876 "column as --caption_column"
848877 )
849878 self .custom_instance_prompts = None
850879 else :
851- if args . caption_column not in column_names :
880+ if caption_column not in column_names :
852881 raise ValueError (
853- f"`--caption_column` value '{ args . caption_column } ' not found in dataset columns. Dataset columns are: { ', ' .join (column_names )} "
882+ f"`--caption_column` value '{ caption_column } ' not found in dataset columns. Dataset columns are: { ', ' .join (column_names )} "
854883 )
855- custom_instance_prompts = dataset ["train" ][args . caption_column ]
884+ custom_instance_prompts = dataset ["train" ][caption_column ]
856885 # create final list of captions according to --repeats
857886 self .custom_instance_prompts = []
858887 for caption in custom_instance_prompts :
@@ -907,7 +936,7 @@ def __getitem__(self, index):
907936 if self .custom_instance_prompts :
908937 caption = self .custom_instance_prompts [index % self .num_instance_images ]
909938 if caption :
910- if args .train_text_encoder_ti :
939+ if self .train_text_encoder_ti :
911940 # replace instances of --token_abstraction in caption with the new tokens: "<si><si+1>" etc.
912941 for token_abs , token_replacement in self .token_abstraction_dict .items ():
913942 caption = caption .replace (token_abs , "" .join (token_replacement ))
@@ -1093,10 +1122,10 @@ def main(args):
10931122 if args .output_dir is not None :
10941123 os .makedirs (args .output_dir , exist_ok = True )
10951124
1125+ model_id = args .hub_model_id or Path (args .output_dir ).name
1126+ repo_id = None
10961127 if args .push_to_hub :
1097- repo_id = create_repo (
1098- repo_id = args .hub_model_id or Path (args .output_dir ).name , exist_ok = True , token = args .hub_token
1099- ).repo_id
1128+ repo_id = create_repo (repo_id = model_id , exist_ok = True , token = args .hub_token ).repo_id
11001129
11011130 # Load the tokenizers
11021131 tokenizer_one = AutoTokenizer .from_pretrained (
@@ -1464,6 +1493,12 @@ def load_model_hook(models, input_dir):
14641493 instance_data_root = args .instance_data_dir ,
14651494 instance_prompt = args .instance_prompt ,
14661495 class_prompt = args .class_prompt ,
1496+ dataset_name = args .dataset_name ,
1497+ dataset_config_name = args .dataset_config_name ,
1498+ cache_dir = args .cache_dir ,
1499+ image_column = args .image_column ,
1500+ train_text_encoder_ti = args .train_text_encoder_ti ,
1501+ caption_column = args .caption_column ,
14671502 class_data_root = args .class_data_dir if args .with_prior_preservation else None ,
14681503 token_abstraction_dict = token_abstraction_dict if args .train_text_encoder_ti else None ,
14691504 class_num = args .num_class_images ,
@@ -2004,23 +2039,23 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
20042039 }
20052040 )
20062041
2007- if args .push_to_hub :
2008- if args .train_text_encoder_ti :
2009- embedding_handler .save_embeddings (
2010- f"{ args .output_dir } /embeddings.safetensors" ,
2011- )
2012- save_model_card (
2013- repo_id ,
2014- images = images ,
2015- base_model = args .pretrained_model_name_or_path ,
2016- train_text_encoder = args .train_text_encoder ,
2017- train_text_encoder_ti = args .train_text_encoder_ti ,
2018- token_abstraction_dict = train_dataset .token_abstraction_dict ,
2019- instance_prompt = args .instance_prompt ,
2020- validation_prompt = args .validation_prompt ,
2021- repo_folder = args .output_dir ,
2022- vae_path = args .pretrained_vae_model_name_or_path ,
2042+ if args .train_text_encoder_ti :
2043+ embedding_handler .save_embeddings (
2044+ f"{ args .output_dir } /embeddings.safetensors" ,
20232045 )
2046+ save_model_card (
2047+ model_id if not args .push_to_hub else repo_id ,
2048+ images = images ,
2049+ base_model = args .pretrained_model_name_or_path ,
2050+ train_text_encoder = args .train_text_encoder ,
2051+ train_text_encoder_ti = args .train_text_encoder_ti ,
2052+ token_abstraction_dict = train_dataset .token_abstraction_dict ,
2053+ instance_prompt = args .instance_prompt ,
2054+ validation_prompt = args .validation_prompt ,
2055+ repo_folder = args .output_dir ,
2056+ vae_path = args .pretrained_vae_model_name_or_path ,
2057+ )
2058+ if args .push_to_hub :
20242059 upload_folder (
20252060 repo_id = repo_id ,
20262061 folder_path = args .output_dir ,
0 commit comments