Skip to content

Commit 29ceafb

Browse files
committed
Fix hl.rand to use tile specific offsets instead of fixed offsets, ensure unique random num per tile
stack-info: PR: #685, branch: karthickai/stack/3
1 parent 38b9967 commit 29ceafb

File tree

4 files changed

+540
-0
lines changed

4 files changed

+540
-0
lines changed

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .matmul_ops import dot as dot
2222
from .memory_ops import load as load
2323
from .memory_ops import store as store
24+
from .random_ops import rand as rand
2425
from .reduce_ops import reduce as reduce
2526
from .scan_ops import associative_scan as associative_scan
2627
from .scan_ops import cumprod as cumprod

helion/language/random_ops.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import torch
6+
7+
from .._compiler.ast_extension import expr_from_string
8+
from .._compiler.compile_environment import CompileEnvironment
9+
from .._compiler.device_function import SymbolArgument
10+
from ..exc import NotInsideKernel
11+
from . import _decorators
12+
from .ref_tile import RefTile
13+
14+
if TYPE_CHECKING:
15+
import ast
16+
17+
from .._compiler.inductor_lowering import CodegenState
18+
19+
__all__ = ["rand"]
20+
21+
22+
@_decorators.api(tiles_as_sizes=True)
23+
def rand(
24+
shape: list[object],
25+
seed: int,
26+
dtype: torch.dtype = torch.float32,
27+
device: torch.device | None = None,
28+
) -> torch.Tensor:
29+
"""
30+
The main propose of ``hl.rand`` is to explicitly pass a seed arg for deterministic
31+
randomness in helion kernels, whereas ``torch.rand_like`` doesn't take seed arg
32+
(though it can seeded globally)`. ``hl.rand`` lower to ``tl.rand(seed, offset)`` with ``offset``
33+
built from a linear range over the allocation and reshaped to the given shape.
34+
35+
Note:
36+
Only use within ``hl.tile()`` loops for creating local tensors.
37+
For host allocations, use ``torch.rand()``.
38+
39+
Args:
40+
shape: A list of sizes
41+
seed: int seed for the random number generator
42+
dtype: currently only float32 supported
43+
44+
Returns:
45+
torch.Tensor: A device tensor of the given shape and dtype filled with random values
46+
47+
Examples:
48+
.. code-block:: python
49+
50+
@helion.kernel
51+
def process_kernel(x: torch.Tensor) -> torch.Tensor:
52+
output = torch.zeros_like(x)
53+
(m,) = x.shape
54+
for (tile_m,) in hl.tile([m]):
55+
output[tile_m] = hl.rand([tile_m], seed=seed)
56+
return output
57+
58+
"""
59+
raise NotInsideKernel
60+
61+
62+
@_decorators.register_fake(rand)
63+
def _rand_fake(
64+
shape: list[int | torch.SymInt],
65+
seed: int,
66+
dtype: torch.dtype = torch.float32,
67+
device: torch.device | None = None,
68+
) -> torch.Tensor:
69+
if not isinstance(shape, (list, tuple)):
70+
raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}")
71+
env = CompileEnvironment.current()
72+
env.add_kernel_tensor_size(shape)
73+
return torch.empty(
74+
[*shape],
75+
dtype=dtype,
76+
device=env.device if device is None else device,
77+
)
78+
79+
80+
@_decorators.codegen(rand)
81+
def _rand_codegen(state: CodegenState) -> ast.AST:
82+
"""
83+
Generate tl.rand() code with global indices for deterministic RNG per element.
84+
"""
85+
fake_value = state.fake_value
86+
assert isinstance(fake_value, torch.Tensor)
87+
88+
tensor_shape = fake_value.size()
89+
ndim = len(tensor_shape)
90+
if ndim == 0:
91+
raise ValueError("hl.rand() requires at least one dimension")
92+
93+
seed_ast = state.ast_arg(1)
94+
env = CompileEnvironment.current()
95+
96+
symbol_args = []
97+
rdim_args = {}
98+
for arg in state.device_function.arguments:
99+
if isinstance(arg, SymbolArgument) and arg.name != "seed":
100+
symbol_args.append(arg.name)
101+
elif arg.name.startswith("_RDIM_SIZE_"):
102+
rdim_args[arg.name] = arg
103+
104+
index_vars = []
105+
size_names = []
106+
used_rdims = set()
107+
symbol_idx = 0
108+
109+
for i in range(ndim):
110+
block_id = env.get_block_id(tensor_shape[i])
111+
if block_id is not None:
112+
rdim_name = f"_RDIM_SIZE_{block_id}"
113+
if rdim_name in rdim_args:
114+
index_vars.append(f"tl.arange(0, {rdim_name})")
115+
size_names.append(rdim_name)
116+
used_rdims.add(rdim_name)
117+
continue
118+
119+
if block_id is not None:
120+
index_vars.append(state.codegen.index_var(block_id))
121+
if symbol_idx < len(symbol_args):
122+
size_names.append(symbol_args[symbol_idx])
123+
symbol_idx += 1
124+
else:
125+
size_names.append(str(tensor_shape[i]))
126+
continue
127+
128+
available_rdims = [name for name in rdim_args if name not in used_rdims]
129+
if available_rdims:
130+
rdim_name = available_rdims[0]
131+
index_vars.append(f"tl.arange(0, {rdim_name})")
132+
size_names.append(rdim_name)
133+
used_rdims.add(rdim_name)
134+
else:
135+
raise RuntimeError(
136+
"hl.rand() requires tiled dimensions. "
137+
"Use hl.rand() inside hl.tile() loops with tile variables."
138+
)
139+
140+
if ndim == 1:
141+
offset_expr = expr_from_string(index_vars[0])
142+
else:
143+
broadcast_slices = []
144+
for i in range(ndim):
145+
slice_parts = ["None"] * ndim
146+
slice_parts[i] = ":"
147+
broadcast_slices.append(f"[{', '.join(slice_parts)}]")
148+
149+
offset_parts = []
150+
for i in range(ndim):
151+
broadcasted_index = f"{index_vars[i]}{broadcast_slices[i]}"
152+
153+
if i < ndim - 1:
154+
stride_expr = " * ".join(size_names[i + 1 :])
155+
offset_parts.append(f"{broadcasted_index} * {stride_expr}")
156+
else:
157+
offset_parts.append(broadcasted_index)
158+
159+
offset_expr = expr_from_string(" + ".join(offset_parts))
160+
161+
return expr_from_string(
162+
"tl.rand({seed}, {offset})", seed=seed_ast, offset=offset_expr
163+
)
164+
165+
166+
@_decorators.get_masked_value(rand)
167+
def _(
168+
node: torch.fx.Node,
169+
) -> float:
170+
return 0
171+
172+
173+
@_decorators.ref(rand)
174+
def _(
175+
shape: list[int | RefTile],
176+
seed: int,
177+
dtype: torch.dtype = torch.float32,
178+
device: torch.device | None = None,
179+
) -> torch.Tensor:
180+
processed_shape: list[int] = []
181+
for s in shape:
182+
if isinstance(s, RefTile):
183+
processed_shape.append(s.end - s.begin)
184+
else:
185+
processed_shape.append(int(s))
186+
env = CompileEnvironment.current()
187+
gen = torch.Generator(device=env.device if device is None else device)
188+
gen.manual_seed(seed)
189+
return torch.rand(
190+
processed_shape,
191+
dtype=dtype,
192+
generator=gen,
193+
device=env.device if device is None else device,
194+
)

test/test_random.expected

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
This file is automatically generated by assertExpectedJournal calls in test_random.py.
2+
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
3+
4+
--- assertExpectedJournal(TestRandom.test_hl_rand_1d)
5+
from __future__ import annotations
6+
7+
import torch
8+
import triton
9+
import triton.language as tl
10+
from helion.runtime import default_launcher as _default_launcher
11+
12+
@triton.jit
13+
def _helion_rand_kernel_tiled_1d(output, output_stride_0, m, seed, _BLOCK_SIZE_0: tl.constexpr):
14+
pid_0 = tl.program_id(0)
15+
offset_0 = pid_0 * _BLOCK_SIZE_0
16+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
17+
mask_0 = indices_0 < m
18+
rand = tl.rand(seed, indices_0)
19+
tl.store(output + indices_0 * output_stride_0, rand, mask_0)
20+
21+
def rand_kernel_tiled_1d(x: torch.Tensor, seed: int, *, _launcher=_default_launcher):
22+
output = torch.zeros_like(x)
23+
m, = x.shape
24+
_BLOCK_SIZE_0 = 128
25+
_launcher(_helion_rand_kernel_tiled_1d, (triton.cdiv(m, _BLOCK_SIZE_0),), output, output.stride(0), m, seed, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
26+
return output
27+
28+
--- assertExpectedJournal(TestRandom.test_hl_rand_2d)
29+
from __future__ import annotations
30+
31+
import torch
32+
import triton
33+
import triton.language as tl
34+
from helion.runtime import default_launcher as _default_launcher
35+
36+
@triton.jit
37+
def _helion_rand_kernel_tiled_2d(output, output_stride_0, output_stride_1, m, n, seed, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
38+
num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0)
39+
pid_0 = tl.program_id(0) % num_blocks_0
40+
pid_1 = tl.program_id(0) // num_blocks_0
41+
offset_0 = pid_0 * _BLOCK_SIZE_0
42+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
43+
mask_0 = indices_0 < m
44+
offset_1 = pid_1 * _BLOCK_SIZE_1
45+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
46+
mask_1 = indices_1 < n
47+
rand = tl.rand(seed, indices_0[:, None] * n + indices_1[None, :])
48+
tl.store(output + (indices_0[:, None] * output_stride_0 + indices_1[None, :] * output_stride_1), rand, mask_0[:, None] & mask_1[None, :])
49+
50+
def rand_kernel_tiled_2d(x: torch.Tensor, seed: int, *, _launcher=_default_launcher):
51+
output = torch.zeros_like(x)
52+
m, n = x.shape
53+
_BLOCK_SIZE_0 = 32
54+
_BLOCK_SIZE_1 = 32
55+
_launcher(_helion_rand_kernel_tiled_2d, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), output, output.stride(0), output.stride(1), m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
56+
return output
57+
58+
--- assertExpectedJournal(TestRandom.test_hl_rand_3d)
59+
from __future__ import annotations
60+
61+
import torch
62+
import triton
63+
import triton.language as tl
64+
from helion.runtime import default_launcher as _default_launcher
65+
66+
@triton.jit
67+
def _helion_rand_kernel_tiled_3d(output, output_stride_0, output_stride_1, output_stride_2, b, m, n, seed, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
68+
num_blocks_0 = tl.cdiv(b, _BLOCK_SIZE_0)
69+
num_blocks_1 = tl.cdiv(m, _BLOCK_SIZE_1)
70+
pid_0 = tl.program_id(0) % num_blocks_0
71+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
72+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
73+
offset_0 = pid_0 * _BLOCK_SIZE_0
74+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
75+
mask_0 = indices_0 < b
76+
offset_1 = pid_1 * _BLOCK_SIZE_1
77+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
78+
mask_1 = indices_1 < m
79+
offset_2 = pid_2 * _BLOCK_SIZE_2
80+
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
81+
mask_2 = indices_2 < n
82+
rand = tl.rand(seed, indices_0[:, None, None] * m * n + indices_1[None, :, None] * n + indices_2[None, None, :])
83+
tl.store(output + (indices_0[:, None, None] * output_stride_0 + indices_1[None, :, None] * output_stride_1 + indices_2[None, None, :] * output_stride_2), rand, mask_0[:, None, None] & mask_1[None, :, None] & mask_2[None, None, :])
84+
85+
def rand_kernel_tiled_3d(x: torch.Tensor, seed: int, *, _launcher=_default_launcher):
86+
output = torch.zeros_like(x)
87+
b, m, n = x.shape
88+
_BLOCK_SIZE_0 = 16
89+
_BLOCK_SIZE_1 = 16
90+
_BLOCK_SIZE_2 = 16
91+
_launcher(_helion_rand_kernel_tiled_3d, (triton.cdiv(b, _BLOCK_SIZE_0) * triton.cdiv(m, _BLOCK_SIZE_1) * triton.cdiv(n, _BLOCK_SIZE_2),), output, output.stride(0), output.stride(1), output.stride(2), b, m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
92+
return output
93+
94+
--- assertExpectedJournal(TestRandom.test_hl_rand_non_tiled_dimensions)
95+
from __future__ import annotations
96+
97+
import torch
98+
import triton
99+
import triton.language as tl
100+
from helion.runtime import default_launcher as _default_launcher
101+
102+
@triton.jit
103+
def _helion_rand_kernel_partial_tile(output, output_stride_0, output_stride_1, output_stride_2, m, n, seed, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
104+
num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0)
105+
pid_0 = tl.program_id(0) % num_blocks_0
106+
pid_1 = tl.program_id(0) // num_blocks_0
107+
offset_0 = pid_0 * _BLOCK_SIZE_0
108+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
109+
mask_0 = indices_0 < m
110+
offset_1 = pid_1 * _BLOCK_SIZE_1
111+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
112+
mask_1 = indices_1 < n
113+
indices_2 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
114+
rand = tl.rand(seed, indices_0[:, None, None] * n * _RDIM_SIZE_2 + indices_1[None, :, None] * _RDIM_SIZE_2 + tl.arange(0, _RDIM_SIZE_2)[None, None, :])
115+
tl.store(output + (indices_0[:, None, None] * output_stride_0 + indices_1[None, :, None] * output_stride_1 + indices_2[None, None, :] * output_stride_2), rand, mask_0[:, None, None] & mask_1[None, :, None])
116+
117+
def rand_kernel_partial_tile(x: torch.Tensor, seed: int, *, _launcher=_default_launcher):
118+
output = torch.zeros_like(x)
119+
m, n, k = x.shape
120+
_BLOCK_SIZE_0 = 32
121+
_BLOCK_SIZE_1 = 32
122+
_RDIM_SIZE_2 = 8
123+
_launcher(_helion_rand_kernel_partial_tile, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), output, output.stride(0), output.stride(1), output.stride(2), m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=3)
124+
return output

0 commit comments

Comments
 (0)