34
34
from .utils import (
35
35
DIFFUSERS_CACHE ,
36
36
HF_HUB_OFFLINE ,
37
- TEXT_ENCODER_TARGET_MODULES ,
37
+ TEXT_ENCODER_ATTN_MODULE ,
38
38
_get_model_file ,
39
39
deprecate ,
40
40
is_safetensors_available ,
@@ -955,6 +955,19 @@ def text_encoder_lora_attn_procs(self):
955
955
return self ._text_encoder_lora_attn_procs
956
956
return
957
957
958
+ def _remove_text_encoder_monkey_patch (self ):
959
+ # Loop over the CLIPAttention module of text_encoder
960
+ for name , attn_module in self .text_encoder .named_modules ():
961
+ if name .endswith (TEXT_ENCODER_ATTN_MODULE ):
962
+ # Loop over the LoRA layers
963
+ for _ , text_encoder_attr in self ._lora_attn_processor_attr_to_text_encoder_attr .items ():
964
+ # Retrieve the q/k/v/out projection of CLIPAttention
965
+ module = attn_module .get_submodule (text_encoder_attr )
966
+ if hasattr (module , "old_forward" ):
967
+ # restore original `forward` to remove monkey-patch
968
+ module .forward = module .old_forward
969
+ delattr (module , "old_forward" )
970
+
958
971
def _modify_text_encoder (self , attn_processors : Dict [str , LoRAAttnProcessor ]):
959
972
r"""
960
973
Monkey-patches the forward passes of attention modules of the text encoder.
@@ -963,37 +976,41 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
963
976
attn_processors: Dict[str, `LoRAAttnProcessor`]:
964
977
A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
965
978
"""
966
- # Loop over the original attention modules.
967
- for name , _ in self .text_encoder .named_modules ():
968
- if any (x in name for x in TEXT_ENCODER_TARGET_MODULES ):
969
- # Retrieve the module and its corresponding LoRA processor.
970
- module = self .text_encoder .get_submodule (name )
971
- # Construct a new function that performs the LoRA merging. We will monkey patch
972
- # this forward pass.
973
- attn_processor_name = "." .join (name .split ("." )[:- 1 ])
974
- lora_layer = getattr (attn_processors [attn_processor_name ], self ._get_lora_layer_attribute (name ))
975
- old_forward = module .forward
976
-
977
- # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
978
- # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
979
- def make_new_forward (old_forward , lora_layer ):
980
- def new_forward (x ):
981
- return old_forward (x ) + lora_layer (x )
982
-
983
- return new_forward
984
-
985
- # Monkey-patch.
986
- module .forward = make_new_forward (old_forward , lora_layer )
987
-
988
- def _get_lora_layer_attribute (self , name : str ) -> str :
989
- if "q_proj" in name :
990
- return "to_q_lora"
991
- elif "v_proj" in name :
992
- return "to_v_lora"
993
- elif "k_proj" in name :
994
- return "to_k_lora"
995
- else :
996
- return "to_out_lora"
979
+
980
+ # First, remove any monkey-patch that might have been applied before
981
+ self ._remove_text_encoder_monkey_patch ()
982
+
983
+ # Loop over the CLIPAttention module of text_encoder
984
+ for name , attn_module in self .text_encoder .named_modules ():
985
+ if name .endswith (TEXT_ENCODER_ATTN_MODULE ):
986
+ # Loop over the LoRA layers
987
+ for attn_proc_attr , text_encoder_attr in self ._lora_attn_processor_attr_to_text_encoder_attr .items ():
988
+ # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
989
+ module = attn_module .get_submodule (text_encoder_attr )
990
+ lora_layer = attn_processors [name ].get_submodule (attn_proc_attr )
991
+
992
+ # save old_forward to module that can be used to remove monkey-patch
993
+ old_forward = module .old_forward = module .forward
994
+
995
+ # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
996
+ # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
997
+ def make_new_forward (old_forward , lora_layer ):
998
+ def new_forward (x ):
999
+ return old_forward (x ) + lora_layer (x )
1000
+
1001
+ return new_forward
1002
+
1003
+ # Monkey-patch.
1004
+ module .forward = make_new_forward (old_forward , lora_layer )
1005
+
1006
+ @property
1007
+ def _lora_attn_processor_attr_to_text_encoder_attr (self ):
1008
+ return {
1009
+ "to_q_lora" : "q_proj" ,
1010
+ "to_k_lora" : "k_proj" ,
1011
+ "to_v_lora" : "v_proj" ,
1012
+ "to_out_lora" : "out_proj" ,
1013
+ }
997
1014
998
1015
def _load_text_encoder_attn_procs (
999
1016
self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], ** kwargs
0 commit comments