@@ -525,11 +525,19 @@ class array : public buffer {
525
525
526
526
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
527
527
public:
528
- array_t () : array() { }
528
+ array_t () = default ;
529
+ array_t (handle h, borrowed_t ) : array(h, borrowed) { }
530
+ array_t (handle h, stolen_t ) : array(h, stolen) { }
529
531
530
- array_t (handle h, bool is_borrowed) : array(h, is_borrowed) { m_ptr = ensure_ (m_ptr); }
532
+ PYBIND11_DEPRECATED (" Use array_t<T>::ensure() instead" )
533
+ array_t (handle h, bool is_borrowed) : array(raw_array_from_any(h.ptr()), stolen) {
534
+ if (!m_ptr) PyErr_Clear ();
535
+ if (!is_borrowed) Py_XDECREF (h.ptr ());
536
+ }
531
537
532
- array_t (const object &o) : array(o) { m_ptr = ensure_ (m_ptr); }
538
+ array_t (const object &o) : array(raw_array_from_any(o.ptr()), stolen) {
539
+ if (!m_ptr) throw error_already_set ();
540
+ }
533
541
534
542
explicit array_t (const buffer_info& info) : array(info) { }
535
543
@@ -577,15 +585,24 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
577
585
return *(static_cast <T*>(array::mutable_data ()) + get_byte_offset (index ...) / itemsize ());
578
586
}
579
587
580
- static PyObject *ensure_ (PyObject *ptr) {
588
+ // / Ensure that the argument is a NumPy array of the correct dtype.
589
+ // / In case of an error, nullptr is returned and the Python error is cleared.
590
+ static array_t ensure (handle h) {
591
+ auto result = reinterpret_steal<array_t >(raw_array_from_any (h.ptr ()));
592
+ if (!result)
593
+ PyErr_Clear ();
594
+ return result;
595
+ }
596
+
597
+ static bool _check (handle) = delete; // Make sure py::instance<array_t<T>>() can't compile.
598
+
599
+ private:
600
+ static PyObject *raw_array_from_any (PyObject *ptr) {
581
601
if (ptr == nullptr )
582
602
return nullptr ;
583
603
auto & api = detail::npy_api::get ();
584
604
PyObject *result = api.PyArray_FromAny_ (ptr, pybind11::dtype::of<T>().release ().ptr (), 0 , 0 ,
585
605
detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr );
586
- if (!result)
587
- PyErr_Clear ();
588
- Py_DECREF (ptr);
589
606
return result;
590
607
}
591
608
};
@@ -618,7 +635,7 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
618
635
using type = array_t <T, ExtraFlags>;
619
636
620
637
bool load (handle src, bool /* convert */ ) {
621
- value = type (src, true );
638
+ value = type::ensure (src);
622
639
return static_cast <bool >(value);
623
640
}
624
641
0 commit comments