-
Notifications
You must be signed in to change notification settings - Fork 64
[Benchmark] Add low mem dropout example #641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
stack-info: PR: #641, branch: karthickai/stack/1
89f5048 to
5992548
Compare
stack-info: PR: #641, branch: karthickai/stack/1
5992548 to
c8b0148
Compare
|
Speedup and Accuracy HELION_USE_DEFAULT_CONFIG=0 python benchmarks/run.py --kernel low_mem_dropout --metrics accuracy,speedup
x_val triton_dropout-speedup triton_dropout-accuracy torch_compile_dropout-speedup torch_compile_dropout-accuracy seeded_dropout-speedup seeded_dropout-accuracy helion_low_mem_dropout_tritonbench-speedup helion_low_mem_dropout_tritonbench-accuracy
------- ------------------------ ------------------------- ------------------------------- -------------------------------- ------------------------ ------------------------- -------------------------------------------- ---------------------------------------------
32 1.99543 1 1.99543 1 2.17413 0 2.13171 0
128 1.68981 1 1.89119 1 1.74641 0 1.88144 0
512 2.16744 1 2.18779 1 2.08036 0 2.51892 0
2048 2 1 2.31088 1 1.93913 0 2.31088 0
8192 2.09302 1 2.21675 1 1.98238 0 2.0362 0
32768 2.47264 1 2.4604 1 2.05372 0 2.47264 0
131072 2.18884 1 2.26667 1 1.96911 0 2.10744 0
524288 1.85586 1 1.91331 1 1.7913 0 2.20714 0
average 2.05788 1 2.1553 1 1.96707 0 2.2083 |
examples/low_mem_dropout.py
Outdated
| Args: | ||
| p (float): dropout probability | ||
| x (torch.Tensor): input tensor | ||
| x_keep (torch.Tensor): mask tensor indicating which elements to keep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know if the tritonbench Triton kernel also takes this as input? Ideally we want to run x_keep = torch.rand_like(x) > p within the Helion kernel's hl.tile device loop; however if tritonbench Triton kernel is not doing that either, we can stay with the current design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Triton bench has two variants _triton_dropout and _seeded_triton_dropout, I referred _triton_dropout variant which take x_keep as arg.
examples/low_mem_dropout.py
Outdated
|
|
||
| for tidx in hl.tile(n): | ||
| xi = x_flat[tidx].to(torch.float32) | ||
| mi = m_flat[tidx].to(torch.float32) > 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is wrong:
- We should be generating a random number here not reading an input. What makes "low mem" dropout "low mem" is you don't take the probability as an input, you generate it from a seed. So you use O(1) reads rather than O(n).
- We may need to add some
hl.randomops to make this possible - Low mem dropout is mainly interesting when you include the backwards, since you can use the same seed for forwards and backwards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed - maybe we should model after _seeded_triton_dropout, and also we could try to use torch.rand_like support (PR).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @jansel, I agree with your points. I've updated the kernel to generate randomness inside the kernel using torch.rand_like per tile. Thanks @yf225 for your suggestion.
@helion.kernel()
def low_mem_dropout(p: float, x: torch.Tensor) -> torch.Tensor:
"""
Applies dropout on x using p
Args:
p (float): dropout probability
x (torch.Tensor): input tensor
Returns:
Output tensor
"""
scale = 1.0 / (1.0 - p)
# flatten to 1D so we can use tile
n = x.numel()
x_flat = x.view(-1)
out_flat = torch.empty_like(x_flat)
for tidx in hl.tile(n):
xi = x_flat[tidx].to(torch.float32)
r = torch.rand_like(xi, dtype=torch.float32)
keep = r > p
yscaled = xi * scale
zeros = xi - xi
yi = torch.where(keep, yscaled, zeros)
out_flat[tidx] = yi.to(x.dtype)
return out_flat.view_as(x)After the change the kernel's output not matching with eager, even with manual_seed I believe this is expected, could you kindly advice should I adjust the testing approach
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Triton/Helion random will not match eager mode, this is expected. What you do need to do is make sure the randomness matches between forwards and backwards.
I worry torch.rand_like will make this hard since it doesn't accept a seed arg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @jansel, I've updated the test to verify that the dropout mask from fwd matches the mask in bwd. since torch.rand_like doesn't take seed, I reseed with torch.manual_seed before each kernel.
stack-info: PR: #641, branch: karthickai/stack/1
c8b0148 to
4ce0b23
Compare
stack-info: PR: #641, branch: karthickai/stack/1
4ce0b23 to
29bd548
Compare
stack-info: PR: #641, branch: karthickai/stack/1
29bd548 to
17ff8d7
Compare
stack-info: PR: #641, branch: karthickai/stack/1
17ff8d7 to
bcd53b3
Compare
stack-info: PR: #641, branch: karthickai/stack/1
bcd53b3 to
9164ec2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See below
examples/low_mem_dropout.py
Outdated
| torch.manual_seed(123) | ||
| y, fwd_mask = low_mem_dropout(p, x) | ||
|
|
||
| # need to set seed again else we can't reproduce | ||
| torch.manual_seed(123) | ||
| grad_y = torch.ones_like(x) | ||
| grad_x, bwd_mask = low_mem_dropout_bwd(p, grad_y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still isn't right. The forward is returning a fwd_mask with is the signature of non-low-mem dropout. The point of low-mem dropout is to not store the mask in memory. If you have the mask, then the backward is just fwd_mask*grad_y*scale.
Also, needing to call torch.manual_seed(123) is kind of clunky since it mutates global state and results in extra kernel launches.
I'd suggest either:
- Just implement regular (not low-mem) dropout and make the backward use the fwd_mask
- Add a seeded random op so we can do low mem dropout properly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again, I will try to implement the seeded random op, once that ready I'll update the low mem dropout example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated the low_mem_dropout with the newly created hl.rand op, which takes a seed arg. Now it's not storing mask in a memory.
stack-info: PR: #641, branch: karthickai/stack/1
9164ec2 to
b7cbc36
Compare
stack-info: PR: #641, branch: karthickai/stack/1
b7cbc36 to
4458226
Compare
test/test_examples.expected
Outdated
| indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) | ||
| mask_0 = indices_0 < n | ||
| xi = tl.load(x_flat + indices_0 * x_flat_stride_0, mask_0, other=0) | ||
| rand = tl.rand(seed, tl.arange(0, _BLOCK_SIZE_0).reshape([_BLOCK_SIZE_0])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this will result in the same RNG used for every tile.
Tile 1 will be:
- seed, range(0, BLOCK_SIZE)
Tile 2 will be: - seed, range(0, BLOCK_SIZE)
So the same elements will get dropped each tile. I think this needs to be the index.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the excellent catch! I'll update hl.rand to use the index instead of tl.arange(0, BLOCK_SIZE)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated hl.rand to generate unique rng for each tile (#685). I'll update this once it's merged.
stack-info: PR: #641, branch: karthickai/stack/1
4458226 to
41e4926
Compare
| ) | ||
| ) | ||
|
|
||
| def test_low_mem_dropout(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test backwards (not just fwd) and assert that the same elements are dropped out in bwd as fwd (and different elements are dopped out if you change the seed).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, I've updated the test case with dropout mask checking.
stack-info: PR: #641, branch: karthickai/stack/1
41e4926 to
822afac
Compare
stack-info: PR: #641, branch: karthickai/stack/1
822afac to
0f4844a
Compare
Stacked PRs:
[Benchmark] Add low mem dropout example