Skip to content

Commit c4126b4

Browse files
authored
Add torch compile to benchmark (#545)
1 parent d2bf84d commit c4126b4

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

benchmarks/run.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ class RunResult:
4444
device: str
4545
shape: list[str]
4646
triton_speedup: list[float]
47-
helion_speedup: list[float]
4847
triton_accuracy: list[float]
48+
torch_compile_speedup: list[float]
49+
torch_compile_accuracy: list[float]
50+
helion_speedup: list[float]
4951
helion_accuracy: list[float]
5052

5153

@@ -539,9 +541,11 @@ def process_result(
539541

540542
shape = []
541543
triton_speedup = []
542-
helion_speedup = []
543544
triton_accuracy = []
545+
helion_speedup = []
544546
helion_accuracy = []
547+
torch_compile_speedup = []
548+
torch_compile_accuracy = []
545549
for row in lines[1:]:
546550
row_data = row.strip().split(";")
547551
if row_data[0] == "average":
@@ -554,6 +558,10 @@ def process_result(
554558
triton_speedup.append(float(item))
555559
elif name.startswith("triton") and name.endswith("-accuracy"):
556560
triton_accuracy.append(float(item))
561+
elif name.startswith("torch_compile") and name.endswith("-speedup"):
562+
torch_compile_speedup.append(float(item))
563+
elif name.startswith("torch_compile") and name.endswith("-accuracy"):
564+
torch_compile_accuracy.append(float(item))
557565
elif name.startswith("helion") and name.endswith("-speedup"):
558566
helion_speedup.append(float(item))
559567
elif name.startswith("helion") and name.endswith("-accuracy"):
@@ -567,8 +575,10 @@ def process_result(
567575
device=get_device_name(),
568576
shape=shape,
569577
triton_speedup=triton_speedup,
570-
helion_speedup=helion_speedup,
571578
triton_accuracy=triton_accuracy,
579+
torch_compile_speedup=torch_compile_speedup,
580+
torch_compile_accuracy=torch_compile_accuracy,
581+
helion_speedup=helion_speedup,
572582
helion_accuracy=helion_accuracy,
573583
)
574584
)
@@ -582,8 +592,10 @@ def write_results_to_json(output: str, results: list[RunResult]) -> None:
582592
for result in results:
583593
for metric_name in [
584594
"triton_speedup",
585-
"helion_speedup",
586595
"triton_accuracy",
596+
"torch_compile_speedup",
597+
"torch_compile_accuracy",
598+
"helion_speedup",
587599
"helion_accuracy",
588600
]:
589601
records.append(

0 commit comments

Comments
 (0)