Skip to content

add support for while/pass #1086

@Tcc0403

Description

@Tcc0403

Is your feature request related to a problem? Please describe.

I'm trying to implement lock mechanism using helion, which would look like the following triton code:

# lock acquire
while tl.atomic_cas(lock_ptr, 0, 1) == 1:
  pass

# critical section

tl.debug_barrier()
# lock release
tl.atomic_xchg(lock_ptr, 0)

Right now helion doesn't support while statement and pass symbol.

helion.exc.UnsupportedPythonType: ast.Pass is not supported in Helion kernels
helion.exc.StatementNotSupported: The statement While is not supported.

Describe the solution you'd like

Having helion natvie support and enabling users to write similar patterns as in triton would be great!

Describe alternatives you've considered

I tried hl.inline_triton() but couldn't make it work, maybe I got it wrong. I'm not that familar with inline stuff, so the document of inline_triton isn't that helpful for me. I ended up checking the unit test test/test_inline_triton.py and still didn't make it work.

Here are the code snippets I tried to replace with inline_triton:

acquire

# while hl.atomic_cas(grad_x_lock, [tile_bt.id, tile_h.id], 0, 1, sem="acquire") == 1:
#     dummy = 1
hl.inline_triton(
    """
    while tl.atomic_cas({0} + {1}, 0, 1) == 1:
       pass
    dummy
    """,
    args=(grad_x_lock, tile_bt.id * num_block_h + tile_h.id),
           output_like=None,
)

I got

...
            # src[fused_linear_cross_entropy.py:121]: args=(grad_x_lock, tile_bt.id * H + tile_h.id),
            tile_id = offset_0 // _BLOCK_SIZE_0
            mul = 4096 * tile_id
            tile_id_1 = offset_4 // _BLOCK_SIZE_1
            add = tile_id_1 + 4096 * tile_id
            # src[fused_linear_cross_entropy.py:11]: @helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper])
            # src[fused_linear_cross_entropy.py:12]: def fused_linear_cross_entropy_fwd_bwd(
            while tl.atomic_cas(_host_tensor + add, 0, 1) == 1:
                                ^
NameError('_host_tensor is not defined')

It failed to parse the host tensor ptr (grad_x_lock) which is defined before hl.tile loops. Refer to the unit test

def test_inline_triton_side_effect_only(self) -> None:
@helion.kernel(autotune_effort="none")
def kernel(x: torch.Tensor) -> torch.Tensor:
flag = torch.zeros(1, device=x.device, dtype=x.dtype)
for tile in hl.tile(x.shape):
val = x[tile]
_ = hl.inline_triton(
"tl.store({0}, {1}[0])",
args=(flag, val),
output_like=None,
)
return flag

are we not able to directly pass tensor as base_ptr?

release

# hl.atomic_xchg(grad_x_lock, [tile_bt.id, tile_h.id], 0, sem="release")
hl.inline_triton(
    """
    tl.atomic_xchg({0} + {1}, 0)
    """,
    args=(grad_x_lock, tile_bt.id * num_block_h + tile_h.id),
    output_like=None,
)

Haven't checked this part.

Addiontal context

It's part of the fused linear cross entropy development(linkedin/Liger-Kernel#928) I'm currently working on. I'm willing to open a PR for helion/examples once I get it done!

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions