Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import safetensors
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from huggingface_hub import hf_hub_download, model_info
from torch import nn

from .utils import (
Expand Down Expand Up @@ -1021,6 +1021,13 @@ def lora_state_dict(
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
# Here we're relaxing the loading check to enable more Inference API
# friendliness where sometimes, it's not at all possible to automatically
# determine `weight_name`.
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
Expand All @@ -1041,7 +1048,12 @@ def lora_state_dict(
# try loading non-safetensors weights
model_file = None
pass

if model_file is None:
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin"
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME,
Expand Down Expand Up @@ -1077,6 +1089,31 @@ def lora_state_dict(

return state_dict, network_alphas

@classmethod
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
targeted_files = []

if os.path.isfile(pretrained_model_name_or_path_or_dict):
return
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
targeted_files = [
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
]
else:
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]

if len(targeted_files) == 0:
return

targeted_files = list(filter(lambda x: "scheduler" not in x and "optimizer" not in x, targeted_files))
if len(targeted_files) > 1:
raise ValueError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT if this is a warning not a raise, and we pick the first one? Not a strong opinion tho.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to raise an error than leaving it silent IMO. This promotes a sense to structuring the Hub repos too IMO.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair. I thought about doing just a warning and still trying the best guess, like "we picked this, if you meant another one, improve your repo"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. But I think we need to meet at a common point and I would prefer the current design to promote what I said in the earlier comment.

f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
)
weight_name = targeted_files[0]
return weight_name

@classmethod
def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
is_all_unet = all(k.startswith("lora_unet") for k in state_dict)
Expand Down