Skip to content

Commit fd64acf

Browse files
committed
Revert "[Peft] fix saving / loading when unet is not "unet" (#6046)"
This reverts commit 4c7e983.
1 parent f645b87 commit fd64acf

File tree

2 files changed

+20
-32
lines changed

2 files changed

+20
-32
lines changed

src/diffusers/loaders/ip_adapter.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,9 @@ def load_ip_adapter(
149149
self.feature_extractor = CLIPImageProcessor()
150150

151151
# load ip-adapter into unet
152-
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
153-
unet._load_ip_adapter_weights(state_dict)
152+
self.unet._load_ip_adapter_weights(state_dict)
154153

155154
def set_ip_adapter_scale(self, scale):
156-
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
157-
for attn_processor in unet.attn_processors.values():
155+
for attn_processor in self.unet.attn_processors.values():
158156
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
159157
attn_processor.scale = scale

src/diffusers/loaders/lora.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -912,10 +912,10 @@ def pack_weights(layers, prefix):
912912
)
913913

914914
if unet_lora_layers:
915-
state_dict.update(pack_weights(unet_lora_layers, cls.unet_name))
915+
state_dict.update(pack_weights(unet_lora_layers, "unet"))
916916

917917
if text_encoder_lora_layers:
918-
state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
918+
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
919919

920920
if transformer_lora_layers:
921921
state_dict.update(pack_weights(transformer_lora_layers, "transformer"))
@@ -975,22 +975,20 @@ def unload_lora_weights(self):
975975
>>> ...
976976
```
977977
"""
978-
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
979-
980978
if not USE_PEFT_BACKEND:
981979
if version.parse(__version__) > version.parse("0.23"):
982980
logger.warn(
983981
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
984982
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
985983
)
986984

987-
for _, module in unet.named_modules():
985+
for _, module in self.unet.named_modules():
988986
if hasattr(module, "set_lora_layer"):
989987
module.set_lora_layer(None)
990988
else:
991-
recurse_remove_peft_layers(unet)
992-
if hasattr(unet, "peft_config"):
993-
del unet.peft_config
989+
recurse_remove_peft_layers(self.unet)
990+
if hasattr(self.unet, "peft_config"):
991+
del self.unet.peft_config
994992

995993
# Safe to call the following regardless of LoRA.
996994
self._remove_text_encoder_monkey_patch()
@@ -1029,8 +1027,7 @@ def fuse_lora(
10291027
)
10301028

10311029
if fuse_unet:
1032-
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1033-
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
1030+
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
10341031

10351032
if USE_PEFT_BACKEND:
10361033
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -1083,14 +1080,13 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True
10831080
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
10841081
LoRA parameters then it won't have any effect.
10851082
"""
1086-
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
10871083
if unfuse_unet:
10881084
if not USE_PEFT_BACKEND:
1089-
unet.unfuse_lora()
1085+
self.unet.unfuse_lora()
10901086
else:
10911087
from peft.tuners.tuners_utils import BaseTunerLayer
10921088

1093-
for module in unet.modules():
1089+
for module in self.unet.modules():
10941090
if isinstance(module, BaseTunerLayer):
10951091
module.unmerge()
10961092

@@ -1206,9 +1202,8 @@ def set_adapters(
12061202
adapter_names: Union[List[str], str],
12071203
adapter_weights: Optional[List[float]] = None,
12081204
):
1209-
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
12101205
# Handle the UNET
1211-
unet.set_adapters(adapter_names, adapter_weights)
1206+
self.unet.set_adapters(adapter_names, adapter_weights)
12121207

12131208
# Handle the Text Encoder
12141209
if hasattr(self, "text_encoder"):
@@ -1221,8 +1216,7 @@ def disable_lora(self):
12211216
raise ValueError("PEFT backend is required for this method.")
12221217

12231218
# Disable unet adapters
1224-
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1225-
unet.disable_lora()
1219+
self.unet.disable_lora()
12261220

12271221
# Disable text encoder adapters
12281222
if hasattr(self, "text_encoder"):
@@ -1235,8 +1229,7 @@ def enable_lora(self):
12351229
raise ValueError("PEFT backend is required for this method.")
12361230

12371231
# Enable unet adapters
1238-
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1239-
unet.enable_lora()
1232+
self.unet.enable_lora()
12401233

12411234
# Enable text encoder adapters
12421235
if hasattr(self, "text_encoder"):
@@ -1258,8 +1251,7 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
12581251
adapter_names = [adapter_names]
12591252

12601253
# Delete unet adapters
1261-
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1262-
unet.delete_adapters(adapter_names)
1254+
self.unet.delete_adapters(adapter_names)
12631255

12641256
for adapter_name in adapter_names:
12651257
# Delete text encoder adapters
@@ -1292,8 +1284,8 @@ def get_active_adapters(self) -> List[str]:
12921284
from peft.tuners.tuners_utils import BaseTunerLayer
12931285

12941286
active_adapters = []
1295-
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1296-
for module in unet.modules():
1287+
1288+
for module in self.unet.modules():
12971289
if isinstance(module, BaseTunerLayer):
12981290
active_adapters = module.active_adapters
12991291
break
@@ -1317,9 +1309,8 @@ def get_list_adapters(self) -> Dict[str, List[str]]:
13171309
if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
13181310
set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())
13191311

1320-
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1321-
if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"):
1322-
set_adapters[self.unet_name] = list(self.unet.peft_config.keys())
1312+
if hasattr(self, "unet") and hasattr(self.unet, "peft_config"):
1313+
set_adapters["unet"] = list(self.unet.peft_config.keys())
13231314

13241315
return set_adapters
13251316

@@ -1340,8 +1331,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
13401331
from peft.tuners.tuners_utils import BaseTunerLayer
13411332

13421333
# Handle the UNET
1343-
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1344-
for unet_module in unet.modules():
1334+
for unet_module in self.unet.modules():
13451335
if isinstance(unet_module, BaseTunerLayer):
13461336
for adapter_name in adapter_names:
13471337
unet_module.lora_A[adapter_name].to(device)

0 commit comments

Comments
 (0)