Skip to content

Commit e0a2bd1

Browse files
authored
Write model card in controlnet training script (#3229)
Write model card in controlnet training script.
1 parent c399de3 commit e0a2bd1

File tree

1 file changed

+58
-1
lines changed

1 file changed

+58
-1
lines changed

examples/controlnet/train_controlnet.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,17 @@
6060
logger = get_logger(__name__)
6161

6262

63+
def image_grid(imgs, rows, cols):
64+
assert len(imgs) == rows * cols
65+
66+
w, h = imgs[0].size
67+
grid = Image.new("RGB", size=(cols * w, rows * h))
68+
69+
for i, img in enumerate(imgs):
70+
grid.paste(img, box=(i % cols * w, i // cols * h))
71+
return grid
72+
73+
6374
def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
6475
logger.info("Running validation... ")
6576

@@ -156,6 +167,8 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
156167
else:
157168
logger.warn(f"image logging not implemented for {tracker.name}")
158169

170+
return image_logs
171+
159172

160173
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
161174
text_encoder_config = PretrainedConfig.from_pretrained(
@@ -177,6 +190,43 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
177190
raise ValueError(f"{model_class} is not supported.")
178191

179192

193+
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
194+
img_str = ""
195+
if image_logs is not None:
196+
img_str = "You can find some example images below.\n"
197+
for i, log in enumerate(image_logs):
198+
images = log["images"]
199+
validation_prompt = log["validation_prompt"]
200+
validation_image = log["validation_image"]
201+
validation_image.save(os.path.join(repo_folder, "image_control.png"))
202+
img_str += f"prompt: {validation_prompt}\n"
203+
images = [validation_image] + images
204+
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
205+
img_str += f"![images_{i})](./images_{i}.png)\n"
206+
207+
yaml = f"""
208+
---
209+
license: creativeml-openrail-m
210+
base_model: {base_model}
211+
tags:
212+
- stable-diffusion
213+
- stable-diffusion-diffusers
214+
- text-to-image
215+
- diffusers
216+
- controlnet
217+
inference: true
218+
---
219+
"""
220+
model_card = f"""
221+
# controlnet-{repo_id}
222+
223+
These are controlnet weights trained on {base_model} with new type of conditioning.
224+
{img_str}
225+
"""
226+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
227+
f.write(yaml + model_card)
228+
229+
180230
def parse_args(input_args=None):
181231
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
182232
parser.add_argument(
@@ -943,6 +993,7 @@ def load_model_hook(models, input_dir):
943993
disable=not accelerator.is_local_main_process,
944994
)
945995

996+
image_logs = None
946997
for epoch in range(first_epoch, args.num_train_epochs):
947998
for step, batch in enumerate(train_dataloader):
948999
with accelerator.accumulate(controlnet):
@@ -1014,7 +1065,7 @@ def load_model_hook(models, input_dir):
10141065
logger.info(f"Saved state to {save_path}")
10151066

10161067
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1017-
log_validation(
1068+
image_logs = log_validation(
10181069
vae,
10191070
text_encoder,
10201071
tokenizer,
@@ -1040,6 +1091,12 @@ def load_model_hook(models, input_dir):
10401091
controlnet.save_pretrained(args.output_dir)
10411092

10421093
if args.push_to_hub:
1094+
save_model_card(
1095+
repo_id,
1096+
image_logs=image_logs,
1097+
base_model=args.pretrained_model_name_or_path,
1098+
repo_folder=args.output_dir,
1099+
)
10431100
upload_folder(
10441101
repo_id=repo_id,
10451102
folder_path=args.output_dir,

0 commit comments

Comments
 (0)