Skip to content

Commit 4458226

Browse files
committed
[Benchmark] Add low mem dropout example
stack-info: PR: #641, branch: karthickai/stack/1
1 parent c3a59f2 commit 4458226

File tree

4 files changed

+207
-0
lines changed

4 files changed

+207
-0
lines changed

benchmarks/run.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ class RunResult:
210210
"examples.int4_gemm",
211211
"int4_gemm_tritonbench",
212212
),
213+
"low_mem_dropout": (
214+
"tritonbench.operators.low_mem_dropout.operator",
215+
"examples.low_mem_dropout",
216+
"low_mem_dropout_tritonbench",
217+
),
213218
}
214219

215220

@@ -334,6 +339,14 @@ class RunResult:
334339
"helion_grouped_gemm_jagged_persistent_tritonbench-speedup": "helion_speedup",
335340
"helion_grouped_gemm_jagged_persistent_tritonbench-accuracy": "helion_accuracy",
336341
},
342+
"low_mem_dropout": {
343+
"seeded_dropout-accuracy": "triton_accuracy",
344+
"seeded_dropout-speedup": "triton_speedup",
345+
"torch_compile_dropout-accuracy": "torch_compile_accuracy",
346+
"torch_compile_dropout-speedup": "torch_compile_speedup",
347+
"helion_low_mem_dropout_tritonbench-accuracy": "helion_accuracy",
348+
"helion_low_mem_dropout_tritonbench-speedup": "helion_speedup",
349+
},
337350
}
338351

339352

examples/low_mem_dropout.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""
2+
Low mem dropout Example
3+
================
4+
5+
This example demonstrates how to implement a Low mem dropout using Helion.
6+
"""
7+
8+
# %%
9+
# Imports
10+
# -------
11+
from __future__ import annotations
12+
13+
from typing import Callable
14+
15+
import torch
16+
17+
import helion
18+
import helion.language as hl
19+
20+
21+
# %%
22+
# Low mem dropout forward implementations
23+
# -------------------
24+
@helion.kernel()
25+
def low_mem_dropout(p: float, x: torch.Tensor, seed: int) -> torch.Tensor:
26+
"""
27+
Applies dropout on x using p
28+
Args:
29+
p (float): dropout probability
30+
x (torch.Tensor): input tensor
31+
Returns:
32+
Output tensor
33+
"""
34+
scale = 1.0 / (1.0 - p)
35+
# flatten to 1D so we can use tile
36+
n = x.numel()
37+
x_flat = x.view(-1)
38+
out_flat = torch.empty_like(x_flat)
39+
for tidx in hl.tile(n):
40+
xi = x_flat[tidx].to(torch.float32)
41+
r = hl.rand([tidx], seed=seed)
42+
keep = r > p
43+
yscaled = xi * scale
44+
yi = torch.where(keep, yscaled, 0.0)
45+
out_flat[tidx] = yi.to(x.dtype)
46+
return out_flat.view_as(x)
47+
48+
49+
# %%
50+
# Low mem dropout backward implementation
51+
# -------------------
52+
@helion.kernel()
53+
def low_mem_dropout_bwd(p: float, grad_y: torch.Tensor, seed: int) -> torch.Tensor:
54+
"""
55+
For low mem dropout we are applying randomness inside both fwd and bwd
56+
technically dropout bwd is same as fwd
57+
Args:
58+
p (float): Dropout probability
59+
grad_y (torch.Tensor): Gradient tensor
60+
Returns:
61+
Output tensor
62+
"""
63+
scale = 1.0 / (1.0 - p)
64+
n = grad_y.numel()
65+
grad_y_flat = grad_y.view(-1)
66+
out_flat = torch.empty_like(grad_y_flat)
67+
for tidx in hl.tile(n):
68+
gi = grad_y_flat[tidx].to(torch.float32)
69+
r = hl.rand([tidx], seed=seed)
70+
keep = r > p
71+
g_scaled = gi * scale
72+
gxi = torch.where(keep, g_scaled, 0.0)
73+
out_flat[tidx] = gxi.to(grad_y.dtype)
74+
return out_flat.view_as(grad_y)
75+
76+
77+
# %%
78+
# TritonBench Wrapper
79+
# -------------------
80+
def low_mem_dropout_tritonbench(tb_op: object, p: float, x: torch.Tensor) -> Callable:
81+
"""
82+
Wrapper for TritonBench compatibility.
83+
84+
Args:
85+
tb_op: TritonBench operator instance
86+
p (float): dropout probability
87+
x (torch.Tensor): Input tensor
88+
89+
Returns:
90+
Callable: A function that performs the low_mem_dropout.
91+
"""
92+
93+
def _inner() -> torch.Tensor:
94+
return low_mem_dropout(p, x, seed=123)
95+
96+
return _inner
97+
98+
99+
# %%
100+
# Verification Function
101+
# -------------------
102+
def check(p: float, size: int) -> None:
103+
"""
104+
Verify the low mem dropout kernel implementation against PyTorch's native dropout implementation.
105+
106+
Args:
107+
p (float): dropout probability
108+
size (int): input tensor size
109+
"""
110+
x = torch.randn(size=(size,)).cuda()
111+
seed = 123
112+
113+
out = low_mem_dropout(p, x, seed)
114+
grad_y = torch.ones_like(x)
115+
grad_x = low_mem_dropout_bwd(p, grad_y, seed)
116+
mask_fwd = out != 0
117+
mask_bwd = grad_x != 0
118+
assert torch.equal(mask_fwd, mask_bwd)
119+
120+
121+
# %%
122+
# Main Function
123+
# -----------
124+
def main() -> None:
125+
"""
126+
Main entry point that runs the low mem dropout kernel verification with different tensor sizes.
127+
Tests with two configurations:
128+
- p=0.25, s=8192
129+
- p=0.25, s=32768
130+
"""
131+
check(0.25, 8192)
132+
check(0.25, 32768)
133+
134+
135+
if __name__ == "__main__":
136+
main()

test/test_examples.expected

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2346,6 +2346,46 @@ def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.T
23462346
_launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, mean, rstd, mean.stride(0), out.stride(0), out.stride(1), rstd.stride(0), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
23472347
return (out, mean, rstd)
23482348

2349+
--- assertExpectedJournal(TestExamples.test_low_mem_dropout)
2350+
from __future__ import annotations
2351+
2352+
import torch
2353+
import triton
2354+
import triton.language as tl
2355+
from helion.runtime import default_launcher as _default_launcher
2356+
2357+
@triton.jit
2358+
def _helion_low_mem_dropout(x_flat, out_flat, out_flat_stride_0, x_flat_stride_0, n, seed, p, scale, _BLOCK_SIZE_0: tl.constexpr):
2359+
pid_0 = tl.program_id(0)
2360+
offset_0 = pid_0 * _BLOCK_SIZE_0
2361+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
2362+
mask_0 = indices_0 < n
2363+
xi = tl.load(x_flat + indices_0 * x_flat_stride_0, mask_0, other=0)
2364+
rand = tl.rand(seed, tl.arange(0, _BLOCK_SIZE_0).reshape([_BLOCK_SIZE_0]))
2365+
v_0 = rand > p
2366+
v_1 = xi * scale
2367+
v_2 = 0.0
2368+
v_3 = v_2[None]
2369+
v_4 = tl.where(v_0, v_1, v_3)
2370+
tl.store(out_flat + indices_0 * out_flat_stride_0, v_4, mask_0)
2371+
2372+
def low_mem_dropout(p: float, x: torch.Tensor, seed: int, *, _launcher=_default_launcher):
2373+
"""
2374+
Applies dropout on x using p
2375+
Args:
2376+
p (float): dropout probability
2377+
x (torch.Tensor): input tensor
2378+
Returns:
2379+
Output tensor
2380+
"""
2381+
scale = 1.0 / (1.0 - p)
2382+
n = x.numel()
2383+
x_flat = x.view(-1)
2384+
out_flat = torch.empty_like(x_flat)
2385+
_BLOCK_SIZE_0 = 1024
2386+
_launcher(_helion_low_mem_dropout, (triton.cdiv(n, _BLOCK_SIZE_0),), x_flat, out_flat, out_flat.stride(0), x_flat.stride(0), n, seed, p, scale, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
2387+
return out_flat.view_as(x)
2388+
23492389
--- assertExpectedJournal(TestExamples.test_matmul)
23502390
from __future__ import annotations
23512391

test/test_examples.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,24 @@ def test_welford(self):
308308
)
309309
)
310310

311+
def test_low_mem_dropout(self):
312+
from examples.low_mem_dropout import low_mem_dropout_bwd
313+
314+
from helion._testing import code_and_output
315+
316+
p = 0.25
317+
size = 8192
318+
seed = 123
319+
x = torch.randn(size=(size,)).cuda()
320+
_, grad_x = code_and_output(
321+
low_mem_dropout_bwd,
322+
(p, x, seed),
323+
)
324+
325+
self.assertExpectedJournal(
326+
check_example("low_mem_dropout", (p, x, seed), grad_x),
327+
)
328+
311329
def test_rms_norm_fwd(self):
312330
args = (
313331
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),

0 commit comments

Comments
 (0)