Skip to content

Commit 9b3e31d

Browse files
committed
[Autotuner] Use cudagraph for time measurement on Nvidia hardware
1 parent 4db264a commit 9b3e31d

File tree

4 files changed

+127
-11
lines changed

4 files changed

+127
-11
lines changed

helion/_testing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Generator
1818
import unittest
1919

20-
import pytest
2120
import torch
2221
from torch.utils._pytree import tree_map
2322
import triton
@@ -267,6 +266,8 @@ def setUp(self) -> None:
267266
if not self._in_ref_eager_mode:
268267
return
269268

269+
import pytest
270+
270271
# Reset assert_close counter for this test
271272
RefEagerTestBase._assert_close_count = 0
272273
# Reset assertRaises counter for this test
@@ -361,6 +362,8 @@ def tearDown(self) -> None:
361362
super().tearDown() # type: ignore[misc]
362363
return
363364

365+
import pytest
366+
364367
try:
365368
# Exit the run_ref tracker
366369
self._run_ref_tracker.__exit__(None, None, None)

helion/autotuner/base_search.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@
4040
from triton.testing import do_bench
4141

4242
from .. import exc
43+
from .._testing import is_cuda
4344
from ..runtime.kernel import BoundKernel
4445
from ..runtime.precompile_shim import already_compiled
4546
from ..runtime.precompile_shim import make_precompiler
47+
from .bench_utils import do_bench_cudagraph_with_cache_clear
4648
from .benchmarking import interleaved_bench
4749
from .config_generation import ConfigGeneration
4850
from .config_generation import FlatConfig
@@ -325,12 +327,18 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
325327
# Accuracy check failed; reject this config
326328
return inf
327329
t1 = time.perf_counter()
328-
res = do_bench(
329-
functools.partial(fn, *self.args),
330-
return_mode="median",
331-
warmup=1, # we are already warmed up above
332-
rep=50,
333-
)
330+
kwargs = {
331+
"fn": functools.partial(fn, *self.args),
332+
"rep": 50,
333+
}
334+
if is_cuda():
335+
res = do_bench_cudagraph_with_cache_clear(**kwargs)
336+
else:
337+
res = do_bench(
338+
**kwargs,
339+
warmup=1, # we are already warmed up above
340+
return_mode="median",
341+
)
334342
t2 = time.perf_counter()
335343
assert isinstance(res, float)
336344
self.log.debug(

helion/autotuner/bench_utils.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from __future__ import annotations
2+
3+
from typing import Callable
4+
from typing import Sequence
5+
6+
import torch
7+
import triton
8+
9+
10+
def do_bench_cudagraph_with_cache_clear(
11+
fn: Callable[[], object],
12+
rep: int = 20,
13+
grad_to_none: Sequence[torch.Tensor] | None = None,
14+
) -> float:
15+
"""
16+
Clone of triton.testing.do_bench_cudagraph with explicit L2 cache clearing.
17+
Only supports calculating mean execution time.
18+
19+
Args:
20+
fn: Function to benchmark
21+
rep: Target total measurement time in milliseconds
22+
grad_to_none: Tensors whose gradients should be cleared before each measurement
23+
24+
Returns:
25+
Mean execution time in milliseconds
26+
"""
27+
# Get a cache tensor and function to zero it for L2 cache clearing
28+
cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() # type: ignore[attr-defined]
29+
clear_cache_fn = cache.zero_
30+
31+
with torch.cuda.stream(torch.cuda.Stream()):
32+
# Warmup: clear cache and run function once to ensure it's compiled
33+
clear_cache_fn()
34+
fn()
35+
36+
# Reset gradients if needed
37+
if grad_to_none is not None:
38+
for x in grad_to_none:
39+
x.detach_()
40+
x.requires_grad_(True)
41+
x.grad = None
42+
43+
# Estimate execution time
44+
start_event = torch.cuda.Event(enable_timing=True)
45+
end_event = torch.cuda.Event(enable_timing=True)
46+
start_event.record()
47+
for _ in range(5):
48+
clear_cache_fn()
49+
fn()
50+
end_event.record()
51+
torch.cuda.synchronize()
52+
estimate_ms = start_event.elapsed_time(end_event) / 5
53+
54+
# Calculate number of repetitions needed to reach target measurement time (rep)
55+
n_repeat = 1000 if estimate_ms == 0 else max(1, int(rep / estimate_ms))
56+
57+
# Create a CUDA graph for the actual kernel execution + cache clearing
58+
g = torch.cuda.CUDAGraph()
59+
with torch.cuda.graph(g):
60+
for _ in range(n_repeat):
61+
if grad_to_none is not None:
62+
for x in grad_to_none:
63+
x.grad = None
64+
clear_cache_fn()
65+
fn()
66+
torch.cuda.synchronize()
67+
68+
# Create a separate CUDA graph for just cache clearing
69+
cache_clear_graph = torch.cuda.CUDAGraph()
70+
with torch.cuda.graph(cache_clear_graph):
71+
for _ in range(n_repeat):
72+
clear_cache_fn()
73+
torch.cuda.synchronize()
74+
75+
# Measure time for cache clearing only
76+
cache_clear_start_event = torch.cuda.Event(enable_timing=True)
77+
cache_clear_end_event = torch.cuda.Event(enable_timing=True)
78+
cache_clear_start_event.record()
79+
cache_clear_graph.replay()
80+
cache_clear_end_event.record()
81+
torch.cuda.synchronize()
82+
cache_clear_time = (
83+
cache_clear_start_event.elapsed_time(cache_clear_end_event) / n_repeat
84+
)
85+
86+
# Measure total time (cache clearing + kernel execution)
87+
start_event = torch.cuda.Event(enable_timing=True)
88+
end_event = torch.cuda.Event(enable_timing=True)
89+
start_event.record()
90+
g.replay()
91+
end_event.record()
92+
torch.cuda.synchronize()
93+
total_time = start_event.elapsed_time(end_event) / n_repeat
94+
95+
# Subtract cache clearing overhead to get pure kernel execution time
96+
return total_time - cache_clear_time

test/test_debug_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from helion._testing import DEVICE
1414
from helion._testing import RefEagerTestDisabled
1515
from helion._testing import TestCase
16+
from helion._testing import is_cuda
1617
from helion._testing import skipIfCpu
1718
import helion.language as hl
1819

@@ -142,20 +143,28 @@ def test_print_repro_on_autotune_error(self):
142143
torch.manual_seed(0)
143144
x = torch.randn([128], dtype=torch.float32, device=DEVICE)
144145

145-
# Mock do_bench to fail on the second config with PTXASError (warn level)
146+
# Mock benchmark helper to fail on the second config with PTXASError (warn level)
146147
from torch._inductor.runtime.triton_compat import PTXASError
147-
from triton.testing import do_bench as original_do_bench
148+
149+
from helion.autotuner import base_search
148150

149151
call_count = [0]
150152

153+
bench_attr = (
154+
"do_bench_cudagraph_with_cache_clear" if is_cuda() else "do_bench"
155+
)
156+
157+
original_bench = getattr(base_search, bench_attr)
158+
bench_target = f"helion.autotuner.base_search.{bench_attr}"
159+
151160
def mock_do_bench(*args, **kwargs):
152161
call_count[0] += 1
153162
if call_count[0] == 2: # Fail on second config
154163
raise PTXASError("Mocked PTXAS error")
155-
return original_do_bench(*args, **kwargs)
164+
return original_bench(*args, **kwargs)
156165

157166
with self.capture_output() as output_capture:
158-
with mock.patch("helion.autotuner.base_search.do_bench", mock_do_bench):
167+
with mock.patch(bench_target, mock_do_bench):
159168
# Autotune will try both configs, second one will fail and print repro
160169
kernel.autotune([x], force=False)
161170

0 commit comments

Comments
 (0)