diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f06822c741ca..0b691d6be773 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -916,6 +916,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. + use_flashpack (`bool`, *optional*, defaults to `False`): + If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file + is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to + the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install + flashpack`. disable_mmap ('bool', *optional*, defaults to 'False'): Whether to disable mmap when loading a Safetensors model. This option can perform better when the model is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. @@ -959,6 +964,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) + use_flashpack = kwargs.pop("use_flashpack", False) quantization_config = kwargs.pop("quantization_config", None) dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) @@ -1177,6 +1183,69 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file) else: + if use_flashpack: + try: + from flashpack import assign_from_file + except ImportError: + pass + else: + flashpack_weights_name = _add_variant("model.flashpack", variant) + + try: + flashpack_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=flashpack_weights_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except EnvironmentError: + pass + else: + dtype_orig = None + if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None): + if not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be a `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + + with no_init_weights(): + model = cls.from_config(config, **unused_kwargs) + + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + # flashpack requires a single dtype across all parameters + + try: + assign_from_file(model, flashpack_file) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None): + model = model.to(torch_dtype) + + model.eval() + + if output_loading_info: + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + return model, loading_info + + return model + + except Exception: + pass # in the case it is sharded, we have already the index if is_sharded: resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 8868e942ce3d..5153195d5d2c 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -756,6 +756,7 @@ def load_sub_model( low_cpu_mem_usage: bool, cached_folder: Union[str, os.PathLike], use_safetensors: bool, + use_flashpack: bool, dduf_entries: Optional[Dict[str, DDUFEntry]], provider_options: Any, quantization_config: Optional[Any] = None, @@ -832,6 +833,9 @@ def load_sub_model( loading_kwargs["variant"] = model_variants.pop(name, None) loading_kwargs["use_safetensors"] = use_safetensors + if is_diffusers_model: + loading_kwargs["use_flashpack"] = use_flashpack + if from_flax: loading_kwargs["from_flax"] = True diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 392d5fb3feb4..c57581314fc3 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -693,6 +693,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. + use_flashpack (`bool`, *optional*, defaults to `False`): + If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file + is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to + the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install + flashpack`. use_onnx (`bool`, *optional*, defaults to `None`): If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is @@ -755,6 +760,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P variant = kwargs.pop("variant", None) dduf_file = kwargs.pop("dduf_file", None) use_safetensors = kwargs.pop("use_safetensors", None) + use_flashpack = kwargs.pop("use_flashpack", False) use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) quantization_config = kwargs.pop("quantization_config", None) @@ -1039,6 +1045,7 @@ def load_module(name, value): low_cpu_mem_usage=low_cpu_mem_usage, cached_folder=cached_folder, use_safetensors=use_safetensors, + use_flashpack=use_flashpack, dduf_entries=dduf_entries, provider_options=provider_options, quantization_config=quantization_config,