Skip to content

Commit e5810e6

Browse files
patrickvonplatenpcuencapatil-surajwilliamberman
authored
[Variant] Add "variant" as input kwarg so to have better UX when downloading no_ema or fp16 weights (#2305)
* [Variant] Add variant loading mechanism * clean * improve further * up * add tests * add some first tests * up * up * use path splittetx * add deprecate * deprecation warnings * improve docs * up * up * up * fix tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * correct code format * fix warning * finish * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * Update docs/source/en/using-diffusers/loading.mdx Co-authored-by: Suraj Patil <[email protected]> * Apply suggestions from code review Co-authored-by: Will Berman <[email protected]> Co-authored-by: Suraj Patil <[email protected]> * correct loading docs * finish --------- Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Will Berman <[email protected]>
1 parent e3ddbe2 commit e5810e6

File tree

7 files changed

+773
-106
lines changed

7 files changed

+773
-106
lines changed

docs/source/en/using-diffusers/loading.mdx

Lines changed: 330 additions & 53 deletions
Large diffs are not rendered by default.

src/diffusers/models/modeling_utils.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,21 @@
1616

1717
import inspect
1818
import os
19+
import warnings
1920
from functools import partial
2021
from typing import Callable, List, Optional, Tuple, Union
2122

2223
import torch
2324
from huggingface_hub import hf_hub_download
2425
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
26+
from packaging import version
2527
from requests import HTTPError
2628
from torch import Tensor, device
2729

2830
from .. import __version__
2931
from ..utils import (
3032
CONFIG_NAME,
33+
DEPRECATED_REVISION_ARGS,
3134
DIFFUSERS_CACHE,
3235
FLAX_WEIGHTS_NAME,
3336
HF_HUB_OFFLINE,
@@ -89,12 +92,12 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
8992
return first_tuple[1].dtype
9093

9194

92-
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
95+
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
9396
"""
9497
Reads a checkpoint file, returning properly formatted errors if they arise.
9598
"""
9699
try:
97-
if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
100+
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
98101
return torch.load(checkpoint_file, map_location="cpu")
99102
else:
100103
return safetensors.torch.load_file(checkpoint_file, device="cpu")
@@ -141,6 +144,15 @@ def load(module: torch.nn.Module, prefix=""):
141144
return error_msgs
142145

143146

147+
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
148+
if variant is not None:
149+
splits = weights_name.split(".")
150+
splits = splits[:-1] + [variant] + splits[-1:]
151+
weights_name = ".".join(splits)
152+
153+
return weights_name
154+
155+
144156
class ModelMixin(torch.nn.Module):
145157
r"""
146158
Base class for all models.
@@ -250,6 +262,7 @@ def save_pretrained(
250262
is_main_process: bool = True,
251263
save_function: Callable = None,
252264
safe_serialization: bool = False,
265+
variant: Optional[str] = None,
253266
):
254267
"""
255268
Save a model and its configuration file to a directory, so that it can be re-loaded using the
@@ -268,6 +281,8 @@ def save_pretrained(
268281
`DIFFUSERS_SAVE_MODE`.
269282
safe_serialization (`bool`, *optional*, defaults to `False`):
270283
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
284+
variant (`str`, *optional*):
285+
If specified, weights are saved in the format pytorch_model.<variant>.bin.
271286
"""
272287
if safe_serialization and not is_safetensors_available():
273288
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
@@ -292,6 +307,7 @@ def save_pretrained(
292307
state_dict = model_to_save.state_dict()
293308

294309
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
310+
weights_name = _add_variant(weights_name, variant)
295311

296312
# Save the model
297313
save_function(state_dict, os.path.join(save_directory, weights_name))
@@ -371,6 +387,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
371387
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
372388
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
373389
setting this argument to `True` will raise an error.
390+
variant (`str`, *optional*):
391+
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
392+
ignored when using `from_flax`.
374393
375394
<Tip>
376395
@@ -401,6 +420,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
401420
subfolder = kwargs.pop("subfolder", None)
402421
device_map = kwargs.pop("device_map", None)
403422
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
423+
variant = kwargs.pop("variant", None)
404424

405425
if low_cpu_mem_usage and not is_accelerate_available():
406426
low_cpu_mem_usage = False
@@ -488,7 +508,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
488508
try:
489509
model_file = _get_model_file(
490510
pretrained_model_name_or_path,
491-
weights_name=SAFETENSORS_WEIGHTS_NAME,
511+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
492512
cache_dir=cache_dir,
493513
force_download=force_download,
494514
resume_download=resume_download,
@@ -504,7 +524,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
504524
if model_file is None:
505525
model_file = _get_model_file(
506526
pretrained_model_name_or_path,
507-
weights_name=WEIGHTS_NAME,
527+
weights_name=_add_variant(WEIGHTS_NAME, variant),
508528
cache_dir=cache_dir,
509529
force_download=force_download,
510530
resume_download=resume_download,
@@ -538,7 +558,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
538558
# if device_map is None, load the state dict and move the params from meta device to the cpu
539559
if device_map is None:
540560
param_device = "cpu"
541-
state_dict = load_state_dict(model_file)
561+
state_dict = load_state_dict(model_file, variant=variant)
542562
# move the params from meta device to cpu
543563
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
544564
if len(missing_keys) > 0:
@@ -587,7 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
587607
)
588608
model = cls.from_config(config, **unused_kwargs)
589609

590-
state_dict = load_state_dict(model_file)
610+
state_dict = load_state_dict(model_file, variant=variant)
591611

592612
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
593613
model,
@@ -800,8 +820,38 @@ def _get_model_file(
800820
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
801821
)
802822
else:
823+
# 1. First check if deprecated way of loading from branches is used
824+
if (
825+
revision in DEPRECATED_REVISION_ARGS
826+
and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
827+
and version.parse(version.parse(__version__).base_version) >= version.parse("0.15.0")
828+
):
829+
try:
830+
model_file = hf_hub_download(
831+
pretrained_model_name_or_path,
832+
filename=_add_variant(weights_name, revision),
833+
cache_dir=cache_dir,
834+
force_download=force_download,
835+
proxies=proxies,
836+
resume_download=resume_download,
837+
local_files_only=local_files_only,
838+
use_auth_token=use_auth_token,
839+
user_agent=user_agent,
840+
subfolder=subfolder,
841+
revision=revision,
842+
)
843+
warnings.warn(
844+
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
845+
FutureWarning,
846+
)
847+
return model_file
848+
except: # noqa: E722
849+
warnings.warn(
850+
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name)}' so that the correct variant file can be added.",
851+
FutureWarning,
852+
)
803853
try:
804-
# Load from URL or cache if already cached
854+
# 2. Load model file as usual
805855
model_file = hf_hub_download(
806856
pretrained_model_name_or_path,
807857
filename=weights_name,

0 commit comments

Comments
 (0)