1818import safetensors
1919import torch
2020from huggingface_hub import model_info
21+ from huggingface_hub .constants import HF_HUB_OFFLINE
2122from huggingface_hub .utils import validate_hf_hub_args
2223from packaging import version
2324from torch import nn
@@ -229,7 +230,9 @@ def lora_state_dict(
229230 # determine `weight_name`.
230231 if weight_name is None :
231232 weight_name = cls ._best_guess_weight_name (
232- pretrained_model_name_or_path_or_dict , file_extension = ".safetensors"
233+ pretrained_model_name_or_path_or_dict ,
234+ file_extension = ".safetensors" ,
235+ local_files_only = local_files_only ,
233236 )
234237 model_file = _get_model_file (
235238 pretrained_model_name_or_path_or_dict ,
@@ -255,7 +258,7 @@ def lora_state_dict(
255258 if model_file is None :
256259 if weight_name is None :
257260 weight_name = cls ._best_guess_weight_name (
258- pretrained_model_name_or_path_or_dict , file_extension = ".bin"
261+ pretrained_model_name_or_path_or_dict , file_extension = ".bin" , local_files_only = local_files_only
259262 )
260263 model_file = _get_model_file (
261264 pretrained_model_name_or_path_or_dict ,
@@ -294,7 +297,12 @@ def lora_state_dict(
294297 return state_dict , network_alphas
295298
296299 @classmethod
297- def _best_guess_weight_name (cls , pretrained_model_name_or_path_or_dict , file_extension = ".safetensors" ):
300+ def _best_guess_weight_name (
301+ cls , pretrained_model_name_or_path_or_dict , file_extension = ".safetensors" , local_files_only = False
302+ ):
303+ if local_files_only or HF_HUB_OFFLINE :
304+ raise ValueError ("When using the offline mode, you must specify a `weight_name`." )
305+
298306 targeted_files = []
299307
300308 if os .path .isfile (pretrained_model_name_or_path_or_dict ):
0 commit comments