Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 72 additions & 44 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,49 +265,64 @@ def get_config(
return config


def get_hf_file_to_dict(file_name: str,
model: Union[str, Path],
revision: Optional[str] = 'main'):
def try_get_local_file(model: Union[str, Path],
file_name: str,
revision: Optional[str] = 'main') -> Optional[Path]:
file_path = Path(model) / file_name
if file_path.is_file():
return file_path
else:
cached_filepath = try_to_load_from_cache(repo_id=model,
filename=file_name,
revision=revision)
if isinstance(cached_filepath, str):
return Path(cached_filepath)
return None


def get_hf_file_to_dict(
file_name: str,
model: Union[str, Path],
revision: Optional[str] = 'main') -> Optional[Dict[str, Any]]:
"""
Downloads a file from the Hugging Face Hub and returns
Downloads a file from the Hugging Face Hub and returns
its contents as a dictionary.

Parameters:
- file_name (str): The name of the file to download.
- model (str): The name of the model on the Hugging Face Hub.
- revision (str): The specific version of the model.
- revision (str): The specific version of the model.

Returns:
- config_dict (dict): A dictionary containing
- config_dict (dict): A dictionary containing
the contents of the downloaded file.
"""
file_path = Path(model) / file_name

if file_or_path_exists(model=model,
config_name=file_name,
revision=revision):
file_path = try_get_local_file(model=model,
file_name=file_name,
revision=revision)

if not file_path.is_file():
try:
hf_hub_file = hf_hub_download(model,
file_name,
revision=revision)
except (RepositoryNotFoundError, RevisionNotFoundError,
EntryNotFoundError, LocalEntryNotFoundError) as e:
logger.debug("File or repository not found in hf_hub_download",
e)
return None
except HfHubHTTPError as e:
logger.warning(
"Cannot connect to Hugging Face Hub. Skipping file "
"download for '%s':",
file_name,
exc_info=e)
return None
file_path = Path(hf_hub_file)
if file_path is None and file_or_path_exists(
model=model, config_name=file_name, revision=revision):
try:
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
except (RepositoryNotFoundError, RevisionNotFoundError,
EntryNotFoundError, LocalEntryNotFoundError) as e:
logger.debug("File or repository not found in hf_hub_download", e)
return None
except HfHubHTTPError as e:
logger.warning(
"Cannot connect to Hugging Face Hub. Skipping file "
"download for '%s':",
file_name,
exc_info=e)
return None
file_path = Path(hf_hub_file)

if file_path is not None and file_path.is_file():
with open(file_path) as file:
return json.load(file)

return None


Expand Down Expand Up @@ -378,21 +393,21 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
return None


def get_sentence_transformer_tokenizer_config(model: str,
revision: Optional[str] = 'main'
):
def get_sentence_transformer_tokenizer_config(
model: str,
revision: Optional[str] = 'main') -> Optional[Dict[str, Any]]:
"""
Returns the tokenization configuration dictionary for a
Returns the tokenization configuration dictionary for a
given Sentence Transformer BERT model.

Parameters:
- model (str): The name of the Sentence Transformer
- model (str): The name of the Sentence Transformer
BERT model.
- revision (str, optional): The revision of the m
odel to use. Defaults to 'main'.

Returns:
- dict: A dictionary containing the configuration parameters
- dict: A dictionary containing the configuration parameters
for the Sentence Transformer BERT model.
"""
sentence_transformer_config_files = [
Expand All @@ -404,20 +419,33 @@ def get_sentence_transformer_tokenizer_config(model: str,
"sentence_xlm-roberta_config.json",
"sentence_xlnet_config.json",
]
try:
# If model is on HuggingfaceHub, get the repo files
repo_files = list_repo_files(model, revision=revision, token=HF_TOKEN)
except Exception as e:
logger.debug("Error getting repo files", e)
repo_files = []

encoder_dict = None
for config_name in sentence_transformer_config_files:
if config_name in repo_files or Path(model).exists():
encoder_dict = get_hf_file_to_dict(config_name, model, revision)

for config_file in sentence_transformer_config_files:
if try_get_local_file(model=model,
file_name=config_file,
revision=revision) is not None:
encoder_dict = get_hf_file_to_dict(config_file, model, revision)
if encoder_dict:
break

if not encoder_dict:
try:
# If model is on HuggingfaceHub, get the repo files
repo_files = list_repo_files(model,
revision=revision,
token=HF_TOKEN)
except Exception as e:
logger.debug("Error getting repo files", e)
repo_files = []

for config_name in sentence_transformer_config_files:
if config_name in repo_files:
encoder_dict = get_hf_file_to_dict(config_name, model,
revision)
if encoder_dict:
break

if not encoder_dict:
return None

Expand Down
Loading