Skip to content

Commit 68e9623

Browse files
Add converter method for ip adapters (#6150)
* Add converter method for ip adapters * Move converter method * Update to image proj converter --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 781775e commit 68e9623

File tree

1 file changed

+77
-98
lines changed

1 file changed

+77
-98
lines changed

src/diffusers/loaders/unet.py

Lines changed: 77 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15-
from collections import OrderedDict, defaultdict
15+
from collections import defaultdict
1616
from contextlib import nullcontext
1717
from typing import Callable, Dict, List, Optional, Union
1818

@@ -664,6 +664,80 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
664664
if hasattr(self, "peft_config"):
665665
self.peft_config.pop(adapter_name, None)
666666

667+
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
668+
updated_state_dict = {}
669+
image_projection = None
670+
671+
if "proj.weight" in state_dict:
672+
# IP-Adapter
673+
num_image_text_embeds = 4
674+
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
675+
cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
676+
677+
image_projection = ImageProjection(
678+
cross_attention_dim=cross_attention_dim,
679+
image_embed_dim=clip_embeddings_dim,
680+
num_image_text_embeds=num_image_text_embeds,
681+
)
682+
683+
for key, value in state_dict.items():
684+
diffusers_name = key.replace("proj", "image_embeds")
685+
updated_state_dict[diffusers_name] = value
686+
687+
elif "proj.3.weight" in state_dict:
688+
# IP-Adapter Full
689+
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
690+
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
691+
692+
image_projection = MLPProjection(
693+
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
694+
)
695+
696+
for key, value in state_dict.items():
697+
diffusers_name = key.replace("proj.0", "ff.net.0.proj")
698+
diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
699+
diffusers_name = diffusers_name.replace("proj.3", "norm")
700+
updated_state_dict[diffusers_name] = value
701+
702+
else:
703+
# IP-Adapter Plus
704+
num_image_text_embeds = state_dict["latents"].shape[1]
705+
embed_dims = state_dict["proj_in.weight"].shape[1]
706+
output_dims = state_dict["proj_out.weight"].shape[0]
707+
hidden_dims = state_dict["latents"].shape[2]
708+
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
709+
710+
image_projection = Resampler(
711+
embed_dims=embed_dims,
712+
output_dims=output_dims,
713+
hidden_dims=hidden_dims,
714+
heads=heads,
715+
num_queries=num_image_text_embeds,
716+
)
717+
718+
for key, value in state_dict.items():
719+
diffusers_name = key.replace("0.to", "2.to")
720+
diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
721+
diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
722+
diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
723+
diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")
724+
725+
if "norm1" in diffusers_name:
726+
updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
727+
elif "norm2" in diffusers_name:
728+
updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
729+
elif "to_kv" in diffusers_name:
730+
v_chunk = value.chunk(2, dim=0)
731+
updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
732+
updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
733+
elif "to_out" in diffusers_name:
734+
updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
735+
else:
736+
updated_state_dict[diffusers_name] = value
737+
738+
image_projection.load_state_dict(updated_state_dict)
739+
return image_projection
740+
667741
def _load_ip_adapter_weights(self, state_dict):
668742
from ..models.attention_processor import (
669743
AttnProcessor,
@@ -724,103 +798,8 @@ def _load_ip_adapter_weights(self, state_dict):
724798

725799
self.set_attn_processor(attn_procs)
726800

727-
# create image projection layers.
728-
if "proj.weight" in state_dict["image_proj"]:
729-
# IP-Adapter
730-
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
731-
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
732-
733-
image_projection = ImageProjection(
734-
cross_attention_dim=cross_attention_dim,
735-
image_embed_dim=clip_embeddings_dim,
736-
num_image_text_embeds=num_image_text_embeds,
737-
)
738-
image_projection.to(dtype=self.dtype, device=self.device)
739-
740-
# load image projection layer weights
741-
image_proj_state_dict = {}
742-
image_proj_state_dict.update(
743-
{
744-
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
745-
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
746-
"norm.weight": state_dict["image_proj"]["norm.weight"],
747-
"norm.bias": state_dict["image_proj"]["norm.bias"],
748-
}
749-
)
750-
image_projection.load_state_dict(image_proj_state_dict)
751-
del image_proj_state_dict
752-
753-
elif "proj.3.weight" in state_dict["image_proj"]:
754-
clip_embeddings_dim = state_dict["image_proj"]["proj.0.weight"].shape[0]
755-
cross_attention_dim = state_dict["image_proj"]["proj.3.weight"].shape[0]
756-
757-
image_projection = MLPProjection(
758-
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
759-
)
760-
image_projection.to(dtype=self.dtype, device=self.device)
761-
762-
# load image projection layer weights
763-
image_proj_state_dict = {}
764-
image_proj_state_dict.update(
765-
{
766-
"ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"],
767-
"ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"],
768-
"ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"],
769-
"ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"],
770-
"norm.weight": state_dict["image_proj"]["proj.3.weight"],
771-
"norm.bias": state_dict["image_proj"]["proj.3.bias"],
772-
}
773-
)
774-
image_projection.load_state_dict(image_proj_state_dict)
775-
del image_proj_state_dict
776-
777-
else:
778-
# IP-Adapter Plus
779-
embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1]
780-
output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0]
781-
hidden_dims = state_dict["image_proj"]["latents"].shape[2]
782-
heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64
783-
784-
image_projection = Resampler(
785-
embed_dims=embed_dims,
786-
output_dims=output_dims,
787-
hidden_dims=hidden_dims,
788-
heads=heads,
789-
num_queries=num_image_text_embeds,
790-
)
791-
792-
image_proj_state_dict = state_dict["image_proj"]
793-
794-
new_sd = OrderedDict()
795-
for k, v in image_proj_state_dict.items():
796-
if "0.to" in k:
797-
k = k.replace("0.to", "2.to")
798-
elif "1.0.weight" in k:
799-
k = k.replace("1.0.weight", "3.0.weight")
800-
elif "1.0.bias" in k:
801-
k = k.replace("1.0.bias", "3.0.bias")
802-
elif "1.1.weight" in k:
803-
k = k.replace("1.1.weight", "3.1.net.0.proj.weight")
804-
elif "1.3.weight" in k:
805-
k = k.replace("1.3.weight", "3.1.net.2.weight")
806-
807-
if "norm1" in k:
808-
new_sd[k.replace("0.norm1", "0")] = v
809-
elif "norm2" in k:
810-
new_sd[k.replace("0.norm2", "1")] = v
811-
elif "to_kv" in k:
812-
v_chunk = v.chunk(2, dim=0)
813-
new_sd[k.replace("to_kv", "to_k")] = v_chunk[0]
814-
new_sd[k.replace("to_kv", "to_v")] = v_chunk[1]
815-
elif "to_out" in k:
816-
new_sd[k.replace("to_out", "to_out.0")] = v
817-
else:
818-
new_sd[k] = v
819-
820-
image_projection.load_state_dict(new_sd)
821-
del image_proj_state_dict
801+
# convert IP-Adapter Image Projection layers to diffusers
802+
image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
822803

823804
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
824805
self.config.encoder_hid_dim_type = "ip_image_proj"
825-
826-
delete_adapter_layers

0 commit comments

Comments
 (0)