@@ -455,12 +455,18 @@ class array : public buffer {
455
455
456
456
array () : array(0 , static_cast <const double *>(nullptr )) {}
457
457
458
- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
459
- const std::vector<size_t > &strides, const void *ptr = nullptr ,
460
- handle base = handle()) {
461
- auto & api = detail::npy_api::get ();
462
- auto ndim = shape.size ();
463
- if (shape.size () != strides.size ())
458
+ using ShapeContainer = detail::any_container<Py_intptr_t>;
459
+ using StridesContainer = detail::any_container<Py_intptr_t>;
460
+
461
+ // Constructs an array taking shape/strides from arbitrary container types
462
+ array (const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides,
463
+ const void *ptr = nullptr , handle base = handle()) {
464
+
465
+ if (strides->empty ())
466
+ strides = default_strides (*shape, dt.itemsize ());
467
+
468
+ auto ndim = shape->size ();
469
+ if (ndim != strides->size ())
464
470
pybind11_fail (" NumPy: shape ndim doesn't match strides ndim" );
465
471
auto descr = dt;
466
472
@@ -474,10 +480,9 @@ class array : public buffer {
474
480
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
475
481
}
476
482
483
+ auto &api = detail::npy_api::get ();
477
484
auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_ (
478
- api.PyArray_Type_ , descr.release ().ptr (), (int ) ndim,
479
- reinterpret_cast <Py_intptr_t *>(const_cast <size_t *>(shape.data ())),
480
- reinterpret_cast <Py_intptr_t *>(const_cast <size_t *>(strides.data ())),
485
+ api.PyArray_Type_ , descr.release ().ptr (), (int ) ndim, shape->data (), strides->data (),
481
486
const_cast <void *>(ptr), flags, nullptr ));
482
487
if (!tmp)
483
488
pybind11_fail (" NumPy: unable to create array!" );
@@ -491,27 +496,24 @@ class array : public buffer {
491
496
m_ptr = tmp.release ().ptr ();
492
497
}
493
498
494
- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
495
- const void *ptr = nullptr , handle base = handle())
496
- : array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
499
+ array (const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr , handle base = handle())
500
+ : array(dt, std::move(shape), {}, ptr, base) { }
497
501
498
502
array (const pybind11::dtype &dt, size_t count, const void *ptr = nullptr ,
499
503
handle base = handle())
500
- : array(dt, std::vector< size_t >{ count }, ptr, base) { }
504
+ : array(dt, ShapeContainer{{ count } }, ptr, base) { }
501
505
502
- template <typename T> array (const std::vector<size_t >& shape,
503
- const std::vector<size_t >& strides,
504
- const T* ptr, handle base = handle())
505
- : array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { }
506
+ template <typename T>
507
+ array (ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
508
+ : array(pybind11::dtype::of<T>(), std::move(shape), std::move(strides), ptr, base) { }
506
509
507
510
template <typename T>
508
- array (const std::vector<size_t > &shape, const T *ptr,
509
- handle base = handle())
510
- : array(shape, default_strides(shape, sizeof (T)), ptr, base) { }
511
+ array (ShapeContainer shape, const T *ptr, handle base = handle())
512
+ : array(std::move(shape), {}, ptr, base) { }
511
513
512
514
template <typename T>
513
515
array (size_t count, const T *ptr, handle base = handle())
514
- : array(std::vector< size_t >{ count }, ptr, base) { }
516
+ : array({{ count } }, ptr, base) { }
515
517
516
518
explicit array (const buffer_info &info)
517
519
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
@@ -673,9 +675,9 @@ class array : public buffer {
673
675
throw std::domain_error (" array is not writeable" );
674
676
}
675
677
676
- static std::vector<size_t > default_strides (const std::vector<size_t >& shape, size_t itemsize) {
678
+ static std::vector<Py_intptr_t > default_strides (const std::vector<Py_intptr_t >& shape, size_t itemsize) {
677
679
auto ndim = shape.size ();
678
- std::vector<size_t > strides (ndim);
680
+ std::vector<Py_intptr_t > strides (ndim);
679
681
if (ndim) {
680
682
std::fill (strides.begin (), strides.end (), itemsize);
681
683
for (size_t i = 0 ; i < ndim - 1 ; i++)
@@ -731,14 +733,11 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
731
733
732
734
explicit array_t (const buffer_info& info) : array(info) { }
733
735
734
- array_t (const std::vector<size_t > &shape,
735
- const std::vector<size_t > &strides, const T *ptr = nullptr ,
736
- handle base = handle())
737
- : array(shape, strides, ptr, base) { }
736
+ array_t (ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr , handle base = handle())
737
+ : array(std::move(shape), std::move(strides), ptr, base) { }
738
738
739
- explicit array_t (const std::vector<size_t > &shape, const T *ptr = nullptr ,
740
- handle base = handle())
741
- : array(shape, ptr, base) { }
739
+ explicit array_t (ShapeContainer shape, const T *ptr = nullptr , handle base = handle())
740
+ : array(std::move(shape), ptr, base) { }
742
741
743
742
explicit array_t (size_t count, const T *ptr = nullptr , handle base = handle())
744
743
: array(count, ptr, base) { }
0 commit comments