Skip to content

Commit 88b93c9

Browse files
committed
add an error message when dealing with _best_guess_weight_name ofline
1 parent 93ea26f commit 88b93c9

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/diffusers/loaders/lora.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import safetensors
1919
import torch
2020
from huggingface_hub import model_info
21+
from huggingface_hub.constants import HF_HUB_OFFLINE
2122
from huggingface_hub.utils import validate_hf_hub_args
2223
from packaging import version
2324
from torch import nn
@@ -229,7 +230,9 @@ def lora_state_dict(
229230
# determine `weight_name`.
230231
if weight_name is None:
231232
weight_name = cls._best_guess_weight_name(
232-
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
233+
pretrained_model_name_or_path_or_dict,
234+
file_extension=".safetensors",
235+
local_files_only=local_files_only,
233236
)
234237
model_file = _get_model_file(
235238
pretrained_model_name_or_path_or_dict,
@@ -255,7 +258,7 @@ def lora_state_dict(
255258
if model_file is None:
256259
if weight_name is None:
257260
weight_name = cls._best_guess_weight_name(
258-
pretrained_model_name_or_path_or_dict, file_extension=".bin"
261+
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
259262
)
260263
model_file = _get_model_file(
261264
pretrained_model_name_or_path_or_dict,
@@ -294,7 +297,12 @@ def lora_state_dict(
294297
return state_dict, network_alphas
295298

296299
@classmethod
297-
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
300+
def _best_guess_weight_name(
301+
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
302+
):
303+
if local_files_only or HF_HUB_OFFLINE:
304+
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
305+
298306
targeted_files = []
299307

300308
if os.path.isfile(pretrained_model_name_or_path_or_dict):
@@ -303,7 +311,7 @@ def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_ext
303311
targeted_files = [
304312
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
305313
]
306-
else:
314+
elif not local_files_only:
307315
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
308316
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
309317
if len(targeted_files) == 0:

0 commit comments

Comments
 (0)