diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index dd307350d385..7401345d93b3 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -62,6 +62,7 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +TOTAL_EXAMPLE_KEYS = 5 TEXT_INVERSION_NAME = "learned_embeds.bin" TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" @@ -187,6 +188,7 @@ def map_from(module, state_dict, *args, **kwargs): class UNet2DConditionLoadersMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME + aux_state_dict_populated = None def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" @@ -1062,6 +1064,7 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet, state_dict_aux=Non if state_dict_aux: unet._load_lora_aux(state_dict_aux, network_alpha=network_alpha) + unet.aux_state_dict_populated = True @classmethod def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0, state_dict_aux=None): @@ -1314,9 +1317,12 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): unet_state_dict_aux = {} te_state_dict_aux = {} network_alpha = None + unloaded_keys = [] for key, value in state_dict.items(): - if "lora_down" in key: + if "hada" in key or "skip" in key: + unloaded_keys.append(key) + elif "lora_down" in key: lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" lora_name_alpha = lora_name + ".alpha" @@ -1351,6 +1357,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): elif any(key in diffusers_name for key in ("proj_in", "proj_out")): unet_state_dict_aux[diffusers_name] = value unet_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif lora_name.startswith("lora_te_"): diffusers_name = key.replace("lora_te_", "").replace("_", ".") diffusers_name = diffusers_name.replace("text.model", "text_model") @@ -1366,6 +1373,13 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): te_state_dict_aux[diffusers_name] = value te_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + logger.info("Kohya-style checkpoint detected.") + if len(unloaded_keys) > 0: + example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS]) + logger.warning( + f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for." + ) + unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} new_state_dict = {**unet_state_dict, **te_state_dict} @@ -1400,6 +1414,12 @@ def unload_lora_weights(self): else: self.unet.set_default_attn_processor() + if self.unet.aux_state_dict_populated: + for _, module in self.unet.named_modules(): + if hasattr(module, "old_forward") and module.old_forward is not None: + module.forward = module.old_forward + self.unet.aux_state_dict_populated = False + # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 4949e3c082be..78ab03081fc5 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -87,11 +87,13 @@ class Conv2dWithLoRA(nn.Conv2d): def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs): super().__init__(*args, **kwargs) self.lora_layer = lora_layer + self.old_forward = None def forward(self, x): if self.lora_layer is None: return super().forward(x) else: + self.old_forward = super().forward return super().forward(x) + self.lora_layer(x) @@ -103,9 +105,11 @@ class LinearWithLoRA(nn.Linear): def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): super().__init__(*args, **kwargs) self.lora_layer = lora_layer + self.old_forward = None def forward(self, x): if self.lora_layer is None: return super().forward(x) else: + self.old_forward = super().forward return super().forward(x) + self.lora_layer(x) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 1396561367e0..a378429f75ab 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -554,7 +554,7 @@ def test_a1111(self): images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.3743, 0.3893, 0.3835, 0.3891, 0.3949, 0.3649, 0.3858, 0.3802, 0.3245]) + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) self.assertTrue(np.allclose(images, expected, atol=1e-4)) @@ -594,6 +594,7 @@ def test_unload_lora(self): lora_filename = "Colored_Icons_by_vizsumit.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + generator = torch.manual_seed(0) lora_images = pipe( prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps ).images