Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
flashpack`.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
Expand Down Expand Up @@ -959,6 +964,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_flashpack = kwargs.pop("use_flashpack", False)
quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
Expand Down Expand Up @@ -1177,6 +1183,69 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
else:
if use_flashpack:
try:
from flashpack import assign_from_file
except ImportError:
pass
else:
flashpack_weights_name = _add_variant("model.flashpack", variant)

try:
flashpack_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=flashpack_weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
except EnvironmentError:
pass
else:
dtype_orig = None
if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None):
if not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be a `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)

with no_init_weights():
model = cls.from_config(config, **unused_kwargs)

if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

# flashpack requires a single dtype across all parameters

try:
assign_from_file(model, flashpack_file)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)

if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None):
model = model.to(torch_dtype)

model.eval()

if output_loading_info:
loading_info = {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
return model, loading_info

return model

except Exception:
pass
# in the case it is sharded, we have already the index
if is_sharded:
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ def load_sub_model(
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
use_safetensors: bool,
use_flashpack: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
quantization_config: Optional[Any] = None,
Expand Down Expand Up @@ -832,6 +833,9 @@ def load_sub_model(
loading_kwargs["variant"] = model_variants.pop(name, None)
loading_kwargs["use_safetensors"] = use_safetensors

if is_diffusers_model:
loading_kwargs["use_flashpack"] = use_flashpack

if from_flax:
loading_kwargs["from_flax"] = True

Expand Down
7 changes: 7 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
flashpack`.
use_onnx (`bool`, *optional*, defaults to `None`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
Expand Down Expand Up @@ -755,6 +760,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
variant = kwargs.pop("variant", None)
dduf_file = kwargs.pop("dduf_file", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_flashpack = kwargs.pop("use_flashpack", False)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)
Expand Down Expand Up @@ -1039,6 +1045,7 @@ def load_module(name, value):
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
use_safetensors=use_safetensors,
use_flashpack=use_flashpack,
dduf_entries=dduf_entries,
provider_options=provider_options,
quantization_config=quantization_config,
Expand Down