Skip to content

Commit 9a30bd1

Browse files
authored
Always raise FailedToUnpackTile when for tile_m, tile_d in hl.tile(m, d) is used (#1009)
1 parent c3c962d commit 9a30bd1

File tree

4 files changed

+46
-0
lines changed

4 files changed

+46
-0
lines changed

docs/api/exceptions.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ These exceptions occur when Helion language functions are used incorrectly with
115115
116116
Raised when tuple unpacking fails for single tile.
117117
118+
.. autoclass:: InvalidTileRange
119+
120+
Raised when ``hl.tile`` is given a range where the begin exceeds the end.
121+
118122
.. autoclass:: OverpackedTile
119123
120124
Raised when tile is wrapped in container when indexing.

helion/exc.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,13 @@ class OverpackedTile(BaseError):
223223
)
224224

225225

226+
class InvalidTileRange(BaseError):
227+
message = (
228+
"hl.tile() expects the begin of the range to be less than or equal to the end. "
229+
"Got begin={0!s}, end={1!s}."
230+
)
231+
232+
226233
class AssignmentMultipleTargets(NotAllowedOnDevice):
227234
message = "Assignment with multiple targets (a=b=1) is not allowed inside the `hl.tile` or `hl.grid` loop."
228235

helion/language/loops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,11 @@ def _(
329329
)
330330
block_size_list = Tile._tiles_to_sizes(block_size_list)
331331

332+
if unpack:
333+
target = getattr(parent, "target", None)
334+
if isinstance(target, (ast.Tuple, ast.List)) and len(target.elts) > 1:
335+
raise exc.FailedToUnpackTile from None
336+
332337
results = []
333338
for begin_part, end_part, bs in zip(
334339
begin_list,
@@ -339,6 +344,8 @@ def _(
339344
if isinstance(begin_part, Tile) or isinstance(end_part, Tile):
340345
raise exc.TileOfTile
341346
size = end_part - begin_part # type: ignore[operator]
347+
if isinstance(size, int) and size < 0:
348+
raise exc.InvalidTileRange(begin_part, end_part)
342349
if isinstance(size, torch.Tensor):
343350
size = None # data dependent size
344351
if bs is None:

test/test_errors.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,34 @@ def fn(x: torch.Tensor) -> torch.Tensor:
135135
with self.assertRaises(helion.exc.OverpackedTile):
136136
code_and_output(fn, (torch.randn(100, 100, device=DEVICE),))
137137

138+
def test_tile_invalid_range_unpack(self):
139+
@helion.kernel()
140+
def fn(x: torch.Tensor) -> torch.Tensor:
141+
m = x.size(0)
142+
m = hl.specialize(m)
143+
d = x.size(2)
144+
for _tile_m, _tile_d in hl.tile(m, d):
145+
pass
146+
return x
147+
148+
with self.assertRaises(helion.exc.FailedToUnpackTile):
149+
code_and_output(fn, (torch.randn(192, 4, 128, device=DEVICE),))
150+
151+
def test_tile_invalid_range_single_dim(self):
152+
@helion.kernel()
153+
def fn(x: torch.Tensor) -> torch.Tensor:
154+
start = hl.specialize(x.size(0))
155+
end = x.size(2)
156+
for _tile_m in hl.tile(start, end):
157+
pass
158+
return x
159+
160+
with self.assertRaisesRegex(
161+
helion.exc.InvalidTileRange,
162+
r"begin=192, end=128",
163+
):
164+
code_and_output(fn, (torch.randn(192, 4, 128, device=DEVICE),))
165+
138166
def test_invalid_config_insufficient_block_sizes(self):
139167
"""Test that InvalidConfig shows helpful message for missing block sizes."""
140168

0 commit comments

Comments
 (0)