Skip to content

Commit 6e22133

Browse files
[advanced_dreambooth_lora_sdxl_tranining_script] save embeddings locally fix (#6058)
* Update train_dreambooth_lora_sdxl_advanced.py * remove global function args from dreamboothdataset class * style * style --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 53bc30d commit 6e22133

File tree

1 file changed

+69
-34
lines changed

1 file changed

+69
-34
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,26 @@ def save_model_card(
123123
"""
124124

125125
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
126+
diffusers_imports_pivotal = ""
127+
diffusers_example_pivotal = ""
126128
if train_text_encoder_ti:
127129
trigger_str = (
128130
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
129131
"in you prompt with the new inserted tokens:\n"
130132
)
133+
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
134+
from safetensors.torch import load_file
135+
"""
136+
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{repo_id}", filename="embeddings.safetensors", repo_type="model")
137+
state_dict = load_file(embedding_path)
138+
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
139+
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
140+
"""
131141
if token_abstraction_dict:
132142
for key, value in token_abstraction_dict.items():
133143
tokens = "".join(value)
134144
trigger_str += f"""
135-
to trigger concept `{key}->` use `{tokens}` in your prompt \n
145+
to trigger concept `{key}` → use `{tokens}` in your prompt \n
136146
"""
137147

138148
yaml = f"""
@@ -172,7 +182,21 @@ def save_model_card(
172182
173183
{trigger_str}
174184
175-
## Download model
185+
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
186+
187+
```py
188+
from diffusers import AutoPipelineForText2Image
189+
import torch
190+
{diffusers_imports_pivotal}
191+
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16).to('cuda')
192+
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
193+
{diffusers_example_pivotal}
194+
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
195+
```
196+
197+
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
198+
199+
## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke)
176200
177201
Weights for this model are available in Safetensors format.
178202
@@ -791,6 +815,12 @@ def __init__(
791815
instance_data_root,
792816
instance_prompt,
793817
class_prompt,
818+
dataset_name,
819+
dataset_config_name,
820+
cache_dir,
821+
image_column,
822+
caption_column,
823+
train_text_encoder_ti,
794824
class_data_root=None,
795825
class_num=None,
796826
token_abstraction_dict=None, # token mapping for textual inversion
@@ -805,10 +835,10 @@ def __init__(
805835
self.custom_instance_prompts = None
806836
self.class_prompt = class_prompt
807837
self.token_abstraction_dict = token_abstraction_dict
808-
838+
self.train_text_encoder_ti = train_text_encoder_ti
809839
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
810840
# we load the training data using load_dataset
811-
if args.dataset_name is not None:
841+
if dataset_name is not None:
812842
try:
813843
from datasets import load_dataset
814844
except ImportError:
@@ -821,38 +851,37 @@ def __init__(
821851
# See more about loading custom images at
822852
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
823853
dataset = load_dataset(
824-
args.dataset_name,
825-
args.dataset_config_name,
826-
cache_dir=args.cache_dir,
854+
dataset_name,
855+
dataset_config_name,
856+
cache_dir=cache_dir,
827857
)
828858
# Preprocessing the datasets.
829859
column_names = dataset["train"].column_names
830860

831861
# 6. Get the column names for input/target.
832-
if args.image_column is None:
862+
if image_column is None:
833863
image_column = column_names[0]
834864
logger.info(f"image column defaulting to {image_column}")
835865
else:
836-
image_column = args.image_column
837866
if image_column not in column_names:
838867
raise ValueError(
839-
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
868+
f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
840869
)
841870
instance_images = dataset["train"][image_column]
842871

843-
if args.caption_column is None:
872+
if caption_column is None:
844873
logger.info(
845874
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
846875
"contains captions/prompts for the images, make sure to specify the "
847876
"column as --caption_column"
848877
)
849878
self.custom_instance_prompts = None
850879
else:
851-
if args.caption_column not in column_names:
880+
if caption_column not in column_names:
852881
raise ValueError(
853-
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
882+
f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
854883
)
855-
custom_instance_prompts = dataset["train"][args.caption_column]
884+
custom_instance_prompts = dataset["train"][caption_column]
856885
# create final list of captions according to --repeats
857886
self.custom_instance_prompts = []
858887
for caption in custom_instance_prompts:
@@ -907,7 +936,7 @@ def __getitem__(self, index):
907936
if self.custom_instance_prompts:
908937
caption = self.custom_instance_prompts[index % self.num_instance_images]
909938
if caption:
910-
if args.train_text_encoder_ti:
939+
if self.train_text_encoder_ti:
911940
# replace instances of --token_abstraction in caption with the new tokens: "<si><si+1>" etc.
912941
for token_abs, token_replacement in self.token_abstraction_dict.items():
913942
caption = caption.replace(token_abs, "".join(token_replacement))
@@ -1093,10 +1122,10 @@ def main(args):
10931122
if args.output_dir is not None:
10941123
os.makedirs(args.output_dir, exist_ok=True)
10951124

1125+
model_id = args.hub_model_id or Path(args.output_dir).name
1126+
repo_id = None
10961127
if args.push_to_hub:
1097-
repo_id = create_repo(
1098-
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
1099-
).repo_id
1128+
repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id
11001129

11011130
# Load the tokenizers
11021131
tokenizer_one = AutoTokenizer.from_pretrained(
@@ -1464,6 +1493,12 @@ def load_model_hook(models, input_dir):
14641493
instance_data_root=args.instance_data_dir,
14651494
instance_prompt=args.instance_prompt,
14661495
class_prompt=args.class_prompt,
1496+
dataset_name=args.dataset_name,
1497+
dataset_config_name=args.dataset_config_name,
1498+
cache_dir=args.cache_dir,
1499+
image_column=args.image_column,
1500+
train_text_encoder_ti=args.train_text_encoder_ti,
1501+
caption_column=args.caption_column,
14671502
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
14681503
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
14691504
class_num=args.num_class_images,
@@ -2004,23 +2039,23 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
20042039
}
20052040
)
20062041

2007-
if args.push_to_hub:
2008-
if args.train_text_encoder_ti:
2009-
embedding_handler.save_embeddings(
2010-
f"{args.output_dir}/embeddings.safetensors",
2011-
)
2012-
save_model_card(
2013-
repo_id,
2014-
images=images,
2015-
base_model=args.pretrained_model_name_or_path,
2016-
train_text_encoder=args.train_text_encoder,
2017-
train_text_encoder_ti=args.train_text_encoder_ti,
2018-
token_abstraction_dict=train_dataset.token_abstraction_dict,
2019-
instance_prompt=args.instance_prompt,
2020-
validation_prompt=args.validation_prompt,
2021-
repo_folder=args.output_dir,
2022-
vae_path=args.pretrained_vae_model_name_or_path,
2042+
if args.train_text_encoder_ti:
2043+
embedding_handler.save_embeddings(
2044+
f"{args.output_dir}/embeddings.safetensors",
20232045
)
2046+
save_model_card(
2047+
model_id if not args.push_to_hub else repo_id,
2048+
images=images,
2049+
base_model=args.pretrained_model_name_or_path,
2050+
train_text_encoder=args.train_text_encoder,
2051+
train_text_encoder_ti=args.train_text_encoder_ti,
2052+
token_abstraction_dict=train_dataset.token_abstraction_dict,
2053+
instance_prompt=args.instance_prompt,
2054+
validation_prompt=args.validation_prompt,
2055+
repo_folder=args.output_dir,
2056+
vae_path=args.pretrained_vae_model_name_or_path,
2057+
)
2058+
if args.push_to_hub:
20242059
upload_folder(
20252060
repo_id=repo_id,
20262061
folder_path=args.output_dir,

0 commit comments

Comments
 (0)