-
Notifications
You must be signed in to change notification settings - Fork 67
Description
Is your feature request related to a problem? Please describe.
I'm trying to implement lock mechanism using helion, which would look like the following triton code:
# lock acquire
while tl.atomic_cas(lock_ptr, 0, 1) == 1:
pass
# critical section
tl.debug_barrier()
# lock release
tl.atomic_xchg(lock_ptr, 0)Right now helion doesn't support while statement and pass symbol.
helion.exc.UnsupportedPythonType: ast.Pass is not supported in Helion kernels
helion.exc.StatementNotSupported: The statement While is not supported.
Describe the solution you'd like
Having helion natvie support and enabling users to write similar patterns as in triton would be great!
Describe alternatives you've considered
I tried hl.inline_triton() but couldn't make it work, maybe I got it wrong. I'm not that familar with inline stuff, so the document of inline_triton isn't that helpful for me. I ended up checking the unit test test/test_inline_triton.py and still didn't make it work.
Here are the code snippets I tried to replace with inline_triton:
acquire
# while hl.atomic_cas(grad_x_lock, [tile_bt.id, tile_h.id], 0, 1, sem="acquire") == 1:
# dummy = 1
hl.inline_triton(
"""
while tl.atomic_cas({0} + {1}, 0, 1) == 1:
pass
dummy
""",
args=(grad_x_lock, tile_bt.id * num_block_h + tile_h.id),
output_like=None,
)I got
...
# src[fused_linear_cross_entropy.py:121]: args=(grad_x_lock, tile_bt.id * H + tile_h.id),
tile_id = offset_0 // _BLOCK_SIZE_0
mul = 4096 * tile_id
tile_id_1 = offset_4 // _BLOCK_SIZE_1
add = tile_id_1 + 4096 * tile_id
# src[fused_linear_cross_entropy.py:11]: @helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper])
# src[fused_linear_cross_entropy.py:12]: def fused_linear_cross_entropy_fwd_bwd(
while tl.atomic_cas(_host_tensor + add, 0, 1) == 1:
^
NameError('_host_tensor is not defined')
It failed to parse the host tensor ptr (grad_x_lock) which is defined before hl.tile loops. Refer to the unit test
helion/test/test_inline_triton.py
Lines 155 to 166 in dee9f57
| def test_inline_triton_side_effect_only(self) -> None: | |
| @helion.kernel(autotune_effort="none") | |
| def kernel(x: torch.Tensor) -> torch.Tensor: | |
| flag = torch.zeros(1, device=x.device, dtype=x.dtype) | |
| for tile in hl.tile(x.shape): | |
| val = x[tile] | |
| _ = hl.inline_triton( | |
| "tl.store({0}, {1}[0])", | |
| args=(flag, val), | |
| output_like=None, | |
| ) | |
| return flag |
are we not able to directly pass tensor as base_ptr?
release
# hl.atomic_xchg(grad_x_lock, [tile_bt.id, tile_h.id], 0, sem="release")
hl.inline_triton(
"""
tl.atomic_xchg({0} + {1}, 0)
""",
args=(grad_x_lock, tile_bt.id * num_block_h + tile_h.id),
output_like=None,
)Haven't checked this part.
Addiontal context
It's part of the fused linear cross entropy development(linkedin/Liger-Kernel#928) I'm currently working on. I'm willing to open a PR for helion/examples once I get it done!