Skip to content

Commit 29bd548

Browse files
committed
[Benchmark] Add low mem dropout example
stack-info: PR: #641, branch: karthickai/stack/1
1 parent 8cf9e61 commit 29bd548

File tree

4 files changed

+218
-0
lines changed

4 files changed

+218
-0
lines changed

benchmarks/run.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@ class RunResult:
205205
"examples.int4_gemm",
206206
"int4_gemm_tritonbench",
207207
),
208+
"low_mem_dropout": (
209+
"tritonbench.operators.low_mem_dropout.operator",
210+
"examples.low_mem_dropout",
211+
"low_mem_dropout_tritonbench",
212+
),
208213
}
209214

210215

@@ -321,6 +326,14 @@ class RunResult:
321326
"helion_grouped_gemm_jagged_persistent_tritonbench-speedup": "helion_speedup",
322327
"helion_grouped_gemm_jagged_persistent_tritonbench-accuracy": "helion_accuracy",
323328
},
329+
"low_mem_dropout": {
330+
"seeded_dropout-accuracy": "triton_accuracy",
331+
"seeded_dropout-speedup": "triton_speedup",
332+
"torch_compile_dropout-accuracy": "torch_compile_accuracy",
333+
"torch_compile_dropout-speedup": "torch_compile_speedup",
334+
"helion_low_mem_dropout_tritonbench-accuracy": "helion_accuracy",
335+
"helion_low_mem_dropout_tritonbench-speedup": "helion_speedup",
336+
},
324337
}
325338

326339

examples/low_mem_dropout.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
Low mem dropout Example
3+
================
4+
5+
This example demonstrates how to implement a Low mem dropout using Helion.
6+
"""
7+
8+
# %%
9+
# Imports
10+
# -------
11+
from __future__ import annotations
12+
13+
from typing import Callable
14+
15+
import torch
16+
17+
import helion
18+
import helion.language as hl
19+
20+
21+
# %%
22+
# Low mem dropout forward implementations
23+
# -------------------
24+
@helion.kernel()
25+
def low_mem_dropout(p: float, x: torch.Tensor) -> torch.Tensor:
26+
"""
27+
Applies dropout on x using p
28+
Args:
29+
p (float): dropout probability
30+
x (torch.Tensor): input tensor
31+
Returns:
32+
Output tensor, mask tensor
33+
"""
34+
scale = 1.0 / (1.0 - p)
35+
# flatten to 1D so we can use tile
36+
n = x.numel()
37+
x_flat = x.view(-1)
38+
out_flat = torch.empty_like(x_flat)
39+
mask_flat = torch.empty_like(x_flat, dtype=torch.bool)
40+
41+
for tidx in hl.tile(n):
42+
xi = x_flat[tidx].to(torch.float32)
43+
r = torch.rand_like(xi, dtype=torch.float32)
44+
keep = r > p
45+
yscaled = xi * scale
46+
yi = torch.where(keep, yscaled, 0.0)
47+
out_flat[tidx] = yi.to(x.dtype)
48+
mask_flat[tidx] = keep
49+
return out_flat.view_as(x), mask_flat.view_as(x)
50+
51+
52+
# %%
53+
# Low mem dropout backward implementation
54+
# -------------------
55+
@helion.kernel()
56+
def low_mem_dropout_bwd(p: float, grad_y: torch.Tensor) -> torch.Tensor:
57+
"""
58+
For low mem dropout we are applying randomness inside both fwd and bwd
59+
technically dropout bwd is same as fwd
60+
Args:
61+
p (float): Dropout probability
62+
grad_y (torch.Tensor): Gradient tensor
63+
Returns:
64+
Output tensor, mask tensor
65+
"""
66+
scale = 1.0 / (1.0 - p)
67+
n = grad_y.numel()
68+
grad_y_flat = grad_y.view(-1)
69+
out_flat = torch.empty_like(grad_y_flat)
70+
mask_flat = torch.empty_like(grad_y_flat, dtype=torch.bool)
71+
for tidx in hl.tile(n):
72+
gi = grad_y_flat[tidx].to(torch.float32)
73+
r = torch.rand_like(gi, dtype=torch.float32)
74+
keep = r > p
75+
g_scaled = gi * scale
76+
gxi = torch.where(keep, g_scaled, 0.0)
77+
out_flat[tidx] = gxi.to(grad_y.dtype)
78+
mask_flat[tidx] = keep
79+
return out_flat.view_as(grad_y), mask_flat.view_as(grad_y)
80+
81+
82+
# %%
83+
# TritonBench Wrapper
84+
# -------------------
85+
def low_mem_dropout_tritonbench(tb_op: object, p: float, x: torch.Tensor) -> Callable:
86+
"""
87+
Wrapper for TritonBench compatibility.
88+
89+
Args:
90+
tb_op: TritonBench operator instance
91+
p (float): dropout probability
92+
x (torch.Tensor): Input tensor
93+
94+
Returns:
95+
Callable: A function that performs the low_mem_dropout.
96+
"""
97+
98+
def _inner() -> torch.Tensor:
99+
out, _ = low_mem_dropout(p, x)
100+
return out
101+
102+
return _inner
103+
104+
105+
# %%
106+
# Verification Function
107+
# -------------------
108+
def check(p: float, size: int) -> None:
109+
"""
110+
Verify the low mem dropout kernel implementation against PyTorch's native dropout implementation.
111+
112+
Args:
113+
p (float): dropout probability
114+
size (int): input tensor size
115+
"""
116+
117+
x = torch.randn(size=(size,)).cuda()
118+
119+
torch.manual_seed(123)
120+
y, fwd_mask = low_mem_dropout(p, x)
121+
122+
# need to set seed again else we can't reproduce
123+
torch.manual_seed(123)
124+
grad_y = torch.ones_like(x)
125+
grad_x, bwd_mask = low_mem_dropout_bwd(p, grad_y)
126+
assert torch.equal(fwd_mask, bwd_mask)
127+
128+
129+
# %%
130+
# Main Function
131+
# -----------
132+
def main() -> None:
133+
"""
134+
Main entry point that runs the low mem dropout kernel verification with different tensor sizes.
135+
Tests with two configurations:
136+
- p=0.25, s=8192
137+
- p=0.25, s=32768
138+
"""
139+
check(0.25, 8192)
140+
check(0.25, 32768)
141+
142+
143+
if __name__ == "__main__":
144+
main()

test/test_examples.expected

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,6 +2276,43 @@ def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.T
22762276
_launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, mean, rstd, mean.stride(0), out.stride(0), out.stride(1), rstd.stride(0), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
22772277
return (out, mean, rstd)
22782278

2279+
--- assertExpectedJournal(TestExamples.test_low_mem_dropout)
2280+
from __future__ import annotations
2281+
2282+
import torch
2283+
import triton
2284+
import triton.language as tl
2285+
from helion.runtime import default_launcher as _default_launcher
2286+
2287+
@triton.jit
2288+
def _helion_low_mem_dropout(x_flat, out_flat, mask_flat, mask_flat_stride_0, out_flat_stride_0, x_flat_stride_0, n, p, scale, _BLOCK_SIZE_0: tl.constexpr, rng_seed_buffer):
2289+
pid_0 = tl.program_id(0)
2290+
offset_0 = pid_0 * _BLOCK_SIZE_0
2291+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
2292+
mask_0 = indices_0 < n
2293+
xi = tl.load(x_flat + indices_0 * x_flat_stride_0, mask_0, other=0)
2294+
rand = tl.rand(tl.load(rng_seed_buffer + 0), indices_0).to(tl.float32)
2295+
v_0 = rand > p
2296+
v_1 = xi * scale
2297+
v_2 = 0.0
2298+
v_3 = v_2[None]
2299+
v_4 = tl.where(v_0, v_1, v_3)
2300+
tl.store(out_flat + indices_0 * out_flat_stride_0, v_4, mask_0)
2301+
tl.store(mask_flat + indices_0 * mask_flat_stride_0, v_0, mask_0)
2302+
2303+
def low_mem_dropout(p: float, x: torch.Tensor, *, _launcher=_default_launcher):
2304+
from torch._inductor import inductor_prims
2305+
_rng_seed_buffer = inductor_prims.seeds(1, torch.device('cuda'))
2306+
'\n Applies dropout on x using p\n Args:\n p (float): dropout probability\n x (torch.Tensor): input tensor\n Returns:\n Output tensor, mask tensor\n '
2307+
scale = 1.0 / (1.0 - p)
2308+
n = x.numel()
2309+
x_flat = x.view(-1)
2310+
out_flat = torch.empty_like(x_flat)
2311+
mask_flat = torch.empty_like(x_flat, dtype=torch.bool)
2312+
_BLOCK_SIZE_0 = 8
2313+
_launcher(_helion_low_mem_dropout, (triton.cdiv(n, _BLOCK_SIZE_0),), x_flat, out_flat, mask_flat, mask_flat.stride(0), out_flat.stride(0), x_flat.stride(0), n, p, scale, _BLOCK_SIZE_0, _rng_seed_buffer, num_warps=4, num_stages=3)
2314+
return (out_flat.view_as(x), mask_flat.view_as(x))
2315+
22792316
--- assertExpectedJournal(TestExamples.test_matmul)
22802317
from __future__ import annotations
22812318

test/test_examples.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,30 @@ def test_welford(self):
308308
)
309309
)
310310

311+
def test_low_mem_dropout(self):
312+
from examples.low_mem_dropout import low_mem_dropout
313+
from examples.low_mem_dropout import low_mem_dropout_bwd
314+
315+
from helion._testing import code_and_output
316+
317+
p, size = 0.25, 8
318+
319+
x = torch.randn(size=(size,)).cuda()
320+
321+
torch.manual_seed(123)
322+
code, (_, fwd_mask) = code_and_output(
323+
low_mem_dropout,
324+
(p, x),
325+
)
326+
327+
# need to set seed again else we can't reproduce
328+
torch.manual_seed(123)
329+
grad_y = torch.ones_like(x)
330+
_, bwd_mask = low_mem_dropout_bwd(p, grad_y)
331+
assert torch.equal(fwd_mask, bwd_mask)
332+
333+
self.assertExpectedJournal(code)
334+
311335
def test_rms_norm_fwd(self):
312336
args = (
313337
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),

0 commit comments

Comments
 (0)