Skip to content

Commit 32d3f4b

Browse files
Eliasj42Elias Joseph
andauthored
added ordered benchmarks to dispatch benchmarking tool (huggingface#450)
* added ordered benchmarks to dispatch benchmarking tool * saved changes * updated readme Co-authored-by: Elias Joseph <[email protected]>
1 parent 18689af commit 32d3f4b

File tree

2 files changed

+96
-79
lines changed

2 files changed

+96
-79
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ shark_module = SharkInference(
154154
```
155155

156156
Output will include:
157+
- An ordered list ordered-dispatches.txt of all the dispatches with their runtime
157158
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
158159
- An .mlir file containing the dispatch benchmark
159160
- A compiled .vmfb file containing the dispatch benchmark

shark/iree_utils/compile_utils.py

Lines changed: 95 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,12 @@ def get_iree_common_args():
6565

6666

6767
def create_dispatch_dirs(bench_dir, device):
68+
protected_files = ["ordered-dispatches.txt"]
6869
bench_dir_path = bench_dir.split("/")
6970
bench_dir_path[-1] = "temp_" + bench_dir_path[-1]
7071
tmp_bench_dir = "/".join(bench_dir_path)
7172
for f_ in os.listdir(bench_dir):
72-
if os.path.isfile(f"{bench_dir}/{f_}"):
73+
if os.path.isfile(f"{bench_dir}/{f_}") and f_ not in protected_files:
7374
dir_name = re.sub("\.\S*$", "", f_)
7475
if os.path.exists(f"{bench_dir}/{dir_name}"):
7576
os.system(f"rm -rf {bench_dir}/{dir_name}")
@@ -88,6 +89,7 @@ def create_dispatch_dirs(bench_dir, device):
8889

8990

9091
def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
92+
benchmark_runtimes = {}
9193
dispatch_list = []
9294
all_dispatches = False
9395

@@ -103,84 +105,98 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
103105
print("ERROR: Invalid dispatch benchmarks")
104106
return None
105107
for d_ in os.listdir(bench_dir):
106-
in_dispatches = False
107-
for dispatch in dispatch_list:
108-
if str(dispatch) in d_:
109-
in_dispatches = True
110-
if all_dispatches or in_dispatches:
111-
for f_ in os.listdir(f"{bench_dir}/{d_}"):
112-
113-
if "benchmark.mlir" in f_:
114-
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
115-
module = dispatch_file.read()
116-
dispatch_file.close()
117-
118-
flatbuffer_blob = ireec.compile_str(
119-
module, target_backends=[IREE_TARGET_MAP[device]]
120-
)
121-
122-
vmfb_file = open(
123-
f"{bench_dir}/{d_}/{d_}_benchmark.vmfb", "wb"
124-
)
125-
vmfb_file.write(flatbuffer_blob)
126-
vmfb_file.close()
127-
128-
config = ireert.Config(IREE_DEVICE_MAP[device])
129-
vm_module = ireert.VmModule.from_flatbuffer(
130-
config.vm_instance, flatbuffer_blob
131-
)
132-
133-
benchmark_cl = build_benchmark_args_non_tensor_input(
134-
input_file=f"{bench_dir}/{d_}/{d_}_benchmark.vmfb",
135-
device=device,
136-
inputs=(0,),
137-
mlir_dialect="linalg",
138-
function_name=vm_module.function_names[0],
139-
)
140-
141-
benchmark_bash = open(
142-
f"{bench_dir}/{d_}/{d_}_benchmark.sh", "w+"
143-
)
144-
benchmark_bash.write("#!/bin/bash\n")
145-
benchmark_bash.write(" ".join(benchmark_cl))
146-
benchmark_bash.close()
147-
148-
benchmark_data = run_benchmark_module(benchmark_cl)
149-
150-
benchmark_file = open(
151-
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
152-
)
153-
benchmark_file.write(f"DISPATCH: {d_}\n")
154-
benchmark_file.write(str(benchmark_data) + "\n")
155-
benchmark_file.write(
156-
"SHARK BENCHMARK RESULT: "
157-
+ str(1 / (benchmark_data * 0.001))
158-
+ "\n"
159-
)
160-
benchmark_file.close()
161-
162-
elif ".mlir" in f_ and "benchmark" not in f_:
163-
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
164-
module = dispatch_file.read()
165-
dispatch_file.close()
166-
167-
module = re.sub(
168-
"hal.executable private",
169-
"hal.executable public",
170-
module,
171-
)
172-
173-
flatbuffer_blob = ireec.compile_str(
174-
module,
175-
target_backends=[IREE_TARGET_MAP[device]],
176-
extra_args=["--compile-mode=hal-executable"],
177-
)
178-
179-
spirv_file = open(
180-
f"{bench_dir}/{d_}/{d_}_spirv.vmfb", "wb"
181-
)
182-
spirv_file.write(flatbuffer_blob)
183-
spirv_file.close()
108+
if os.path.isdir(f"{bench_dir}/{d_}"):
109+
in_dispatches = False
110+
for dispatch in dispatch_list:
111+
if str(dispatch) in d_:
112+
in_dispatches = True
113+
if all_dispatches or in_dispatches:
114+
for f_ in os.listdir(f"{bench_dir}/{d_}"):
115+
116+
if "benchmark.mlir" in f_:
117+
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
118+
module = dispatch_file.read()
119+
dispatch_file.close()
120+
121+
flatbuffer_blob = ireec.compile_str(
122+
module, target_backends=[IREE_TARGET_MAP[device]]
123+
)
124+
125+
vmfb_file = open(
126+
f"{bench_dir}/{d_}/{d_}_benchmark.vmfb", "wb"
127+
)
128+
vmfb_file.write(flatbuffer_blob)
129+
vmfb_file.close()
130+
131+
config = ireert.Config(IREE_DEVICE_MAP[device])
132+
vm_module = ireert.VmModule.from_flatbuffer(
133+
config.vm_instance, flatbuffer_blob
134+
)
135+
136+
benchmark_cl = build_benchmark_args_non_tensor_input(
137+
input_file=f"{bench_dir}/{d_}/{d_}_benchmark.vmfb",
138+
device=device,
139+
inputs=(0,),
140+
mlir_dialect="linalg",
141+
function_name=vm_module.function_names[0],
142+
)
143+
144+
benchmark_bash = open(
145+
f"{bench_dir}/{d_}/{d_}_benchmark.sh", "w+"
146+
)
147+
benchmark_bash.write("#!/bin/bash\n")
148+
benchmark_bash.write(" ".join(benchmark_cl))
149+
benchmark_bash.close()
150+
151+
benchmark_data = run_benchmark_module(benchmark_cl)
152+
153+
benchmark_file = open(
154+
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
155+
)
156+
benchmark_file.write(f"DISPATCH: {d_}\n")
157+
benchmark_file.write(str(benchmark_data) + "\n")
158+
benchmark_file.write(
159+
"SHARK BENCHMARK RESULT: "
160+
+ str(1 / (benchmark_data * 0.001))
161+
+ "\n"
162+
)
163+
benchmark_file.close()
164+
165+
benchmark_runtimes[d_] = 1 / (benchmark_data * 0.001)
166+
167+
elif ".mlir" in f_ and "benchmark" not in f_:
168+
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
169+
module = dispatch_file.read()
170+
dispatch_file.close()
171+
172+
module = re.sub(
173+
"hal.executable private",
174+
"hal.executable public",
175+
module,
176+
)
177+
178+
flatbuffer_blob = ireec.compile_str(
179+
module,
180+
target_backends=[IREE_TARGET_MAP[device]],
181+
extra_args=["--compile-mode=hal-executable"],
182+
)
183+
184+
spirv_file = open(
185+
f"{bench_dir}/{d_}/{d_}_spirv.vmfb", "wb"
186+
)
187+
spirv_file.write(flatbuffer_blob)
188+
spirv_file.close()
189+
190+
ordered_dispatches = [
191+
(k, v)
192+
for k, v in sorted(
193+
benchmark_runtimes.items(), key=lambda item: item[1]
194+
)
195+
][::-1]
196+
f_ = open(f"{bench_dir}/ordered-dispatches.txt", "w+")
197+
for dispatch in ordered_dispatches:
198+
f_.write(f"{dispatch[0]}: {dispatch[1]}ms\n")
199+
f_.close()
184200

185201

186202
def compile_module_to_flatbuffer(

0 commit comments

Comments
 (0)