Skip to content

Commit 6bf991e

Browse files
authored
adding more robust main.py testing (huggingface#889)
Co-authored-by: dan <[email protected]>
1 parent 9644e78 commit 6bf991e

File tree

4 files changed

+93
-16
lines changed

4 files changed

+93
-16
lines changed

.github/workflows/test-models.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ jobs:
115115
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="${GITHUB_WORKSPACE}/shark_tmp/shark_cache" -k cuda
116116
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
117117
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
118-
sh build_tools/stable_diff_main_test.sh
118+
# Disabled due to black image bug
119+
# python build_tools/stable_diffusion_testing.py --device=cuda
119120
120121
- name: Validate Vulkan Models (MacOS)
121122
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
@@ -135,3 +136,4 @@ jobs:
135136
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
136137
source shark.venv/bin/activate
137138
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="${GITHUB_WORKSPACE}/shark_tmp/shark_cache" -k vulkan
139+
python build_tools/stable_diffusion_testing.py --device=vulkan

build_tools/image_comparison.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import argparse
2-
import torchvision
2+
from PIL import Image
33
import numpy as np
44

55
import requests
@@ -22,20 +22,24 @@ def get_image(url, local_filename):
2222
if res.status_code == 200:
2323
with open(local_filename, "wb") as f:
2424
shutil.copyfileobj(res.raw, f)
25-
return torchvision.io.read_image(local_filename).numpy()
2625

2726

28-
if __name__ == "__main__":
29-
args = parser.parse_args()
30-
new = torchvision.io.read_image(args.newfile).numpy() / 255.0
31-
tempfile_name = os.path.join(os.getcwd(), "golden.png")
32-
golden = get_image(args.golden_url, tempfile_name) / 255.0
27+
def compare_images(new_filename, golden_filename):
28+
new = np.array(Image.open(new_filename)) / 255.0
29+
golden = np.array(Image.open(golden_filename)) / 255.0
3330
diff = np.abs(new - golden)
3431
mean = np.mean(diff)
35-
if not mean < 0.2:
32+
if mean > 0.01:
3633
subprocess.run(
37-
["gsutil", "cp", args.newfile, "gs://shark_tank/testdata/builder/"]
34+
["gsutil", "cp", new_filename, "gs://shark_tank/testdata/builder/"]
3835
)
3936
raise SystemExit("new and golden not close")
4037
else:
4138
print("SUCCESS")
39+
40+
41+
if __name__ == "__main__":
42+
args = parser.parse_args()
43+
tempfile_name = os.path.join(os.getcwd(), "golden.png")
44+
get_image(args.golden_url, tempfile_name)
45+
compare_images(args.newfile, tempfile_name)

build_tools/stable_diff_main_test.sh

Lines changed: 0 additions & 6 deletions
This file was deleted.
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import os
2+
import subprocess
3+
from shark.examples.shark_inference.stable_diffusion.resources import (
4+
get_json_file,
5+
)
6+
from shark.shark_downloader import download_public_file
7+
from image_comparison import compare_images
8+
import argparse
9+
from glob import glob
10+
import shutil
11+
12+
model_config_dicts = get_json_file(
13+
os.path.join(
14+
os.getcwd(),
15+
"shark/examples/shark_inference/stable_diffusion/resources/model_config.json",
16+
)
17+
)
18+
19+
20+
def test_loop(device="vulkan", beta=False, extra_flags=[]):
21+
# Get golden values from tank
22+
shutil.rmtree("./test_images", ignore_errors=True)
23+
os.mkdir("./test_images")
24+
os.mkdir("./test_images/golden")
25+
hf_model_names = model_config_dicts[0].values()
26+
tuned_options = ["--no-use_tuned"] #'use_tuned']
27+
devices = ["vulkan"]
28+
if beta:
29+
extra_flags.append("--beta_models=True")
30+
for model_name in hf_model_names:
31+
for use_tune in tuned_options:
32+
command = [
33+
"python",
34+
"shark/examples/shark_inference/stable_diffusion/main.py",
35+
"--device=" + device,
36+
"--output_dir=./test_images/" + model_name,
37+
"--hf_model_id=" + model_name,
38+
use_tune,
39+
]
40+
command += extra_flags
41+
generated_image = not subprocess.call(
42+
command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
43+
)
44+
if generated_image:
45+
os.makedirs(
46+
"./test_images/golden/" + model_name, exist_ok=True
47+
)
48+
download_public_file(
49+
"gs://shark_tank/testdata/golden/" + model_name,
50+
"./test_images/golden/" + model_name,
51+
)
52+
comparison = [
53+
"python",
54+
"build_tools/image_comparison.py",
55+
"--golden_url=gs://shark_tank/testdata/golden/"
56+
+ model_name
57+
+ "/*.png",
58+
"--newfile=./test_images/" + model_name + "/*.png",
59+
]
60+
test_file = glob("./test_images/" + model_name + "/*.png")[0]
61+
golden_path = "./test_images/golden/" + model_name + "/*.png"
62+
golden_file = glob(golden_path)[0]
63+
compare_images(test_file, golden_file)
64+
65+
66+
parser = argparse.ArgumentParser()
67+
68+
parser.add_argument("-d", "--device", default="vulkan")
69+
parser.add_argument(
70+
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
71+
)
72+
73+
74+
if __name__ == "__main__":
75+
args = parser.parse_args()
76+
print(args)
77+
test_loop(args.device, args.beta, [])

0 commit comments

Comments
 (0)