Skip to content

Commit 18f447d

Browse files
authored
fix hash comparison (huggingface#563)
Co-authored-by: dan <[email protected]>
1 parent d7e1078 commit 18f447d

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

shark/shark_downloader.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,42 @@
2020
from google.cloud import storage
2121

2222

23-
def download_public_file(full_gs_url, destination_folder_name):
23+
def download_public_file(
24+
full_gs_url, destination_folder_name, single_file=False
25+
):
2426
"""Downloads a public blob from the bucket."""
2527
# bucket_name = "gs://your-bucket-name/path/to/file"
2628
# destination_file_name = "local/path/to/file"
2729

28-
storage_client = storage.Client()
30+
storage_client = storage.Client.create_anonymous_client()
2931
bucket_name = full_gs_url.split("/")[2]
30-
source_blob_name = "/".join(full_gs_url.split("/")[3:])
32+
source_blob_name = None
33+
dest_filename = None
34+
desired_file = None
35+
if single_file:
36+
37+
desired_file = full_gs_url.split("/")[-1]
38+
source_blob_name = "/".join(full_gs_url.split("/")[3:-1])
39+
destination_folder_name, dest_filename = os.path.split(
40+
destination_folder_name
41+
)
42+
else:
43+
source_blob_name = "/".join(full_gs_url.split("/")[3:])
3144
bucket = storage_client.bucket(bucket_name)
3245
blobs = bucket.list_blobs(prefix=source_blob_name)
3346
if not os.path.exists(destination_folder_name):
3447
os.mkdir(destination_folder_name)
3548
for blob in blobs:
3649
blob_name = blob.name.split("/")[-1]
50+
if single_file:
51+
if blob_name == desired_file:
52+
destination_filename = os.path.join(
53+
destination_folder_name, dest_filename
54+
)
55+
blob.download_to_filename(destination_filename)
56+
else:
57+
continue
58+
3759
destination_filename = os.path.join(destination_folder_name, blob_name)
3860
blob.download_to_filename(destination_filename)
3961

@@ -134,7 +156,9 @@ def download_model(
134156
tank_url.rstrip("/") + "/" + model_dir_name + "/hash.npy"
135157
)
136158
download_public_file(
137-
gs_hash_url, os.path.join(model_dir, "upstream_hash.npy")
159+
gs_hash_url,
160+
os.path.join(model_dir, "upstream_hash.npy"),
161+
single_file=True,
138162
)
139163
upstream_hash = str(
140164
np.load(os.path.join(model_dir, "upstream_hash.npy"))

0 commit comments

Comments
 (0)