Skip to content

Commit 7f37599

Browse files
Eliasj42Elias Joseph
andauthored
Added a dispatch benchmarking tool (huggingface#441)
To produce benchmarks of individual dispatches, you can add --dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir> to your command line argument. Co-authored-by: Elias Joseph <[email protected]>
1 parent 77c9a2c commit 7f37599

File tree

6 files changed

+221
-2
lines changed

6 files changed

+221
-2
lines changed

README.md

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,33 @@ pytest tank/test_models.py -k "MiniLM"
121121
<details>
122122
<summary>Testing and Benchmarks</summary>
123123

124+
## Benchmarking Dispatches
125+
126+
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your command line argument.
127+
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
128+
129+
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `SharkInference`, and the benchmarks will be generated when compiled. E.G:
130+
131+
```
132+
shark_module = SharkInference(
133+
mlir_model,
134+
func_name,
135+
device=args.device,
136+
mlir_dialect="tm_tensor",
137+
dispatch_benchmarks="all",
138+
dispatch_benchmarks_dir="results"
139+
)
140+
```
141+
142+
Output will include:
143+
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
144+
- An .mlir file containing the dispatch benchmark
145+
- A compiled .vmfb file containing the dispatch benchmark
146+
- An .mlir file containing just the hal executable
147+
- A compiled .vmfb file of the hal executable
148+
- A .txt file containing benchmark output
149+
150+
124151
See tank/README.md for instructions on how to run model tests and benchmarks from the SHARK tank.
125152

126153
</details>
@@ -175,7 +202,6 @@ result = shark_module.forward((arg0, arg1))
175202
```
176203
</details>
177204

178-
179205
## Supported and Validated Models
180206

181207
SHARK is maintained to support the latest innovations in ML Models:

shark/examples/shark_inference/resnet50_script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def forward(self, img):
6969
mlir_model, func_name, inputs, golden_out = download_torch_model("resnet50")
7070

7171
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
72-
# shark_module.compile()
72+
shark_module.compile()
7373
path = shark_module.save_module()
7474
shark_module.load_module(path)
7575
result = shark_module.forward((img.detach().numpy(),))

shark/iree_utils/benchmark_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,31 @@ def build_benchmark_args(
7878
return benchmark_cl
7979

8080

81+
def build_benchmark_args_non_tensor_input(
82+
input_file: str,
83+
device: str,
84+
inputs: tuple,
85+
mlir_dialect: str,
86+
function_name: str,
87+
):
88+
"""
89+
Inputs: input_file leading to vmfb, input_tensor to function, target device,
90+
and whether it is training or not.
91+
Outputs: string that execute benchmark-module on target model.
92+
"""
93+
path = benchmark_module.__path__[0]
94+
benchmarker_path = os.path.join(path, "..", "..", "iree-benchmark-module")
95+
benchmark_cl = [benchmarker_path, f"--module_file={input_file}"]
96+
# TODO: The function named can be passed as one of the args.
97+
benchmark_cl.append(f"--entry_function={function_name}")
98+
benchmark_cl.append(f"--device={IREE_DEVICE_MAP[device]}")
99+
for input in inputs:
100+
benchmark_cl.append(f"--function_input={input}")
101+
time_extractor = "| awk 'END{{print $2 $3}}'"
102+
benchmark_cl.append(time_extractor)
103+
return benchmark_cl
104+
105+
81106
def run_benchmark_module(benchmark_cl):
82107
"""
83108
Run benchmark command, extract result and return iteration/seconds.

shark/iree_utils/compile_utils.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
import iree.runtime as ireert
1515
import iree.compiler as ireec
1616
from shark.iree_utils._common import IREE_DEVICE_MAP, IREE_TARGET_MAP
17+
from shark.iree_utils.benchmark_utils import *
1718
import numpy as np
1819
import os
20+
import re
1921

2022
# Get the iree-compile arguments given device.
2123
def get_iree_device_args(device, extra_args=[]):
@@ -62,6 +64,125 @@ def get_iree_common_args():
6264
]
6365

6466

67+
def create_dispatch_dirs(bench_dir, device):
68+
bench_dir_path = bench_dir.split("/")
69+
bench_dir_path[-1] = "temp_" + bench_dir_path[-1]
70+
tmp_bench_dir = "/".join(bench_dir_path)
71+
for f_ in os.listdir(bench_dir):
72+
if os.path.isfile(f"{bench_dir}/{f_}"):
73+
dir_name = re.sub("\.\S*$", "", f_)
74+
if os.path.exists(f"{bench_dir}/{dir_name}"):
75+
os.system(f"rm -rf {bench_dir}/{dir_name}")
76+
os.system(f"mkdir {bench_dir}/{dir_name}")
77+
os.system(f"mv {bench_dir}/{f_} {bench_dir}/{dir_name}/{f_}")
78+
for f_ in os.listdir(tmp_bench_dir):
79+
if os.path.isfile(f"{tmp_bench_dir}/{f_}"):
80+
dir_name = ""
81+
for d_ in os.listdir(bench_dir):
82+
if re.search(f"{d_}(?=\D)", f_):
83+
dir_name = d_
84+
if dir_name != "":
85+
os.system(
86+
f"mv {tmp_bench_dir}/{f_} {bench_dir}/{dir_name}/{dir_name}_benchmark.mlir"
87+
)
88+
89+
90+
def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
91+
dispatch_list = []
92+
all_dispatches = False
93+
94+
if dispatch_benchmarks.lower().strip() == "all":
95+
all_dispatches = True
96+
else:
97+
try:
98+
dispatch_list = [
99+
int(dispatch_index)
100+
for dispatch_index in dispatch_benchmarks.split(" ")
101+
]
102+
except:
103+
print("ERROR: Invalid dispatch benchmarks")
104+
return None
105+
for d_ in os.listdir(bench_dir):
106+
in_dispatches = False
107+
for dispatch in dispatch_list:
108+
if str(dispatch) in d_:
109+
in_dispatches = True
110+
if all_dispatches or in_dispatches:
111+
for f_ in os.listdir(f"{bench_dir}/{d_}"):
112+
113+
if "benchmark.mlir" in f_:
114+
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
115+
module = dispatch_file.read()
116+
dispatch_file.close()
117+
118+
flatbuffer_blob = ireec.compile_str(
119+
module, target_backends=[IREE_TARGET_MAP[device]]
120+
)
121+
122+
vmfb_file = open(
123+
f"{bench_dir}/{d_}/{d_}_benchmark.vmfb", "wb"
124+
)
125+
vmfb_file.write(flatbuffer_blob)
126+
vmfb_file.close()
127+
128+
config = ireert.Config(IREE_DEVICE_MAP[device])
129+
vm_module = ireert.VmModule.from_flatbuffer(
130+
config.vm_instance, flatbuffer_blob
131+
)
132+
133+
benchmark_cl = build_benchmark_args_non_tensor_input(
134+
input_file=f"{bench_dir}/{d_}/{d_}_benchmark.vmfb",
135+
device=device,
136+
inputs=(0,),
137+
mlir_dialect="linalg",
138+
function_name=vm_module.function_names[0],
139+
)
140+
141+
benchmark_bash = open(
142+
f"{bench_dir}/{d_}/{d_}_benchmark.sh", "w+"
143+
)
144+
benchmark_bash.write("#!/bin/bash\n")
145+
benchmark_bash.write(" ".join(benchmark_cl))
146+
benchmark_bash.close()
147+
148+
benchmark_data = run_benchmark_module(benchmark_cl)
149+
150+
benchmark_file = open(
151+
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
152+
)
153+
benchmark_file.write(f"DISPATCH: {d_}\n")
154+
benchmark_file.write(str(benchmark_data) + "\n")
155+
benchmark_file.write(
156+
"SHARK BENCHMARK RESULT: "
157+
+ str(1 / (benchmark_data * 0.001))
158+
+ "\n"
159+
)
160+
benchmark_file.close()
161+
162+
elif ".mlir" in f_ and "benchmark" not in f_:
163+
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
164+
module = dispatch_file.read()
165+
dispatch_file.close()
166+
167+
module = re.sub(
168+
"hal.executable private",
169+
"hal.executable public",
170+
module,
171+
)
172+
173+
flatbuffer_blob = ireec.compile_str(
174+
module,
175+
target_backends=[IREE_TARGET_MAP[device]],
176+
extra_args=["--compile-mode=hal-executable"],
177+
)
178+
179+
spirv_file = open(
180+
f"{bench_dir}/{d_}/{d_}_spirv.vmfb", "wb"
181+
)
182+
spirv_file.write(flatbuffer_blob)
183+
spirv_file.close()
184+
185+
65186
def compile_module_to_flatbuffer(
66187
module, device, frontend, func_name, model_config_path, extra_args
67188
):

shark/parser.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,16 @@ def dir_file(path):
9393
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
9494
)
9595

96+
parser.add_argument(
97+
"--dispatch_benchmarks",
98+
default=None,
99+
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
100+
)
101+
102+
parser.add_argument(
103+
"--dispatch_benchmarks_dir",
104+
default="temp_dispatch_benchmarks",
105+
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
106+
)
107+
96108
shark_args, unknown = parser.parse_known_args()

shark/shark_inference.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from shark.iree_utils.compile_utils import (
1313
export_iree_module_to_vmfb,
1414
load_flatbuffer,
15+
create_dispatch_dirs,
16+
compile_benchmark_dirs,
1517
)
1618
import os
1719
from shark.shark_runner import SharkRunner
@@ -68,17 +70,41 @@ def __init__(
6870
device: str = "none",
6971
mlir_dialect: str = "linalg",
7072
is_benchmark: bool = False,
73+
dispatch_benchmark: str = None,
74+
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
7175
):
7276
self.mlir_module = mlir_module
7377
self.function_name = function_name
7478
self.device = shark_args.device if device == "none" else device
7579
self.mlir_dialect = mlir_dialect
7680
self.is_benchmark = is_benchmark
81+
self.dispatch_benchmarks = (
82+
shark_args.dispatch_benchmarks
83+
if dispatch_benchmark is None
84+
else dispatch_benchmark
85+
)
86+
self.dispatch_benchmarks_dir = (
87+
shark_args.dispatch_benchmarks_dir
88+
if dispatch_benchmark_dir == "temp_dispatch_benchmarks"
89+
else dispatch_benchmark_dir
90+
)
7791

7892
self.shark_runner = None
7993

8094
def compile(self, extra_args=[]):
8195

96+
if self.dispatch_benchmarks is not None:
97+
extra_args.append(
98+
f"--iree-hal-dump-executable-sources-to={self.dispatch_benchmarks_dir}"
99+
)
100+
temp_dir = self.dispatch_benchmarks_dir.split("/")
101+
temp_dir[-1] = "temp_" + temp_dir[-1]
102+
temp_dir = "/".join(temp_dir)
103+
self.temp_dispatch_benchmarks_dir = temp_dir
104+
extra_args.append(
105+
f"--iree-hal-dump-executable-benchmarks-to={self.temp_dispatch_benchmarks_dir}"
106+
)
107+
82108
if self.is_benchmark == True:
83109
from shark.shark_benchmark_runner import SharkBenchmarkRunner
84110

@@ -99,6 +125,15 @@ def compile(self, extra_args=[]):
99125
extra_args=extra_args,
100126
)
101127

128+
if self.dispatch_benchmarks is not None:
129+
create_dispatch_dirs(self.dispatch_benchmarks_dir, self.device)
130+
compile_benchmark_dirs(
131+
self.dispatch_benchmarks_dir,
132+
self.device,
133+
self.dispatch_benchmarks,
134+
)
135+
os.system(f"rm -rf {self.temp_dispatch_benchmarks_dir}")
136+
102137
# inputs are considered to be tuple of np.array.
103138
def forward(self, inputs: tuple):
104139
return self.shark_runner.run(inputs)

0 commit comments

Comments
 (0)