@@ -912,10 +912,10 @@ def pack_weights(layers, prefix):
912912 )
913913
914914 if unet_lora_layers :
915- state_dict .update (pack_weights (unet_lora_layers , cls . unet_name ))
915+ state_dict .update (pack_weights (unet_lora_layers , "unet" ))
916916
917917 if text_encoder_lora_layers :
918- state_dict .update (pack_weights (text_encoder_lora_layers , cls . text_encoder_name ))
918+ state_dict .update (pack_weights (text_encoder_lora_layers , "text_encoder" ))
919919
920920 if transformer_lora_layers :
921921 state_dict .update (pack_weights (transformer_lora_layers , "transformer" ))
@@ -975,22 +975,20 @@ def unload_lora_weights(self):
975975 >>> ...
976976 ```
977977 """
978- unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
979-
980978 if not USE_PEFT_BACKEND :
981979 if version .parse (__version__ ) > version .parse ("0.23" ):
982980 logger .warn (
983981 "You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
984982 "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
985983 )
986984
987- for _ , module in unet .named_modules ():
985+ for _ , module in self . unet .named_modules ():
988986 if hasattr (module , "set_lora_layer" ):
989987 module .set_lora_layer (None )
990988 else :
991- recurse_remove_peft_layers (unet )
992- if hasattr (unet , "peft_config" ):
993- del unet .peft_config
989+ recurse_remove_peft_layers (self . unet )
990+ if hasattr (self . unet , "peft_config" ):
991+ del self . unet .peft_config
994992
995993 # Safe to call the following regardless of LoRA.
996994 self ._remove_text_encoder_monkey_patch ()
@@ -1029,8 +1027,7 @@ def fuse_lora(
10291027 )
10301028
10311029 if fuse_unet :
1032- unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1033- unet .fuse_lora (lora_scale , safe_fusing = safe_fusing )
1030+ self .unet .fuse_lora (lora_scale , safe_fusing = safe_fusing )
10341031
10351032 if USE_PEFT_BACKEND :
10361033 from peft .tuners .tuners_utils import BaseTunerLayer
@@ -1083,14 +1080,13 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True
10831080 Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
10841081 LoRA parameters then it won't have any effect.
10851082 """
1086- unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
10871083 if unfuse_unet :
10881084 if not USE_PEFT_BACKEND :
1089- unet .unfuse_lora ()
1085+ self . unet .unfuse_lora ()
10901086 else :
10911087 from peft .tuners .tuners_utils import BaseTunerLayer
10921088
1093- for module in unet .modules ():
1089+ for module in self . unet .modules ():
10941090 if isinstance (module , BaseTunerLayer ):
10951091 module .unmerge ()
10961092
@@ -1206,9 +1202,8 @@ def set_adapters(
12061202 adapter_names : Union [List [str ], str ],
12071203 adapter_weights : Optional [List [float ]] = None ,
12081204 ):
1209- unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
12101205 # Handle the UNET
1211- unet .set_adapters (adapter_names , adapter_weights )
1206+ self . unet .set_adapters (adapter_names , adapter_weights )
12121207
12131208 # Handle the Text Encoder
12141209 if hasattr (self , "text_encoder" ):
@@ -1221,8 +1216,7 @@ def disable_lora(self):
12211216 raise ValueError ("PEFT backend is required for this method." )
12221217
12231218 # Disable unet adapters
1224- unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1225- unet .disable_lora ()
1219+ self .unet .disable_lora ()
12261220
12271221 # Disable text encoder adapters
12281222 if hasattr (self , "text_encoder" ):
@@ -1235,8 +1229,7 @@ def enable_lora(self):
12351229 raise ValueError ("PEFT backend is required for this method." )
12361230
12371231 # Enable unet adapters
1238- unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1239- unet .enable_lora ()
1232+ self .unet .enable_lora ()
12401233
12411234 # Enable text encoder adapters
12421235 if hasattr (self , "text_encoder" ):
@@ -1258,8 +1251,7 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
12581251 adapter_names = [adapter_names ]
12591252
12601253 # Delete unet adapters
1261- unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1262- unet .delete_adapters (adapter_names )
1254+ self .unet .delete_adapters (adapter_names )
12631255
12641256 for adapter_name in adapter_names :
12651257 # Delete text encoder adapters
@@ -1292,8 +1284,8 @@ def get_active_adapters(self) -> List[str]:
12921284 from peft .tuners .tuners_utils import BaseTunerLayer
12931285
12941286 active_adapters = []
1295- unet = getattr ( self , self . unet_name ) if not hasattr ( self , "unet" ) else self . unet
1296- for module in unet .modules ():
1287+
1288+ for module in self . unet .modules ():
12971289 if isinstance (module , BaseTunerLayer ):
12981290 active_adapters = module .active_adapters
12991291 break
@@ -1317,9 +1309,8 @@ def get_list_adapters(self) -> Dict[str, List[str]]:
13171309 if hasattr (self , "text_encoder_2" ) and hasattr (self .text_encoder_2 , "peft_config" ):
13181310 set_adapters ["text_encoder_2" ] = list (self .text_encoder_2 .peft_config .keys ())
13191311
1320- unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1321- if hasattr (self , self .unet_name ) and hasattr (unet , "peft_config" ):
1322- set_adapters [self .unet_name ] = list (self .unet .peft_config .keys ())
1312+ if hasattr (self , "unet" ) and hasattr (self .unet , "peft_config" ):
1313+ set_adapters ["unet" ] = list (self .unet .peft_config .keys ())
13231314
13241315 return set_adapters
13251316
@@ -1340,8 +1331,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
13401331 from peft .tuners .tuners_utils import BaseTunerLayer
13411332
13421333 # Handle the UNET
1343- unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1344- for unet_module in unet .modules ():
1334+ for unet_module in self .unet .modules ():
13451335 if isinstance (unet_module , BaseTunerLayer ):
13461336 for adapter_name in adapter_names :
13471337 unet_module .lora_A [adapter_name ].to (device )
0 commit comments