Skip to content

Commit f8cfb5a

Browse files
authored
Refactor autotune error handling (#595)
1 parent 3d8af25 commit f8cfb5a

File tree

4 files changed

+83
-37
lines changed

4 files changed

+83
-37
lines changed

helion/autotuner/base_search.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from math import inf
1212
from multiprocessing import connection
1313
import os
14-
import re
1514
import sys
1615
import time
1716
from typing import TYPE_CHECKING
@@ -21,8 +20,6 @@
2120
if TYPE_CHECKING:
2221
from triton.runtime.jit import JITFunction
2322

24-
from torch._inductor.runtime.triton_compat import OutOfResources
25-
from torch._inductor.runtime.triton_compat import PTXASError
2623
import torch.multiprocessing as mp
2724
from triton.testing import do_bench
2825

@@ -32,6 +29,8 @@
3229
from .config_generation import ConfigGeneration
3330
from .config_generation import FlatConfig
3431
from .logger import LambdaLogger
32+
from .logger import classify_triton_exception
33+
from .logger import format_triton_compile_failure
3534

3635
if TYPE_CHECKING:
3736
from collections.abc import Sequence
@@ -44,20 +43,6 @@
4443
from ..runtime.settings import Settings
4544
from . import ConfigSpec
4645

47-
_expected_errors_regexp: re.Pattern[str] = re.compile(
48-
r"|".join(
49-
map(
50-
re.escape,
51-
[
52-
"[CUDA]: invalid argument", # CUDA Error
53-
"misaligned address", # CUDA Error
54-
"PassManager::run failed", # Triton Error
55-
"illegal memory access", # CUDA Error
56-
],
57-
)
58-
)
59-
)
60-
6146

6247
class BaseAutotuner(abc.ABC):
6348
"""
@@ -143,22 +128,15 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
143128
lambda: f"result: {res:.4f}ms (took {t1 - t0:.1f}s + {t2 - t1:.1f}s)",
144129
)
145130
return res # pyright: ignore[reportReturnType]
146-
except OutOfResources:
147-
self.log.debug("Benchmarking failed: OutOfResources")
148-
except PTXASError:
149-
self.log.warning(f"PTXASError compiling config: {config}")
150131
except Exception as e:
151-
msg = str(e)
152-
if not _expected_errors_regexp.search(msg):
132+
action = classify_triton_exception(e)
133+
if action == "raise":
153134
raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e
154-
# Surface Triton IR pass failures more prominently for easier bug reports.
155-
if "PassManager::run failed" in msg:
156-
self.log.warning(
157-
f"Triton PassManager::run failed while compiling config: {config}. Error: {e}"
158-
)
135+
if action == "warn":
136+
self.log.warning(format_triton_compile_failure(config, e))
159137
else:
160138
self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}")
161-
return inf
139+
return inf
162140

163141
def start_precompile_and_check_for_hangs(
164142
self, config: Config, fn: CompiledConfig
@@ -195,7 +173,7 @@ def extract_launcher(
195173
# Should not reach here
196174
raise RuntimeError("Expected _ExtractedLaunchArgs exception")
197175
except _ExtractedLaunchArgs as e:
198-
precompiler = make_precompiler(e.kernel)(*e.args, **e.kwargs)
176+
precompiler = make_precompiler(e.kernel, config)(*e.args, **e.kwargs)
199177
if precompiler is already_compiled:
200178
return PrecompileFuture.skip(self, config, True)
201179
process: mp.Process = ctx.Process(target=precompiler) # pyright: ignore[reportAssignmentType]
@@ -575,8 +553,8 @@ def _mark_complete(self) -> bool:
575553
if not self.started:
576554
self.start()
577555
if not process.is_alive():
578-
self.ok = True
579-
return True
556+
self.ok = process.exitcode == 0
557+
return self.ok
580558
process.terminate()
581559
process.join(10)
582560
msg = f"Timeout after {self.elapsed:.0f}s compiling {self.config}"

helion/autotuner/logger.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,18 @@
22

33
import itertools
44
import logging
5+
import re
56
import sys
67
import time
8+
from typing import TYPE_CHECKING
79
from typing import Callable
10+
from typing import Literal
11+
12+
from torch._inductor.runtime.triton_compat import OutOfResources
13+
from torch._inductor.runtime.triton_compat import PTXASError
14+
15+
if TYPE_CHECKING:
16+
from ..runtime.config import Config
817

918

1019
class LambdaLogger:
@@ -81,3 +90,52 @@ def _maybe_call(fn: Callable[[], str] | str) -> str:
8190
if callable(fn):
8291
return fn()
8392
return fn
93+
94+
95+
def format_triton_compile_failure(config: Config, err: BaseException) -> str:
96+
return (
97+
"Triton compile failed. This likely indicates a bug in Triton. "
98+
"Skipping failing config.\n"
99+
f"Config: {config!r}\n"
100+
f"Error: {type(err).__name__}: {err}"
101+
)
102+
103+
104+
# Common logic to decide how to surface Triton errors
105+
_EXPECTED_TRITON_ERRORS_RE: re.Pattern[str] = re.compile(
106+
"|".join(
107+
map(
108+
re.escape,
109+
[
110+
"[CUDA]: invalid argument", # CUDA Error
111+
"misaligned address", # CUDA Error
112+
"illegal memory access", # CUDA Error
113+
"PassManager::run failed", # Triton Error
114+
],
115+
)
116+
)
117+
)
118+
119+
120+
def classify_triton_exception(err: BaseException) -> Literal["raise", "warn", "debug"]:
121+
"""
122+
Classify a Triton compile/runtime exception during autotuning.
123+
124+
Returns one of:
125+
- "raise": unexpected error, caller should raise
126+
- "warn": notable expected error (e.g., PassManager pipeline failure)
127+
- "debug": benign/expected error; caller can log at debug level
128+
"""
129+
# Known exception types first
130+
if isinstance(err, OutOfResources):
131+
return "debug"
132+
# Different PTXASError classes may be raised from different modules; match by name as well
133+
if isinstance(err, PTXASError) or err.__class__.__name__ == "PTXASError":
134+
return "warn"
135+
136+
msg = str(err)
137+
if "PassManager::run failed" in msg:
138+
return "warn"
139+
if _EXPECTED_TRITON_ERRORS_RE.search(msg):
140+
return "debug"
141+
return "raise"

helion/runtime/precompile_shim.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
from __future__ import annotations
22

33
import os
4+
import sys
45
from typing import TYPE_CHECKING
56

7+
from ..autotuner.logger import classify_triton_exception
8+
from ..autotuner.logger import format_triton_compile_failure
9+
610
if TYPE_CHECKING:
711
from collections.abc import Callable
812

913
from triton.runtime.jit import JITFunction
1014

15+
from .config import Config
16+
1117

12-
def make_precompiler(fn: JITFunction[object]) -> Callable[..., Callable[[], None]]:
18+
def make_precompiler(
19+
fn: JITFunction[object], config: Config
20+
) -> Callable[..., Callable[[], None]]:
1321
from triton.runtime.jit import find_paths_if
1422
from triton.runtime.jit import get_iterable_path
1523

@@ -48,14 +56,16 @@ def _make_precompiler(*args: object, **kwargs: object) -> Callable[[], None]:
4856
def finish_it() -> None:
4957
src = fn.ASTSource(fn, signature, constexprs, attrs)
5058
# here we update the cache so if this is called in the parent we skip a extra compile
51-
from triton.runtime.errors import PTXASError
5259

5360
try:
5461
kernel_cache[key] = fn.compile(
5562
src, target=target, options=options.__dict__
5663
)
57-
except PTXASError:
58-
return
64+
except Exception as e:
65+
action = classify_triton_exception(e)
66+
if action != "debug":
67+
print(format_triton_compile_failure(config, e), file=sys.stderr)
68+
sys.exit(1)
5969

6070
return finish_it
6171

test/test_autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_random_search(self):
131131
torch.randn([512, 512], device=DEVICE),
132132
)
133133
bound_kernel = examples_matmul.bind(args)
134-
best = RandomSearch(bound_kernel, args, 5).autotune()
134+
best = RandomSearch(bound_kernel, args, 10).autotune()
135135
fn = bound_kernel.compile_config(best)
136136
torch.testing.assert_close(fn(*args), args[0] @ args[1], rtol=1e-2, atol=1e-1)
137137

0 commit comments

Comments
 (0)