Skip to content

Update HF mixin #910

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 250 additions & 0 deletions examples/save_load_model_and_share_with_hf_hub.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import segmentation_models_pytorch as smp"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save to local directory and load back"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading weights from local directory\n"
]
}
],
"source": [
"model = smp.Unet()\n",
"\n",
"# save the model\n",
"model.save_pretrained(\"saved-model-dir/unet/\")\n",
"\n",
"# load the model\n",
"restored_model = smp.from_pretrained(\"saved-model-dir/unet/\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save model with additional metadata"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model = smp.Unet()\n",
"\n",
"# save the model\n",
"model.save_pretrained(\n",
" \"saved-model-dir/unet-with-metadata/\",\n",
"\n",
" # additional information to be saved with the model\n",
" # only \"dataset\" and \"metrics\" are supported\n",
" dataset=\"PASCAL VOC\", # only string name is supported\n",
" metrics={ # should be a dictionary with metric name as key and metric value as value\n",
" \"mIoU\": 0.95,\n",
" \"accuracy\": 0.96\n",
" }\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---\n",
"library_name: segmentation-models-pytorch\n",
"license: mit\n",
"pipeline_tag: image-segmentation\n",
"tags:\n",
"- semantic-segmentation\n",
"- pytorch\n",
"- segmentation-models-pytorch\n",
"languages:\n",
"- python\n",
"---\n",
"# Unet Model Card\n",
"\n",
"Table of Contents:\n",
"- [Load trained model](#load-trained-model)\n",
"- [Model init parameters](#model-init-parameters)\n",
"- [Model metrics](#model-metrics)\n",
"- [Dataset](#dataset)\n",
"\n",
"## Load trained model\n",
"```python\n",
"import segmentation_models_pytorch as smp\n",
"\n",
"model = smp.from_pretrained(\"<save-directory-or-this-repo>\")\n",
"```\n",
"\n",
"## Model init parameters\n",
"```python\n",
"model_init_params = {\n",
" \"encoder_name\": \"resnet34\",\n",
" \"encoder_depth\": 5,\n",
" \"encoder_weights\": \"imagenet\",\n",
" \"decoder_use_batchnorm\": True,\n",
" \"decoder_channels\": (256, 128, 64, 32, 16),\n",
" \"decoder_attention_type\": None,\n",
" \"in_channels\": 3,\n",
" \"classes\": 1,\n",
" \"activation\": None,\n",
" \"aux_params\": None\n",
"}\n",
"```\n",
"\n",
"## Model metrics\n",
"```json\n",
"{\n",
" \"mIoU\": 0.95,\n",
" \"accuracy\": 0.96\n",
"}\n",
"```\n",
"\n",
"## Dataset\n",
"Dataset name: PASCAL VOC\n",
"\n",
"## More Information\n",
"- Library: https://github.com/qubvel/segmentation_models.pytorch\n",
"- Docs: https://smp.readthedocs.io/en/latest/\n",
"\n",
"This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin)"
]
}
],
"source": [
"!cat \"saved-model-dir/unet-with-metadata/README.md\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Share model with HF Hub"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "075ae026811542bdb4030e53b943efc7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from huggingface_hub import notebook_login\n",
"\n",
"# You only need to run this once on the machine,\n",
"# the token will be stored for later use\n",
"notebook_login()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2921a81d7fd747939b4a425cc17d6104",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/97.8M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-metadata/commit/9f821c7bc3a12db827c0da96a31f354ec6ba5253', commit_message='Push model using huggingface_hub.', commit_description='', oid='9f821c7bc3a12db827c0da96a31f354ec6ba5253', pr_url=None, pr_revision=None, pr_num=None)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = smp.Unet()\n",
"\n",
"# save the model and share it on the HF Hub (https://huggingface.co/models)\n",
"model.save_pretrained(\n",
" \"qubvel-hf/unet-with-metadata/\",\n",
" push_to_hub=True, # <---------- push the model to the hub\n",
" private=False, # <---------- make the model private or or public\n",
" dataset=\"PASCAL VOC\",\n",
" metrics={\n",
" \"mIoU\": 0.95,\n",
" \"accuracy\": 0.96\n",
" }\n",
")\n",
"\n",
"# see result here https://huggingface.co/qubvel-hf/unet-with-metadata"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ torchvision>=0.5.0
pretrainedmodels==0.7.4
efficientnet-pytorch==0.7.1
timm==0.9.7
huggingface_hub>=0.24.6

tqdm
pillow
Expand Down
5 changes: 5 additions & 0 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from . import datasets
from . import encoders
from . import decoders
Expand All @@ -20,6 +22,9 @@
from typing import Optional as _Optional
import torch as _torch

# Suppress the specific SyntaxWarning for `pretrainedmodels`
warnings.filterwarnings("ignore", message="is with a literal", category=SyntaxWarning)


def create_model(
arch: str,
Expand Down
68 changes: 26 additions & 42 deletions segmentation_models_pytorch/base/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
```python
import segmentation_models_pytorch as smp

model = smp.{{ model_name }}.from_pretrained("{{ save_directory | default("<save-directory-or-repo>", true)}}")
model = smp.from_pretrained("<save-directory-or-this-repo>")
```

## Model init parameters
Expand Down Expand Up @@ -61,23 +61,22 @@ def _format_parameters(parameters: dict):

class SMPHubMixin(PyTorchModelHubMixin):
def generate_model_card(self, *args, **kwargs) -> ModelCard:
model_parameters_json = _format_parameters(self._hub_mixin_config)
directory = self._save_directory if hasattr(self, "_save_directory") else None
repo_id = self._repo_id if hasattr(self, "_repo_id") else None
repo_or_directory = repo_id if repo_id is not None else directory

metrics = self._metrics if hasattr(self, "_metrics") else None
dataset = self._dataset if hasattr(self, "_dataset") else None
model_parameters_json = _format_parameters(self.config)
metrics = kwargs.get("metrics", None)
dataset = kwargs.get("dataset", None)

if metrics is not None:
metrics = json.dumps(metrics, indent=4)
metrics = f"```json\n{metrics}\n```"

tags = self._hub_mixin_info.model_card_data.get("tags", []) or []
tags.extend(["segmentation-models-pytorch", "semantic-segmentation", "pytorch"])

model_card_data = ModelCardData(
languages=["python"],
library_name="segmentation-models-pytorch",
license="mit",
tags=["semantic-segmentation", "pytorch", "segmentation-models-pytorch"],
tags=tags,
pipeline_tag="image-segmentation",
)
model_card = ModelCard.from_template(
Expand All @@ -86,64 +85,49 @@ def generate_model_card(self, *args, **kwargs) -> ModelCard:
repo_url="https://github.com/qubvel/segmentation_models.pytorch",
docs_url="https://smp.readthedocs.io/en/latest/",
model_parameters=model_parameters_json,
save_directory=repo_or_directory,
model_name=self.__class__.__name__,
metrics=metrics,
dataset=dataset,
)
return model_card

def _set_attrs_from_kwargs(self, attrs, kwargs):
for attr in attrs:
if attr in kwargs:
setattr(self, f"_{attr}", kwargs.pop(attr))

def _del_attrs(self, attrs):
for attr in attrs:
if hasattr(self, f"_{attr}"):
delattr(self, f"_{attr}")

@wraps(PyTorchModelHubMixin.save_pretrained)
def save_pretrained(
self, save_directory: Union[str, Path], *args, **kwargs
) -> Optional[str]:
# set additional attributes to be used in generate_model_card
self._save_directory = save_directory
self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs)
model_card_kwargs = kwargs.pop("model_card_kwargs", {})
if "dataset" in kwargs:
model_card_kwargs["dataset"] = kwargs.pop("dataset")
if "metrics" in kwargs:
model_card_kwargs["metrics"] = kwargs.pop("metrics")
kwargs["model_card_kwargs"] = model_card_kwargs

# set additional attribute to be used in from_pretrained
self._hub_mixin_config["_model_class"] = self.__class__.__name__
# set additional attribute to be able to deserialize the model
self.config["_model_class"] = self.__class__.__name__

try:
# call the original save_pretrained
result = super().save_pretrained(save_directory, *args, **kwargs)
finally:
# delete the additional attributes
self._del_attrs(["save_directory", "metrics", "dataset"])
self._hub_mixin_config.pop("_model_class", None)
self.config.pop("_model_class", None)

return result

@wraps(PyTorchModelHubMixin.push_to_hub)
def push_to_hub(self, repo_id: str, *args, **kwargs):
self._repo_id = repo_id
self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs)
result = super().push_to_hub(repo_id, *args, **kwargs)
self._del_attrs(["repo_id", "metrics", "dataset"])
return result

@property
def config(self):
def config(self) -> dict:
return self._hub_mixin_config


@wraps(PyTorchModelHubMixin.from_pretrained)
def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
config_path = hf_hub_download(
pretrained_model_name_or_path,
filename="config.json",
revision=kwargs.get("revision", None),
)
config_path = Path(pretrained_model_name_or_path) / "config.json"
if not config_path.exists():
config_path = hf_hub_download(
pretrained_model_name_or_path,
filename="config.json",
revision=kwargs.get("revision", None),
)

with open(config_path, "r") as f:
config = json.load(f)
model_class_name = config.pop("_model_class")
Expand Down