diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 90a700d9443f..aab4d16172d9 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -42,7 +42,7 @@ is_torch_version, logging, ) -from ..utils.hub_utils import PushToHubMixin +from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card logger = logging.get_logger(__name__) @@ -377,6 +377,11 @@ def save_pretrained( logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") if push_to_hub: + # Create a new empty model card and eventually tag it + model_card = load_or_create_model_card(repo_id, token=token) + model_card = populate_model_card(model_card) + model_card.save(os.path.join(save_directory, "README.md")) + self._upload_folder( save_directory, repo_id, diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 4d0bc8b13a92..06187645f000 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -60,6 +60,7 @@ logging, numpy_to_pil, ) +from ..utils.hub_utils import load_or_create_model_card, populate_model_card from ..utils.torch_utils import is_compiled_module @@ -725,6 +726,11 @@ def is_saveable_module(name, value): self.save_config(save_directory) if push_to_hub: + # Create a new empty model card and eventually tag it + model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True) + model_card = populate_model_card(model_card) + model_card.save(os.path.join(save_directory, "README.md")) + self._upload_folder( save_directory, repo_id, diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 6f311b957abf..c6a45b569218 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -28,7 +28,6 @@ ModelCard, ModelCardData, create_repo, - get_full_repo_name, hf_hub_download, upload_folder, ) @@ -67,7 +66,6 @@ logger = get_logger(__name__) -MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md" SESSION_ID = uuid4().hex @@ -95,7 +93,20 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: return ua -def create_model_card(args, model_name): +def load_or_create_model_card( + repo_id_or_path: Optional[str] = None, token: Optional[str] = None, is_pipeline: bool = False +) -> ModelCard: + """ + Loads or creates a model card. + + Args: + repo_id (`str`): + The repo_id where to look for the model card. + token (`str`, *optional*): + Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details. + is_pipeline (`bool`, *optional*): + Boolean to indicate if we're adding tag to a [`DiffusionPipeline`]. + """ if not is_jinja_available(): raise ValueError( "Modelcard rendering is based on Jinja templates." @@ -103,45 +114,24 @@ def create_model_card(args, model_name): " To install it, please run `pip install Jinja2`." ) - if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: - return - - hub_token = args.hub_token if hasattr(args, "hub_token") else None - repo_name = get_full_repo_name(model_name, token=hub_token) - - model_card = ModelCard.from_template( - card_data=ModelCardData( # Card metadata object that will be converted to YAML block - language="en", - license="apache-2.0", - library_name="diffusers", - tags=[], - datasets=args.dataset_name, - metrics=[], - ), - template_path=MODEL_CARD_TEMPLATE_PATH, - model_name=model_name, - repo_name=repo_name, - dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None, - learning_rate=args.learning_rate, - train_batch_size=args.train_batch_size, - eval_batch_size=args.eval_batch_size, - gradient_accumulation_steps=( - args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None - ), - adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None, - adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None, - adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None, - adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None, - lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None, - lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None, - ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None, - ema_power=args.ema_power if hasattr(args, "ema_power") else None, - ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None, - mixed_precision=args.mixed_precision, - ) - - card_path = os.path.join(args.output_dir, "README.md") - model_card.save(card_path) + try: + # Check if the model card is present on the remote repo + model_card = ModelCard.load(repo_id_or_path, token=token) + except EntryNotFoundError: + # Otherwise create a simple model card from template + component = "pipeline" if is_pipeline else "model" + model_description = f"This is the model card of a ๐Ÿงจ diffusers {component} that has been pushed on the Hub. This model card has been automatically generated." + card_data = ModelCardData() + model_card = ModelCard.from_template(card_data, model_description=model_description) + + return model_card + + +def populate_model_card(model_card: ModelCard) -> ModelCard: + """Populates the `model_card` with library name.""" + if model_card.data.library_name is None: + model_card.data.library_name = "diffusers" + return model_card def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None): @@ -435,6 +425,10 @@ def push_to_hub( """ repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id + # Create a new empty model card and eventually tag it + model_card = load_or_create_model_card(repo_id, token=token) + model_card = populate_model_card(model_card) + # Save all files. save_kwargs = {"safe_serialization": safe_serialization} if "Scheduler" not in self.__class__.__name__: @@ -443,6 +437,9 @@ def push_to_hub( with tempfile.TemporaryDirectory() as tmpdir: self.save_pretrained(tmpdir, **save_kwargs) + # Update model card if needed: + model_card.save(os.path.join(tmpdir, "README.md")) + return self._upload_folder( tmpdir, repo_id, diff --git a/src/diffusers/utils/model_card_template.md b/src/diffusers/utils/model_card_template.md deleted file mode 100644 index f19c85b0fcf2..000000000000 --- a/src/diffusers/utils/model_card_template.md +++ /dev/null @@ -1,50 +0,0 @@ ---- -{{ card_data }} ---- - - - -# {{ model_name | default("Diffusion Model") }} - -## Model description - -This diffusion model is trained with the [๐Ÿค— Diffusers](https://github.com/huggingface/diffusers) library -on the `{{ dataset_name }}` dataset. - -## Intended uses & limitations - -#### How to use - -```python -# TODO: add an example code snippet for running this diffusion pipeline -``` - -#### Limitations and bias - -[TODO: provide examples of latent issues and potential remediations] - -## Training data - -[TODO: describe the data used to train the model] - -### Training hyperparameters - -The following hyperparameters were used during training: -- learning_rate: {{ learning_rate }} -- train_batch_size: {{ train_batch_size }} -- eval_batch_size: {{ eval_batch_size }} -- gradient_accumulation_steps: {{ gradient_accumulation_steps }} -- optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }} -- lr_scheduler: {{ lr_scheduler }} -- lr_warmup_steps: {{ lr_warmup_steps }} -- ema_inv_gamma: {{ ema_inv_gamma }} -- ema_inv_gamma: {{ ema_power }} -- ema_inv_gamma: {{ ema_max_decay }} -- mixed_precision: {{ mixed_precision }} - -### Training results - -๐Ÿ“ˆ [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars) - - diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 5ea0d910f3a3..4464b420ac33 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -24,7 +24,8 @@ import numpy as np import requests_mock import torch -from huggingface_hub import delete_repo +from huggingface_hub import ModelCard, delete_repo +from huggingface_hub.utils import is_jinja_available from requests.exceptions import HTTPError from diffusers.models import UNet2DConditionModel @@ -732,3 +733,26 @@ def test_push_to_hub_in_organization(self): # Reset repo delete_repo(self.org_repo_id, token=TOKEN) + + @unittest.skipIf( + not is_jinja_available(), + reason="Model card tests cannot be performed without Jinja installed.", + ) + def test_push_to_hub_library_name(self): + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + model.push_to_hub(self.repo_id, token=TOKEN) + + model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data + assert model_card.library_name == "diffusers" + + # Reset repo + delete_repo(self.repo_id, token=TOKEN) diff --git a/tests/others/test_hub_utils.py b/tests/others/test_hub_utils.py index e8b8ea3a2fd9..7d01172125a8 100644 --- a/tests/others/test_hub_utils.py +++ b/tests/others/test_hub_utils.py @@ -15,37 +15,15 @@ import unittest from pathlib import Path from tempfile import TemporaryDirectory -from unittest.mock import Mock, patch -import diffusers.utils.hub_utils +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card class CreateModelCardTest(unittest.TestCase): - @patch("diffusers.utils.hub_utils.get_full_repo_name") - def test_create_model_card(self, repo_name_mock: Mock) -> None: - repo_name_mock.return_value = "full_repo_name" + def test_generate_model_card_with_library_name(self): with TemporaryDirectory() as tmpdir: - # Dummy args values - args = Mock() - args.output_dir = tmpdir - args.local_rank = 0 - args.hub_token = "hub_token" - args.dataset_name = "dataset_name" - args.learning_rate = 0.01 - args.train_batch_size = 100000 - args.eval_batch_size = 10000 - args.gradient_accumulation_steps = 0.01 - args.adam_beta1 = 0.02 - args.adam_beta2 = 0.03 - args.adam_weight_decay = 0.0005 - args.adam_epsilon = 0.000001 - args.lr_scheduler = 1 - args.lr_warmup_steps = 10 - args.ema_inv_gamma = 0.001 - args.ema_power = 0.1 - args.ema_max_decay = 0.2 - args.mixed_precision = True - - # Model card mush be rendered and saved - diffusers.utils.hub_utils.create_model_card(args, model_name="model_name") - self.assertTrue((Path(tmpdir) / "README.md").is_file()) + file_path = Path(tmpdir) / "README.md" + file_path.write_text("---\nlibrary_name: foo\n---\nContent\n") + model_card = load_or_create_model_card(file_path) + populate_model_card(model_card) + assert model_card.data.library_name == "foo" diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index ed2920cb0c73..e107c5772af9 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -13,7 +13,8 @@ import numpy as np import PIL.Image import torch -from huggingface_hub import delete_repo +from huggingface_hub import ModelCard, delete_repo +from huggingface_hub.utils import is_jinja_available from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer import diffusers @@ -1142,6 +1143,21 @@ def test_push_to_hub_in_organization(self): # Reset repo delete_repo(self.org_repo_id, token=TOKEN) + @unittest.skipIf( + not is_jinja_available(), + reason="Model card tests cannot be performed without Jinja installed.", + ) + def test_push_to_hub_library_name(self): + components = self.get_pipeline_components() + pipeline = StableDiffusionPipeline(**components) + pipeline.push_to_hub(self.repo_id, token=TOKEN) + + model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data + assert model_card.library_name == "diffusers" + + # Reset repo + delete_repo(self.repo_id, token=TOKEN) + # For SDXL and its derivative pipelines (such as ControlNet), we have the text encoders # and the tokenizers as optional components. So, we need to override the `test_save_load_optional_components()`