diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 4745801c067b..79bc66288338 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -58,6 +58,7 @@ convert_unet_state_dict_to_peft, is_wandb_available, ) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -70,33 +71,20 @@ def save_model_card( repo_id: str, - images=None, - base_model=str, - dataset_name=str, - train_text_encoder=False, - repo_folder=None, - vae_path=None, + images: list = None, + base_model: str = None, + dataset_name: str = None, + train_text_encoder: bool = False, + repo_folder: str = None, + vae_path: str = None, ): img_str = "" - for i, image in enumerate(images): - image.save(os.path.join(repo_folder, f"image_{i}.png")) - img_str += f"![img_{i}](./image_{i}.png)\n" - - yaml = f""" ---- -license: creativeml-openrail-m -base_model: {base_model} -dataset: {dataset_name} -tags: -- stable-diffusion-xl -- stable-diffusion-xl-diffusers -- text-to-image -- diffusers -- lora -inference: true ---- - """ - model_card = f""" + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + model_description = f""" # LoRA text2image fine-tuning - {repo_id} These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n @@ -106,8 +94,19 @@ def save_model_card( Special VAE used for training: {vae_path}. """ - with open(os.path.join(repo_folder, "README.md"), "w") as f: - f.write(yaml + model_card) + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = ["stable-diffusion-xl", "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "lora"] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) def import_model_class_from_model_name_or_path(