Skip to content

Commit 89f5048

Browse files
committed
[Benchmark] Add low mem dropout example
1 parent 8442721 commit 89f5048

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

benchmarks/run.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,11 @@ class RunResult:
198198
"examples.int4_gemm",
199199
"int4_gemm_tritonbench",
200200
),
201+
"low_mem_dropout": (
202+
"tritonbench.operators.low_mem_dropout.operator",
203+
"examples.low_mem_dropout",
204+
"low_mem_dropout_tritonbench",
205+
),
201206
}
202207

203208

@@ -306,6 +311,14 @@ class RunResult:
306311
"helion_int4_gemm_tritonbench-speedup": "helion_speedup",
307312
"helion_int4_gemm_tritonbench-accuracy": "helion_accuracy",
308313
},
314+
"low_mem_dropout": {
315+
"triton_dropout-accuracy": "triton_accuracy",
316+
"triton_dropout-speedup": "triton_speedup",
317+
"torch_compile_dropout-accuracy": "torch_compile_accuracy",
318+
"torch_compile_dropout-speedup": "torch_compile_speedup",
319+
"helion_low_mem_dropout_tritonbench-accuracy": "helion_accuracy",
320+
"helion_low_mem_dropout_tritonbench-speedup": "helion_speedup",
321+
},
309322
}
310323

311324

examples/low_mem_dropout.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
Low mem dropout Example
3+
================
4+
5+
This example demonstrates how to implement a Low mem dropout using Helion.
6+
"""
7+
8+
# %%
9+
# Imports
10+
# -------
11+
from __future__ import annotations
12+
13+
from typing import Callable
14+
15+
import torch
16+
17+
import helion
18+
from helion._testing import run_example
19+
import helion.language as hl
20+
21+
22+
# %%
23+
# Low mem dropout implementations
24+
# -------------------
25+
@helion.kernel()
26+
def low_mem_dropout(p: float, x: torch.Tensor, x_keep: torch.Tensor) -> torch.Tensor:
27+
"""
28+
Applies dropout on x using p
29+
Args:
30+
p (float): dropout probability
31+
x (torch.Tensor): input tensor
32+
x_keep (torch.Tensor): mask tensor indicating which elements to keep
33+
Returns:
34+
Output tensor
35+
"""
36+
37+
scale = 1.0 / (1.0 - p)
38+
# flatten to 1D so we can use tile
39+
n = x.numel()
40+
x_flat, m_flat = x.view(-1), x_keep.view(-1)
41+
out_flat = torch.empty_like(x_flat)
42+
43+
for tidx in hl.tile(n):
44+
xi = x_flat[tidx].to(torch.float32)
45+
mi = m_flat[tidx].to(torch.float32) > 0.5
46+
yscaled = xi * scale
47+
zeros = xi - xi
48+
yi = torch.where(mi, yscaled, zeros)
49+
out_flat[tidx] = yi.to(x.dtype)
50+
return out_flat.view_as(x)
51+
52+
53+
# %%
54+
# TritonBench Wrapper
55+
# -------------------
56+
def low_mem_dropout_tritonbench(tb_op: object, p: float, x: torch.Tensor) -> Callable:
57+
"""
58+
Wrapper for TritonBench compatibility.
59+
60+
Args:
61+
tb_op: TritonBench operator instance
62+
p (float): dropout probability
63+
x (torch.Tensor): Input tensor
64+
65+
Returns:
66+
Callable: A function that performs the low_mem_dropout.
67+
"""
68+
torch.manual_seed(123) # Set seed for reproducibility
69+
x_keep = torch.rand_like(x) > p
70+
return lambda: low_mem_dropout(p, x, x_keep)
71+
72+
73+
# %%
74+
# Baseline Function
75+
# -------------------
76+
def eager_dropout(p: float, x: torch.Tensor, x_keep: torch.Tensor) -> torch.Tensor:
77+
return x * x_keep.to(x.dtype) / (1 - p)
78+
79+
80+
# %%
81+
# Verification Function
82+
# -------------------
83+
def check(p: float, size: int) -> None:
84+
"""
85+
Verify the low mem dropout kernel implementation against PyTorch's native dropout implementation.
86+
87+
Args:
88+
p (float): dropout probability
89+
size (int): input tensor size
90+
"""
91+
x = torch.randn(size=(size,)).cuda()
92+
torch.manual_seed(123) # Set seed for reproducibility
93+
x_keep = torch.rand_like(x) > p
94+
kernels = {"low_mem_dropout": low_mem_dropout}
95+
run_example(kernels, eager_dropout, (p, x, x_keep))
96+
97+
98+
# %%
99+
# Main Function
100+
# -----------
101+
def main() -> None:
102+
"""
103+
Main entry point that runs the low mem dropout kernel verification with different tensor sizes.
104+
105+
Tests with two configurations:
106+
- p=0.25, s=8192
107+
- p=0.25, s=32768
108+
"""
109+
check(0.25, 8192)
110+
check(0.25, 32768)
111+
112+
113+
if __name__ == "__main__":
114+
main()

0 commit comments

Comments
 (0)