Skip to content

Commit 9ec5905

Browse files
pytypes: Add iterable_t<>
1 parent 6709abb commit 9ec5905

File tree

4 files changed

+123
-2
lines changed

4 files changed

+123
-2
lines changed

include/pybind11/cast.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ PYBIND11_NAMESPACE_BEGIN(detail)
5050
template <typename type, typename SFINAE = void> class type_caster : public type_caster_base<type> { };
5151
template <typename type> using make_caster = type_caster<intrinsic_t<type>>;
5252

53+
template <typename T>
54+
struct is_generic_type<
55+
T,
56+
enable_if_t<std::is_base_of<type_caster_generic, make_caster<T>>::value>
57+
> : public std::true_type {};
58+
59+
template <typename T>
60+
struct is_generic_type<
61+
T,
62+
enable_if_t<!std::is_base_of<type_caster_generic, make_caster<T>>::value>
63+
> : public std::false_type {};
64+
5365
// Shortcut for calling a caster's `cast_op_type` cast operator for casting a type_caster to a T
5466
template <typename T> typename make_caster<T>::template cast_op_type<T> cast_op(make_caster<T> &caster) {
5567
return caster.operator typename make_caster<T>::template cast_op_type<T>();
@@ -747,11 +759,16 @@ template <typename T> struct handle_type_name { static constexpr auto name = _<T
747759
template <> struct handle_type_name<bytes> { static constexpr auto name = _(PYBIND11_BYTES_NAME); };
748760
template <> struct handle_type_name<int_> { static constexpr auto name = _("int"); };
749761
template <> struct handle_type_name<iterable> { static constexpr auto name = _("Iterable"); };
762+
template <typename T>
763+
struct handle_type_name<iterable_t<T>> {
764+
static constexpr auto name = _("Iterable[") + type_caster<T>::name + _("]");
765+
};
750766
template <> struct handle_type_name<iterator> { static constexpr auto name = _("Iterator"); };
751767
template <> struct handle_type_name<none> { static constexpr auto name = _("None"); };
752768
template <> struct handle_type_name<args> { static constexpr auto name = _("*args"); };
753769
template <> struct handle_type_name<kwargs> { static constexpr auto name = _("**kwargs"); };
754770

771+
755772
template <typename type>
756773
struct pyobject_caster {
757774
template <typename T = type, enable_if_t<std::is_same<T, handle>::value, int> = 0>

include/pybind11/pytypes.h

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ PYBIND11_NAMESPACE_BEGIN(detail)
2626
class args_proxy;
2727
inline bool isinstance_generic(handle obj, const std::type_info &tp);
2828

29+
// Indicates that type is generic and and does not have a specialized
30+
// `type_caster<>` specialization. Defined in `cast.h`.
31+
template <typename T, typename SFINAE = void>
32+
struct is_generic_type;
33+
2934
// Accessor forward declarations
3035
template <typename Policy> class accessor;
3136
namespace accessor_policies {
@@ -380,13 +385,18 @@ class error_already_set : public std::runtime_error {
380385
/** \ingroup python_builtins
381386
\rst
382387
Return true if ``obj`` is an instance of ``T``. Type ``T`` must be a subclass of
383-
`object` or a class which was exposed to Python as ``py::class_<T>``.
388+
`object` or a class which was exposed to Python as ``py::class_<T>`` (generic).
384389
\endrst */
385390
template <typename T, detail::enable_if_t<std::is_base_of<object, T>::value, int> = 0>
386391
bool isinstance(handle obj) { return T::check_(obj); }
387392

388393
template <typename T, detail::enable_if_t<!std::is_base_of<object, T>::value, int> = 0>
389-
bool isinstance(handle obj) { return detail::isinstance_generic(obj, typeid(T)); }
394+
bool isinstance(handle obj) {
395+
static_assert(
396+
detail::is_generic_type<T>::value,
397+
"isisntance<T>() requires specialization for this type");
398+
return detail::isinstance_generic(obj, typeid(T));
399+
}
390400

391401
template <> inline bool isinstance<handle>(handle) = delete;
392402
template <> inline bool isinstance<object>(handle obj) { return obj.ptr() != nullptr; }
@@ -753,6 +763,39 @@ inline bool PyIterable_Check(PyObject *obj) {
753763
}
754764
}
755765

766+
template <typename T>
767+
bool PyIterableT_Check(PyObject *obj) {
768+
static_assert(
769+
is_generic_type<T>::value || is_pyobject<T>::value,
770+
"iterable_t can only be used with pyobjects and generic types "
771+
"(py::class_<T>)");
772+
PyObject *iter = PyObject_GetIter(obj);
773+
if (iter) {
774+
if (iter == obj) {
775+
// If they are the same, then that's bad! For now, just throw a
776+
// cast error.
777+
Py_DECREF(iter);
778+
throw cast_error(
779+
"iterable_t<T> cannot be used with exhausitble iterables "
780+
"(e.g., iterators, generators).");
781+
}
782+
bool good = true;
783+
// Now that we know that the iterable `obj` will not be exhausted,
784+
// let's check the contained types.
785+
for (handle h : handle(iter)) {
786+
if (!isinstance<T>(h)) {
787+
good = false;
788+
break;
789+
}
790+
}
791+
Py_DECREF(iter);
792+
return good;
793+
} else {
794+
PyErr_Clear();
795+
return false;
796+
}
797+
}
798+
756799
inline bool PyNone_Check(PyObject *o) { return o == Py_None; }
757800
inline bool PyEllipsis_Check(PyObject *o) { return o == Py_Ellipsis; }
758801

@@ -941,6 +984,20 @@ class iterable : public object {
941984
PYBIND11_OBJECT_DEFAULT(iterable, object, detail::PyIterable_Check)
942985
};
943986

987+
/// Provides similar interface to `iterable`, but constraining the intended
988+
/// type.
989+
/// @warning Due to technical reasons, this is constrained in two ways:
990+
/// - Due to how `isinstance<T>()` works, this does *not* work for iterables of
991+
/// type-converted values (e.g. `int`).
992+
/// - Because we must check the contained types within the iterable (for
993+
/// overloads), we must iterate through the iterable. For this reason, the
994+
/// iterable should *not* be exhaustible (e.g., iterator, generator).
995+
template <typename T>
996+
class iterable_t : public iterable {
997+
public:
998+
PYBIND11_OBJECT_DEFAULT(iterable_t, iterable, detail::PyIterableT_Check<T>)
999+
};
1000+
9441001
class bytes;
9451002

9461003
class str : public object {

tests/test_pytypes.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ TEST_SUBMODULE(pytypes, m) {
1717
m.def("get_iterator", []{return py::iterator();});
1818
// test_iterable
1919
m.def("get_iterable", []{return py::iterable();});
20+
// test_iterable_t
21+
m.def("get_iterable_t", []{return py::iterable_t<py::str>();});
22+
// test_iterable_t_overloads
23+
m.def("accept_iterable_t", [](py::iterable_t<py::str>) { return "str"; });
24+
m.def("accept_iterable_t", [](py::iterable_t<py::bytes>) { return "bytes"; });
25+
// // Uncomment to see compiler error.
26+
// m.def("accept_iterable_t", [](py::iterable_t<int>) { return "int"; });
27+
2028
// test_list
2129
m.def("get_list", []() {
2230
py::list list;

tests/test_pytypes.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,45 @@ def test_iterable(doc):
2121
assert doc(m.get_iterable) == "get_iterable() -> Iterable"
2222

2323

24+
def test_iterable_t(doc):
25+
assert doc(m.get_iterable_t) == "get_iterable_t() -> Iterable[str]"
26+
27+
28+
def test_iterable_t_overloads():
29+
# Empty: First one wins.
30+
list_empty = []
31+
set_empty = set()
32+
assert m.accept_iterable_t(list_empty) == "str"
33+
assert m.accept_iterable_t(set_empty) == "str"
34+
# Negative: Exhaustible iterables (e.g. iterators, generators).
35+
gen_empty = (x for x in set_empty)
36+
with pytest.raises(RuntimeError):
37+
m.accept_iterable_t(gen_empty)
38+
iter_empty = iter(list_empty)
39+
with pytest.raises(RuntimeError):
40+
m.accept_iterable_t(iter_empty)
41+
42+
# Str.
43+
list_of_str = ["hey", "you"]
44+
set_of_str = {"hey", "you"}
45+
assert m.accept_iterable_t(list_of_str) == "str"
46+
assert m.accept_iterable_t(set_of_str) == "str"
47+
# - Negative: Not fully `str`.
48+
list_of_str_and_then_some = ["hey", 0]
49+
with pytest.raises(TypeError):
50+
m.accept_iterable_t(list_of_str_and_then_some)
51+
52+
# Bytes.
53+
list_of_bytes = [b"hey", b"you"]
54+
set_of_bytes = {b"hey", b"you"}
55+
assert m.accept_iterable_t(list_of_bytes) == "bytes"
56+
assert m.accept_iterable_t(set_of_bytes) == "bytes"
57+
# - Negative: Not fully `bytes`.
58+
list_of_bytes_and_then_some = [b"hey", 0]
59+
with pytest.raises(TypeError):
60+
m.accept_iterable_t(list_of_bytes_and_then_some)
61+
62+
2463
def test_list(capture, doc):
2564
with capture:
2665
lst = m.get_list()

0 commit comments

Comments
 (0)