Skip to content

Commit a81334e

Browse files
authored
[LoRA] add an error message when dealing with _best_guess_weight_name ofline (#6184)
* add an error message when dealing with _best_guess_weight_name ofline * simplify condition
1 parent d704a73 commit a81334e

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/diffusers/loaders/lora.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import safetensors
1919
import torch
2020
from huggingface_hub import model_info
21+
from huggingface_hub.constants import HF_HUB_OFFLINE
2122
from huggingface_hub.utils import validate_hf_hub_args
2223
from packaging import version
2324
from torch import nn
@@ -229,7 +230,9 @@ def lora_state_dict(
229230
# determine `weight_name`.
230231
if weight_name is None:
231232
weight_name = cls._best_guess_weight_name(
232-
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
233+
pretrained_model_name_or_path_or_dict,
234+
file_extension=".safetensors",
235+
local_files_only=local_files_only,
233236
)
234237
model_file = _get_model_file(
235238
pretrained_model_name_or_path_or_dict,
@@ -255,7 +258,7 @@ def lora_state_dict(
255258
if model_file is None:
256259
if weight_name is None:
257260
weight_name = cls._best_guess_weight_name(
258-
pretrained_model_name_or_path_or_dict, file_extension=".bin"
261+
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
259262
)
260263
model_file = _get_model_file(
261264
pretrained_model_name_or_path_or_dict,
@@ -294,7 +297,12 @@ def lora_state_dict(
294297
return state_dict, network_alphas
295298

296299
@classmethod
297-
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
300+
def _best_guess_weight_name(
301+
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
302+
):
303+
if local_files_only or HF_HUB_OFFLINE:
304+
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
305+
298306
targeted_files = []
299307

300308
if os.path.isfile(pretrained_model_name_or_path_or_dict):

0 commit comments

Comments
 (0)