Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
b566c95
feat: support saving a model in sharded checkpoints.
sayakpaul May 1, 2024
8605909
feat: make loading of sharded checkpoints work.
sayakpaul May 1, 2024
885d5b6
add tests
sayakpaul May 1, 2024
560fe32
cleanse the loading logic a bit more.
sayakpaul May 1, 2024
fc5d837
more resilience while loading from the Hub.
sayakpaul May 1, 2024
0d3b9e1
parallelize shard downloads by using snapshot_download()/
sayakpaul May 1, 2024
df8e945
default to a shard size.
sayakpaul May 1, 2024
6eff632
more fix
sayakpaul May 1, 2024
ed83244
Empty-Commit
sayakpaul May 1, 2024
642ee39
debug
sayakpaul May 1, 2024
36de0c4
fix
sayakpaul May 1, 2024
cc5656e
uality
sayakpaul May 1, 2024
8898717
more debugging
sayakpaul May 1, 2024
2dfb9a1
fix more
sayakpaul May 1, 2024
179495f
Merge branch 'main' into feat-save-sharded-ckpt
sayakpaul May 1, 2024
7e2c09b
merge main and fix conflicts.
sayakpaul May 10, 2024
3535701
resolve conflicts.
sayakpaul May 15, 2024
5ae8e46
initial comments from Benjamin
sayakpaul May 15, 2024
aefd0db
move certain methods to loading_utils
sayakpaul May 15, 2024
80005be
add test to check if the correct number of shards are present.
sayakpaul May 15, 2024
d144526
add a test to check if loading of sharded checkpoints from the Hub is…
sayakpaul May 15, 2024
8e52c6d
clarify the unit when passed as an int.
sayakpaul May 15, 2024
9d2f19a
Merge branch 'main' into feat-save-sharded-ckpt
sayakpaul May 17, 2024
a8f5c03
Merge branch 'main' into feat-save-sharded-ckpt
sayakpaul May 21, 2024
c917be2
Merge branch 'main' into feat-save-sharded-ckpt
sayakpaul May 22, 2024
7cf5340
Merge branch 'main' into feat-save-sharded-ckpt
sayakpaul May 27, 2024
d6f9b17
use hf_hub for sharding.
sayakpaul May 27, 2024
1ae5987
remove unnecessary code
sayakpaul May 27, 2024
fda5d99
Merge branch 'main' into feat-save-sharded-ckpt
sayakpaul May 27, 2024
2ec27d5
remove unnecessary function
sayakpaul May 27, 2024
c98d779
Merge branch 'main' into feat-save-sharded-ckpt
sayakpaul May 28, 2024
ed289c9
lucain's comments.
sayakpaul May 28, 2024
9acbbea
fixes
sayakpaul May 28, 2024
7326e25
address high-level comments.
sayakpaul May 28, 2024
0706cae
fix test
sayakpaul May 28, 2024
711fd50
subfolder shenanigans./
sayakpaul May 28, 2024
32419ac
Update src/diffusers/utils/hub_utils.py
sayakpaul May 28, 2024
b03e13c
Merge branch 'main' into feat-save-sharded-ckpt
sayakpaul May 28, 2024
cbfd70f
Apply suggestions from code review
sayakpaul May 29, 2024
868cfb6
remove _huggingface_hub_version as not needed.
sayakpaul May 29, 2024
13fd063
address more feedback.
sayakpaul May 29, 2024
0d32c45
Merge branch 'main' into feat-save-sharded-ckpt
sayakpaul May 29, 2024
bad44c0
add a test for local_files_only=True/
sayakpaul May 29, 2024
c779618
need hf hub to be at least 0.23.2
sayakpaul May 29, 2024
302d59d
style
sayakpaul May 29, 2024
ab3a5aa
Merge branch 'main' into feat-save-sharded-ckpt
sayakpaul May 29, 2024
7f88742
Merge branch 'main' into feat-save-sharded-ckpt
sayakpaul May 29, 2024
a7fc2ae
final comment.
sayakpaul May 29, 2024
f74fc67
Merge branch 'main' into feat-save-sharded-ckpt
sayakpaul Jun 3, 2024
38749fc
clean up subfolder.
sayakpaul Jun 3, 2024
edbd8de
deal with suffixes in code.
sayakpaul Jun 3, 2024
2ecd4da
_add_variant default.
sayakpaul Jun 3, 2024
d51d0b9
use weights_name_pattern
sayakpaul Jun 3, 2024
c2a71a0
remove add_suffix_keyword
sayakpaul Jun 3, 2024
a70e927
clean up downloading of sharded ckpts.
sayakpaul Jun 3, 2024
65da7dc
don't return something special when using index.json
sayakpaul Jun 3, 2024
5599388
fix more
sayakpaul Jun 3, 2024
7cdf958
don't use bare except
sayakpaul Jun 3, 2024
16dcdf8
remove comments and catch the errors better
sayakpaul Jun 3, 2024
737e627
fix a couple of things when using is_file()
sayakpaul Jun 3, 2024
491e1e2
resolve conflicts
sayakpaul Jun 7, 2024
2a18518
empty
sayakpaul Jun 7, 2024
51906c6
Merge branch 'main' into feat-save-sharded-ckpt
sayakpaul Jun 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.20.2",
"huggingface-hub>=0.23.2",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.20.2",
"huggingface-hub": "huggingface-hub>=0.23.2",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
Expand Down
55 changes: 55 additions & 0 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@
import inspect
import os
from collections import OrderedDict
from pathlib import Path
from typing import List, Optional, Union

import safetensors
import torch
from huggingface_hub.utils import EntryNotFoundError

from ..utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
WEIGHTS_INDEX_NAME,
_add_variant,
_get_model_file,
is_accelerate_available,
is_torch_version,
logging,
Expand Down Expand Up @@ -175,3 +181,52 @@ def load(module: torch.nn.Module, prefix: str = ""):
load(model_to_load)

return error_msgs


def _fetch_index_file(
is_local,
pretrained_model_name_or_path,
subfolder,
use_safetensors,
cache_dir,
variant,
force_download,
resume_download,
proxies,
local_files_only,
token,
revision,
user_agent,
commit_hash,
):
if is_local:
index_file = Path(
pretrained_model_name_or_path,
subfolder or "",
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
)
else:
index_file_in_repo = Path(
subfolder or "",
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
).as_posix()
try:
index_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=index_file_in_repo,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
index_file = Path(index_file)
except (EntryNotFoundError, EnvironmentError):
index_file = None

return index_file
137 changes: 121 additions & 16 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import inspect
import itertools
import json
import os
import re
from collections import OrderedDict
Expand All @@ -25,17 +26,20 @@

import safetensors
import torch
from huggingface_hub import create_repo
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn

from .. import __version__
from ..utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
_add_variant,
_get_checkpoint_shard_files,
_get_model_file,
deprecate,
is_accelerate_available,
Expand All @@ -49,6 +53,7 @@
)
from .model_loading_utils import (
_determine_device_map,
_fetch_index_file,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_state_dict,
Expand All @@ -57,6 +62,8 @@

logger = logging.get_logger(__name__)

_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")


if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT = True
Expand Down Expand Up @@ -263,6 +270,7 @@ def save_pretrained(
save_function: Optional[Callable] = None,
safe_serialization: bool = True,
variant: Optional[str] = None,
max_shard_size: Union[int, str] = "5GB",
push_to_hub: bool = False,
**kwargs,
):
Expand All @@ -285,6 +293,10 @@ def save_pretrained(
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
max_shard_size (`int` or `str`, defaults to `"5GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
If expressed as an integer, the unit is bytes.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
Expand All @@ -296,6 +308,14 @@ def save_pretrained(
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return

weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weight_name_split = weights_name.split(".")
if len(weight_name_split) in [2, 3]:
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
else:
raise ValueError(f"Invalid {weights_name} provided.")

os.makedirs(save_directory, exist_ok=True)

if push_to_hub:
Expand All @@ -317,18 +337,58 @@ def save_pretrained(
# Save the model
state_dict = model_to_save.state_dict()

weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)

# Save the model
if safe_serialization:
safetensors.torch.save_file(
state_dict, Path(save_directory, weights_name).as_posix(), metadata={"format": "pt"}
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)

# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)

for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor] for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)

if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
torch.save(state_dict, Path(save_directory, weights_name).as_posix())

logger.info(f"Model weights saved in {Path(save_directory, weights_name).as_posix()}")
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")

if push_to_hub:
# Create a new empty model card and eventually tag it
Expand Down Expand Up @@ -566,6 +626,32 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
**kwargs,
)

# Determine if we're loading from a directory of sharded checkpoints.
is_sharded = False
index_file = None
is_local = os.path.isdir(pretrained_model_name_or_path)
index_file = _fetch_index_file(
is_local=is_local,
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder or "",
use_safetensors=use_safetensors,
cache_dir=cache_dir,
variant=variant,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
user_agent=user_agent,
commit_hash=commit_hash,
)
if index_file is not None and index_file.is_file():
is_sharded = True

if is_sharded and from_flax:
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")

# load model
model_file = None
if from_flax:
Expand All @@ -590,7 +676,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

model = load_flax_checkpoint_in_pytorch_model(model, model_file)
else:
if use_safetensors:
if is_sharded:
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_file,
cache_dir=cache_dir,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder or "",
)

elif use_safetensors and not is_sharded:
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
Expand All @@ -606,11 +706,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
user_agent=user_agent,
commit_hash=commit_hash,
)

except IOError as e:
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
if not allow_pickle:
raise e
pass
if model_file is None:
raise
logger.warning(
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)

if model_file is None and not is_sharded:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
Expand All @@ -632,7 +737,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model = cls.from_config(config, **unused_kwargs)

# if device_map is None, load the state dict and move the params from meta device to the cpu
if device_map is None:
if device_map is None and not is_sharded:
param_device = "cpu"
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
Expand Down Expand Up @@ -670,7 +775,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
try:
accelerate.load_checkpoint_and_dispatch(
model,
model_file,
model_file if not is_sharded else sharded_ckpt_cached_folder,
device_map,
max_memory=max_memory,
offload_folder=offload_folder,
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
MIN_PEFT_VERSION,
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
SAFETENSORS_WEIGHTS_NAME,
USE_PEFT_BACKEND,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from .deprecation_utils import deprecate
Expand All @@ -40,6 +42,7 @@
from .hub_utils import (
PushToHubMixin,
_add_variant,
_get_checkpoint_shard_files,
_get_model_file,
extract_commit_hash,
http_user_agent,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@

CONFIG_NAME = "config.json"
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.bin.index.json"
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
SAFETENSORS_FILE_EXTENSION = "safetensors"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
Expand Down
Loading