60
60
logger = get_logger (__name__ )
61
61
62
62
63
+ def image_grid (imgs , rows , cols ):
64
+ assert len (imgs ) == rows * cols
65
+
66
+ w , h = imgs [0 ].size
67
+ grid = Image .new ("RGB" , size = (cols * w , rows * h ))
68
+
69
+ for i , img in enumerate (imgs ):
70
+ grid .paste (img , box = (i % cols * w , i // cols * h ))
71
+ return grid
72
+
73
+
63
74
def log_validation (vae , text_encoder , tokenizer , unet , controlnet , args , accelerator , weight_dtype , step ):
64
75
logger .info ("Running validation... " )
65
76
@@ -156,6 +167,8 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
156
167
else :
157
168
logger .warn (f"image logging not implemented for { tracker .name } " )
158
169
170
+ return image_logs
171
+
159
172
160
173
def import_model_class_from_model_name_or_path (pretrained_model_name_or_path : str , revision : str ):
161
174
text_encoder_config = PretrainedConfig .from_pretrained (
@@ -177,6 +190,43 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
177
190
raise ValueError (f"{ model_class } is not supported." )
178
191
179
192
193
+ def save_model_card (repo_id : str , image_logs = None , base_model = str , repo_folder = None ):
194
+ img_str = ""
195
+ if image_logs is not None :
196
+ img_str = "You can find some example images below.\n "
197
+ for i , log in enumerate (image_logs ):
198
+ images = log ["images" ]
199
+ validation_prompt = log ["validation_prompt" ]
200
+ validation_image = log ["validation_image" ]
201
+ validation_image .save (os .path .join (repo_folder , "image_control.png" ))
202
+ img_str += f"prompt: { validation_prompt } \n "
203
+ images = [validation_image ] + images
204
+ image_grid (images , 1 , len (images )).save (os .path .join (repo_folder , f"images_{ i } .png" ))
205
+ img_str += f"\n "
206
+
207
+ yaml = f"""
208
+ ---
209
+ license: creativeml-openrail-m
210
+ base_model: { base_model }
211
+ tags:
212
+ - stable-diffusion
213
+ - stable-diffusion-diffusers
214
+ - text-to-image
215
+ - diffusers
216
+ - controlnet
217
+ inference: true
218
+ ---
219
+ """
220
+ model_card = f"""
221
+ # controlnet-{ repo_id }
222
+
223
+ These are controlnet weights trained on { base_model } with new type of conditioning.
224
+ { img_str }
225
+ """
226
+ with open (os .path .join (repo_folder , "README.md" ), "w" ) as f :
227
+ f .write (yaml + model_card )
228
+
229
+
180
230
def parse_args (input_args = None ):
181
231
parser = argparse .ArgumentParser (description = "Simple example of a ControlNet training script." )
182
232
parser .add_argument (
@@ -943,6 +993,7 @@ def load_model_hook(models, input_dir):
943
993
disable = not accelerator .is_local_main_process ,
944
994
)
945
995
996
+ image_logs = None
946
997
for epoch in range (first_epoch , args .num_train_epochs ):
947
998
for step , batch in enumerate (train_dataloader ):
948
999
with accelerator .accumulate (controlnet ):
@@ -1014,7 +1065,7 @@ def load_model_hook(models, input_dir):
1014
1065
logger .info (f"Saved state to { save_path } " )
1015
1066
1016
1067
if args .validation_prompt is not None and global_step % args .validation_steps == 0 :
1017
- log_validation (
1068
+ image_logs = log_validation (
1018
1069
vae ,
1019
1070
text_encoder ,
1020
1071
tokenizer ,
@@ -1040,6 +1091,12 @@ def load_model_hook(models, input_dir):
1040
1091
controlnet .save_pretrained (args .output_dir )
1041
1092
1042
1093
if args .push_to_hub :
1094
+ save_model_card (
1095
+ repo_id ,
1096
+ image_logs = image_logs ,
1097
+ base_model = args .pretrained_model_name_or_path ,
1098
+ repo_folder = args .output_dir ,
1099
+ )
1043
1100
upload_folder (
1044
1101
repo_id = repo_id ,
1045
1102
folder_path = args .output_dir ,
0 commit comments