Skip to content

Commit 3eb569f

Browse files
authored
Merge pull request #452 from aldanor/feature/numpy-enum
Auto-implement format/numpy descriptors for enum types
2 parents 3599585 + 2f3f368 commit 3eb569f

File tree

3 files changed

+67
-4
lines changed

3 files changed

+67
-4
lines changed

include/pybind11/numpy.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,14 @@ template <size_t N> struct format_descriptor<std::array<char, N>> {
552552
static std::string format() { return std::to_string(N) + "s"; }
553553
};
554554

555+
template <typename T>
556+
struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
557+
static std::string format() {
558+
return format_descriptor<
559+
typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
560+
}
561+
};
562+
555563
NAMESPACE_BEGIN(detail)
556564
template <typename T> struct is_std_array : std::false_type { };
557565
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
@@ -563,6 +571,7 @@ struct is_pod_struct {
563571
!std::is_array<T>::value &&
564572
!is_std_array<T>::value &&
565573
!std::is_integral<T>::value &&
574+
!std::is_enum<T>::value &&
566575
!std::is_same<typename std::remove_cv<T>::type, float>::value &&
567576
!std::is_same<typename std::remove_cv<T>::type, double>::value &&
568577
!std::is_same<typename std::remove_cv<T>::type, bool>::value &&
@@ -612,6 +621,14 @@ template <size_t N> struct npy_format_descriptor<char[N]> { DECL_CHAR_FMT };
612621
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { DECL_CHAR_FMT };
613622
#undef DECL_CHAR_FMT
614623

624+
template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
625+
private:
626+
using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
627+
public:
628+
static PYBIND11_DESCR name() { return base_descr::name(); }
629+
static pybind11::dtype dtype() { return base_descr::dtype(); }
630+
};
631+
615632
struct field_descriptor {
616633
const char *name;
617634
size_t offset;

tests/test_numpy_dtypes.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@ struct StringStruct {
6767
std::array<char, 3> b;
6868
};
6969

70+
enum class E1 : int64_t { A = -1, B = 1 };
71+
enum E2 : uint8_t { X = 1, Y = 2 };
72+
73+
PYBIND11_PACKED(struct EnumStruct {
74+
E1 e1;
75+
E2 e2;
76+
});
77+
7078
std::ostream& operator<<(std::ostream& os, const StringStruct& v) {
7179
os << "a='";
7280
for (size_t i = 0; i < 3 && v.a[i]; i++) os << v.a[i];
@@ -75,6 +83,10 @@ std::ostream& operator<<(std::ostream& os, const StringStruct& v) {
7583
return os << "'";
7684
}
7785

86+
std::ostream& operator<<(std::ostream& os, const EnumStruct& v) {
87+
return os << "e1=" << (v.e1 == E1::A ? "A" : "B") << ",e2=" << (v.e2 == E2::X ? "X" : "Y");
88+
}
89+
7890
template <typename T>
7991
py::array mkarray_via_buffer(size_t n) {
8092
return py::array(py::buffer_info(nullptr, sizeof(T),
@@ -137,6 +149,16 @@ py::array_t<StringStruct, 0> create_string_array(bool non_empty) {
137149
return arr;
138150
}
139151

152+
py::array_t<EnumStruct, 0> create_enum_array(size_t n) {
153+
auto arr = mkarray_via_buffer<EnumStruct>(n);
154+
auto ptr = (EnumStruct *) arr.mutable_data();
155+
for (size_t i = 0; i < n; i++) {
156+
ptr[i].e1 = static_cast<E1>(-1 + ((int) i % 2) * 2);
157+
ptr[i].e2 = static_cast<E2>(1 + (i % 2));
158+
}
159+
return arr;
160+
}
161+
140162
template <typename S>
141163
py::list print_recarray(py::array_t<S, 0> arr) {
142164
const auto req = arr.request();
@@ -157,7 +179,8 @@ py::list print_format_descriptors() {
157179
py::format_descriptor<NestedStruct>::format(),
158180
py::format_descriptor<PartialStruct>::format(),
159181
py::format_descriptor<PartialNestedStruct>::format(),
160-
py::format_descriptor<StringStruct>::format()
182+
py::format_descriptor<StringStruct>::format(),
183+
py::format_descriptor<EnumStruct>::format()
161184
};
162185
auto l = py::list();
163186
for (const auto &fmt : fmts) {
@@ -173,7 +196,8 @@ py::list print_dtypes() {
173196
py::dtype::of<NestedStruct>().str(),
174197
py::dtype::of<PartialStruct>().str(),
175198
py::dtype::of<PartialNestedStruct>().str(),
176-
py::dtype::of<StringStruct>().str()
199+
py::dtype::of<StringStruct>().str(),
200+
py::dtype::of<EnumStruct>().str()
177201
};
178202
auto l = py::list();
179203
for (const auto &s : dtypes) {
@@ -280,6 +304,7 @@ test_initializer numpy_dtypes([](py::module &m) {
280304
PYBIND11_NUMPY_DTYPE(PartialStruct, x, y, z);
281305
PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
282306
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
307+
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
283308

284309
m.def("create_rec_simple", &create_recarray<SimpleStruct>);
285310
m.def("create_rec_packed", &create_recarray<PackedStruct>);
@@ -294,6 +319,8 @@ test_initializer numpy_dtypes([](py::module &m) {
294319
m.def("get_format_unbound", &get_format_unbound);
295320
m.def("create_string_array", &create_string_array);
296321
m.def("print_string_array", &print_recarray<StringStruct>);
322+
m.def("create_enum_array", &create_enum_array);
323+
m.def("print_enum_array", &print_recarray<EnumStruct>);
297324
m.def("test_array_ctors", &test_array_ctors);
298325
m.def("test_dtype_ctors", &test_dtype_ctors);
299326
m.def("test_dtype_methods", &test_dtype_methods);

tests/test_numpy_dtypes.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def test_format_descriptors():
2626
"T{=T{=?:x:3x=I:y:=f:z:}:a:=T{=?:x:=I:y:=f:z:}:b:}",
2727
"T{=?:x:3x=I:y:=f:z:12x}",
2828
"T{8x=T{=?:x:3x=I:y:=f:z:12x}:a:8x}",
29-
"T{=3s:a:=3s:b:}"
29+
"T{=3s:a:=3s:b:}",
30+
'T{=q:e1:=B:e2:}'
3031
]
3132

3233

@@ -40,7 +41,8 @@ def test_dtype():
4041
"[('a', {'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}), ('b', [('x', '?'), ('y', '<u4'), ('z', '<f4')])]",
4142
"{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}",
4243
"{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}",
43-
"[('a', 'S3'), ('b', 'S3')]"
44+
"[('a', 'S3'), ('b', 'S3')]",
45+
"[('e1', '<i8'), ('e2', 'u1')]"
4446
]
4547

4648
d1 = np.dtype({'names': ['a', 'b'], 'formats': ['int32', 'float64'],
@@ -150,6 +152,23 @@ def test_string_array():
150152
assert dtype == arr.dtype
151153

152154

155+
@pytest.requires_numpy
156+
def test_enum_array():
157+
from pybind11_tests import create_enum_array, print_enum_array
158+
159+
arr = create_enum_array(3)
160+
dtype = arr.dtype
161+
assert dtype == np.dtype([('e1', '<i8'), ('e2', 'u1')])
162+
assert print_enum_array(arr) == [
163+
"e1=A,e2=X",
164+
"e1=B,e2=Y",
165+
"e1=A,e2=X"
166+
]
167+
assert arr['e1'].tolist() == [-1, 1, -1]
168+
assert arr['e2'].tolist() == [1, 2, 1]
169+
assert create_enum_array(0).dtype == dtype
170+
171+
153172
@pytest.requires_numpy
154173
def test_signature(doc):
155174
from pybind11_tests import create_rec_nested

0 commit comments

Comments
 (0)