diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 6ebf67c4f2a7..dde717959f8e 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -391,6 +391,10 @@ def load_lora_into_unet( # their prefixes. keys = list(state_dict.keys()) + if all(key.startswith("unet.unet") for key in keys): + deprecation_message = "Keys starting with 'unet.unet' are deprecated." + deprecate("unet.unet keys", "0.27", deprecation_message) + if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") @@ -407,8 +411,9 @@ def load_lora_into_unet( else: # Otherwise, we're dealing with the old format. This means the `state_dict` should only # contain the module names of the `unet` as its keys WITHOUT any prefix. - warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." - logger.warn(warn_message) + if not USE_PEFT_BACKEND: + warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." + logger.warn(warn_message) if USE_PEFT_BACKEND and len(state_dict.keys()) > 0: from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict @@ -800,29 +805,21 @@ def save_lora_weights( safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ - # Create a flat dictionary. state_dict = {} - # Populate the dictionary. - if unet_lora_layers is not None: - weights = ( - unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers - ) + def pack_weights(layers, prefix): + layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers + layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict - unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()} - state_dict.update(unet_lora_state_dict) + if not (unet_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.") - if text_encoder_lora_layers is not None: - weights = ( - text_encoder_lora_layers.state_dict() - if isinstance(text_encoder_lora_layers, torch.nn.Module) - else text_encoder_lora_layers - ) + if unet_lora_layers: + state_dict.update(pack_weights(unet_lora_layers, "unet")) - text_encoder_lora_state_dict = { - f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items() - } - state_dict.update(text_encoder_lora_state_dict) + if text_encoder_lora_layers: + state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) # Save the model cls.write_lora_layers( diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index fdc83237f9f9..992ae7d1b194 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -67,7 +67,7 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: current_lora_layer_sd = lora_layer.state_dict() for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items(): # The matrix name can either be "down" or "up". - lora_state_dict[f"unet.{name}.lora.{lora_layer_matrix_name}"] = lora_param + lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param return lora_state_dict