Skip to content

Commit 8e1f9d5

Browse files
authored
Add format_descriptor<> & npy_format_descriptor<> PyObject * specializations. (#4674)
* Add `npy_format_descriptor<PyObject *>` to enable `py::array_t<PyObject *>` to/from-python conversions. * resolve clang-tidy warning * Use existing constructor instead of adding a static method. Thanks @Skylion007 for pointing out. * Add `format_descriptor<PyObject *>` Trivial addition, but still in search for a meaningful test. * Add test_format_descriptor_format * Ensure the Eigen `type_caster`s do not segfault when loading arrays with dtype=object * Use `static_assert()` `!std::is_pointer<>` to replace runtime guards. * Add comments to explain how to check for ref-count bugs. (NO code changes.) * Make the "Pointer types ... are not supported" message Eigen-specific, as suggested by @lalaland. Move to new pybind11/eigen/common.h header. * Change "format_descriptor_format" implementation as suggested by @lalaland. Additional tests meant to ensure consistency between py::format_descriptor<>, np.array, np.format_parser turn out to be useful only to highlight long-standing inconsistencies. * resolve clang-tidy warning * Account for np.float128, np.complex256 not being available on Windows, in a future-proof way. * Fully address i|q|l ambiguity (hopefully). * Remove the new `np.format_parser()`-based test, it's much more distracting than useful. * Use bi.itemsize to disambiguate "l" or "L" * Use `py::detail::compare_buffer_info<T>::compare()` to validate the `format_descriptor<T>::format()` strings. * Add `buffer_info::compare<T>` to make `detail::compare_buffer_info<T>::compare` more visible & accessible. * silence clang-tidy warning * pytest-compatible access to np.float128, np.complex256 * Revert "pytest-compatible access to np.float128, np.complex256" This reverts commit e9a289c. * Use `sizeof(long double) == sizeof(double)` instead of `std::is_same<>` * Report skipped `long double` tests. * Change the name of the new `buffer_info` member function to `item_type_is_equivalent_to`. Add comment defining "equivalent" by example. * Change `item_type_is_equivalent_to<>()` from `static` function to member function, as suggested by @lalaland
1 parent 6e6bcca commit 8e1f9d5

File tree

12 files changed

+255
-7
lines changed

12 files changed

+255
-7
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ set(PYBIND11_HEADERS
126126
include/pybind11/complex.h
127127
include/pybind11/options.h
128128
include/pybind11/eigen.h
129+
include/pybind11/eigen/common.h
129130
include/pybind11/eigen/matrix.h
130131
include/pybind11/eigen/tensor.h
131132
include/pybind11/embed.h

include/pybind11/buffer_info.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ inline std::vector<ssize_t> f_strides(const std::vector<ssize_t> &shape, ssize_t
3737
return strides;
3838
}
3939

40+
template <typename T, typename SFINAE = void>
41+
struct compare_buffer_info;
42+
4043
PYBIND11_NAMESPACE_END(detail)
4144

4245
/// Information record describing a Python buffer object
@@ -150,6 +153,17 @@ struct buffer_info {
150153
Py_buffer *view() const { return m_view; }
151154
Py_buffer *&view() { return m_view; }
152155

156+
/* True if the buffer item type is equivalent to `T`. */
157+
// To define "equivalent" by example:
158+
// `buffer_info::item_type_is_equivalent_to<int>(b)` and
159+
// `buffer_info::item_type_is_equivalent_to<long>(b)` may both be true
160+
// on some platforms, but `int` and `unsigned` will never be equivalent.
161+
// For the ground truth, please inspect `detail::compare_buffer_info<>`.
162+
template <typename T>
163+
bool item_type_is_equivalent_to() const {
164+
return detail::compare_buffer_info<T>::compare(*this);
165+
}
166+
153167
private:
154168
struct private_ctr_tag {};
155169

@@ -170,9 +184,10 @@ struct buffer_info {
170184

171185
PYBIND11_NAMESPACE_BEGIN(detail)
172186

173-
template <typename T, typename SFINAE = void>
187+
template <typename T, typename SFINAE>
174188
struct compare_buffer_info {
175189
static bool compare(const buffer_info &b) {
190+
// NOLINTNEXTLINE(bugprone-sizeof-expression) Needed for `PyObject *`
176191
return b.format == format_descriptor<T>::format() && b.itemsize == (ssize_t) sizeof(T);
177192
}
178193
};

include/pybind11/detail/common.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,15 @@ PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used in
10251025
template <typename T, typename SFINAE = void>
10261026
struct format_descriptor {};
10271027

1028+
template <typename T>
1029+
struct format_descriptor<
1030+
T,
1031+
detail::enable_if_t<detail::is_same_ignoring_cvref<T, PyObject *>::value>> {
1032+
static constexpr const char c = 'O';
1033+
static constexpr const char value[2] = {c, '\0'};
1034+
static std::string format() { return std::string(1, c); }
1035+
};
1036+
10281037
PYBIND11_NAMESPACE_BEGIN(detail)
10291038
// Returns the index of the given type in the type char array below, and in the list in numpy.h
10301039
// The order here is: bool; 8 ints ((signed,unsigned)x(8,16,32,64)bits); float,double,long double;

include/pybind11/eigen/common.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// Copyright (c) 2023 The pybind Community.
2+
3+
#pragma once
4+
5+
// Common message for `static_assert()`s, which are useful to easily
6+
// preempt much less obvious errors.
7+
#define PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED \
8+
"Pointer types (in particular `PyObject *`) are not supported as scalar types for Eigen " \
9+
"types."

include/pybind11/eigen/matrix.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111

1212
#include "../numpy.h"
13+
#include "common.h"
1314

1415
/* HINT: To suppress warnings originating from the Eigen headers, use -isystem.
1516
See also:
@@ -287,6 +288,8 @@ handle eigen_encapsulate(Type *src) {
287288
template <typename Type>
288289
struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
289290
using Scalar = typename Type::Scalar;
291+
static_assert(!std::is_pointer<Scalar>::value,
292+
PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
290293
using props = EigenProps<Type>;
291294

292295
bool load(handle src, bool convert) {
@@ -405,6 +408,9 @@ struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
405408
// Base class for casting reference/map/block/etc. objects back to python.
406409
template <typename MapType>
407410
struct eigen_map_caster {
411+
static_assert(!std::is_pointer<typename MapType::Scalar>::value,
412+
PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
413+
408414
private:
409415
using props = EigenProps<MapType>;
410416

@@ -457,6 +463,8 @@ struct type_caster<
457463
using Type = Eigen::Ref<PlainObjectType, 0, StrideType>;
458464
using props = EigenProps<Type>;
459465
using Scalar = typename props::Scalar;
466+
static_assert(!std::is_pointer<Scalar>::value,
467+
PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
460468
using MapType = Eigen::Map<PlainObjectType, 0, StrideType>;
461469
using Array
462470
= array_t<Scalar,
@@ -604,6 +612,9 @@ struct type_caster<
604612
// regular Eigen::Matrix, then casting that.
605613
template <typename Type>
606614
struct type_caster<Type, enable_if_t<is_eigen_other<Type>::value>> {
615+
static_assert(!std::is_pointer<typename Type::Scalar>::value,
616+
PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
617+
607618
protected:
608619
using Matrix
609620
= Eigen::Matrix<typename Type::Scalar, Type::RowsAtCompileTime, Type::ColsAtCompileTime>;
@@ -632,6 +643,8 @@ struct type_caster<Type, enable_if_t<is_eigen_other<Type>::value>> {
632643
template <typename Type>
633644
struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
634645
using Scalar = typename Type::Scalar;
646+
static_assert(!std::is_pointer<Scalar>::value,
647+
PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
635648
using StorageIndex = remove_reference_t<decltype(*std::declval<Type>().outerIndexPtr())>;
636649
using Index = typename Type::Index;
637650
static constexpr bool rowMajor = Type::IsRowMajor;

include/pybind11/eigen/tensor.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
#include "../numpy.h"
11+
#include "common.h"
1112

1213
#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER)
1314
static_assert(__GNUC__ > 5, "Eigen Tensor support in pybind11 requires GCC > 5.0");
@@ -164,6 +165,8 @@ PYBIND11_WARNING_POP
164165

165166
template <typename Type>
166167
struct type_caster<Type, typename eigen_tensor_helper<Type>::ValidType> {
168+
static_assert(!std::is_pointer<typename Type::Scalar>::value,
169+
PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
167170
using Helper = eigen_tensor_helper<Type>;
168171
static constexpr auto temp_name = get_tensor_descriptor<Type, false>::value;
169172
PYBIND11_TYPE_CASTER(Type, temp_name);
@@ -359,6 +362,8 @@ struct get_storage_pointer_type<MapType, void_t<typename MapType::PointerArgType
359362
template <typename Type, int Options>
360363
struct type_caster<Eigen::TensorMap<Type, Options>,
361364
typename eigen_tensor_helper<remove_cv_t<Type>>::ValidType> {
365+
static_assert(!std::is_pointer<typename Type::Scalar>::value,
366+
PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
362367
using MapType = Eigen::TensorMap<Type, Options>;
363368
using Helper = eigen_tensor_helper<remove_cv_t<Type>>;
364369

include/pybind11/numpy.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,8 @@ class dtype : public object {
564564
m_ptr = from_args(args).release().ptr();
565565
}
566566

567+
/// Return dtype for the given typenum (one of the NPY_TYPES).
568+
/// https://numpy.org/devdocs/reference/c-api/array.html#c.PyArray_DescrFromType
567569
explicit dtype(int typenum)
568570
: object(detail::npy_api::get().PyArray_DescrFromType_(typenum), stolen_t{}) {
569571
if (m_ptr == nullptr) {
@@ -1283,12 +1285,16 @@ struct npy_format_descriptor<
12831285
public:
12841286
static constexpr int value = values[detail::is_fmt_numeric<T>::index];
12851287

1286-
static pybind11::dtype dtype() {
1287-
if (auto *ptr = npy_api::get().PyArray_DescrFromType_(value)) {
1288-
return reinterpret_steal<pybind11::dtype>(ptr);
1289-
}
1290-
pybind11_fail("Unsupported buffer format!");
1291-
}
1288+
static pybind11::dtype dtype() { return pybind11::dtype(/*typenum*/ value); }
1289+
};
1290+
1291+
template <typename T>
1292+
struct npy_format_descriptor<T, enable_if_t<is_same_ignoring_cvref<T, PyObject *>::value>> {
1293+
static constexpr auto name = const_name("object");
1294+
1295+
static constexpr int value = npy_api::NPY_OBJECT_;
1296+
1297+
static pybind11::dtype dtype() { return pybind11::dtype(/*typenum*/ value); }
12921298
};
12931299

12941300
#define PYBIND11_DECL_CHAR_FMT \

tests/extra_python_package/test_files.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
}
5858

5959
eigen_headers = {
60+
"include/pybind11/eigen/common.h",
6061
"include/pybind11/eigen/matrix.h",
6162
"include/pybind11/eigen/tensor.h",
6263
}

tests/test_buffers.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,47 @@
77
BSD-style license that can be found in the LICENSE file.
88
*/
99

10+
#include <pybind11/complex.h>
1011
#include <pybind11/stl.h>
1112

1213
#include "constructor_stats.h"
1314
#include "pybind11_tests.h"
1415

1516
TEST_SUBMODULE(buffers, m) {
17+
m.attr("long_double_and_double_have_same_size") = (sizeof(long double) == sizeof(double));
18+
19+
m.def("format_descriptor_format_buffer_info_equiv",
20+
[](const std::string &cpp_name, const py::buffer &buffer) {
21+
// https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
22+
static auto *format_table = new std::map<std::string, std::string>;
23+
static auto *equiv_table
24+
= new std::map<std::string, bool (py::buffer_info::*)() const>;
25+
if (format_table->empty()) {
26+
#define PYBIND11_ASSIGN_HELPER(...) \
27+
(*format_table)[#__VA_ARGS__] = py::format_descriptor<__VA_ARGS__>::format(); \
28+
(*equiv_table)[#__VA_ARGS__] = &py::buffer_info::item_type_is_equivalent_to<__VA_ARGS__>;
29+
PYBIND11_ASSIGN_HELPER(PyObject *)
30+
PYBIND11_ASSIGN_HELPER(bool)
31+
PYBIND11_ASSIGN_HELPER(std::int8_t)
32+
PYBIND11_ASSIGN_HELPER(std::uint8_t)
33+
PYBIND11_ASSIGN_HELPER(std::int16_t)
34+
PYBIND11_ASSIGN_HELPER(std::uint16_t)
35+
PYBIND11_ASSIGN_HELPER(std::int32_t)
36+
PYBIND11_ASSIGN_HELPER(std::uint32_t)
37+
PYBIND11_ASSIGN_HELPER(std::int64_t)
38+
PYBIND11_ASSIGN_HELPER(std::uint64_t)
39+
PYBIND11_ASSIGN_HELPER(float)
40+
PYBIND11_ASSIGN_HELPER(double)
41+
PYBIND11_ASSIGN_HELPER(long double)
42+
PYBIND11_ASSIGN_HELPER(std::complex<float>)
43+
PYBIND11_ASSIGN_HELPER(std::complex<double>)
44+
PYBIND11_ASSIGN_HELPER(std::complex<long double>)
45+
#undef PYBIND11_ASSIGN_HELPER
46+
}
47+
return std::pair<std::string, bool>(
48+
(*format_table)[cpp_name], (buffer.request().*((*equiv_table)[cpp_name]))());
49+
});
50+
1651
// test_from_python / test_to_python:
1752
class Matrix {
1853
public:

tests/test_buffers.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,63 @@
1010

1111
np = pytest.importorskip("numpy")
1212

13+
if m.long_double_and_double_have_same_size:
14+
# Determined by the compiler used to build the pybind11 tests
15+
# (e.g. MSVC gets here, but MinGW might not).
16+
np_float128 = None
17+
np_complex256 = None
18+
else:
19+
# Determined by the compiler used to build numpy (e.g. MinGW).
20+
np_float128 = getattr(np, *["float128"] * 2)
21+
np_complex256 = getattr(np, *["complex256"] * 2)
22+
23+
CPP_NAME_FORMAT_NP_DTYPE_TABLE = [
24+
("PyObject *", "O", object),
25+
("bool", "?", np.bool_),
26+
("std::int8_t", "b", np.int8),
27+
("std::uint8_t", "B", np.uint8),
28+
("std::int16_t", "h", np.int16),
29+
("std::uint16_t", "H", np.uint16),
30+
("std::int32_t", "i", np.int32),
31+
("std::uint32_t", "I", np.uint32),
32+
("std::int64_t", "q", np.int64),
33+
("std::uint64_t", "Q", np.uint64),
34+
("float", "f", np.float32),
35+
("double", "d", np.float64),
36+
("long double", "g", np_float128),
37+
("std::complex<float>", "Zf", np.complex64),
38+
("std::complex<double>", "Zd", np.complex128),
39+
("std::complex<long double>", "Zg", np_complex256),
40+
]
41+
CPP_NAME_FORMAT_TABLE = [
42+
(cpp_name, format)
43+
for cpp_name, format, np_dtype in CPP_NAME_FORMAT_NP_DTYPE_TABLE
44+
if np_dtype is not None
45+
]
46+
CPP_NAME_NP_DTYPE_TABLE = [
47+
(cpp_name, np_dtype) for cpp_name, _, np_dtype in CPP_NAME_FORMAT_NP_DTYPE_TABLE
48+
]
49+
50+
51+
@pytest.mark.parametrize(("cpp_name", "np_dtype"), CPP_NAME_NP_DTYPE_TABLE)
52+
def test_format_descriptor_format_buffer_info_equiv(cpp_name, np_dtype):
53+
if np_dtype is None:
54+
pytest.skip(
55+
f"cpp_name=`{cpp_name}`: `long double` and `double` have same size."
56+
)
57+
if isinstance(np_dtype, str):
58+
pytest.skip(f"np.{np_dtype} does not exist.")
59+
np_array = np.array([], dtype=np_dtype)
60+
for other_cpp_name, expected_format in CPP_NAME_FORMAT_TABLE:
61+
format, np_array_is_matching = m.format_descriptor_format_buffer_info_equiv(
62+
other_cpp_name, np_array
63+
)
64+
assert format == expected_format
65+
if other_cpp_name == cpp_name:
66+
assert np_array_is_matching
67+
else:
68+
assert not np_array_is_matching
69+
1370

1471
def test_from_python():
1572
with pytest.raises(RuntimeError) as excinfo:

tests/test_numpy_array.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,4 +523,30 @@ TEST_SUBMODULE(numpy_array, sm) {
523523
sm.def("test_fmt_desc_const_double", [](const py::array_t<const double> &) {});
524524

525525
sm.def("round_trip_float", [](double d) { return d; });
526+
527+
sm.def("pass_array_pyobject_ptr_return_sum_str_values",
528+
[](const py::array_t<PyObject *> &objs) {
529+
std::string sum_str_values;
530+
for (const auto &obj : objs) {
531+
sum_str_values += py::str(obj.attr("value"));
532+
}
533+
return sum_str_values;
534+
});
535+
536+
sm.def("pass_array_pyobject_ptr_return_as_list",
537+
[](const py::array_t<PyObject *> &objs) -> py::list { return objs; });
538+
539+
sm.def("return_array_pyobject_ptr_cpp_loop", [](const py::list &objs) {
540+
py::size_t arr_size = py::len(objs);
541+
py::array_t<PyObject *> arr_from_list(static_cast<py::ssize_t>(arr_size));
542+
PyObject **data = arr_from_list.mutable_data();
543+
for (py::size_t i = 0; i < arr_size; i++) {
544+
assert(data[i] == nullptr);
545+
data[i] = py::cast<PyObject *>(objs[i].attr("value"));
546+
}
547+
return arr_from_list;
548+
});
549+
550+
sm.def("return_array_pyobject_ptr_from_list",
551+
[](const py::list &objs) -> py::array_t<PyObject *> { return objs; });
526552
}

0 commit comments

Comments
 (0)