Skip to content

Commit 8859853

Browse files
authored
Revert "Revert "find gsutil on linux (huggingface#557)" (huggingface#560)" (huggingface#561)
This reverts commit 3c46021.
1 parent 3c46021 commit 8859853

File tree

20 files changed

+98
-249
lines changed

20 files changed

+98
-249
lines changed

cpp/save_img.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22
import tensorflow as tf
33
from shark.shark_inference import SharkInference
4-
from shark.shark_downloader import download_tf_model
54

65

76
def load_and_preprocess_image(fname: str):

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ pyinstaller
66
tqdm
77

88
# SHARK Downloader
9-
gsutil
9+
google-cloud-storage
1010

1111
# Testing
1212
pytest

shark/examples/shark_inference/bloom_tank.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from shark.shark_inference import SharkInference
2-
from shark.shark_downloader import download_torch_model
2+
from shark.shark_downloader import download_model
33

4-
mlir_model, func_name, inputs, golden_out = download_torch_model("bloom")
4+
mlir_model, func_name, inputs, golden_out = download_model(
5+
"bloom", frontend="torch"
6+
)
57

68
shark_module = SharkInference(
79
mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor"

shark/examples/shark_inference/minilm_jit.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from shark.shark_inference import SharkInference
2-
from shark.shark_downloader import download_torch_model
2+
from shark.shark_downloader import download_model
33

44

5-
mlir_model, func_name, inputs, golden_out = download_torch_model(
6-
"microsoft/MiniLM-L12-H384-uncased"
5+
mlir_model, func_name, inputs, golden_out = download_model(
6+
"microsoft/MiniLM-L12-H384-uncased",
7+
frontend="torch",
78
)
89

910

shark/examples/shark_inference/resnet50_script.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torchvision import transforms
66
import sys
77
from shark.shark_inference import SharkInference
8-
from shark.shark_downloader import download_torch_model
8+
from shark.shark_downloader import download_model
99

1010

1111
################################## Preprocessing inputs and model ############
@@ -66,7 +66,9 @@ def forward(self, img):
6666

6767

6868
## Can pass any img or input to the forward module.
69-
mlir_model, func_name, inputs, golden_out = download_torch_model("resnet50")
69+
mlir_model, func_name, inputs, golden_out = download_model(
70+
"resnet50", frontend="torch"
71+
)
7072

7173
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
7274
shark_module.compile()

shark/examples/shark_inference/stable_diff_f16.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@
3737

3838

3939
def fp16_unet():
40-
from shark.shark_downloader import download_torch_model
40+
from shark.shark_downloader import download_model
4141

42-
mlir_model, func_name, inputs, golden_out = download_torch_model(
43-
"stable_diff_f16_18_OCT", tank_url="gs://shark_tank/prashant_nod"
42+
mlir_model, func_name, inputs, golden_out = download_model(
43+
"stable_diff_f16_18_OCT",
44+
tank_url="gs://shark_tank/prashant_nod",
45+
frontend="torch",
4446
)
4547
shark_module = SharkInference(
4648
mlir_model, func_name, device=args.device, mlir_dialect="linalg"

shark/examples/shark_inference/stable_diff_tf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818

1919
from shark.shark_inference import SharkInference
20-
from shark.shark_downloader import download_tf_model
20+
from shark.shark_downloader import download_model
2121
from PIL import Image
2222

2323
# pip install "git+https://github.com/keras-team/keras-cv.git"
@@ -75,8 +75,8 @@ def __init__(self, device="cpu", jit_compile=True):
7575
# Create models
7676
self.text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
7777

78-
mlir_model, func_name, inputs, golden_out = download_tf_model(
79-
"stable_diff", tank_url="gs://shark_tank/quinn"
78+
mlir_model, func_name, inputs, golden_out = download_model(
79+
"stable_diff", tank_url="gs://shark_tank/quinn", frontend="tf"
8080
)
8181
shark_module = SharkInference(
8282
mlir_model, func_name, device=device, mlir_dialect="mhlo"

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,12 @@ def _compile_module(shark_module, model_name, extra_args=[]):
3939

4040
# Downloads the model from shark_tank and returns the shark_module.
4141
def get_shark_model(tank_url, model_name, extra_args=[]):
42-
from shark.shark_downloader import download_torch_model
42+
from shark.shark_downloader import download_model
4343

44-
mlir_model, func_name, inputs, golden_out = download_torch_model(
45-
model_name, tank_url=tank_url
44+
mlir_model, func_name, inputs, golden_out = download_model(
45+
model_name,
46+
tank_url=tank_url,
47+
frontend="torch",
4648
)
4749
shark_module = SharkInference(
4850
mlir_model, func_name, device=args.device, mlir_dialect="linalg"

shark/examples/shark_inference/v_diffusion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from shark.shark_inference import SharkInference
2-
from shark.shark_downloader import download_torch_model
2+
from shark.shark_downloader import download_model
33

44

5-
mlir_model, func_name, inputs, golden_out = download_torch_model("v_diffusion")
5+
mlir_model, func_name, inputs, golden_out = download_model(
6+
"v_diffusion", frontend="torch"
7+
)
68

79
shark_module = SharkInference(
810
mlir_model, func_name, device="vulkan", mlir_dialect="linalg"

shark/shark_downloader.py

Lines changed: 35 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,22 @@
1717
import sys
1818
from pathlib import Path
1919
from shark.parser import shark_args
20+
from google.cloud import storage
2021

2122

22-
def resource_path(relative_path):
23-
"""Get absolute path to resource, works for dev and for PyInstaller"""
24-
base_path = getattr(
25-
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
26-
)
27-
return os.path.join(base_path, relative_path)
23+
def download_public_file(full_gs_url, destination_file_name):
24+
"""Downloads a public blob from the bucket."""
25+
# bucket_name = "gs://your-bucket-name/path/to/file"
26+
# destination_file_name = "local/path/to/file"
27+
28+
storage_client = storage.Client.create_anonymous_client()
29+
bucket_name = full_gs_url.split("/")[2]
30+
source_blob_name = "/".join(full_gs_url.split("/")[3:])
31+
bucket = storage_client.bucket(bucket_name)
32+
blob = bucket.blob(source_blob_name)
33+
blob.download_to_filename(destination_file_name)
2834

2935

30-
GSUTIL_PATH = resource_path("gsutil")
3136
GSUTIL_FLAGS = ' -o "GSUtil:parallel_process_count=1" -m cp -r '
3237

3338

@@ -98,103 +103,23 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
98103

99104

100105
# Downloads the torch model from gs://shark_tank dir.
101-
def download_torch_model(
102-
model_name, dynamic=False, tank_url="gs://shark_tank/latest"
106+
def download_model(
107+
model_name,
108+
dynamic=False,
109+
tank_url="gs://shark_tank/latest",
110+
frontend=None,
111+
tuned=None,
103112
):
104113
model_name = model_name.replace("/", "_")
105114
dyn_str = "_dynamic" if dynamic else ""
106115
os.makedirs(WORKDIR, exist_ok=True)
107-
model_dir_name = model_name + "_torch"
108-
109-
def gs_download_model():
110-
gs_command = (
111-
GSUTIL_PATH
112-
+ GSUTIL_FLAGS
113-
+ tank_url
114-
+ "/"
115-
+ model_dir_name
116-
+ ' "'
117-
+ WORKDIR
118-
+ '"'
119-
)
120-
if os.system(gs_command) != 0:
121-
raise Exception("model not present in the tank. Contact Nod Admin")
122-
123-
if not check_dir_exists(model_dir_name, frontend="torch", dynamic=dyn_str):
124-
gs_download_model()
125-
else:
126-
if not _internet_connected():
127-
print(
128-
"No internet connection. Using the model already present in the tank."
129-
)
130-
else:
131-
model_dir = os.path.join(WORKDIR, model_dir_name)
132-
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
133-
gs_hash = (
134-
GSUTIL_PATH
135-
+ GSUTIL_FLAGS
136-
+ tank_url
137-
+ "/"
138-
+ model_dir_name
139-
+ "/hash.npy"
140-
+ " "
141-
+ os.path.join(model_dir, "upstream_hash.npy")
142-
)
143-
if os.system(gs_hash) != 0:
144-
raise Exception("hash of the model not present in the tank.")
145-
upstream_hash = str(
146-
np.load(os.path.join(model_dir, "upstream_hash.npy"))
147-
)
148-
if local_hash != upstream_hash:
149-
if shark_args.update_tank == True:
150-
gs_download_model()
151-
else:
152-
print(
153-
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
154-
)
155-
156-
model_dir = os.path.join(WORKDIR, model_dir_name)
157-
with open(
158-
os.path.join(model_dir, model_name + dyn_str + "_torch.mlir"),
159-
mode="rb",
160-
) as f:
161-
mlir_file = f.read()
162-
163-
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
164-
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
165-
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
166-
167-
inputs_tuple = tuple([inputs[key] for key in inputs])
168-
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
169-
return mlir_file, function_name, inputs_tuple, golden_out_tuple
170-
171-
172-
# Downloads the tflite model from gs://shark_tank dir.
173-
def download_tflite_model(
174-
model_name, dynamic=False, tank_url="gs://shark_tank/latest"
175-
):
176-
dyn_str = "_dynamic" if dynamic else ""
177-
os.makedirs(WORKDIR, exist_ok=True)
178-
model_dir_name = model_name + "_tflite"
179-
180-
def gs_download_model():
181-
gs_command = (
182-
GSUTIL_PATH
183-
+ GSUTIL_FLAGS
184-
+ tank_url
185-
+ "/"
186-
+ model_dir_name
187-
+ ' "'
188-
+ WORKDIR
189-
+ '"'
190-
)
191-
if os.system(gs_command) != 0:
192-
raise Exception("model not present in the tank. Contact Nod Admin")
116+
model_dir_name = model_name + "_" + frontend
117+
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
193118

194119
if not check_dir_exists(
195-
model_dir_name, frontend="tflite", dynamic=dyn_str
120+
model_dir_name, frontend=frontend, dynamic=dyn_str
196121
):
197-
gs_download_model()
122+
download_public_file(full_gs_url, WORKDIR)
198123
else:
199124
if not _internet_connected():
200125
print(
@@ -203,104 +128,34 @@ def gs_download_model():
203128
else:
204129
model_dir = os.path.join(WORKDIR, model_dir_name)
205130
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
206-
gs_hash = (
207-
GSUTIL_PATH
208-
+ GSUTIL_FLAGS
209-
+ tank_url
210-
+ "/"
211-
+ model_dir_name
212-
+ "/hash.npy"
213-
+ " "
214-
+ os.path.join(model_dir, "upstream_hash.npy")
131+
gs_hash_url = (
132+
tank_url.rstrip("/") + "/" + model_dir_name + "/hash.npy"
215133
)
216-
if os.system(gs_hash) != 0:
217-
raise Exception("hash of the model not present in the tank.")
218-
upstream_hash = str(
219-
np.load(os.path.join(model_dir, "upstream_hash.npy"))
134+
download_public_file(
135+
gs_hash_url, os.path.join(model_dir, "upstream_hash.npy")
220136
)
221-
if local_hash != upstream_hash:
222-
if shark_args.update_tank == True:
223-
gs_download_model()
224-
else:
225-
print(
226-
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
227-
)
228-
229-
model_dir = os.path.join(WORKDIR, model_dir_name)
230-
with open(
231-
os.path.join(model_dir, model_name + dyn_str + "_tflite.mlir"),
232-
mode="rb",
233-
) as f:
234-
mlir_file = f.read()
235-
236-
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
237-
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
238-
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
239-
240-
inputs_tuple = tuple([inputs[key] for key in inputs])
241-
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
242-
return mlir_file, function_name, inputs_tuple, golden_out_tuple
243-
244-
245-
def download_tf_model(
246-
model_name, tuned=None, tank_url="gs://shark_tank/latest"
247-
):
248-
model_name = model_name.replace("/", "_")
249-
os.makedirs(WORKDIR, exist_ok=True)
250-
model_dir_name = model_name + "_tf"
251-
252-
def gs_download_model():
253-
gs_command = (
254-
GSUTIL_PATH
255-
+ GSUTIL_FLAGS
256-
+ tank_url
257-
+ "/"
258-
+ model_dir_name
259-
+ ' "'
260-
+ WORKDIR
261-
+ '"'
262-
)
263-
if os.system(gs_command) != 0:
264-
raise Exception("model not present in the tank. Contact Nod Admin")
265-
266-
if not check_dir_exists(model_dir_name, frontend="tf"):
267-
gs_download_model()
268-
else:
269-
if not _internet_connected():
270-
print(
271-
"No internet connection. Using the model already present in the tank."
272-
)
273-
else:
274-
model_dir = os.path.join(WORKDIR, model_dir_name)
275-
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
276-
gs_hash = (
277-
GSUTIL_PATH
278-
+ GSUTIL_FLAGS
279-
+ tank_url
280-
+ "/"
281-
+ model_dir_name
282-
+ "/hash.npy"
283-
+ " "
284-
+ os.path.join(model_dir, "upstream_hash.npy")
285-
)
286-
if os.system(gs_hash) != 0:
287-
raise Exception("hash of the model not present in the tank.")
288137
upstream_hash = str(
289138
np.load(os.path.join(model_dir, "upstream_hash.npy"))
290139
)
291140
if local_hash != upstream_hash:
292141
if shark_args.update_tank == True:
293-
gs_download_model()
142+
download_public_file(full_gs_url, WORKDIR)
294143
else:
295144
print(
296145
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
297146
)
298147

299148
model_dir = os.path.join(WORKDIR, model_dir_name)
300-
suffix = "_tf.mlir" if tuned is None else "_tf_" + tuned + ".mlir"
149+
suffix = (
150+
"_" + frontend + ".mlir"
151+
if tuned is None
152+
else "_" + frontend + "_" + tuned + ".mlir"
153+
)
301154
filename = os.path.join(model_dir, model_name + suffix)
302155
if not os.path.isfile(filename):
303-
filename = os.path.join(model_dir, model_name + "_tf.mlir")
156+
filename = os.path.join(
157+
model_dir, model_name + "_" + frontend + ".mlir"
158+
)
304159

305160
with open(filename, mode="rb") as f:
306161
mlir_file = f.read()

0 commit comments

Comments
 (0)