Skip to content

Commit ea9ba5f

Browse files
committed
API: Add numpy2.h instead and make numpy.h safe
This means that users of `numpy.h` cannot be broken, but need to update to `numpy2.h` if they want to compile for NumPy 2. Using Macros simply and didn't bother to try to remove unnecessary code paths.
1 parent 9116d69 commit ea9ba5f

File tree

6 files changed

+66
-15
lines changed

6 files changed

+66
-15
lines changed

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ set(PYBIND11_HEADERS
161161
include/pybind11/iostream.h
162162
include/pybind11/functional.h
163163
include/pybind11/numpy.h
164+
include/pybind11/numpy2.h
164165
include/pybind11/operators.h
165166
include/pybind11/pybind11.h
166167
include/pybind11/pytypes.h

include/pybind11/eigen/matrix.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
#pragma once
1111

12-
#include "../numpy.h"
12+
#include "../numpy2.h"
1313
#include "common.h"
1414

1515
/* HINT: To suppress warnings originating from the Eigen headers, use -isystem.

include/pybind11/eigen/tensor.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
#pragma once
99

10-
#include "../numpy.h"
10+
#include "../numpy2.h"
1111
#include "common.h"
1212

1313
#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER)

include/pybind11/numpy.h

+57-13
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,6 @@ struct handle_type_name<array> {
5353
template <typename type, typename SFINAE = void>
5454
struct npy_format_descriptor;
5555

56-
struct PyArrayDescr_Proxy {
57-
PyObject_HEAD
58-
PyObject *typeobj;
59-
char kind;
60-
char type;
61-
char byteorder;
62-
char _former_flags;
63-
int type_num;
64-
/* Additional fields are NumPy version specific. */
65-
};
66-
6756
/* NumPy 1 proxy (always includes legacy fields) */
6857
struct PyArrayDescr1_Proxy {
6958
PyObject_HEAD
@@ -80,6 +69,22 @@ struct PyArrayDescr1_Proxy {
8069
PyObject *names;
8170
};
8271

72+
#ifdef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
73+
struct PyArrayDescr_Proxy {
74+
PyObject_HEAD
75+
PyObject *typeobj;
76+
char kind;
77+
char type;
78+
char byteorder;
79+
char _former_flags;
80+
int type_num;
81+
/* Additional fields are NumPy version specific. */
82+
};
83+
#else
84+
/* NumPy 1.x only, we can expose all fields */
85+
typedef PyArrayDescr1_Proxy PyArrayDescr_Proxy;
86+
#endif
87+
8388
/* NumPy 2 proxy, including legacy fields */
8489
struct PyArrayDescr2_Proxy {
8590
PyObject_HEAD
@@ -164,6 +169,13 @@ PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name
164169
object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
165170
int major_version = numpy_version.attr("major").cast<int>();
166171

172+
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
173+
if (major_version >= 2) {
174+
throw std::runtime_error("module compiled without NumPy 2 support. Please modify the "
175+
"`pybind11/numpy.h` include to `pybind11/numpy2.h` and recompile "
176+
"(this remains NumPy 1.x compatible but has minor changes).");
177+
}
178+
#endif
167179
/* `numpy.core` was renamed to `numpy._core` in NumPy 2.0 as it officially
168180
became a private module. */
169181
std::string numpy_core_path = major_version >= 2 ? "numpy._core" : "numpy.core";
@@ -276,6 +288,16 @@ struct npy_api {
276288
PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
277289
int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
278290
bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
291+
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
292+
int (*PyArray_GetArrayParamsFromObject_)(PyObject *,
293+
PyObject *,
294+
unsigned char,
295+
PyObject **,
296+
int *,
297+
Py_intptr_t *,
298+
PyObject **,
299+
PyObject *);
300+
#endif
279301
PyObject *(*PyArray_Squeeze_)(PyObject *);
280302
// Unused. Not removed because that affects ABI of the class.
281303
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
@@ -302,6 +324,9 @@ struct npy_api {
302324
API_PyArray_View = 137,
303325
API_PyArray_DescrConverter = 174,
304326
API_PyArray_EquivTypes = 182,
327+
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
328+
API_PyArray_GetArrayParamsFromObject = 278,
329+
#endif
305330
API_PyArray_SetBaseObject = 282
306331
};
307332

@@ -336,6 +361,9 @@ struct npy_api {
336361
DECL_NPY_API(PyArray_View);
337362
DECL_NPY_API(PyArray_DescrConverter);
338363
DECL_NPY_API(PyArray_EquivTypes);
364+
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
365+
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
366+
#endif
339367
DECL_NPY_API(PyArray_SetBaseObject);
340368

341369
#undef DECL_NPY_API
@@ -644,14 +672,21 @@ class dtype : public object {
644672
}
645673

646674
/// Size of the data type in bytes.
675+
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
676+
ssize_t itemsize() const { return detail::array_descriptor_proxy(m_ptr)->elsize; }
677+
#else
647678
ssize_t itemsize() const {
648679
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
649680
return detail::array_descriptor1_proxy(m_ptr)->elsize;
650681
}
651682
return detail::array_descriptor2_proxy(m_ptr)->elsize;
652683
}
684+
#endif
653685

654686
/// Returns true for structured data types.
687+
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
688+
bool has_fields() const { return detail::array_descriptor_proxy(m_ptr)->names != nullptr; }
689+
#else
655690
bool has_fields() const {
656691
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
657692
return detail::array_descriptor1_proxy(m_ptr)->names != nullptr;
@@ -661,6 +696,7 @@ class dtype : public object {
661696
}
662697
return detail::array_descriptor2_proxy(m_ptr)->names != nullptr;
663698
}
699+
#endif
664700

665701
/// Single-character code for dtype's kind.
666702
/// For example, floating point types are 'f' and integral types are 'i'.
@@ -686,21 +722,29 @@ class dtype : public object {
686722
/// Single character for byteorder
687723
char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }
688724

689-
/// Alignment of the data type
725+
/// Alignment of the data type
726+
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
727+
int alignment() const { return detail::array_descriptor_proxy(m_ptr)->alignment; }
728+
#else
690729
ssize_t alignment() const {
691730
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
692731
return detail::array_descriptor1_proxy(m_ptr)->alignment;
693732
}
694733
return detail::array_descriptor2_proxy(m_ptr)->alignment;
695734
}
735+
#endif
696736

697-
/// Flags for the array descriptor
737+
/// Flags for the array descriptor
738+
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
739+
char flags() const { return detail::array_descriptor_proxy(m_ptr)->flags; }
740+
#else
698741
std::uint64_t flags() const {
699742
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
700743
return (unsigned char) detail::array_descriptor1_proxy(m_ptr)->flags;
701744
}
702745
return detail::array_descriptor2_proxy(m_ptr)->flags;
703746
}
747+
#endif
704748

705749
private:
706750
static object &_dtype_from_pep3118() {

include/pybind11/numpy2.h

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
#define PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
4+
#include "numpy.h"
5+
#undef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT

tests/extra_python_package/test_files.py

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"include/pybind11/gil_safe_call_once.h",
3939
"include/pybind11/iostream.h",
4040
"include/pybind11/numpy.h",
41+
"include/pybind11/numpy2.h",
4142
"include/pybind11/operators.h",
4243
"include/pybind11/options.h",
4344
"include/pybind11/pybind11.h",

0 commit comments

Comments
 (0)