Skip to content

Commit 0645384

Browse files
Wauplinrwightman
authored andcommitted
Review huggingface_hub integration
1 parent e7bd97b commit 0645384

File tree

2 files changed

+8
-15
lines changed

2 files changed

+8
-15
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
torch>=1.7
22
torchvision
33
pyyaml
4-
huggingface_hub
4+
huggingface_hub>=0.17.0
55
safetensors>=0.2
66
numpy

timm/models/_hub.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,7 @@
3030
from timm.models._pretrained import filter_pretrained_cfg
3131

3232
try:
33-
from huggingface_hub import (
34-
create_repo, get_hf_file_metadata,
35-
hf_hub_download, hf_hub_url,
36-
repo_type_and_id_from_hf_id, upload_folder)
33+
from huggingface_hub import HfApi, hf_hub_download
3734
from huggingface_hub.utils import EntryNotFoundError
3835
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
3936
_has_hf_hub = True
@@ -414,20 +411,16 @@ def push_to_hf_hub(
414411
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
415412
Can be set to `"both"` in order to push both safe and unsafe weights.
416413
"""
414+
api = HfApi(token=token, library_name="timm", library_version=__version__)
415+
417416
# Create repo if it doesn't exist yet
418-
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
417+
repo_url = api.create_repo(repo_id, private=private, exist_ok=True)
419418

420-
# Infer complete repo_id from repo_url
421419
# Can be different from the input `repo_id` if repo_owner was implicit
422-
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
423-
repo_id = f"{repo_owner}/{repo_name}"
420+
repo_id = repo_url.repo_id
424421

425422
# Check if README file already exist in repo
426-
try:
427-
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
428-
has_readme = True
429-
except EntryNotFoundError:
430-
has_readme = False
423+
has_readme = api.file_exists(repo_id=repo_id, filename="README.md", revision=revision)
431424

432425
# Dump model and push to Hub
433426
with TemporaryDirectory() as tmpdir:
@@ -449,7 +442,7 @@ def push_to_hf_hub(
449442
readme_path.write_text(readme_text)
450443

451444
# Upload model and return
452-
return upload_folder(
445+
return api.upload_folder(
453446
repo_id=repo_id,
454447
folder_path=tmpdir,
455448
revision=revision,

0 commit comments

Comments
 (0)