3737 _get_model_file ,
3838 delete_adapter_layers ,
3939 is_accelerate_available ,
40+ is_torch_version ,
4041 logging ,
4142 set_adapter_layers ,
4243 set_weights_and_activate_adapters ,
@@ -168,15 +169,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
168169 "framework" : "pytorch" ,
169170 }
170171
171- if low_cpu_mem_usage and not is_accelerate_available ():
172- low_cpu_mem_usage = False
173- logger .warning (
174- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
175- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
176- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip"
177- " install accelerate\n ```\n ."
178- )
179-
180172 model_file = None
181173 if not isinstance (pretrained_model_name_or_path_or_dict , dict ):
182174 # Let's first try to load .safetensors weights
@@ -694,21 +686,42 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
694686 if hasattr (self , "peft_config" ):
695687 self .peft_config .pop (adapter_name , None )
696688
697- def _convert_ip_adapter_image_proj_to_diffusers (self , state_dict ):
689+ def _convert_ip_adapter_image_proj_to_diffusers (self , state_dict , low_cpu_mem_usage = False ):
690+ if low_cpu_mem_usage :
691+ if is_accelerate_available ():
692+ from accelerate import init_empty_weights
693+
694+ else :
695+ low_cpu_mem_usage = False
696+ logger .warning (
697+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
698+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
699+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip"
700+ " install accelerate\n ```\n ."
701+ )
702+
703+ if low_cpu_mem_usage is True and not is_torch_version (">=" , "1.9.0" ):
704+ raise NotImplementedError (
705+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
706+ " `low_cpu_mem_usage=False`."
707+ )
708+
698709 updated_state_dict = {}
699710 image_projection = None
711+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
700712
701713 if "proj.weight" in state_dict :
702714 # IP-Adapter
703715 num_image_text_embeds = 4
704716 clip_embeddings_dim = state_dict ["proj.weight" ].shape [- 1 ]
705717 cross_attention_dim = state_dict ["proj.weight" ].shape [0 ] // 4
706718
707- image_projection = ImageProjection (
708- cross_attention_dim = cross_attention_dim ,
709- image_embed_dim = clip_embeddings_dim ,
710- num_image_text_embeds = num_image_text_embeds ,
711- )
719+ with init_context ():
720+ image_projection = ImageProjection (
721+ cross_attention_dim = cross_attention_dim ,
722+ image_embed_dim = clip_embeddings_dim ,
723+ num_image_text_embeds = num_image_text_embeds ,
724+ )
712725
713726 for key , value in state_dict .items ():
714727 diffusers_name = key .replace ("proj" , "image_embeds" )
@@ -719,9 +732,10 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
719732 clip_embeddings_dim = state_dict ["proj.0.weight" ].shape [0 ]
720733 cross_attention_dim = state_dict ["proj.3.weight" ].shape [0 ]
721734
722- image_projection = IPAdapterFullImageProjection (
723- cross_attention_dim = cross_attention_dim , image_embed_dim = clip_embeddings_dim
724- )
735+ with init_context ():
736+ image_projection = IPAdapterFullImageProjection (
737+ cross_attention_dim = cross_attention_dim , image_embed_dim = clip_embeddings_dim
738+ )
725739
726740 for key , value in state_dict .items ():
727741 diffusers_name = key .replace ("proj.0" , "ff.net.0.proj" )
@@ -737,13 +751,14 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
737751 hidden_dims = state_dict ["latents" ].shape [2 ]
738752 heads = state_dict ["layers.0.0.to_q.weight" ].shape [0 ] // 64
739753
740- image_projection = IPAdapterPlusImageProjection (
741- embed_dims = embed_dims ,
742- output_dims = output_dims ,
743- hidden_dims = hidden_dims ,
744- heads = heads ,
745- num_queries = num_image_text_embeds ,
746- )
754+ with init_context ():
755+ image_projection = IPAdapterPlusImageProjection (
756+ embed_dims = embed_dims ,
757+ output_dims = output_dims ,
758+ hidden_dims = hidden_dims ,
759+ heads = heads ,
760+ num_queries = num_image_text_embeds ,
761+ )
747762
748763 for key , value in state_dict .items ():
749764 diffusers_name = key .replace ("0.to" , "2.to" )
@@ -765,20 +780,44 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
765780 else :
766781 updated_state_dict [diffusers_name ] = value
767782
768- image_projection .load_state_dict (updated_state_dict )
783+ if not low_cpu_mem_usage :
784+ image_projection .load_state_dict (updated_state_dict )
785+ else :
786+ load_model_dict_into_meta (image_projection , updated_state_dict , device = self .device , dtype = self .dtype )
787+
769788 return image_projection
770789
771- def _convert_ip_adapter_attn_to_diffusers (self , state_dicts ):
790+ def _convert_ip_adapter_attn_to_diffusers (self , state_dicts , low_cpu_mem_usage = False ):
772791 from ..models .attention_processor import (
773792 AttnProcessor ,
774793 AttnProcessor2_0 ,
775794 IPAdapterAttnProcessor ,
776795 IPAdapterAttnProcessor2_0 ,
777796 )
778797
798+ if low_cpu_mem_usage :
799+ if is_accelerate_available ():
800+ from accelerate import init_empty_weights
801+
802+ else :
803+ low_cpu_mem_usage = False
804+ logger .warning (
805+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
806+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
807+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip"
808+ " install accelerate\n ```\n ."
809+ )
810+
811+ if low_cpu_mem_usage is True and not is_torch_version (">=" , "1.9.0" ):
812+ raise NotImplementedError (
813+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
814+ " `low_cpu_mem_usage=False`."
815+ )
816+
779817 # set ip-adapter cross-attention processors & load state_dict
780818 attn_procs = {}
781819 key_id = 1
820+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
782821 for name in self .attn_processors .keys ():
783822 cross_attention_dim = None if name .endswith ("attn1.processor" ) else self .config .cross_attention_dim
784823 if name .startswith ("mid_block" ):
@@ -811,39 +850,49 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts):
811850 # IP-Adapter Plus
812851 num_image_text_embeds += [state_dict ["image_proj" ]["latents" ].shape [1 ]]
813852
814- attn_procs [name ] = attn_processor_class (
815- hidden_size = hidden_size ,
816- cross_attention_dim = cross_attention_dim ,
817- scale = 1.0 ,
818- num_tokens = num_image_text_embeds ,
819- ).to (dtype = self .dtype , device = self .device )
853+ with init_context ():
854+ attn_procs [name ] = attn_processor_class (
855+ hidden_size = hidden_size ,
856+ cross_attention_dim = cross_attention_dim ,
857+ scale = 1.0 ,
858+ num_tokens = num_image_text_embeds ,
859+ )
820860
821861 value_dict = {}
822862 for i , state_dict in enumerate (state_dicts ):
823863 value_dict .update ({f"to_k_ip.{ i } .weight" : state_dict ["ip_adapter" ][f"{ key_id } .to_k_ip.weight" ]})
824864 value_dict .update ({f"to_v_ip.{ i } .weight" : state_dict ["ip_adapter" ][f"{ key_id } .to_v_ip.weight" ]})
825865
826- attn_procs [name ].load_state_dict (value_dict )
866+ if not low_cpu_mem_usage :
867+ attn_procs [name ].load_state_dict (value_dict )
868+ else :
869+ device = next (iter (value_dict .values ())).device
870+ dtype = next (iter (value_dict .values ())).dtype
871+ load_model_dict_into_meta (attn_procs [name ], value_dict , device = device , dtype = dtype )
872+
827873 key_id += 2
828874
829875 return attn_procs
830876
831- def _load_ip_adapter_weights (self , state_dicts ):
877+ def _load_ip_adapter_weights (self , state_dicts , low_cpu_mem_usage = False ):
832878 if not isinstance (state_dicts , list ):
833879 state_dicts = [state_dicts ]
834880 # Set encoder_hid_proj after loading ip_adapter weights,
835881 # because `IPAdapterPlusImageProjection` also has `attn_processors`.
836882 self .encoder_hid_proj = None
837883
838- attn_procs = self ._convert_ip_adapter_attn_to_diffusers (state_dicts )
884+ attn_procs = self ._convert_ip_adapter_attn_to_diffusers (state_dicts , low_cpu_mem_usage = low_cpu_mem_usage )
839885 self .set_attn_processor (attn_procs )
840886
841887 # convert IP-Adapter Image Projection layers to diffusers
842888 image_projection_layers = []
843889 for state_dict in state_dicts :
844- image_projection_layer = self ._convert_ip_adapter_image_proj_to_diffusers (state_dict ["image_proj" ])
845- image_projection_layer .to (device = self .device , dtype = self .dtype )
890+ image_projection_layer = self ._convert_ip_adapter_image_proj_to_diffusers (
891+ state_dict ["image_proj" ], low_cpu_mem_usage = low_cpu_mem_usage
892+ )
846893 image_projection_layers .append (image_projection_layer )
847894
848895 self .encoder_hid_proj = MultiIPAdapterImageProjection (image_projection_layers )
849896 self .config .encoder_hid_dim_type = "ip_image_proj"
897+
898+ self .to (dtype = self .dtype , device = self .device )
0 commit comments