Skip to content

[Variant] Add "variant" as input kwarg so to have better UX when downloading no_ema or fp16 weights #2305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
ab9ef6a
[Variant] Add variant loading mechanism
patrickvonplaten Feb 9, 2023
91ee04e
clean
patrickvonplaten Feb 9, 2023
0b45377
improve further
patrickvonplaten Feb 9, 2023
cbe2066
up
patrickvonplaten Feb 10, 2023
c760708
add tests
patrickvonplaten Feb 10, 2023
8d77537
add some first tests
patrickvonplaten Feb 10, 2023
4f6d13c
up
patrickvonplaten Feb 10, 2023
e329951
up
patrickvonplaten Feb 10, 2023
847bc0f
Merge branch 'main' into add_variant
patrickvonplaten Feb 10, 2023
710480d
use path splittetx
patrickvonplaten Feb 13, 2023
b506882
Merge branch 'main' into add_variant
patrickvonplaten Feb 13, 2023
9262bbf
add deprecate
patrickvonplaten Feb 13, 2023
4a0ff60
deprecation warnings
patrickvonplaten Feb 13, 2023
04622d2
Merge branch 'add_variant' of https://github.com/huggingface/diffuser…
patrickvonplaten Feb 13, 2023
010f2ed
improve docs
patrickvonplaten Feb 14, 2023
bdebb36
up
patrickvonplaten Feb 14, 2023
73bf79f
up
patrickvonplaten Feb 14, 2023
48226f7
up
patrickvonplaten Feb 14, 2023
009ed74
Merge branch 'add_variant' of https://github.com/huggingface/diffuser…
patrickvonplaten Feb 14, 2023
23ace69
fix tests
patrickvonplaten Feb 14, 2023
9e09f62
Apply suggestions from code review
patrickvonplaten Feb 14, 2023
fb5d7b9
Merge branch 'add_variant' of https://github.com/huggingface/diffuser…
patrickvonplaten Feb 14, 2023
70cf040
Apply suggestions from code review
patrickvonplaten Feb 14, 2023
501446d
correct code format
patrickvonplaten Feb 14, 2023
081cd1a
Merge branch 'add_variant' of https://github.com/huggingface/diffuser…
patrickvonplaten Feb 14, 2023
61f7ff2
fix warning
patrickvonplaten Feb 14, 2023
69e3659
Merge branch 'main' into add_variant
patrickvonplaten Feb 14, 2023
81dc107
finish
patrickvonplaten Feb 14, 2023
f26abeb
Apply suggestions from code review
patrickvonplaten Feb 16, 2023
57fbe8f
Apply suggestions from code review
patrickvonplaten Feb 16, 2023
cbffa77
Update docs/source/en/using-diffusers/loading.mdx
patrickvonplaten Feb 16, 2023
fb22078
Apply suggestions from code review
patrickvonplaten Feb 16, 2023
dbdd126
correct loading docs
patrickvonplaten Feb 16, 2023
5bcd411
finish
patrickvonplaten Feb 16, 2023
f7b9fb7
Merge branch 'main' into add_variant
patrickvonplaten Feb 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
383 changes: 330 additions & 53 deletions docs/source/en/using-diffusers/loading.mdx

Large diffs are not rendered by default.

64 changes: 57 additions & 7 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@

import inspect
import os
import warnings
from functools import partial
from typing import Callable, List, Optional, Tuple, Union

import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from packaging import version
from requests import HTTPError
from torch import Tensor, device

from .. import __version__
from ..utils import (
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE,
Expand Down Expand Up @@ -89,12 +92,12 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
return first_tuple[1].dtype


def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could directly check the extension of checkpoint_file to know whether or not to load with safetensors. That way we wouldn't have to pass in variant

"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
try:
if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
return torch.load(checkpoint_file, map_location="cpu")
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
Expand Down Expand Up @@ -141,6 +144,15 @@ def load(module: torch.nn.Module, prefix=""):
return error_msgs


def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)

return weights_name


class ModelMixin(torch.nn.Module):
r"""
Base class for all models.
Expand Down Expand Up @@ -250,6 +262,7 @@ def save_pretrained(
is_main_process: bool = True,
save_function: Callable = None,
safe_serialization: bool = False,
variant: Optional[str] = None,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
Expand All @@ -268,6 +281,8 @@ def save_pretrained(
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
"""
if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
Expand All @@ -292,6 +307,7 @@ def save_pretrained(
state_dict = model_to_save.state_dict()

weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)

# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))
Expand Down Expand Up @@ -371,6 +387,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.

<Tip>

Expand Down Expand Up @@ -401,6 +420,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)

if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
Expand Down Expand Up @@ -488,7 +508,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=SAFETENSORS_WEIGHTS_NAME,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
Expand All @@ -504,7 +524,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=WEIGHTS_NAME,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
Expand Down Expand Up @@ -538,7 +558,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# if device_map is None, load the state dict and move the params from meta device to the cpu
if device_map is None:
param_device = "cpu"
state_dict = load_state_dict(model_file)
state_dict = load_state_dict(model_file, variant=variant)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
Expand Down Expand Up @@ -587,7 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)
model = cls.from_config(config, **unused_kwargs)

state_dict = load_state_dict(model_file)
state_dict = load_state_dict(model_file, variant=variant)

model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
Expand Down Expand Up @@ -800,8 +820,38 @@ def _get_model_file(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
else:
# 1. First check if deprecated way of loading from branches is used
if (
revision in DEPRECATED_REVISION_ARGS
and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
and version.parse(version.parse(__version__).base_version) >= version.parse("0.15.0")
):
try:
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=_add_variant(weights_name, revision),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
warnings.warn(
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
return model_file
except: # noqa: E722
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name)}' so that the correct variant file can be added.",
FutureWarning,
)
try:
# Load from URL or cache if already cached
# 2. Load model file as usual
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=weights_name,
Expand Down
Loading