diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 2bf3cc8f7c9c..df5477d0d643 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -123,16 +123,26 @@ def save_model_card( """ trigger_str = f"You should use {instance_prompt} to trigger the image generation." + diffusers_imports_pivotal = "" + diffusers_example_pivotal = "" if train_text_encoder_ti: trigger_str = ( "To trigger image generation of trained concept(or concepts) replace each concept identifier " "in you prompt with the new inserted tokens:\n" ) + diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download +from safetensors.torch import load_file + """ + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{repo_id}", filename="embeddings.safetensors", repo_type="model") +state_dict = load_file(embedding_path) +pipeline.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) +pipeline.load_textual_inversion(state_dict["clip_g"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) + """ if token_abstraction_dict: for key, value in token_abstraction_dict.items(): tokens = "".join(value) trigger_str += f""" -to trigger concept `{key}->` use `{tokens}` in your prompt \n +to trigger concept `{key}` → use `{tokens}` in your prompt \n """ yaml = f""" @@ -172,7 +182,21 @@ def save_model_card( {trigger_str} -## Download model +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +{diffusers_imports_pivotal} +pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +{diffusers_example_pivotal} +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke) Weights for this model are available in Safetensors format. @@ -791,6 +815,12 @@ def __init__( instance_data_root, instance_prompt, class_prompt, + dataset_name, + dataset_config_name, + cache_dir, + image_column, + caption_column, + train_text_encoder_ti, class_data_root=None, class_num=None, token_abstraction_dict=None, # token mapping for textual inversion @@ -805,10 +835,10 @@ def __init__( self.custom_instance_prompts = None self.class_prompt = class_prompt self.token_abstraction_dict = token_abstraction_dict - + self.train_text_encoder_ti = train_text_encoder_ti # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, # we load the training data using load_dataset - if args.dataset_name is not None: + if dataset_name is not None: try: from datasets import load_dataset except ImportError: @@ -821,26 +851,25 @@ def __init__( # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script dataset = load_dataset( - args.dataset_name, - args.dataset_config_name, - cache_dir=args.cache_dir, + dataset_name, + dataset_config_name, + cache_dir=cache_dir, ) # Preprocessing the datasets. column_names = dataset["train"].column_names # 6. Get the column names for input/target. - if args.image_column is None: + if image_column is None: image_column = column_names[0] logger.info(f"image column defaulting to {image_column}") else: - image_column = args.image_column if image_column not in column_names: raise ValueError( - f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) instance_images = dataset["train"][image_column] - if args.caption_column is None: + if caption_column is None: logger.info( "No caption column provided, defaulting to instance_prompt for all images. If your dataset " "contains captions/prompts for the images, make sure to specify the " @@ -848,11 +877,11 @@ def __init__( ) self.custom_instance_prompts = None else: - if args.caption_column not in column_names: + if caption_column not in column_names: raise ValueError( - f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) - custom_instance_prompts = dataset["train"][args.caption_column] + custom_instance_prompts = dataset["train"][caption_column] # create final list of captions according to --repeats self.custom_instance_prompts = [] for caption in custom_instance_prompts: @@ -907,7 +936,7 @@ def __getitem__(self, index): if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] if caption: - if args.train_text_encoder_ti: + if self.train_text_encoder_ti: # replace instances of --token_abstraction in caption with the new tokens: "" etc. for token_abs, token_replacement in self.token_abstraction_dict.items(): caption = caption.replace(token_abs, "".join(token_replacement)) @@ -1093,10 +1122,10 @@ def main(args): if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + model_id = args.hub_model_id or Path(args.output_dir).name + repo_id = None if args.push_to_hub: - repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token - ).repo_id + repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id # Load the tokenizers tokenizer_one = AutoTokenizer.from_pretrained( @@ -1464,6 +1493,12 @@ def load_model_hook(models, input_dir): instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, class_prompt=args.class_prompt, + dataset_name=args.dataset_name, + dataset_config_name=args.dataset_config_name, + cache_dir=args.cache_dir, + image_column=args.image_column, + train_text_encoder_ti=args.train_text_encoder_ti, + caption_column=args.caption_column, class_data_root=args.class_data_dir if args.with_prior_preservation else None, token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None, class_num=args.num_class_images, @@ -2004,23 +2039,23 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): } ) - if args.push_to_hub: - if args.train_text_encoder_ti: - embedding_handler.save_embeddings( - f"{args.output_dir}/embeddings.safetensors", - ) - save_model_card( - repo_id, - images=images, - base_model=args.pretrained_model_name_or_path, - train_text_encoder=args.train_text_encoder, - train_text_encoder_ti=args.train_text_encoder_ti, - token_abstraction_dict=train_dataset.token_abstraction_dict, - instance_prompt=args.instance_prompt, - validation_prompt=args.validation_prompt, - repo_folder=args.output_dir, - vae_path=args.pretrained_vae_model_name_or_path, + if args.train_text_encoder_ti: + embedding_handler.save_embeddings( + f"{args.output_dir}/embeddings.safetensors", ) + save_model_card( + model_id if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict=train_dataset.token_abstraction_dict, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + vae_path=args.pretrained_vae_model_name_or_path, + ) + if args.push_to_hub: upload_folder( repo_id=repo_id, folder_path=args.output_dir,