Skip to content

Commit 4e7321f

Browse files
authored
Adding new setting, autotune_max_generations, that allows user to set the maximum number of generations for autotuning (#796)
1 parent 31beca2 commit 4e7321f

File tree

4 files changed

+58
-5
lines changed

4 files changed

+58
-5
lines changed

docs/api/settings.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ with helion.set_default_settings(
134134
.. autoattribute:: Settings.autotune_random_seed
135135
136136
Seed used for autotuner random number generation. Defaults to ``HELION_AUTOTUNE_RANDOM_SEED`` if set, otherwise a time-based value.
137+
138+
.. autoattribute:: Settings.autotune_max_generations
139+
140+
Override the default number of generations set for Pattern Search and Differential Evolution Search autotuning algorithms with HELION_AUTOTUNE_MAX_GENERATIONS=N or @helion.kernel(autotune_max_generations=N).
141+
142+
Lower values result in faster autotuning but may find less optimal configurations.
137143
```
138144

139145
### Debugging and Development

helion/autotuner/differential_evolution.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ def __init__(
2727
kernel: BoundKernel,
2828
args: Sequence[object],
2929
population_size: int = 40,
30-
num_generations: int = 40,
30+
max_generations: int = 40,
3131
crossover_rate: float = 0.8,
3232
immediate_update: bool | None = None,
3333
) -> None:
3434
super().__init__(kernel, args)
3535
if immediate_update is None:
3636
immediate_update = not kernel.settings.autotune_precompile
3737
self.population_size = population_size
38-
self.num_generations = num_generations
38+
self.max_generations = max_generations
3939
self.crossover_rate = crossover_rate
4040
self.immediate_update = immediate_update
4141

@@ -90,11 +90,11 @@ def _autotune(self) -> Config:
9090
self.log(
9191
lambda: (
9292
f"Starting DifferentialEvolutionSearch with population={self.population_size}, "
93-
f"generations={self.num_generations}, crossover_rate={self.crossover_rate}"
93+
f"generations={self.max_generations}, crossover_rate={self.crossover_rate}"
9494
)
9595
)
9696
self.initial_two_generations()
97-
for i in range(2, self.num_generations):
97+
for i in range(2, self.max_generations):
9898
replaced = self.evolve_population()
9999
self.log(f"Generation {i}: replaced={replaced}", self.statistics)
100100
self.rebenchmark_population()

helion/runtime/settings.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ def default_autotuner_fn(
7373
f"Unknown HELION_AUTOTUNER value: {autotuner_name}, valid options are: "
7474
f"{', '.join(search_algorithms.keys())}"
7575
)
76+
77+
# Use autotune_max_generations from settings if kwarg is not explicitly provided
78+
if autotuner_name in ("PatternSearch", "DifferentialEvolutionSearch"):
79+
if bound_kernel.settings.autotune_max_generations is not None:
80+
kwargs.setdefault(
81+
"max_generations", bound_kernel.settings.autotune_max_generations
82+
)
83+
7684
return LocalAutotuneCache(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType]
7785

7886

@@ -83,6 +91,13 @@ def _get_autotune_random_seed() -> int:
8391
return int(time.time() * 1000) % 2**32
8492

8593

94+
def _get_autotune_max_generations() -> int | None:
95+
value = os.environ.get("HELION_AUTOTUNE_MAX_GENERATIONS")
96+
if value is not None:
97+
return int(value)
98+
return None
99+
100+
86101
@dataclasses.dataclass
87102
class _Settings:
88103
# see __slots__ below for the doc strings that show up in help(Settings)
@@ -114,6 +129,9 @@ class _Settings:
114129
autotune_progress_bar: bool = (
115130
os.environ.get("HELION_AUTOTUNE_PROGRESS_BAR", "1") == "1"
116131
)
132+
autotune_max_generations: int | None = dataclasses.field(
133+
default_factory=_get_autotune_max_generations
134+
)
117135
print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1"
118136
force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1"
119137
autotune_config_overrides: dict[str, object] = dataclasses.field(
@@ -149,6 +167,7 @@ class Settings(_Settings):
149167
"autotune_accuracy_check": "If True, validate candidate configs against the baseline kernel output before accepting them during autotuning.",
150168
"autotune_rebenchmark_threshold": "If a config is within threshold*best_perf, re-benchmark it to avoid outliers. Default is 1.5x. Set to <1 to disable.",
151169
"autotune_progress_bar": "If True, show progress bar during autotuning. Default is True. Set HELION_AUTOTUNE_PROGRESS_BAR=0 to disable.",
170+
"autotune_max_generations": "Override the maximum number of generations for Pattern Search and Differential Evolution Search autotuning algorithms with HELION_AUTOTUNE_MAX_GENERATIONS=N or @helion.kernel(autotune_max_generations=N).",
152171
"print_output_code": "If True, print the output code of the kernel to stderr.",
153172
"force_autotune": "If True, force autotuning even if a config is provided.",
154173
"autotune_config_overrides": "Dictionary of config key/value pairs forced during autotuning.",

test/test_autotuner.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def test_differential_evolution_search(self):
207207
bound_kernel = examples_matmul.bind(args)
208208
random.seed(123)
209209
best = DifferentialEvolutionSearch(
210-
bound_kernel, args, 5, num_generations=3
210+
bound_kernel, args, 5, max_generations=3
211211
).autotune()
212212
fn = bound_kernel.compile_config(best)
213213
torch.testing.assert_close(fn(*args), args[0] @ args[1], rtol=1e-2, atol=1e-1)
@@ -373,6 +373,34 @@ def wrong_fn(*fn_args, **fn_kwargs):
373373
self.assertEqual(best, good_config)
374374
self.assertGreaterEqual(search.counters.get("accuracy_mismatch", 0), 1)
375375

376+
def test_max_generations(self):
377+
"""Autotuner max generation respects explicit kwargs then setting override."""
378+
379+
with patch.dict(os.environ, {"HELION_AUTOTUNER": "PatternSearch"}):
380+
381+
@helion.kernel(autotune_max_generations=1)
382+
def add(a, b):
383+
out = torch.empty_like(a)
384+
for tile in hl.tile(out.size()):
385+
out[tile] = a[tile] + b[tile]
386+
return out
387+
388+
args = (
389+
torch.randn([8], device=DEVICE),
390+
torch.randn([8], device=DEVICE),
391+
)
392+
393+
bound = add.bind(args)
394+
autotuner_factory = bound.settings.autotuner_fn
395+
396+
# Settings override defaults
397+
autotuner = autotuner_factory(bound, args)
398+
self.assertEqual(autotuner.autotuner.max_generations, 1)
399+
400+
# Explicit constructor value wins
401+
autotuner_override = autotuner_factory(bound, args, max_generations=2)
402+
self.assertEqual(autotuner_override.autotuner.max_generations, 2)
403+
376404
def test_use_default_config(self):
377405
@helion.kernel(use_default_config=True)
378406
def add(a, b):

0 commit comments

Comments
 (0)