Skip to content

Commit 93f17ed

Browse files
committed
API: Add numpy2.h instead and make numpy.h safe
This means that users of `numpy.h` cannot be broken, but need to update to `numpy2.h` if they want to compile for NumPy 2. Using Macros simply and didn't bother to try to remove unnecessary code paths.
1 parent 9116d69 commit 93f17ed

File tree

5 files changed

+60
-19
lines changed

5 files changed

+60
-19
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ set(PYBIND11_HEADERS
161161
include/pybind11/iostream.h
162162
include/pybind11/functional.h
163163
include/pybind11/numpy.h
164+
include/pybind11/numpy2.h
164165
include/pybind11/operators.h
165166
include/pybind11/pybind11.h
166167
include/pybind11/pytypes.h

include/pybind11/eigen/matrix.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
#pragma once
1111

12-
#include "../numpy.h"
12+
#include "../numpy2.h"
1313
#include "common.h"
1414

1515
/* HINT: To suppress warnings originating from the Eigen headers, use -isystem.

include/pybind11/eigen/tensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
#pragma once
99

10-
#include "../numpy.h"
10+
#include "../numpy2.h"
1111
#include "common.h"
1212

1313
#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER)

include/pybind11/numpy.h

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,6 @@ struct handle_type_name<array> {
5353
template <typename type, typename SFINAE = void>
5454
struct npy_format_descriptor;
5555

56-
struct PyArrayDescr_Proxy {
57-
PyObject_HEAD
58-
PyObject *typeobj;
59-
char kind;
60-
char type;
61-
char byteorder;
62-
char _former_flags;
63-
int type_num;
64-
/* Additional fields are NumPy version specific. */
65-
};
66-
6756
/* NumPy 1 proxy (always includes legacy fields) */
6857
struct PyArrayDescr1_Proxy {
6958
PyObject_HEAD
@@ -80,6 +69,22 @@ struct PyArrayDescr1_Proxy {
8069
PyObject *names;
8170
};
8271

72+
#ifdef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
73+
struct PyArrayDescr_Proxy {
74+
PyObject_HEAD
75+
PyObject *typeobj;
76+
char kind;
77+
char type;
78+
char byteorder;
79+
char _former_flags;
80+
int type_num;
81+
/* Additional fields are NumPy version specific. */
82+
};
83+
#else
84+
/* NumPy 1.x only, we can expose all fields */
85+
typedef PyArrayDescr1_Proxy PyArrayDescr_Proxy;
86+
#endif
87+
8388
/* NumPy 2 proxy, including legacy fields */
8489
struct PyArrayDescr2_Proxy {
8590
PyObject_HEAD
@@ -164,6 +169,13 @@ PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name
164169
object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
165170
int major_version = numpy_version.attr("major").cast<int>();
166171

172+
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
173+
if (major_version >= 2) {
174+
throw std::runtime_error("module compiled without NumPy 2 support. Please modify the "
175+
"`pybind11/numpy.h` include to `pybind11/numpy2.h` and recompile "
176+
"(this remains NumPy 1.x compatible but has minor changes).");
177+
}
178+
#endif
167179
/* `numpy.core` was renamed to `numpy._core` in NumPy 2.0 as it officially
168180
became a private module. */
169181
std::string numpy_core_path = major_version >= 2 ? "numpy._core" : "numpy.core";
@@ -276,6 +288,16 @@ struct npy_api {
276288
PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
277289
int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
278290
bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
291+
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
292+
int (*PyArray_GetArrayParamsFromObject_)(PyObject *,
293+
PyObject *,
294+
unsigned char,
295+
PyObject **,
296+
int *,
297+
Py_intptr_t *,
298+
PyObject **,
299+
PyObject *);
300+
#endif
279301
PyObject *(*PyArray_Squeeze_)(PyObject *);
280302
// Unused. Not removed because that affects ABI of the class.
281303
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
@@ -301,7 +323,9 @@ struct npy_api {
301323
API_PyArray_Squeeze = 136,
302324
API_PyArray_View = 137,
303325
API_PyArray_DescrConverter = 174,
326+
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
304327
API_PyArray_EquivTypes = 182,
328+
#endif
305329
API_PyArray_SetBaseObject = 282
306330
};
307331

@@ -644,12 +668,16 @@ class dtype : public object {
644668
}
645669

646670
/// Size of the data type in bytes.
671+
#ifdef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
672+
int itemsize() const { detail::array_descriptor_proxy(m_ptr)->elsize; }
673+
#else
647674
ssize_t itemsize() const {
648675
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
649676
return detail::array_descriptor1_proxy(m_ptr)->elsize;
650677
}
651678
return detail::array_descriptor2_proxy(m_ptr)->elsize;
652679
}
680+
#endif
653681

654682
/// Returns true for structured data types.
655683
bool has_fields() const {
@@ -686,24 +714,31 @@ class dtype : public object {
686714
/// Single character for byteorder
687715
char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }
688716

689-
/// Alignment of the data type
717+
/// Alignment of the data type
718+
#ifdef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
719+
int alignment() const { detail::array_descriptor2_proxy(m_ptr)->alignment; }
720+
#else
690721
ssize_t alignment() const {
691722
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
692-
return detail::array_descriptor1_proxy(m_ptr)->alignment;
723+
return detail::array_descriptor2_proxy(m_ptr)->alignment;
693724
}
694-
return detail::array_descriptor2_proxy(m_ptr)->alignment;
725+
return detail::array_descriptor1_proxy(m_ptr)->alignment;
695726
}
727+
#endif
696728

697-
/// Flags for the array descriptor
729+
/// Flags for the array descriptor
730+
#ifdef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
731+
char flags() const {detail::array_descriptor_proxy(m_ptr)->flags}
732+
#else
698733
std::uint64_t flags() const {
699734
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
700735
return (unsigned char) detail::array_descriptor1_proxy(m_ptr)->flags;
701736
}
702737
return detail::array_descriptor2_proxy(m_ptr)->flags;
703738
}
739+
#endif
704740

705-
private:
706-
static object &_dtype_from_pep3118() {
741+
private : static object &_dtype_from_pep3118() {
707742
PYBIND11_CONSTINIT static gil_safe_call_once_and_store<object> storage;
708743
return storage
709744
.call_once_and_store_result([]() {

include/pybind11/numpy2.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
#define PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
4+
#include "numpy.h"
5+
#undef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT

0 commit comments

Comments
 (0)