Skip to content

Commit 08d6b15

Browse files
authored
Merge branch 'main' into ip-adapter-test-mixin
2 parents 91a35e3 + e6d1728 commit 08d6b15

File tree

3 files changed

+117
-40
lines changed

3 files changed

+117
-40
lines changed

docs/source/en/using-diffusers/ip_adapter.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ export_to_gif(frames, "gummy_bear.gif")
231231
</hfoption>
232232
</hfoptions>
233233

234+
> [!TIP]
235+
> While calling `load_ip_adapter()`, pass `low_cpu_mem_usage=True` to speed up the loading time.
236+
234237
## Specific use cases
235238

236239
IP-Adapter's image prompting and compatibility with other adapters and models makes it a versatile tool for a variety of use cases. This section covers some of the more popular applications of IP-Adapter, and we can't wait to see what you come up with!

src/diffusers/loaders/ip_adapter.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919
from huggingface_hub.utils import validate_hf_hub_args
2020
from safetensors import safe_open
2121

22+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
2223
from ..utils import (
2324
_get_model_file,
25+
is_accelerate_available,
26+
is_torch_version,
2427
is_transformers_available,
2528
logging,
2629
)
@@ -86,6 +89,11 @@ def load_ip_adapter(
8689
allowed by Git.
8790
subfolder (`str`, *optional*, defaults to `""`):
8891
The subfolder location of a model file within a larger model repository on the Hub or locally.
92+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
93+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
94+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
95+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
96+
argument to `True` will raise an error.
8997
"""
9098

9199
# handle the list inputs for multiple IP Adapters
@@ -116,6 +124,22 @@ def load_ip_adapter(
116124
local_files_only = kwargs.pop("local_files_only", None)
117125
token = kwargs.pop("token", None)
118126
revision = kwargs.pop("revision", None)
127+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
128+
129+
if low_cpu_mem_usage and not is_accelerate_available():
130+
low_cpu_mem_usage = False
131+
logger.warning(
132+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
133+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
134+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
135+
" install accelerate\n```\n."
136+
)
137+
138+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
139+
raise NotImplementedError(
140+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
141+
" `low_cpu_mem_usage=False`."
142+
)
119143

120144
user_agent = {
121145
"file_type": "attn_procs_weights",
@@ -165,6 +189,7 @@ def load_ip_adapter(
165189
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
166190
pretrained_model_name_or_path_or_dict,
167191
subfolder=Path(subfolder, "image_encoder").as_posix(),
192+
low_cpu_mem_usage=low_cpu_mem_usage,
168193
).to(self.device, dtype=self.dtype)
169194
self.register_modules(image_encoder=image_encoder)
170195
else:
@@ -175,9 +200,9 @@ def load_ip_adapter(
175200
feature_extractor = CLIPImageProcessor()
176201
self.register_modules(feature_extractor=feature_extractor)
177202

178-
# load ip-adapter into unet
203+
# load ip-adapter into unet
179204
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
180-
unet._load_ip_adapter_weights(state_dicts)
205+
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
181206

182207
def set_ip_adapter_scale(self, scale):
183208
"""

src/diffusers/loaders/unet.py

Lines changed: 87 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
_get_model_file,
3838
delete_adapter_layers,
3939
is_accelerate_available,
40+
is_torch_version,
4041
logging,
4142
set_adapter_layers,
4243
set_weights_and_activate_adapters,
@@ -168,15 +169,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
168169
"framework": "pytorch",
169170
}
170171

171-
if low_cpu_mem_usage and not is_accelerate_available():
172-
low_cpu_mem_usage = False
173-
logger.warning(
174-
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
175-
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
176-
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
177-
" install accelerate\n```\n."
178-
)
179-
180172
model_file = None
181173
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
182174
# Let's first try to load .safetensors weights
@@ -694,21 +686,42 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
694686
if hasattr(self, "peft_config"):
695687
self.peft_config.pop(adapter_name, None)
696688

697-
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
689+
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
690+
if low_cpu_mem_usage:
691+
if is_accelerate_available():
692+
from accelerate import init_empty_weights
693+
694+
else:
695+
low_cpu_mem_usage = False
696+
logger.warning(
697+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
698+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
699+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
700+
" install accelerate\n```\n."
701+
)
702+
703+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
704+
raise NotImplementedError(
705+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
706+
" `low_cpu_mem_usage=False`."
707+
)
708+
698709
updated_state_dict = {}
699710
image_projection = None
711+
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
700712

701713
if "proj.weight" in state_dict:
702714
# IP-Adapter
703715
num_image_text_embeds = 4
704716
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
705717
cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
706718

707-
image_projection = ImageProjection(
708-
cross_attention_dim=cross_attention_dim,
709-
image_embed_dim=clip_embeddings_dim,
710-
num_image_text_embeds=num_image_text_embeds,
711-
)
719+
with init_context():
720+
image_projection = ImageProjection(
721+
cross_attention_dim=cross_attention_dim,
722+
image_embed_dim=clip_embeddings_dim,
723+
num_image_text_embeds=num_image_text_embeds,
724+
)
712725

713726
for key, value in state_dict.items():
714727
diffusers_name = key.replace("proj", "image_embeds")
@@ -719,9 +732,10 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
719732
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
720733
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
721734

722-
image_projection = IPAdapterFullImageProjection(
723-
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
724-
)
735+
with init_context():
736+
image_projection = IPAdapterFullImageProjection(
737+
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
738+
)
725739

726740
for key, value in state_dict.items():
727741
diffusers_name = key.replace("proj.0", "ff.net.0.proj")
@@ -737,13 +751,14 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
737751
hidden_dims = state_dict["latents"].shape[2]
738752
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
739753

740-
image_projection = IPAdapterPlusImageProjection(
741-
embed_dims=embed_dims,
742-
output_dims=output_dims,
743-
hidden_dims=hidden_dims,
744-
heads=heads,
745-
num_queries=num_image_text_embeds,
746-
)
754+
with init_context():
755+
image_projection = IPAdapterPlusImageProjection(
756+
embed_dims=embed_dims,
757+
output_dims=output_dims,
758+
hidden_dims=hidden_dims,
759+
heads=heads,
760+
num_queries=num_image_text_embeds,
761+
)
747762

748763
for key, value in state_dict.items():
749764
diffusers_name = key.replace("0.to", "2.to")
@@ -765,20 +780,44 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
765780
else:
766781
updated_state_dict[diffusers_name] = value
767782

768-
image_projection.load_state_dict(updated_state_dict)
783+
if not low_cpu_mem_usage:
784+
image_projection.load_state_dict(updated_state_dict)
785+
else:
786+
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
787+
769788
return image_projection
770789

771-
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts):
790+
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
772791
from ..models.attention_processor import (
773792
AttnProcessor,
774793
AttnProcessor2_0,
775794
IPAdapterAttnProcessor,
776795
IPAdapterAttnProcessor2_0,
777796
)
778797

798+
if low_cpu_mem_usage:
799+
if is_accelerate_available():
800+
from accelerate import init_empty_weights
801+
802+
else:
803+
low_cpu_mem_usage = False
804+
logger.warning(
805+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
806+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
807+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
808+
" install accelerate\n```\n."
809+
)
810+
811+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
812+
raise NotImplementedError(
813+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
814+
" `low_cpu_mem_usage=False`."
815+
)
816+
779817
# set ip-adapter cross-attention processors & load state_dict
780818
attn_procs = {}
781819
key_id = 1
820+
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
782821
for name in self.attn_processors.keys():
783822
cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
784823
if name.startswith("mid_block"):
@@ -811,39 +850,49 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts):
811850
# IP-Adapter Plus
812851
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
813852

814-
attn_procs[name] = attn_processor_class(
815-
hidden_size=hidden_size,
816-
cross_attention_dim=cross_attention_dim,
817-
scale=1.0,
818-
num_tokens=num_image_text_embeds,
819-
).to(dtype=self.dtype, device=self.device)
853+
with init_context():
854+
attn_procs[name] = attn_processor_class(
855+
hidden_size=hidden_size,
856+
cross_attention_dim=cross_attention_dim,
857+
scale=1.0,
858+
num_tokens=num_image_text_embeds,
859+
)
820860

821861
value_dict = {}
822862
for i, state_dict in enumerate(state_dicts):
823863
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
824864
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
825865

826-
attn_procs[name].load_state_dict(value_dict)
866+
if not low_cpu_mem_usage:
867+
attn_procs[name].load_state_dict(value_dict)
868+
else:
869+
device = next(iter(value_dict.values())).device
870+
dtype = next(iter(value_dict.values())).dtype
871+
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
872+
827873
key_id += 2
828874

829875
return attn_procs
830876

831-
def _load_ip_adapter_weights(self, state_dicts):
877+
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
832878
if not isinstance(state_dicts, list):
833879
state_dicts = [state_dicts]
834880
# Set encoder_hid_proj after loading ip_adapter weights,
835881
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
836882
self.encoder_hid_proj = None
837883

838-
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts)
884+
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
839885
self.set_attn_processor(attn_procs)
840886

841887
# convert IP-Adapter Image Projection layers to diffusers
842888
image_projection_layers = []
843889
for state_dict in state_dicts:
844-
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
845-
image_projection_layer.to(device=self.device, dtype=self.dtype)
890+
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
891+
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
892+
)
846893
image_projection_layers.append(image_projection_layer)
847894

848895
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
849896
self.config.encoder_hid_dim_type = "ip_image_proj"
897+
898+
self.to(dtype=self.dtype, device=self.device)

0 commit comments

Comments
 (0)