Skip to content

Commit 4bc7d3c

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Handle unsupported pybind inputs non-fatally
Summary: Currently, passing unsupported python types to pybind method execution, such as lists, dicts, or tuples, will crash the kernel due to hitting an assert. This PR updates the logic to raise an exception, which gets nicely bubbled up to the notebook. This gives the user a nicer error message and does not crash the bento/jupyter process. Differential Revision: D74118509
1 parent 94f7b10 commit 4bc7d3c

File tree

2 files changed

+34
-20
lines changed

2 files changed

+34
-20
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,10 @@ struct PyModule final {
757757
} else if (py::isinstance<py::int_>(python_input)) {
758758
cpp_inputs.push_back(EValue(py::cast<int64_t>(python_input)));
759759
} else {
760-
ET_ASSERT_UNREACHABLE_MSG("Unsupported pytype: %s", type_str.c_str());
760+
throw std::runtime_error(
761+
"Unsupported python type " +
762+
type_str +
763+
". Ensure that inputs are passed as a flat list of tensors.");
761764
}
762765
}
763766

extension/pybindings/test/make_test.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
# pyre-unsafe
88

9+
import sys
910
import unittest
11+
from io import StringIO
1012
from types import ModuleType
1113
from typing import Any, Callable, Optional, Tuple
1214

@@ -16,6 +18,23 @@
1618
from torch.export import export
1719

1820

21+
class RedirectedStderr:
22+
def __init__(self):
23+
self._stderr = None
24+
self._string_io = None
25+
26+
def __enter__(self):
27+
self._stderr = sys.stderr
28+
sys.stderr = self._string_io = StringIO()
29+
return self
30+
31+
def __exit__(self, type, value, traceback):
32+
sys.stderr = self._stderr
33+
34+
def __str__(self):
35+
return self._string_io.getvalue()
36+
37+
1938
class ModuleAdd(torch.nn.Module):
2039
"""The module to serialize and execute."""
2140

@@ -237,25 +256,6 @@ def test_module_single_input(tester):
237256
tester.assertEqual(str(expected), str(executorch_output))
238257

239258
def test_stderr_redirect(tester):
240-
import sys
241-
from io import StringIO
242-
243-
class RedirectedStderr:
244-
def __init__(self):
245-
self._stderr = None
246-
self._string_io = None
247-
248-
def __enter__(self):
249-
self._stderr = sys.stderr
250-
sys.stderr = self._string_io = StringIO()
251-
return self
252-
253-
def __exit__(self, type, value, traceback):
254-
sys.stderr = self._stderr
255-
256-
def __str__(self):
257-
return self._string_io.getvalue()
258-
259259
with RedirectedStderr() as out:
260260
try:
261261
# Create an ExecuTorch program from ModuleAdd.
@@ -464,6 +464,16 @@ def test_verification_config(tester) -> None:
464464

465465
tester.assertEqual(str(expected), str(executorch_output))
466466

467+
def test_unsupported_input_type(tester):
468+
exported_program, inputs = create_program(ModuleAdd())
469+
executorch_module = load_fn(exported_program.buffer)
470+
471+
# Pass an unsupported input type to the module.
472+
inputs = ([*inputs],)
473+
474+
# This should raise a Python error, not hit a fatal assert in the C++ code.
475+
tester.assertRaises(RuntimeError, executorch_module, inputs)
476+
467477
######### RUN TEST CASES #########
468478
test_e2e(tester)
469479
test_multiple_entry(tester)
@@ -479,5 +489,6 @@ def test_verification_config(tester) -> None:
479489
test_method_meta(tester)
480490
test_bad_name(tester)
481491
test_verification_config(tester)
492+
test_unsupported_input_type(tester)
482493

483494
return wrapper

0 commit comments

Comments
 (0)