diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index ea657ccbdf63..d50a855c83c0 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -25,7 +25,7 @@ import safetensors import torch import torch.nn.functional as F -from huggingface_hub import hf_hub_download +from huggingface_hub import hf_hub_download, model_info from torch import nn from .utils import ( @@ -1021,6 +1021,13 @@ def lora_state_dict( weight_name is not None and weight_name.endswith(".safetensors") ): try: + # Here we're relaxing the loading check to enable more Inference API + # friendliness where sometimes, it's not at all possible to automatically + # determine `weight_name`. + if weight_name is None: + weight_name = cls._best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".safetensors" + ) model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, @@ -1041,7 +1048,12 @@ def lora_state_dict( # try loading non-safetensors weights model_file = None pass + if model_file is None: + if weight_name is None: + weight_name = cls._best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".bin" + ) model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME, @@ -1077,6 +1089,31 @@ def lora_state_dict( return state_dict, network_alphas + @classmethod + def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"): + targeted_files = [] + + if os.path.isfile(pretrained_model_name_or_path_or_dict): + return + elif os.path.isdir(pretrained_model_name_or_path_or_dict): + targeted_files = [ + f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension) + ] + else: + files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings + targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] + + if len(targeted_files) == 0: + return + + targeted_files = list(filter(lambda x: "scheduler" not in x and "optimizer" not in x, targeted_files)) + if len(targeted_files) > 1: + raise ValueError( + f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." + ) + weight_name = targeted_files[0] + return weight_name + @classmethod def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5): is_all_unet = all(k.startswith("lora_unet") for k in state_dict)