|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import os |
15 | | -from collections import OrderedDict, defaultdict |
| 15 | +from collections import defaultdict |
16 | 16 | from contextlib import nullcontext |
17 | 17 | from typing import Callable, Dict, List, Optional, Union |
18 | 18 |
|
@@ -664,6 +664,80 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): |
664 | 664 | if hasattr(self, "peft_config"): |
665 | 665 | self.peft_config.pop(adapter_name, None) |
666 | 666 |
|
| 667 | + def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): |
| 668 | + updated_state_dict = {} |
| 669 | + image_projection = None |
| 670 | + |
| 671 | + if "proj.weight" in state_dict: |
| 672 | + # IP-Adapter |
| 673 | + num_image_text_embeds = 4 |
| 674 | + clip_embeddings_dim = state_dict["proj.weight"].shape[-1] |
| 675 | + cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 |
| 676 | + |
| 677 | + image_projection = ImageProjection( |
| 678 | + cross_attention_dim=cross_attention_dim, |
| 679 | + image_embed_dim=clip_embeddings_dim, |
| 680 | + num_image_text_embeds=num_image_text_embeds, |
| 681 | + ) |
| 682 | + |
| 683 | + for key, value in state_dict.items(): |
| 684 | + diffusers_name = key.replace("proj", "image_embeds") |
| 685 | + updated_state_dict[diffusers_name] = value |
| 686 | + |
| 687 | + elif "proj.3.weight" in state_dict: |
| 688 | + # IP-Adapter Full |
| 689 | + clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] |
| 690 | + cross_attention_dim = state_dict["proj.3.weight"].shape[0] |
| 691 | + |
| 692 | + image_projection = MLPProjection( |
| 693 | + cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim |
| 694 | + ) |
| 695 | + |
| 696 | + for key, value in state_dict.items(): |
| 697 | + diffusers_name = key.replace("proj.0", "ff.net.0.proj") |
| 698 | + diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") |
| 699 | + diffusers_name = diffusers_name.replace("proj.3", "norm") |
| 700 | + updated_state_dict[diffusers_name] = value |
| 701 | + |
| 702 | + else: |
| 703 | + # IP-Adapter Plus |
| 704 | + num_image_text_embeds = state_dict["latents"].shape[1] |
| 705 | + embed_dims = state_dict["proj_in.weight"].shape[1] |
| 706 | + output_dims = state_dict["proj_out.weight"].shape[0] |
| 707 | + hidden_dims = state_dict["latents"].shape[2] |
| 708 | + heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64 |
| 709 | + |
| 710 | + image_projection = Resampler( |
| 711 | + embed_dims=embed_dims, |
| 712 | + output_dims=output_dims, |
| 713 | + hidden_dims=hidden_dims, |
| 714 | + heads=heads, |
| 715 | + num_queries=num_image_text_embeds, |
| 716 | + ) |
| 717 | + |
| 718 | + for key, value in state_dict.items(): |
| 719 | + diffusers_name = key.replace("0.to", "2.to") |
| 720 | + diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight") |
| 721 | + diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias") |
| 722 | + diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight") |
| 723 | + diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight") |
| 724 | + |
| 725 | + if "norm1" in diffusers_name: |
| 726 | + updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value |
| 727 | + elif "norm2" in diffusers_name: |
| 728 | + updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value |
| 729 | + elif "to_kv" in diffusers_name: |
| 730 | + v_chunk = value.chunk(2, dim=0) |
| 731 | + updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] |
| 732 | + updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] |
| 733 | + elif "to_out" in diffusers_name: |
| 734 | + updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value |
| 735 | + else: |
| 736 | + updated_state_dict[diffusers_name] = value |
| 737 | + |
| 738 | + image_projection.load_state_dict(updated_state_dict) |
| 739 | + return image_projection |
| 740 | + |
667 | 741 | def _load_ip_adapter_weights(self, state_dict): |
668 | 742 | from ..models.attention_processor import ( |
669 | 743 | AttnProcessor, |
@@ -724,103 +798,8 @@ def _load_ip_adapter_weights(self, state_dict): |
724 | 798 |
|
725 | 799 | self.set_attn_processor(attn_procs) |
726 | 800 |
|
727 | | - # create image projection layers. |
728 | | - if "proj.weight" in state_dict["image_proj"]: |
729 | | - # IP-Adapter |
730 | | - clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] |
731 | | - cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 |
732 | | - |
733 | | - image_projection = ImageProjection( |
734 | | - cross_attention_dim=cross_attention_dim, |
735 | | - image_embed_dim=clip_embeddings_dim, |
736 | | - num_image_text_embeds=num_image_text_embeds, |
737 | | - ) |
738 | | - image_projection.to(dtype=self.dtype, device=self.device) |
739 | | - |
740 | | - # load image projection layer weights |
741 | | - image_proj_state_dict = {} |
742 | | - image_proj_state_dict.update( |
743 | | - { |
744 | | - "image_embeds.weight": state_dict["image_proj"]["proj.weight"], |
745 | | - "image_embeds.bias": state_dict["image_proj"]["proj.bias"], |
746 | | - "norm.weight": state_dict["image_proj"]["norm.weight"], |
747 | | - "norm.bias": state_dict["image_proj"]["norm.bias"], |
748 | | - } |
749 | | - ) |
750 | | - image_projection.load_state_dict(image_proj_state_dict) |
751 | | - del image_proj_state_dict |
752 | | - |
753 | | - elif "proj.3.weight" in state_dict["image_proj"]: |
754 | | - clip_embeddings_dim = state_dict["image_proj"]["proj.0.weight"].shape[0] |
755 | | - cross_attention_dim = state_dict["image_proj"]["proj.3.weight"].shape[0] |
756 | | - |
757 | | - image_projection = MLPProjection( |
758 | | - cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim |
759 | | - ) |
760 | | - image_projection.to(dtype=self.dtype, device=self.device) |
761 | | - |
762 | | - # load image projection layer weights |
763 | | - image_proj_state_dict = {} |
764 | | - image_proj_state_dict.update( |
765 | | - { |
766 | | - "ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"], |
767 | | - "ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"], |
768 | | - "ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"], |
769 | | - "ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"], |
770 | | - "norm.weight": state_dict["image_proj"]["proj.3.weight"], |
771 | | - "norm.bias": state_dict["image_proj"]["proj.3.bias"], |
772 | | - } |
773 | | - ) |
774 | | - image_projection.load_state_dict(image_proj_state_dict) |
775 | | - del image_proj_state_dict |
776 | | - |
777 | | - else: |
778 | | - # IP-Adapter Plus |
779 | | - embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1] |
780 | | - output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0] |
781 | | - hidden_dims = state_dict["image_proj"]["latents"].shape[2] |
782 | | - heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 |
783 | | - |
784 | | - image_projection = Resampler( |
785 | | - embed_dims=embed_dims, |
786 | | - output_dims=output_dims, |
787 | | - hidden_dims=hidden_dims, |
788 | | - heads=heads, |
789 | | - num_queries=num_image_text_embeds, |
790 | | - ) |
791 | | - |
792 | | - image_proj_state_dict = state_dict["image_proj"] |
793 | | - |
794 | | - new_sd = OrderedDict() |
795 | | - for k, v in image_proj_state_dict.items(): |
796 | | - if "0.to" in k: |
797 | | - k = k.replace("0.to", "2.to") |
798 | | - elif "1.0.weight" in k: |
799 | | - k = k.replace("1.0.weight", "3.0.weight") |
800 | | - elif "1.0.bias" in k: |
801 | | - k = k.replace("1.0.bias", "3.0.bias") |
802 | | - elif "1.1.weight" in k: |
803 | | - k = k.replace("1.1.weight", "3.1.net.0.proj.weight") |
804 | | - elif "1.3.weight" in k: |
805 | | - k = k.replace("1.3.weight", "3.1.net.2.weight") |
806 | | - |
807 | | - if "norm1" in k: |
808 | | - new_sd[k.replace("0.norm1", "0")] = v |
809 | | - elif "norm2" in k: |
810 | | - new_sd[k.replace("0.norm2", "1")] = v |
811 | | - elif "to_kv" in k: |
812 | | - v_chunk = v.chunk(2, dim=0) |
813 | | - new_sd[k.replace("to_kv", "to_k")] = v_chunk[0] |
814 | | - new_sd[k.replace("to_kv", "to_v")] = v_chunk[1] |
815 | | - elif "to_out" in k: |
816 | | - new_sd[k.replace("to_out", "to_out.0")] = v |
817 | | - else: |
818 | | - new_sd[k] = v |
819 | | - |
820 | | - image_projection.load_state_dict(new_sd) |
821 | | - del image_proj_state_dict |
| 801 | + # convert IP-Adapter Image Projection layers to diffusers |
| 802 | + image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) |
822 | 803 |
|
823 | 804 | self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) |
824 | 805 | self.config.encoder_hid_dim_type = "ip_image_proj" |
825 | | - |
826 | | - delete_adapter_layers |
0 commit comments