Skip to content

Commit 4a2f2ee

Browse files
committed
Make array_t’s converting constructor consistent with other pytypes
* `array_t(const object &)` throws on error * `array_t::ensure()` is intended for casters * `py::isinstance<array_T<T>>()` is intentionally disabled
1 parent e41484a commit 4a2f2ee

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

include/pybind11/eigen.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ struct type_caster<Type, enable_if_t<is_eigen_dense<Type>::value && !is_eigen_re
5454
static constexpr bool isVector = Type::IsVectorAtCompileTime;
5555

5656
bool load(handle src, bool) {
57-
array_t<Scalar> buf(src, true);
57+
auto buf = array_t<Scalar>::ensure(src);
5858
if (!buf)
5959
return false;
6060

include/pybind11/numpy.h

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -525,11 +525,19 @@ class array : public buffer {
525525

526526
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
527527
public:
528-
array_t() : array() { }
528+
array_t() = default;
529+
array_t(handle h, borrowed_t) : array(h, borrowed) { }
530+
array_t(handle h, stolen_t) : array(h, stolen) { }
529531

530-
array_t(handle h, bool is_borrowed) : array(h, is_borrowed) { m_ptr = ensure_(m_ptr); }
532+
PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
533+
array_t(handle h, bool is_borrowed) : array(raw_array_from_any(h.ptr()), stolen) {
534+
if (!m_ptr) PyErr_Clear();
535+
if (!is_borrowed) Py_XDECREF(h.ptr());
536+
}
531537

532-
array_t(const object &o) : array(o) { m_ptr = ensure_(m_ptr); }
538+
array_t(const object &o) : array(raw_array_from_any(o.ptr()), stolen) {
539+
if (!m_ptr) throw error_already_set();
540+
}
533541

534542
explicit array_t(const buffer_info& info) : array(info) { }
535543

@@ -577,15 +585,24 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
577585
return *(static_cast<T*>(array::mutable_data()) + get_byte_offset(index...) / itemsize());
578586
}
579587

580-
static PyObject *ensure_(PyObject *ptr) {
588+
/// Ensure that the argument is a NumPy array of the correct dtype.
589+
/// In case of an error, nullptr is returned and the Python error is cleared.
590+
static array_t ensure(handle h) {
591+
auto result = reinterpret_steal<array_t>(raw_array_from_any(h.ptr()));
592+
if (!result)
593+
PyErr_Clear();
594+
return result;
595+
}
596+
597+
static bool _check(handle) = delete; // Make sure py::instance<array_t<T>>() can't compile.
598+
599+
private:
600+
static PyObject *raw_array_from_any(PyObject *ptr) {
581601
if (ptr == nullptr)
582602
return nullptr;
583603
auto& api = detail::npy_api::get();
584604
PyObject *result = api.PyArray_FromAny_(ptr, pybind11::dtype::of<T>().release().ptr(), 0, 0,
585605
detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
586-
if (!result)
587-
PyErr_Clear();
588-
Py_DECREF(ptr);
589606
return result;
590607
}
591608
};
@@ -618,7 +635,7 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
618635
using type = array_t<T, ExtraFlags>;
619636

620637
bool load(handle src, bool /* convert */) {
621-
value = type(src, true);
638+
value = type::ensure(src);
622639
return static_cast<bool>(value);
623640
}
624641

0 commit comments

Comments
 (0)