|
11 | 11 | from math import inf |
12 | 12 | from multiprocessing import connection |
13 | 13 | import os |
14 | | -import re |
15 | 14 | import sys |
16 | 15 | import time |
17 | 16 | from typing import TYPE_CHECKING |
|
21 | 20 | if TYPE_CHECKING: |
22 | 21 | from triton.runtime.jit import JITFunction |
23 | 22 |
|
24 | | -from torch._inductor.runtime.triton_compat import OutOfResources |
25 | | -from torch._inductor.runtime.triton_compat import PTXASError |
26 | 23 | import torch.multiprocessing as mp |
27 | 24 | from triton.testing import do_bench |
28 | 25 |
|
|
32 | 29 | from .config_generation import ConfigGeneration |
33 | 30 | from .config_generation import FlatConfig |
34 | 31 | from .logger import LambdaLogger |
| 32 | +from .logger import classify_triton_exception |
| 33 | +from .logger import format_triton_compile_failure |
35 | 34 |
|
36 | 35 | if TYPE_CHECKING: |
37 | 36 | from collections.abc import Sequence |
|
44 | 43 | from ..runtime.settings import Settings |
45 | 44 | from . import ConfigSpec |
46 | 45 |
|
47 | | -_expected_errors_regexp: re.Pattern[str] = re.compile( |
48 | | - r"|".join( |
49 | | - map( |
50 | | - re.escape, |
51 | | - [ |
52 | | - "[CUDA]: invalid argument", # CUDA Error |
53 | | - "misaligned address", # CUDA Error |
54 | | - "PassManager::run failed", # Triton Error |
55 | | - "illegal memory access", # CUDA Error |
56 | | - ], |
57 | | - ) |
58 | | - ) |
59 | | -) |
60 | | - |
61 | 46 |
|
62 | 47 | class BaseAutotuner(abc.ABC): |
63 | 48 | """ |
@@ -143,22 +128,15 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: |
143 | 128 | lambda: f"result: {res:.4f}ms (took {t1 - t0:.1f}s + {t2 - t1:.1f}s)", |
144 | 129 | ) |
145 | 130 | return res # pyright: ignore[reportReturnType] |
146 | | - except OutOfResources: |
147 | | - self.log.debug("Benchmarking failed: OutOfResources") |
148 | | - except PTXASError: |
149 | | - self.log.warning(f"PTXASError compiling config: {config}") |
150 | 131 | except Exception as e: |
151 | | - msg = str(e) |
152 | | - if not _expected_errors_regexp.search(msg): |
| 132 | + action = classify_triton_exception(e) |
| 133 | + if action == "raise": |
153 | 134 | raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e |
154 | | - # Surface Triton IR pass failures more prominently for easier bug reports. |
155 | | - if "PassManager::run failed" in msg: |
156 | | - self.log.warning( |
157 | | - f"Triton PassManager::run failed while compiling config: {config}. Error: {e}" |
158 | | - ) |
| 135 | + if action == "warn": |
| 136 | + self.log.warning(format_triton_compile_failure(config, e)) |
159 | 137 | else: |
160 | 138 | self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}") |
161 | | - return inf |
| 139 | + return inf |
162 | 140 |
|
163 | 141 | def start_precompile_and_check_for_hangs( |
164 | 142 | self, config: Config, fn: CompiledConfig |
@@ -195,7 +173,7 @@ def extract_launcher( |
195 | 173 | # Should not reach here |
196 | 174 | raise RuntimeError("Expected _ExtractedLaunchArgs exception") |
197 | 175 | except _ExtractedLaunchArgs as e: |
198 | | - precompiler = make_precompiler(e.kernel)(*e.args, **e.kwargs) |
| 176 | + precompiler = make_precompiler(e.kernel, config)(*e.args, **e.kwargs) |
199 | 177 | if precompiler is already_compiled: |
200 | 178 | return PrecompileFuture.skip(self, config, True) |
201 | 179 | process: mp.Process = ctx.Process(target=precompiler) # pyright: ignore[reportAssignmentType] |
@@ -575,8 +553,8 @@ def _mark_complete(self) -> bool: |
575 | 553 | if not self.started: |
576 | 554 | self.start() |
577 | 555 | if not process.is_alive(): |
578 | | - self.ok = True |
579 | | - return True |
| 556 | + self.ok = process.exitcode == 0 |
| 557 | + return self.ok |
580 | 558 | process.terminate() |
581 | 559 | process.join(10) |
582 | 560 | msg = f"Timeout after {self.elapsed:.0f}s compiling {self.config}" |
|
0 commit comments