|
20 | 20 | from google.cloud import storage |
21 | 21 |
|
22 | 22 |
|
23 | | -def download_public_file(full_gs_url, destination_folder_name): |
| 23 | +def download_public_file( |
| 24 | + full_gs_url, destination_folder_name, single_file=False |
| 25 | +): |
24 | 26 | """Downloads a public blob from the bucket.""" |
25 | 27 | # bucket_name = "gs://your-bucket-name/path/to/file" |
26 | 28 | # destination_file_name = "local/path/to/file" |
27 | 29 |
|
28 | | - storage_client = storage.Client() |
| 30 | + storage_client = storage.Client.create_anonymous_client() |
29 | 31 | bucket_name = full_gs_url.split("/")[2] |
30 | | - source_blob_name = "/".join(full_gs_url.split("/")[3:]) |
| 32 | + source_blob_name = None |
| 33 | + dest_filename = None |
| 34 | + desired_file = None |
| 35 | + if single_file: |
| 36 | + |
| 37 | + desired_file = full_gs_url.split("/")[-1] |
| 38 | + source_blob_name = "/".join(full_gs_url.split("/")[3:-1]) |
| 39 | + destination_folder_name, dest_filename = os.path.split( |
| 40 | + destination_folder_name |
| 41 | + ) |
| 42 | + else: |
| 43 | + source_blob_name = "/".join(full_gs_url.split("/")[3:]) |
31 | 44 | bucket = storage_client.bucket(bucket_name) |
32 | 45 | blobs = bucket.list_blobs(prefix=source_blob_name) |
33 | 46 | if not os.path.exists(destination_folder_name): |
34 | 47 | os.mkdir(destination_folder_name) |
35 | 48 | for blob in blobs: |
36 | 49 | blob_name = blob.name.split("/")[-1] |
| 50 | + if single_file: |
| 51 | + if blob_name == desired_file: |
| 52 | + destination_filename = os.path.join( |
| 53 | + destination_folder_name, dest_filename |
| 54 | + ) |
| 55 | + blob.download_to_filename(destination_filename) |
| 56 | + else: |
| 57 | + continue |
| 58 | + |
37 | 59 | destination_filename = os.path.join(destination_folder_name, blob_name) |
38 | 60 | blob.download_to_filename(destination_filename) |
39 | 61 |
|
@@ -134,7 +156,9 @@ def download_model( |
134 | 156 | tank_url.rstrip("/") + "/" + model_dir_name + "/hash.npy" |
135 | 157 | ) |
136 | 158 | download_public_file( |
137 | | - gs_hash_url, os.path.join(model_dir, "upstream_hash.npy") |
| 159 | + gs_hash_url, |
| 160 | + os.path.join(model_dir, "upstream_hash.npy"), |
| 161 | + single_file=True, |
138 | 162 | ) |
139 | 163 | upstream_hash = str( |
140 | 164 | np.load(os.path.join(model_dir, "upstream_hash.npy")) |
|
0 commit comments