16
16
17
17
import inspect
18
18
import os
19
+ import warnings
19
20
from functools import partial
20
21
from typing import Callable , List , Optional , Tuple , Union
21
22
22
23
import torch
23
24
from huggingface_hub import hf_hub_download
24
25
from huggingface_hub .utils import EntryNotFoundError , RepositoryNotFoundError , RevisionNotFoundError
26
+ from packaging import version
25
27
from requests import HTTPError
26
28
from torch import Tensor , device
27
29
28
30
from .. import __version__
29
31
from ..utils import (
30
32
CONFIG_NAME ,
33
+ DEPRECATED_REVISION_ARGS ,
31
34
DIFFUSERS_CACHE ,
32
35
FLAX_WEIGHTS_NAME ,
33
36
HF_HUB_OFFLINE ,
@@ -89,12 +92,12 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
89
92
return first_tuple [1 ].dtype
90
93
91
94
92
- def load_state_dict (checkpoint_file : Union [str , os .PathLike ]):
95
+ def load_state_dict (checkpoint_file : Union [str , os .PathLike ], variant : Optional [ str ] = None ):
93
96
"""
94
97
Reads a checkpoint file, returning properly formatted errors if they arise.
95
98
"""
96
99
try :
97
- if os .path .basename (checkpoint_file ) == WEIGHTS_NAME :
100
+ if os .path .basename (checkpoint_file ) == _add_variant ( WEIGHTS_NAME , variant ) :
98
101
return torch .load (checkpoint_file , map_location = "cpu" )
99
102
else :
100
103
return safetensors .torch .load_file (checkpoint_file , device = "cpu" )
@@ -141,6 +144,15 @@ def load(module: torch.nn.Module, prefix=""):
141
144
return error_msgs
142
145
143
146
147
+ def _add_variant (weights_name : str , variant : Optional [str ] = None ) -> str :
148
+ if variant is not None :
149
+ splits = weights_name .split ("." )
150
+ splits = splits [:- 1 ] + [variant ] + splits [- 1 :]
151
+ weights_name = "." .join (splits )
152
+
153
+ return weights_name
154
+
155
+
144
156
class ModelMixin (torch .nn .Module ):
145
157
r"""
146
158
Base class for all models.
@@ -250,6 +262,7 @@ def save_pretrained(
250
262
is_main_process : bool = True ,
251
263
save_function : Callable = None ,
252
264
safe_serialization : bool = False ,
265
+ variant : Optional [str ] = None ,
253
266
):
254
267
"""
255
268
Save a model and its configuration file to a directory, so that it can be re-loaded using the
@@ -268,6 +281,8 @@ def save_pretrained(
268
281
`DIFFUSERS_SAVE_MODE`.
269
282
safe_serialization (`bool`, *optional*, defaults to `False`):
270
283
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
284
+ variant (`str`, *optional*):
285
+ If specified, weights are saved in the format pytorch_model.<variant>.bin.
271
286
"""
272
287
if safe_serialization and not is_safetensors_available ():
273
288
raise ImportError ("`safe_serialization` requires the `safetensors library: `pip install safetensors`." )
@@ -292,6 +307,7 @@ def save_pretrained(
292
307
state_dict = model_to_save .state_dict ()
293
308
294
309
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
310
+ weights_name = _add_variant (weights_name , variant )
295
311
296
312
# Save the model
297
313
save_function (state_dict , os .path .join (save_directory , weights_name ))
@@ -371,6 +387,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
371
387
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
372
388
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
373
389
setting this argument to `True` will raise an error.
390
+ variant (`str`, *optional*):
391
+ If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
392
+ ignored when using `from_flax`.
374
393
375
394
<Tip>
376
395
@@ -401,6 +420,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
401
420
subfolder = kwargs .pop ("subfolder" , None )
402
421
device_map = kwargs .pop ("device_map" , None )
403
422
low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT )
423
+ variant = kwargs .pop ("variant" , None )
404
424
405
425
if low_cpu_mem_usage and not is_accelerate_available ():
406
426
low_cpu_mem_usage = False
@@ -488,7 +508,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
488
508
try :
489
509
model_file = _get_model_file (
490
510
pretrained_model_name_or_path ,
491
- weights_name = SAFETENSORS_WEIGHTS_NAME ,
511
+ weights_name = _add_variant ( SAFETENSORS_WEIGHTS_NAME , variant ) ,
492
512
cache_dir = cache_dir ,
493
513
force_download = force_download ,
494
514
resume_download = resume_download ,
@@ -504,7 +524,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
504
524
if model_file is None :
505
525
model_file = _get_model_file (
506
526
pretrained_model_name_or_path ,
507
- weights_name = WEIGHTS_NAME ,
527
+ weights_name = _add_variant ( WEIGHTS_NAME , variant ) ,
508
528
cache_dir = cache_dir ,
509
529
force_download = force_download ,
510
530
resume_download = resume_download ,
@@ -538,7 +558,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
538
558
# if device_map is None, load the state dict and move the params from meta device to the cpu
539
559
if device_map is None :
540
560
param_device = "cpu"
541
- state_dict = load_state_dict (model_file )
561
+ state_dict = load_state_dict (model_file , variant = variant )
542
562
# move the params from meta device to cpu
543
563
missing_keys = set (model .state_dict ().keys ()) - set (state_dict .keys ())
544
564
if len (missing_keys ) > 0 :
@@ -587,7 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
587
607
)
588
608
model = cls .from_config (config , ** unused_kwargs )
589
609
590
- state_dict = load_state_dict (model_file )
610
+ state_dict = load_state_dict (model_file , variant = variant )
591
611
592
612
model , missing_keys , unexpected_keys , mismatched_keys , error_msgs = cls ._load_pretrained_model (
593
613
model ,
@@ -800,8 +820,38 @@ def _get_model_file(
800
820
f"Error no file named { weights_name } found in directory { pretrained_model_name_or_path } ."
801
821
)
802
822
else :
823
+ # 1. First check if deprecated way of loading from branches is used
824
+ if (
825
+ revision in DEPRECATED_REVISION_ARGS
826
+ and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME )
827
+ and version .parse (version .parse (__version__ ).base_version ) >= version .parse ("0.15.0" )
828
+ ):
829
+ try :
830
+ model_file = hf_hub_download (
831
+ pretrained_model_name_or_path ,
832
+ filename = _add_variant (weights_name , revision ),
833
+ cache_dir = cache_dir ,
834
+ force_download = force_download ,
835
+ proxies = proxies ,
836
+ resume_download = resume_download ,
837
+ local_files_only = local_files_only ,
838
+ use_auth_token = use_auth_token ,
839
+ user_agent = user_agent ,
840
+ subfolder = subfolder ,
841
+ revision = revision ,
842
+ )
843
+ warnings .warn (
844
+ f"Loading the variant { revision } from { pretrained_model_name_or_path } via `revision='{ revision } '` is deprecated. Loading instead from `revision='main'` with `variant={ revision } `. Loading model variants via `revision='{ revision } '` will be removed in diffusers v1. Please use `variant='{ revision } '` instead." ,
845
+ FutureWarning ,
846
+ )
847
+ return model_file
848
+ except : # noqa: E722
849
+ warnings .warn (
850
+ f"You are loading the variant { revision } from { pretrained_model_name_or_path } via `revision='{ revision } '`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{ revision } '` instead. However, it appears that { pretrained_model_name_or_path } currently does not have a { _add_variant (weights_name )} file in the 'main' branch of { pretrained_model_name_or_path } . \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{ pretrained_model_name_or_path } is missing { _add_variant (weights_name )} ' so that the correct variant file can be added." ,
851
+ FutureWarning ,
852
+ )
803
853
try :
804
- # Load from URL or cache if already cached
854
+ # 2. Load model file as usual
805
855
model_file = hf_hub_download (
806
856
pretrained_model_name_or_path ,
807
857
filename = weights_name ,
0 commit comments