Skip to content

Commit ac88d91

Browse files
committed
[Autotuner] Use cudagraph for time measurement on Nvidia hardware
1 parent 4db264a commit ac88d91

File tree

4 files changed

+164
-11
lines changed

4 files changed

+164
-11
lines changed

helion/_testing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Generator
1818
import unittest
1919

20-
import pytest
2120
import torch
2221
from torch.utils._pytree import tree_map
2322
import triton
@@ -267,6 +266,8 @@ def setUp(self) -> None:
267266
if not self._in_ref_eager_mode:
268267
return
269268

269+
import pytest
270+
270271
# Reset assert_close counter for this test
271272
RefEagerTestBase._assert_close_count = 0
272273
# Reset assertRaises counter for this test
@@ -361,6 +362,8 @@ def tearDown(self) -> None:
361362
super().tearDown() # type: ignore[misc]
362363
return
363364

365+
import pytest
366+
364367
try:
365368
# Exit the run_ref tracker
366369
self._run_ref_tracker.__exit__(None, None, None)

helion/autotuner/base_search.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@
4040
from triton.testing import do_bench
4141

4242
from .. import exc
43+
from .._testing import is_cuda
4344
from ..runtime.kernel import BoundKernel
4445
from ..runtime.precompile_shim import already_compiled
4546
from ..runtime.precompile_shim import make_precompiler
47+
from .bench_utils import do_bench_cudagraph_with_cache_clear
4648
from .benchmarking import interleaved_bench
4749
from .config_generation import ConfigGeneration
4850
from .config_generation import FlatConfig
@@ -325,12 +327,18 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
325327
# Accuracy check failed; reject this config
326328
return inf
327329
t1 = time.perf_counter()
328-
res = do_bench(
329-
functools.partial(fn, *self.args),
330-
return_mode="median",
331-
warmup=1, # we are already warmed up above
332-
rep=50,
333-
)
330+
kwargs = {
331+
"fn": functools.partial(fn, *self.args),
332+
"return_mode": "median",
333+
"rep": 50,
334+
}
335+
if is_cuda():
336+
res = do_bench_cudagraph_with_cache_clear(**kwargs)
337+
else:
338+
res = do_bench(
339+
**kwargs,
340+
warmup=1, # we are already warmed up above
341+
)
334342
t2 = time.perf_counter()
335343
assert isinstance(res, float)
336344
self.log.debug(

helion/autotuner/bench_utils.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from __future__ import annotations
2+
3+
from typing import Callable
4+
from typing import Sequence
5+
6+
import torch
7+
import triton
8+
9+
10+
def _summarize_statistics(
11+
times: torch.Tensor,
12+
quantiles: Sequence[float] | None,
13+
return_mode: str,
14+
) -> float | list[float]:
15+
if quantiles is not None:
16+
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
17+
if len(ret) == 1:
18+
ret = ret[0]
19+
return ret
20+
if return_mode == "all":
21+
return times.tolist()
22+
return getattr(torch, return_mode)(times).item()
23+
24+
25+
def do_bench_cudagraph_with_cache_clear(
26+
fn: Callable[[], object],
27+
rep: int = 20,
28+
grad_to_none: Sequence[torch.Tensor] | None = None,
29+
quantiles: Sequence[float] | None = None,
30+
return_mode: str = "mean",
31+
) -> float | list[float]:
32+
"""
33+
Clone of triton.testing.do_bench_cudagraph with explicit L2 cache clearing.
34+
35+
NOTE: We will switch to use triton.testing.do_bench_cudagraph once it has explicit L2 cache clearing.
36+
37+
Args:
38+
fn: Function to benchmark
39+
rep: Target total measurement time in milliseconds
40+
grad_to_none: Tensors whose gradients should be cleared before each measurement
41+
quantiles: Quantiles to compute from the timing measurements
42+
return_mode: "min", "max", "mean", "median", or "all"
43+
44+
Returns:
45+
Timing measurement(s) in milliseconds according to return_mode
46+
"""
47+
assert return_mode in ["min", "max", "mean", "median", "all"]
48+
49+
# Get a cache tensor and function to zero it for L2 cache clearing
50+
cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() # type: ignore[attr-defined]
51+
clear_cache_fn = cache.zero_
52+
53+
# Use a separate CUDA stream for all benchmark operations
54+
with torch.cuda.stream(torch.cuda.Stream()):
55+
# Warmup: clear cache and run function once to ensure it's compiled
56+
clear_cache_fn()
57+
fn()
58+
59+
# Reset gradients if needed (for autograd-enabled benchmarks)
60+
if grad_to_none is not None:
61+
for x in grad_to_none:
62+
x.detach_()
63+
x.requires_grad_(True)
64+
x.grad = None
65+
66+
# Estimate execution time
67+
start_event = torch.cuda.Event(enable_timing=True)
68+
end_event = torch.cuda.Event(enable_timing=True)
69+
start_event.record()
70+
for _ in range(5):
71+
clear_cache_fn()
72+
fn()
73+
end_event.record()
74+
torch.cuda.synchronize()
75+
estimate_ms = start_event.elapsed_time(end_event) / 5
76+
77+
# Calculate number of repetitions needed to reach target measurement time (rep)
78+
n_repeat = 1000 if estimate_ms == 0 else max(1, int(rep / estimate_ms))
79+
80+
# Create a CUDA graph for the actual kernel execution + cache clearing
81+
g = torch.cuda.CUDAGraph()
82+
with torch.cuda.graph(g):
83+
for _ in range(n_repeat):
84+
if grad_to_none is not None:
85+
for x in grad_to_none:
86+
x.grad = None
87+
clear_cache_fn()
88+
fn()
89+
torch.cuda.synchronize()
90+
91+
# Create a separate CUDA graph for just cache clearing
92+
cache_clear_graph = torch.cuda.CUDAGraph()
93+
with torch.cuda.graph(cache_clear_graph):
94+
for _ in range(n_repeat):
95+
clear_cache_fn()
96+
torch.cuda.synchronize()
97+
98+
# Run multiple retries to get stable measurements
99+
n_retries = 10
100+
cache_clear_times = []
101+
total_times = []
102+
for _ in range(n_retries):
103+
# Measure time for cache clearing only
104+
cache_clear_start_event = torch.cuda.Event(enable_timing=True)
105+
cache_clear_end_event = torch.cuda.Event(enable_timing=True)
106+
cache_clear_start_event.record()
107+
cache_clear_graph.replay()
108+
cache_clear_end_event.record()
109+
torch.cuda.synchronize()
110+
cache_clear_times.append(
111+
cache_clear_start_event.elapsed_time(cache_clear_end_event) / n_repeat
112+
)
113+
114+
# Measure total time (cache clearing + kernel execution)
115+
start_event = torch.cuda.Event(enable_timing=True)
116+
end_event = torch.cuda.Event(enable_timing=True)
117+
start_event.record()
118+
g.replay()
119+
end_event.record()
120+
torch.cuda.synchronize()
121+
total_times.append(start_event.elapsed_time(end_event) / n_repeat)
122+
123+
# Subtract cache clearing overhead to get pure kernel execution time
124+
all_kernel_times = []
125+
for total_time, cache_clear_time in zip(
126+
total_times, cache_clear_times, strict=True
127+
):
128+
kernel_time = total_time - cache_clear_time
129+
all_kernel_times.append(kernel_time)
130+
131+
# Compute the requested statistic
132+
times = torch.tensor(all_kernel_times, dtype=torch.float)
133+
return _summarize_statistics(times, quantiles, return_mode)

test/test_debug_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from helion._testing import DEVICE
1414
from helion._testing import RefEagerTestDisabled
1515
from helion._testing import TestCase
16+
from helion._testing import is_cuda
1617
from helion._testing import skipIfCpu
1718
import helion.language as hl
1819

@@ -142,20 +143,28 @@ def test_print_repro_on_autotune_error(self):
142143
torch.manual_seed(0)
143144
x = torch.randn([128], dtype=torch.float32, device=DEVICE)
144145

145-
# Mock do_bench to fail on the second config with PTXASError (warn level)
146+
# Mock benchmark helper to fail on the second config with PTXASError (warn level)
146147
from torch._inductor.runtime.triton_compat import PTXASError
147-
from triton.testing import do_bench as original_do_bench
148+
149+
from helion.autotuner import base_search
148150

149151
call_count = [0]
150152

153+
bench_attr = (
154+
"do_bench_cudagraph_with_cache_clear" if is_cuda() else "do_bench"
155+
)
156+
157+
original_bench = getattr(base_search, bench_attr)
158+
bench_target = f"helion.autotuner.base_search.{bench_attr}"
159+
151160
def mock_do_bench(*args, **kwargs):
152161
call_count[0] += 1
153162
if call_count[0] == 2: # Fail on second config
154163
raise PTXASError("Mocked PTXAS error")
155-
return original_do_bench(*args, **kwargs)
164+
return original_bench(*args, **kwargs)
156165

157166
with self.capture_output() as output_capture:
158-
with mock.patch("helion.autotuner.base_search.do_bench", mock_do_bench):
167+
with mock.patch(bench_target, mock_do_bench):
159168
# Autotune will try both configs, second one will fail and print repro
160169
kernel.autotune([x], force=False)
161170

0 commit comments

Comments
 (0)