Skip to content

Commit 536d1f8

Browse files
committed
[RFC] Implement basic on disk caching
stack-info: PR: #336, branch: oulgen/stack/26
1 parent 83f3253 commit 536d1f8

File tree

8 files changed

+369
-8
lines changed

8 files changed

+369
-8
lines changed

helion/_testing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import collections
4+
import contextlib
45
import importlib
56
import inspect
67
import operator
@@ -15,6 +16,7 @@
1516
import torch
1617
from triton.testing import do_bench
1718

19+
from ._utils import counters
1820
from .runtime.config import Config
1921
from helion._compat import get_tensor_descriptor_fn_name
2022

@@ -291,6 +293,20 @@ def tearDownClass(cls) -> None:
291293
super().tearDownClass()
292294
del cls._expected_journal
293295

296+
def setUp(self) -> None:
297+
super().setUp()
298+
self._test_stack = contextlib.ExitStack()
299+
300+
from torch._inductor.utils import fresh_cache
301+
302+
self._test_stack.enter_context(fresh_cache())
303+
304+
counters.clear()
305+
306+
def tearDown(self) -> None:
307+
super().tearDown()
308+
self._test_stack.close()
309+
294310
def assertExpectedJournal(self, value: str) -> None:
295311
"""
296312
Assert that the given value matches the expected output stored in <testfile>.expected.

helion/_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from __future__ import annotations
2+
3+
import collections
4+
5+
counters: collections.defaultdict[str, collections.Counter[str]] = (
6+
collections.defaultdict(collections.Counter)
7+
)

helion/autotuner/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@
99
DifferentialEvolutionSearch as DifferentialEvolutionSearch,
1010
)
1111
from .finite_search import FiniteSearch as FiniteSearch
12+
from .local_cache import LocalAutotuneCache as LocalAutotuneCache
13+
from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache
1214
from .random_search import RandomSearch as RandomSearch

helion/autotuner/base_cache.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from __future__ import annotations
2+
3+
import abc
4+
import dataclasses
5+
import functools
6+
import hashlib
7+
import logging
8+
import os
9+
from typing import TYPE_CHECKING
10+
from typing import Hashable
11+
from typing import Sequence
12+
13+
from .._utils import counters
14+
15+
if TYPE_CHECKING:
16+
from ..runtime.config import Config
17+
from ..runtime.kernel import BoundKernel
18+
from .base_search import BaseSearch
19+
20+
log: logging.Logger = logging.getLogger(__name__)
21+
22+
23+
@functools.cache
24+
def helion_key() -> str:
25+
from torch._inductor.codecache import build_code_hash
26+
27+
here = os.path.abspath(__file__)
28+
helion_path = os.path.dirname(os.path.dirname(here))
29+
30+
combined_hash = hashlib.sha256()
31+
build_code_hash([helion_path], "", combined_hash)
32+
return combined_hash.hexdigest()
33+
34+
35+
@functools.cache
36+
def torch_key_wrapper() -> str:
37+
from torch._inductor.codecache import torch_key
38+
39+
return torch_key().hex()
40+
41+
42+
@functools.cache
43+
def triton_key_wrapper() -> str:
44+
from torch._inductor.runtime.triton_compat import triton_key
45+
46+
return triton_key()
47+
48+
49+
class CacheKeyBase:
50+
"""
51+
Base class to provide utility functions to all cache key dataclasses
52+
"""
53+
54+
def stable_hash(self) -> str:
55+
return hashlib.sha256(repr(self).encode("utf-8")).hexdigest()
56+
57+
58+
@dataclasses.dataclass(frozen=True)
59+
class BoundKernelInMemoryCacheKey(CacheKeyBase):
60+
"""
61+
Default in memory cache key.
62+
63+
This key includes:
64+
65+
specialization_key: Information about all kernel inputs.
66+
For tensors this means their device, shape, size etc.
67+
extra_results: Information regarding `hl.specialize` decisions
68+
"""
69+
70+
specialization_key: tuple[Hashable, ...]
71+
extra_results: tuple[Hashable, ...]
72+
73+
74+
@dataclasses.dataclass(frozen=True)
75+
class LooseAutotuneCacheKey(BoundKernelInMemoryCacheKey):
76+
"""
77+
Autotune Cache key to use for most use cases.
78+
79+
This key includes (in addition to BoundKernelInMemoryCacheKey):
80+
81+
kernel_source_hash: Hash of source code of input Helion kernel
82+
hardware: Hardware of the input device
83+
runtime_name: Version of the cuda/rocm arch
84+
"""
85+
86+
kernel_source_hash: str
87+
hardware: str
88+
runtime_name: str
89+
90+
def stable_hash(self) -> str:
91+
return hashlib.sha256(repr(self).encode("utf-8")).hexdigest()
92+
93+
94+
@dataclasses.dataclass(frozen=True)
95+
class StrictAutotuneCacheKey(LooseAutotuneCacheKey):
96+
"""
97+
Autotune Cache key to use for utmost strictness in terms of re-autotuning
98+
when library source code changes.
99+
100+
This key includes (in addition to StrictAutotuneCacheKey):
101+
102+
helion_key: Hash of source code of Helion
103+
torch_key: Hash of source code of PyTorch
104+
triton_key: Hash of source code of Triton
105+
"""
106+
107+
helion_key: str = dataclasses.field(default_factory=helion_key)
108+
torch_key: str = dataclasses.field(default_factory=torch_key_wrapper)
109+
triton_key: str = dataclasses.field(default_factory=triton_key_wrapper)
110+
111+
112+
class AutotuneCacheBase(abc.ABC):
113+
"""
114+
Abstract base class that all autotune caches need to implement.
115+
Any user defined cache will need to extend this class, and
116+
provide implementations for get and put methods.
117+
"""
118+
119+
def __init__(
120+
self, kernel: BoundKernel, args: Sequence[object], autotuner: BaseSearch
121+
) -> None:
122+
self.autotuner = autotuner
123+
self.kernel = kernel
124+
self.args = args
125+
126+
@abc.abstractmethod
127+
def get(self) -> Config | None:
128+
raise NotImplementedError
129+
130+
@abc.abstractmethod
131+
def put(self, config: Config) -> None:
132+
raise NotImplementedError
133+
134+
def autotune(self) -> Config:
135+
if (config := self.get()) is not None:
136+
counters["autotune"]["cache_hit"] += 1
137+
log.debug("cache hit: %s", str(config))
138+
return config
139+
140+
counters["autotune"]["cache_miss"] += 1
141+
log.debug("cache miss")
142+
143+
config = self.autotuner.autotune()
144+
145+
self.put(config)
146+
counters["autotune"]["cache_put"] += 1
147+
log.debug("cache put: %s", str(config))
148+
149+
return config

helion/autotuner/local_cache.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from __future__ import annotations
2+
3+
import hashlib
4+
import inspect
5+
import logging
6+
import os
7+
from pathlib import Path
8+
import textwrap
9+
from typing import TYPE_CHECKING
10+
from typing import Sequence
11+
12+
import torch
13+
14+
from ..runtime.config import Config
15+
from .base_cache import AutotuneCacheBase
16+
from .base_cache import LooseAutotuneCacheKey
17+
from .base_cache import StrictAutotuneCacheKey
18+
19+
if TYPE_CHECKING:
20+
from ..runtime.kernel import BoundKernel
21+
from .base_search import BaseSearch
22+
23+
log: logging.Logger = logging.getLogger(__name__)
24+
25+
26+
class LocalAutotuneCache(AutotuneCacheBase):
27+
"""
28+
This class implements the local autotune cache, storing the
29+
best config artifact on the local file system either by default
30+
on torch's cache directory, or at a user specified HELION_CACHE_DIR
31+
directory.
32+
It uses the LooseAutotuneCacheKey implementation for the cache key
33+
which takes into account device and source code properties, but does
34+
not account for library level code changes such as Triton, Helion or
35+
PyTorch. Use StrictLocalAutotuneCache to consider these properties.
36+
"""
37+
38+
def __init__(
39+
self, kernel: BoundKernel, args: Sequence[object], autotuner: BaseSearch
40+
) -> None:
41+
super().__init__(kernel, args, autotuner)
42+
self.key = self._generate_key()
43+
44+
def _generate_key(self) -> LooseAutotuneCacheKey:
45+
in_memory_cache_key = self.kernel.kernel._create_bound_kernel_cache_key(
46+
self.kernel,
47+
tuple(self.args),
48+
self.kernel.kernel.specialization_key(self.args),
49+
)
50+
kernel_source = textwrap.dedent(inspect.getsource(self.kernel.kernel.fn))
51+
kernel_source_hash = hashlib.sha256(kernel_source.encode("utf-8")).hexdigest()
52+
53+
hardware = None
54+
runtime_name = None
55+
56+
for arg in self.args:
57+
if isinstance(arg, torch.Tensor):
58+
device_properties = torch.cuda.get_device_properties(arg.device)
59+
if torch.version.cuda is not None: # pyright: ignore[reportAttributeAccessIssue]
60+
hardware = device_properties.name
61+
runtime_name = torch.version.cuda # pyright: ignore[reportAttributeAccessIssue]
62+
else:
63+
hardware = device_properties.gcnArchName
64+
runtime_name = torch.version.hip # pyright: ignore[reportAttributeAccessIssue]
65+
66+
assert hardware is not None and runtime_name is not None
67+
return LooseAutotuneCacheKey(
68+
specialization_key=in_memory_cache_key.specialization_key,
69+
extra_results=in_memory_cache_key.extra_results,
70+
kernel_source_hash=kernel_source_hash,
71+
hardware=hardware,
72+
runtime_name=runtime_name,
73+
)
74+
75+
def _get_local_cache_path(self) -> Path:
76+
if (user_path := os.environ.get("HELION_CACHE_DIR", None)) is not None:
77+
cache_path = Path(user_path)
78+
else:
79+
from torch._inductor.runtime.cache_dir_utils import (
80+
cache_dir, # pyright: ignore[reportPrivateImportUsage]
81+
)
82+
83+
cache_path = Path(cache_dir()) / "helion"
84+
85+
return cache_path / f"{self.key.stable_hash()}.best_config"
86+
87+
def get(self) -> Config | None:
88+
path = self._get_local_cache_path()
89+
try:
90+
return Config.load(path)
91+
except Exception:
92+
return None
93+
94+
def put(self, config: Config) -> None:
95+
path = self._get_local_cache_path()
96+
config.save(path)
97+
98+
99+
class StrictLocalAutotuneCache(LocalAutotuneCache):
100+
"""
101+
Stricter implementation of the local autotune cache, which takes into
102+
account library level code changes such as Triton, Helion or PyTorch.
103+
"""
104+
105+
def _generate_key(self) -> StrictAutotuneCacheKey:
106+
loose_key = super()._generate_key()
107+
return StrictAutotuneCacheKey(**vars(loose_key))

helion/runtime/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def from_json(cls, json_str: str) -> Config:
118118

119119
def save(self, path: str | Path) -> None:
120120
"""Save the config to a JSON file."""
121+
Path(path).parent.mkdir(parents=True, exist_ok=True)
121122
Path(path).write_text(self.to_json())
122123

123124
@classmethod

helion/runtime/kernel.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from torch._guards import Source
4646

4747
from ..autotuner import ConfigSpec
48+
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey
4849

4950
ConfigLike = Config | dict[str, object]
5051

@@ -53,12 +54,6 @@
5354
CompiledConfig = Callable[..., _R]
5455

5556

56-
@dataclasses.dataclass(frozen=True)
57-
class BoundKernelInMemoryCacheKey:
58-
specialization_key: tuple[Hashable, ...]
59-
extra_results: tuple[Hashable, ...]
60-
61-
6257
class Kernel(Generic[_R]):
6358
def __init__(
6459
self,
@@ -114,6 +109,8 @@ def __init__(
114109
def _get_bound_kernel_cache_key(
115110
self, args: tuple[object, ...], signature: tuple[Hashable, ...]
116111
) -> BoundKernelInMemoryCacheKey | None:
112+
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey
113+
117114
extra_fns = self._specialize_extra.get(signature)
118115
if extra_fns is not None:
119116
extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns])
@@ -126,6 +123,8 @@ def _create_bound_kernel_cache_key(
126123
args: tuple[object, ...],
127124
signature: tuple[Hashable, ...],
128125
) -> BoundKernelInMemoryCacheKey:
126+
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey
127+
129128
self._specialize_extra[signature] = extra_fns = bound_kernel._specialize_extra()
130129
extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns])
131130
return BoundKernelInMemoryCacheKey(signature, extra_results)
@@ -458,12 +457,18 @@ def autotune(
458457
self.settings.check_autotuning_disabled()
459458

460459
from ..autotuner import DifferentialEvolutionSearch
460+
from ..autotuner import LocalAutotuneCache
461461

462-
config = DifferentialEvolutionSearch(
462+
config = LocalAutotuneCache(
463463
self,
464464
args,
465-
**kwargs, # pyright: ignore[reportArgumentType]
465+
DifferentialEvolutionSearch(
466+
self,
467+
args,
468+
**kwargs, # pyright: ignore[reportArgumentType]
469+
),
466470
).autotune()
471+
467472
self.set_config(config)
468473
return config
469474

0 commit comments

Comments
 (0)