Skip to content

Commit c221cab

Browse files
authored
Display error message when too many arguments are passed (#526)
1 parent 08de077 commit c221cab

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

helion/runtime/kernel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ def bind(self, args: tuple[object, ...]) -> BoundKernel[_R]:
150150
if not isinstance(args, tuple):
151151
assert isinstance(args, list), "args must be a tuple or list"
152152
args = tuple(args)
153+
if len(args) > len(self.signature.parameters):
154+
raise TypeError(
155+
f"Too many arguments passed to the kernel, expected: {len(self.signature.parameters)} got: {len(args)}."
156+
)
153157
signature = self.specialization_key(args)
154158
cache_key = self._get_bound_kernel_cache_key(args, signature)
155159
bound_kernel = (

test/test_errors.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,20 @@ def fn(a: torch.Tensor) -> torch.Tensor:
214214
):
215215
code_and_output(fn, (torch.randn(4, device=DEVICE),))
216216

217+
def test_too_many_args(self):
218+
@helion.kernel()
219+
def kernel(x: torch.Tensor) -> torch.Tensor:
220+
result = torch.zeros_like(x)
221+
for i in hl.tile(x.size()):
222+
result[i] = x[i]
223+
return result
224+
225+
with self.assertRaisesRegex(
226+
TypeError, r"Too many arguments passed to the kernel, expected: 1 got: 2."
227+
):
228+
a = torch.randn(8, device=DEVICE)
229+
code_and_output(kernel, (a, a))
230+
217231

218232
if __name__ == "__main__":
219233
unittest.main()

0 commit comments

Comments
 (0)