Skip to content

Commit 521dcc7

Browse files
committed
Fix hl.rand to use tile specific offsets instead of fixed offsets, ensure unique random num per tile
stack-info: PR: #685, branch: karthickai/stack/3
1 parent c3a59f2 commit 521dcc7

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

helion/language/random_ops.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,24 @@ def _rand_codegen(state: CodegenState) -> ast.AST:
8181
fake_value = state.fake_value
8282
assert isinstance(fake_value, torch.Tensor)
8383
shape_str = state.device_function.tile_strategy.shape_str(fake_value.size())
84-
85-
numel = " * ".join(shape_str.strip("[]").split(","))
8684
seed_ast = state.ast_arg(1)
87-
offs_expr = f"tl.arange(0, {numel}).reshape({shape_str})"
85+
offs_expr = None
86+
env = CompileEnvironment.current()
87+
for size in fake_value.size():
88+
block_id = env.get_block_id(size)
89+
if block_id is not None:
90+
if len(fake_value.size()) == 1:
91+
# 1D: use indices_0 directly, it already has the right values
92+
index_var = state.codegen.index_var(block_id)
93+
offs_expr = f"{index_var}.reshape({shape_str})"
94+
else:
95+
# N_D: use offset_0 + full range
96+
offset_var = state.codegen.offset_var(block_id)
97+
numel = " * ".join(shape_str.strip("[]").split(","))
98+
offs_expr = (
99+
f"({offset_var} + tl.arange(0, {numel})).reshape({shape_str})"
100+
)
88101
expr = f"tl.rand({{seed}}, {offs_expr})"
89-
90102
return expr_from_string(expr, seed=seed_ast)
91103

92104

test/test_rng.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def rand_kernel_tiled_1d(x: torch.Tensor, seed: int) -> torch.Tensor:
366366
"Different seeds should produce different outputs",
367367
)
368368

369-
_, output3 = code_and_output(rand_kernel_tiled_1d, (x_small, 42))
369+
code3, output3 = code_and_output(rand_kernel_tiled_1d, (x_small, 42))
370370
self.assertTrue(
371371
torch.allclose(output, output3),
372372
"Same seed should produce identical outputs",
@@ -376,6 +376,8 @@ def rand_kernel_tiled_1d(x: torch.Tensor, seed: int) -> torch.Tensor:
376376
self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0")
377377
self.assertTrue(torch.all(output < 1.0), "All values should be < 1")
378378

379+
self.assertIn("tl.rand(seed, indices_0", code3)
380+
379381
def test_hl_rand_2d(self):
380382
@helion.kernel
381383
def rand_kernel_tiled_2d(x: torch.Tensor, seed: int) -> torch.Tensor:
@@ -394,14 +396,15 @@ def rand_kernel_tiled_2d(x: torch.Tensor, seed: int) -> torch.Tensor:
394396
"Different seeds should produce different outputs",
395397
)
396398

397-
_, output3 = code_and_output(rand_kernel_tiled_2d, (x_small, 42))
399+
code3, output3 = code_and_output(rand_kernel_tiled_2d, (x_small, 42))
398400
self.assertTrue(
399401
torch.allclose(output, output3),
400402
"Same seed should produce identical outputs",
401403
)
402404

403405
self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0")
404406
self.assertTrue(torch.all(output < 1.0), "All values should be < 1")
407+
self.assertIn("tl.rand(seed, (offset_0 + tl.arange(0,", code3)
405408

406409
def test_hl_rand_3d(self):
407410
@helion.kernel
@@ -423,7 +426,7 @@ def rand_kernel_tiled_3d(x: torch.Tensor, seed: int) -> torch.Tensor:
423426
"Different seeds should produce different outputs",
424427
)
425428

426-
_, output3 = code_and_output(rand_kernel_tiled_3d, (x_small, 42))
429+
code3, output3 = code_and_output(rand_kernel_tiled_3d, (x_small, 42))
427430
self.assertTrue(
428431
torch.allclose(output, output3),
429432
"Same seed should produce identical outputs",
@@ -438,6 +441,7 @@ def rand_kernel_tiled_3d(x: torch.Tensor, seed: int) -> torch.Tensor:
438441
0.4 < mean_val < 0.6,
439442
f"Mean {mean_val:.3f} should be around 0.5 for uniform distribution",
440443
)
444+
self.assertIn("tl.rand(seed, (offset_0 + tl.arange(0,", code3)
441445

442446

443447
if __name__ == "__main__":

0 commit comments

Comments
 (0)