@@ -67,6 +67,14 @@ struct StringStruct {
67
67
std::array<char , 3 > b;
68
68
};
69
69
70
+ enum class E1 : int64_t { A = -1 , B = 1 };
71
+ enum E2 : uint8_t { X = 1 , Y = 2 };
72
+
73
+ PYBIND11_PACKED (struct EnumStruct {
74
+ E1 e1 ;
75
+ E2 e2 ;
76
+ });
77
+
70
78
std::ostream& operator <<(std::ostream& os, const StringStruct& v) {
71
79
os << " a='" ;
72
80
for (size_t i = 0 ; i < 3 && v.a [i]; i++) os << v.a [i];
@@ -75,6 +83,10 @@ std::ostream& operator<<(std::ostream& os, const StringStruct& v) {
75
83
return os << " '" ;
76
84
}
77
85
86
+ std::ostream& operator <<(std::ostream& os, const EnumStruct& v) {
87
+ return os << " e1=" << (v.e1 == E1 ::A ? " A" : " B" ) << " ,e2=" << (v.e2 == E2 ::X ? " X" : " Y" );
88
+ }
89
+
78
90
template <typename T>
79
91
py::array mkarray_via_buffer (size_t n) {
80
92
return py::array (py::buffer_info (nullptr , sizeof (T),
@@ -137,6 +149,16 @@ py::array_t<StringStruct, 0> create_string_array(bool non_empty) {
137
149
return arr;
138
150
}
139
151
152
+ py::array_t <EnumStruct, 0 > create_enum_array (size_t n) {
153
+ auto arr = mkarray_via_buffer<EnumStruct>(n);
154
+ auto ptr = (EnumStruct *) arr.mutable_data ();
155
+ for (size_t i = 0 ; i < n; i++) {
156
+ ptr[i].e1 = static_cast <E1 >(-1 + ((int ) i % 2 ) * 2 );
157
+ ptr[i].e2 = static_cast <E2 >(1 + (i % 2 ));
158
+ }
159
+ return arr;
160
+ }
161
+
140
162
template <typename S>
141
163
py::list print_recarray (py::array_t <S, 0 > arr) {
142
164
const auto req = arr.request ();
@@ -157,7 +179,8 @@ py::list print_format_descriptors() {
157
179
py::format_descriptor<NestedStruct>::format (),
158
180
py::format_descriptor<PartialStruct>::format (),
159
181
py::format_descriptor<PartialNestedStruct>::format (),
160
- py::format_descriptor<StringStruct>::format ()
182
+ py::format_descriptor<StringStruct>::format (),
183
+ py::format_descriptor<EnumStruct>::format ()
161
184
};
162
185
auto l = py::list ();
163
186
for (const auto &fmt : fmts) {
@@ -173,7 +196,8 @@ py::list print_dtypes() {
173
196
py::dtype::of<NestedStruct>().str (),
174
197
py::dtype::of<PartialStruct>().str (),
175
198
py::dtype::of<PartialNestedStruct>().str (),
176
- py::dtype::of<StringStruct>().str ()
199
+ py::dtype::of<StringStruct>().str (),
200
+ py::dtype::of<EnumStruct>().str ()
177
201
};
178
202
auto l = py::list ();
179
203
for (const auto &s : dtypes) {
@@ -280,6 +304,7 @@ test_initializer numpy_dtypes([](py::module &m) {
280
304
PYBIND11_NUMPY_DTYPE (PartialStruct, x, y, z);
281
305
PYBIND11_NUMPY_DTYPE (PartialNestedStruct, a);
282
306
PYBIND11_NUMPY_DTYPE (StringStruct, a, b);
307
+ PYBIND11_NUMPY_DTYPE (EnumStruct, e1 , e2 );
283
308
284
309
m.def (" create_rec_simple" , &create_recarray<SimpleStruct>);
285
310
m.def (" create_rec_packed" , &create_recarray<PackedStruct>);
@@ -294,6 +319,8 @@ test_initializer numpy_dtypes([](py::module &m) {
294
319
m.def (" get_format_unbound" , &get_format_unbound);
295
320
m.def (" create_string_array" , &create_string_array);
296
321
m.def (" print_string_array" , &print_recarray<StringStruct>);
322
+ m.def (" create_enum_array" , &create_enum_array);
323
+ m.def (" print_enum_array" , &print_recarray<EnumStruct>);
297
324
m.def (" test_array_ctors" , &test_array_ctors);
298
325
m.def (" test_dtype_ctors" , &test_dtype_ctors);
299
326
m.def (" test_dtype_methods" , &test_dtype_methods);
0 commit comments