Skip to content

Commit 0090fc7

Browse files
committed
API: Make numpy.h compatible with both NumPy 1.x and 2.x
1 parent 8b48ff8 commit 0090fc7

File tree

1 file changed

+76
-18
lines changed

1 file changed

+76
-18
lines changed

include/pybind11/numpy.h

Lines changed: 76 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ template <typename type, typename SFINAE = void>
5454
struct npy_format_descriptor;
5555

5656
struct PyArrayDescr_Proxy {
57+
PyObject_HEAD
58+
PyObject *typeobj;
59+
char kind;
60+
char type;
61+
char byteorder;
62+
char _former_flags;
63+
int type_num;
64+
/* Additional fields are NumPy version specific. */
65+
};
66+
67+
/* NumPy 1 proxy (always includes legacy fields) */
68+
struct PyArrayDescr1_Proxy {
5769
PyObject_HEAD
5870
PyObject *typeobj;
5971
char kind;
@@ -68,6 +80,27 @@ struct PyArrayDescr_Proxy {
6880
PyObject *names;
6981
};
7082

83+
/* NumPy 2 proxy, including legacy fields */
84+
struct PyArrayDescr2_Proxy {
85+
PyObject_HEAD
86+
PyObject *typeobj;
87+
char kind;
88+
char type;
89+
char byteorder;
90+
char _former_flags;
91+
int type_num;
92+
std::uint64_t flags;
93+
ssize_t elsize;
94+
ssize_t alignment;
95+
PyObject *metadata;
96+
Py_hash_t hash;
97+
void *reserved_null;
98+
/* The following fields only exist if 0 < type_num < 2000 */
99+
struct _arr_descr *subarray;
100+
PyObject *fields;
101+
PyObject *names;
102+
};
103+
71104
struct PyArray_Proxy {
72105
PyObject_HEAD
73106
char *data;
@@ -203,6 +236,8 @@ struct npy_api {
203236
NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
204237
};
205238

239+
unsigned int PyArray_RUNTIME_VERSION_;
240+
206241
struct PyArray_Dims {
207242
Py_intptr_t *ptr;
208243
int len;
@@ -241,14 +276,6 @@ struct npy_api {
241276
PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
242277
int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
243278
bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
244-
int (*PyArray_GetArrayParamsFromObject_)(PyObject *,
245-
PyObject *,
246-
unsigned char,
247-
PyObject **,
248-
int *,
249-
Py_intptr_t *,
250-
PyObject **,
251-
PyObject *);
252279
PyObject *(*PyArray_Squeeze_)(PyObject *);
253280
// Unused. Not removed because that affects ABI of the class.
254281
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
@@ -275,7 +302,6 @@ struct npy_api {
275302
API_PyArray_View = 137,
276303
API_PyArray_DescrConverter = 174,
277304
API_PyArray_EquivTypes = 182,
278-
API_PyArray_GetArrayParamsFromObject = 278,
279305
API_PyArray_SetBaseObject = 282
280306
};
281307

@@ -290,7 +316,8 @@ struct npy_api {
290316
npy_api api;
291317
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
292318
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
293-
if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7) {
319+
api.PyArray_RUNTIME_VERSION_ = api.PyArray_GetNDArrayCFeatureVersion_();
320+
if (api.PyArray_RUNTIME_VERSION_ < 0x7) {
294321
pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
295322
}
296323
DECL_NPY_API(PyArray_Type);
@@ -309,7 +336,6 @@ struct npy_api {
309336
DECL_NPY_API(PyArray_View);
310337
DECL_NPY_API(PyArray_DescrConverter);
311338
DECL_NPY_API(PyArray_EquivTypes);
312-
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
313339
DECL_NPY_API(PyArray_SetBaseObject);
314340

315341
#undef DECL_NPY_API
@@ -331,6 +357,14 @@ inline const PyArrayDescr_Proxy *array_descriptor_proxy(const PyObject *ptr) {
331357
return reinterpret_cast<const PyArrayDescr_Proxy *>(ptr);
332358
}
333359

360+
inline const PyArrayDescr1_Proxy *array_descriptor1_proxy(const PyObject *ptr) {
361+
return reinterpret_cast<const PyArrayDescr1_Proxy *>(ptr);
362+
}
363+
364+
inline const PyArrayDescr2_Proxy *array_descriptor2_proxy(const PyObject *ptr) {
365+
return reinterpret_cast<const PyArrayDescr2_Proxy *>(ptr);
366+
}
367+
334368
inline bool check_flags(const void *ptr, int flag) {
335369
return (flag == (array_proxy(ptr)->flags & flag));
336370
}
@@ -610,10 +644,24 @@ class dtype : public object {
610644
}
611645

612646
/// Size of the data type in bytes.
613-
ssize_t itemsize() const { return detail::array_descriptor_proxy(m_ptr)->elsize; }
647+
ssize_t itemsize() const {
648+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
649+
return detail::array_descriptor1_proxy(m_ptr)->elsize;
650+
} else {
651+
return detail::array_descriptor2_proxy(m_ptr)->elsize;
652+
}
653+
}
614654

615655
/// Returns true for structured data types.
616-
bool has_fields() const { return detail::array_descriptor_proxy(m_ptr)->names != nullptr; }
656+
bool has_fields() const {
657+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
658+
return detail::array_descriptor1_proxy(m_ptr)->names != nullptr;
659+
} else if (num() < 0 || num() > 2000) {
660+
return false;
661+
} else {
662+
return detail::array_descriptor2_proxy(m_ptr)->names != nullptr;
663+
}
664+
}
617665

618666
/// Single-character code for dtype's kind.
619667
/// For example, floating point types are 'f' and integral types are 'i'.
@@ -640,10 +688,22 @@ class dtype : public object {
640688
char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }
641689

642690
/// Alignment of the data type
643-
int alignment() const { return detail::array_descriptor_proxy(m_ptr)->alignment; }
691+
ssize_t alignment() const {
692+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
693+
return detail::array_descriptor1_proxy(m_ptr)->alignment;
694+
} else {
695+
return detail::array_descriptor2_proxy(m_ptr)->alignment;
696+
}
697+
}
644698

645699
/// Flags for the array descriptor
646-
char flags() const { return detail::array_descriptor_proxy(m_ptr)->flags; }
700+
std::uint64_t flags() const {
701+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
702+
return (unsigned char) detail::array_descriptor1_proxy(m_ptr)->flags;
703+
} else {
704+
return detail::array_descriptor2_proxy(m_ptr)->flags;
705+
}
706+
}
647707

648708
private:
649709
static object &_dtype_from_pep3118() {
@@ -810,9 +870,7 @@ class array : public buffer {
810870
}
811871

812872
/// Byte size of a single element
813-
ssize_t itemsize() const {
814-
return detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
815-
}
873+
ssize_t itemsize() const { return dtype().itemsize(); }
816874

817875
/// Total number of bytes
818876
ssize_t nbytes() const { return size() * itemsize(); }

0 commit comments

Comments
 (0)