Skip to content

Commit 1701f8d

Browse files
authored
Raise better error when hl.atomic_* is used on device tensor (#658)
1 parent 94b0650 commit 1701f8d

File tree

3 files changed

+28
-0
lines changed

3 files changed

+28
-0
lines changed

helion/exc.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,13 @@ class CannotModifyHostVariableOnDevice(BaseError):
370370
message = "Cannot modify host variable '{0}' inside `hl.tile` or `hl.grid` loop without subscript assignment. Use '{0}[tile] = ...' instead."
371371

372372

373+
class AtomicOnDeviceTensor(BaseError):
374+
message = (
375+
"hl.{0}() target must be host-allocated tensor (i.e. allocated outside of hl.tile or hl.grid loop). "
376+
"Tensors created inside device loops do not have an addressable pointer for atomics."
377+
)
378+
379+
373380
class CannotReadDeviceVariableOnHost(BaseError):
374381
message = "Cannot read variable '{0}' defined inside `hl.tile` or `hl.grid` loop from host code."
375382

helion/language/atomic_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from .. import exc
1212
from .._compiler.ast_extension import expr_from_string
13+
from .._compiler.host_function import HostFunction
1314
from .._compiler.indexing_strategy import SubscriptIndexing
1415
from . import _decorators
1516

@@ -64,6 +65,10 @@ def _codegen_common(
6465
assert isinstance(target, torch.Tensor)
6566
assert isinstance(index, list)
6667

68+
host_function = HostFunction.current()
69+
if target not in host_function.tensor_to_origin:
70+
raise exc.AtomicOnDeviceTensor(tl_func)
71+
6772
indices = SubscriptIndexing.create(state, target, index)
6873
name = state.device_function.tensor_arg(target).name
6974

test/test_atomic_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,22 @@ def test_atomic_add_w_tile_attr(self):
276276
torch.testing.assert_close(result, expected)
277277
self.assertExpectedJournal(code)
278278

279+
@skipIfRefEager("Error only raises in normal mode")
280+
def test_atomic_add_device_tensor_error(self):
281+
@helion.kernel(static_shapes=True)
282+
def kernel(x: torch.Tensor) -> torch.Tensor:
283+
for tile in hl.tile(x.size(0), block_size=128):
284+
device_tensor = hl.zeros([tile], dtype=x.dtype)
285+
hl.atomic_add(device_tensor, [tile], x[tile])
286+
return x
287+
288+
x = torch.ones(256, device=DEVICE, dtype=torch.float32)
289+
with self.assertRaisesRegex(
290+
helion.exc.AtomicOnDeviceTensor,
291+
r"hl\.atomic_add\(\)",
292+
):
293+
kernel(x)
294+
279295
# New tests for other atomics (correctness only; no journal asserts)
280296
def test_atomic_and(self):
281297
x0 = torch.full((8,), 0b1111, device=DEVICE, dtype=torch.int32)

0 commit comments

Comments
 (0)