diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e7cfcd71062f..c9fabf93253b 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -398,6 +398,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_state_dict (`bool`, *optional*): + If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU + RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to + `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the @@ -439,6 +448,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) @@ -510,6 +522,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P revision=revision, subfolder=subfolder, device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, user_agent=user_agent, **kwargs, ) @@ -614,7 +629,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by default the device_map is None and the weights are loaded on the CPU - accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype) + accelerate.load_checkpoint_and_dispatch( + model, + model_file, + device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + ) loading_info = { "missing_keys": [], diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index fa71a181f521..aed1139a2a16 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -354,6 +354,9 @@ def load_sub_model( provider: Any, sess_options: Any, device_map: Optional[Union[Dict[str, torch.device], str]], + max_memory: Optional[Dict[Union[int, str], Union[int, str]]], + offload_folder: Optional[Union[str, os.PathLike]], + offload_state_dict: bool, model_variants: Dict[str, str], name: str, from_flax: bool, @@ -416,6 +419,9 @@ def load_sub_model( # This makes sure that the weights won't be initialized which significantly speeds up loading. if is_diffusers_model or is_transformers_model: loading_kwargs["device_map"] = device_map + loading_kwargs["max_memory"] = max_memory + loading_kwargs["offload_folder"] = offload_folder + loading_kwargs["offload_state_dict"] = offload_state_dict loading_kwargs["variant"] = model_variants.pop(name, None) if from_flax: loading_kwargs["from_flax"] = True @@ -808,6 +814,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_state_dict (`bool`, *optional*): + If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU + RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to + `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the @@ -873,6 +888,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P provider = kwargs.pop("provider", None) sess_options = kwargs.pop("sess_options", None) device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) @@ -1046,6 +1064,9 @@ def load_module(name, value): provider=provider, sess_options=sess_options, device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, model_variants=model_variants, name=name, from_flax=from_flax,