17
17
import sys
18
18
from pathlib import Path
19
19
from shark .parser import shark_args
20
+ from google .cloud import storage
20
21
21
22
22
- def resource_path (relative_path ):
23
- """Get absolute path to resource, works for dev and for PyInstaller"""
24
- base_path = getattr (
25
- sys , "_MEIPASS" , os .path .dirname (os .path .abspath (__file__ ))
26
- )
27
- return os .path .join (base_path , relative_path )
23
+ def download_public_file (full_gs_url , destination_file_name ):
24
+ """Downloads a public blob from the bucket."""
25
+ # bucket_name = "gs://your-bucket-name/path/to/file"
26
+ # destination_file_name = "local/path/to/file"
27
+
28
+ storage_client = storage .Client .create_anonymous_client ()
29
+ bucket_name = full_gs_url .split ("/" )[2 ]
30
+ source_blob_name = "/" .join (full_gs_url .split ("/" )[3 :])
31
+ bucket = storage_client .bucket (bucket_name )
32
+ blob = bucket .blob (source_blob_name )
33
+ blob .download_to_filename (destination_file_name )
28
34
29
35
30
- GSUTIL_PATH = resource_path ("gsutil" )
31
36
GSUTIL_FLAGS = ' -o "GSUtil:parallel_process_count=1" -m cp -r '
32
37
33
38
@@ -98,103 +103,23 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
98
103
99
104
100
105
# Downloads the torch model from gs://shark_tank dir.
101
- def download_torch_model (
102
- model_name , dynamic = False , tank_url = "gs://shark_tank/latest"
106
+ def download_model (
107
+ model_name ,
108
+ dynamic = False ,
109
+ tank_url = "gs://shark_tank/latest" ,
110
+ frontend = None ,
111
+ tuned = None ,
103
112
):
104
113
model_name = model_name .replace ("/" , "_" )
105
114
dyn_str = "_dynamic" if dynamic else ""
106
115
os .makedirs (WORKDIR , exist_ok = True )
107
- model_dir_name = model_name + "_torch"
108
-
109
- def gs_download_model ():
110
- gs_command = (
111
- GSUTIL_PATH
112
- + GSUTIL_FLAGS
113
- + tank_url
114
- + "/"
115
- + model_dir_name
116
- + ' "'
117
- + WORKDIR
118
- + '"'
119
- )
120
- if os .system (gs_command ) != 0 :
121
- raise Exception ("model not present in the tank. Contact Nod Admin" )
122
-
123
- if not check_dir_exists (model_dir_name , frontend = "torch" , dynamic = dyn_str ):
124
- gs_download_model ()
125
- else :
126
- if not _internet_connected ():
127
- print (
128
- "No internet connection. Using the model already present in the tank."
129
- )
130
- else :
131
- model_dir = os .path .join (WORKDIR , model_dir_name )
132
- local_hash = str (np .load (os .path .join (model_dir , "hash.npy" )))
133
- gs_hash = (
134
- GSUTIL_PATH
135
- + GSUTIL_FLAGS
136
- + tank_url
137
- + "/"
138
- + model_dir_name
139
- + "/hash.npy"
140
- + " "
141
- + os .path .join (model_dir , "upstream_hash.npy" )
142
- )
143
- if os .system (gs_hash ) != 0 :
144
- raise Exception ("hash of the model not present in the tank." )
145
- upstream_hash = str (
146
- np .load (os .path .join (model_dir , "upstream_hash.npy" ))
147
- )
148
- if local_hash != upstream_hash :
149
- if shark_args .update_tank == True :
150
- gs_download_model ()
151
- else :
152
- print (
153
- "Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
154
- )
155
-
156
- model_dir = os .path .join (WORKDIR , model_dir_name )
157
- with open (
158
- os .path .join (model_dir , model_name + dyn_str + "_torch.mlir" ),
159
- mode = "rb" ,
160
- ) as f :
161
- mlir_file = f .read ()
162
-
163
- function_name = str (np .load (os .path .join (model_dir , "function_name.npy" )))
164
- inputs = np .load (os .path .join (model_dir , "inputs.npz" ))
165
- golden_out = np .load (os .path .join (model_dir , "golden_out.npz" ))
166
-
167
- inputs_tuple = tuple ([inputs [key ] for key in inputs ])
168
- golden_out_tuple = tuple ([golden_out [key ] for key in golden_out ])
169
- return mlir_file , function_name , inputs_tuple , golden_out_tuple
170
-
171
-
172
- # Downloads the tflite model from gs://shark_tank dir.
173
- def download_tflite_model (
174
- model_name , dynamic = False , tank_url = "gs://shark_tank/latest"
175
- ):
176
- dyn_str = "_dynamic" if dynamic else ""
177
- os .makedirs (WORKDIR , exist_ok = True )
178
- model_dir_name = model_name + "_tflite"
179
-
180
- def gs_download_model ():
181
- gs_command = (
182
- GSUTIL_PATH
183
- + GSUTIL_FLAGS
184
- + tank_url
185
- + "/"
186
- + model_dir_name
187
- + ' "'
188
- + WORKDIR
189
- + '"'
190
- )
191
- if os .system (gs_command ) != 0 :
192
- raise Exception ("model not present in the tank. Contact Nod Admin" )
116
+ model_dir_name = model_name + "_" + frontend
117
+ full_gs_url = tank_url .rstrip ("/" ) + "/" + model_dir_name
193
118
194
119
if not check_dir_exists (
195
- model_dir_name , frontend = "tflite" , dynamic = dyn_str
120
+ model_dir_name , frontend = frontend , dynamic = dyn_str
196
121
):
197
- gs_download_model ( )
122
+ download_public_file ( full_gs_url , WORKDIR )
198
123
else :
199
124
if not _internet_connected ():
200
125
print (
@@ -203,104 +128,34 @@ def gs_download_model():
203
128
else :
204
129
model_dir = os .path .join (WORKDIR , model_dir_name )
205
130
local_hash = str (np .load (os .path .join (model_dir , "hash.npy" )))
206
- gs_hash = (
207
- GSUTIL_PATH
208
- + GSUTIL_FLAGS
209
- + tank_url
210
- + "/"
211
- + model_dir_name
212
- + "/hash.npy"
213
- + " "
214
- + os .path .join (model_dir , "upstream_hash.npy" )
131
+ gs_hash_url = (
132
+ tank_url .rstrip ("/" ) + "/" + model_dir_name + "/hash.npy"
215
133
)
216
- if os .system (gs_hash ) != 0 :
217
- raise Exception ("hash of the model not present in the tank." )
218
- upstream_hash = str (
219
- np .load (os .path .join (model_dir , "upstream_hash.npy" ))
134
+ download_public_file (
135
+ gs_hash_url , os .path .join (model_dir , "upstream_hash.npy" )
220
136
)
221
- if local_hash != upstream_hash :
222
- if shark_args .update_tank == True :
223
- gs_download_model ()
224
- else :
225
- print (
226
- "Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
227
- )
228
-
229
- model_dir = os .path .join (WORKDIR , model_dir_name )
230
- with open (
231
- os .path .join (model_dir , model_name + dyn_str + "_tflite.mlir" ),
232
- mode = "rb" ,
233
- ) as f :
234
- mlir_file = f .read ()
235
-
236
- function_name = str (np .load (os .path .join (model_dir , "function_name.npy" )))
237
- inputs = np .load (os .path .join (model_dir , "inputs.npz" ))
238
- golden_out = np .load (os .path .join (model_dir , "golden_out.npz" ))
239
-
240
- inputs_tuple = tuple ([inputs [key ] for key in inputs ])
241
- golden_out_tuple = tuple ([golden_out [key ] for key in golden_out ])
242
- return mlir_file , function_name , inputs_tuple , golden_out_tuple
243
-
244
-
245
- def download_tf_model (
246
- model_name , tuned = None , tank_url = "gs://shark_tank/latest"
247
- ):
248
- model_name = model_name .replace ("/" , "_" )
249
- os .makedirs (WORKDIR , exist_ok = True )
250
- model_dir_name = model_name + "_tf"
251
-
252
- def gs_download_model ():
253
- gs_command = (
254
- GSUTIL_PATH
255
- + GSUTIL_FLAGS
256
- + tank_url
257
- + "/"
258
- + model_dir_name
259
- + ' "'
260
- + WORKDIR
261
- + '"'
262
- )
263
- if os .system (gs_command ) != 0 :
264
- raise Exception ("model not present in the tank. Contact Nod Admin" )
265
-
266
- if not check_dir_exists (model_dir_name , frontend = "tf" ):
267
- gs_download_model ()
268
- else :
269
- if not _internet_connected ():
270
- print (
271
- "No internet connection. Using the model already present in the tank."
272
- )
273
- else :
274
- model_dir = os .path .join (WORKDIR , model_dir_name )
275
- local_hash = str (np .load (os .path .join (model_dir , "hash.npy" )))
276
- gs_hash = (
277
- GSUTIL_PATH
278
- + GSUTIL_FLAGS
279
- + tank_url
280
- + "/"
281
- + model_dir_name
282
- + "/hash.npy"
283
- + " "
284
- + os .path .join (model_dir , "upstream_hash.npy" )
285
- )
286
- if os .system (gs_hash ) != 0 :
287
- raise Exception ("hash of the model not present in the tank." )
288
137
upstream_hash = str (
289
138
np .load (os .path .join (model_dir , "upstream_hash.npy" ))
290
139
)
291
140
if local_hash != upstream_hash :
292
141
if shark_args .update_tank == True :
293
- gs_download_model ( )
142
+ download_public_file ( full_gs_url , WORKDIR )
294
143
else :
295
144
print (
296
145
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
297
146
)
298
147
299
148
model_dir = os .path .join (WORKDIR , model_dir_name )
300
- suffix = "_tf.mlir" if tuned is None else "_tf_" + tuned + ".mlir"
149
+ suffix = (
150
+ "_" + frontend + ".mlir"
151
+ if tuned is None
152
+ else "_" + frontend + "_" + tuned + ".mlir"
153
+ )
301
154
filename = os .path .join (model_dir , model_name + suffix )
302
155
if not os .path .isfile (filename ):
303
- filename = os .path .join (model_dir , model_name + "_tf.mlir" )
156
+ filename = os .path .join (
157
+ model_dir , model_name + "_" + frontend + ".mlir"
158
+ )
304
159
305
160
with open (filename , mode = "rb" ) as f :
306
161
mlir_file = f .read ()
0 commit comments