Skip to content

Commit 5f38386

Browse files
committed
Accept abitrary containers and iterators for shape/strides
This adds support for constructing `buffer_info` and `array`s using arbitrary containers or iterator pairs instead of requiring a vector. This is primarily needed by PR #782 (which makes strides signed to properly support negative strides, and will likely also make shape and itemsize to avoid mixed integer issues), but also needs to preserve backwards compatibility with 2.1 and earlier which accepts the strides parameter as a vector of size_t's. Rather than adding nearly duplicate constructors for each stride-taking constructor, it seems nicer to simply allow any type of container (or iterator pairs). This works by replacing the existing vector arguments with a new `detail::any_container` class that handles implicit conversion of arbitrary containers into a vector of the desired type. It can also be explicitly instantiated with a pair of iterators (e.g. by passing {begin, end} instead of the container).
1 parent dbb4c5b commit 5f38386

File tree

5 files changed

+98
-60
lines changed

5 files changed

+98
-60
lines changed

include/pybind11/buffer_info.h

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
BSD-style license that can be found in the LICENSE file.
88
*/
99

10-
#pragma once
10+
#pragma once
1111

1212
#include "common.h"
1313

@@ -26,25 +26,22 @@ struct buffer_info {
2626
buffer_info() { }
2727

2828
buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t ndim,
29-
const std::vector<size_t> &shape, const std::vector<size_t> &strides)
30-
: ptr(ptr), itemsize(itemsize), size(1), format(format),
31-
ndim(ndim), shape(shape), strides(strides) {
29+
detail::any_container<size_t> shape_in, detail::any_container<size_t> strides_in)
30+
: ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
31+
shape(std::move(shape_in)), strides(std::move(strides_in)) {
32+
if (ndim != shape.size() || ndim != strides.size())
33+
pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
3234
for (size_t i = 0; i < ndim; ++i)
3335
size *= shape[i];
3436
}
3537

3638
buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t size)
37-
: buffer_info(ptr, itemsize, format, 1, std::vector<size_t> { size },
38-
std::vector<size_t> { itemsize }) { }
39-
40-
explicit buffer_info(Py_buffer *view, bool ownview = true)
41-
: ptr(view->buf), itemsize((size_t) view->itemsize), size(1), format(view->format),
42-
ndim((size_t) view->ndim), shape((size_t) view->ndim), strides((size_t) view->ndim), view(view), ownview(ownview) {
43-
for (size_t i = 0; i < (size_t) view->ndim; ++i) {
44-
shape[i] = (size_t) view->shape[i];
45-
strides[i] = (size_t) view->strides[i];
46-
size *= shape[i];
47-
}
39+
: buffer_info(ptr, itemsize, format, 1, { size }, { itemsize }) { }
40+
41+
explicit buffer_info(Py_buffer *view, bool ownview_in = true)
42+
: buffer_info(view->buf, (size_t) view->itemsize, view->format, (size_t) view->ndim,
43+
{view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) {
44+
ownview = ownview_in;
4845
}
4946

5047
buffer_info(const buffer_info &) = delete;

include/pybind11/common.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,12 @@ struct is_instantiation<Class, Class<Us...>> : std::true_type { };
490490
/// Check if T is std::shared_ptr<U> where U can be anything
491491
template <typename T> using is_shared_ptr = is_instantiation<std::shared_ptr, T>;
492492

493+
/// Check if T looks like an input iterator
494+
template <typename T, typename = void> struct is_input_iterator : std::false_type {};
495+
template <typename T>
496+
struct is_input_iterator<T, void_t<decltype(*std::declval<T &>()), decltype(++std::declval<T &>())>>
497+
: std::true_type {};
498+
493499
/// Ignore that a variable is unused in compiler warnings
494500
inline void ignore_unused(const int *) { }
495501

@@ -651,4 +657,46 @@ static constexpr auto const_ = std::true_type{};
651657

652658
#endif // overload_cast
653659

660+
NAMESPACE_BEGIN(detail)
661+
662+
// Adaptor for converting arbitrary container arguments into a vector; implicitly convertible from
663+
// any standard container (or C-style array) supporting std::begin/std::end.
664+
template <typename T>
665+
class any_container {
666+
std::vector<T> v;
667+
public:
668+
any_container() = default;
669+
670+
// Can construct from a pair of iterators
671+
template <typename It, typename = enable_if_t<is_input_iterator<It>::value>>
672+
any_container(It first, It last) : v(first, last) { }
673+
674+
// Implicit conversion constructor from any arbitrary container type with values convertible to T
675+
template <typename Container, typename = enable_if_t<std::is_convertible<decltype(*std::begin(std::declval<const Container &>())), T>::value>>
676+
any_container(const Container &c) : any_container(std::begin(c), std::end(c)) { }
677+
678+
// initializer_list's aren't deducible, so don't get matched by the above template; we need this
679+
// to explicitly allow implicit conversion from one:
680+
template <typename TIn, typename = enable_if_t<std::is_convertible<TIn, T>::value>>
681+
any_container(const std::initializer_list<TIn> &c) : any_container(c.begin(), c.end()) { }
682+
683+
// Avoid copying if given an rvalue vector of the correct type.
684+
any_container(std::vector<T> &&v) : v(std::move(v)) { }
685+
686+
// Moves the vector out of an rvalue any_container
687+
operator std::vector<T> &&() && { return std::move(v); }
688+
689+
// Dereferencing obtains a reference to the underlying vector
690+
std::vector<T> &operator*() { return v; }
691+
const std::vector<T> &operator*() const { return v; }
692+
693+
// -> lets you call methods on the underlying vector
694+
std::vector<T> *operator->() { return &v; }
695+
const std::vector<T> *operator->() const { return &v; }
696+
};
697+
698+
NAMESPACE_END(detail)
699+
700+
701+
654702
NAMESPACE_END(pybind11)

include/pybind11/eigen.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -201,18 +201,13 @@ template <typename Type_> struct EigenProps {
201201
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
202202
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
203203
constexpr size_t elem_size = sizeof(typename props::Scalar);
204-
std::vector<size_t> shape, strides;
205-
if (props::vector) {
206-
shape.push_back(src.size());
207-
strides.push_back(elem_size * src.innerStride());
208-
}
209-
else {
210-
shape.push_back(src.rows());
211-
shape.push_back(src.cols());
212-
strides.push_back(elem_size * src.rowStride());
213-
strides.push_back(elem_size * src.colStride());
214-
}
215-
array a(std::move(shape), std::move(strides), src.data(), base);
204+
array a;
205+
if (props::vector)
206+
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
207+
else
208+
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
209+
src.data(), base);
210+
216211
if (!writeable)
217212
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
218213

include/pybind11/numpy.h

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -455,12 +455,18 @@ class array : public buffer {
455455

456456
array() : array(0, static_cast<const double *>(nullptr)) {}
457457

458-
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
459-
const std::vector<size_t> &strides, const void *ptr = nullptr,
460-
handle base = handle()) {
461-
auto& api = detail::npy_api::get();
462-
auto ndim = shape.size();
463-
if (shape.size() != strides.size())
458+
using ShapeContainer = detail::any_container<Py_intptr_t>;
459+
using StridesContainer = detail::any_container<Py_intptr_t>;
460+
461+
// Constructs an array taking shape/strides from arbitrary container types
462+
array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides,
463+
const void *ptr = nullptr, handle base = handle()) {
464+
465+
if (strides->empty())
466+
strides = default_strides(*shape, dt.itemsize());
467+
468+
auto ndim = shape->size();
469+
if (ndim != strides->size())
464470
pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
465471
auto descr = dt;
466472

@@ -474,10 +480,9 @@ class array : public buffer {
474480
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
475481
}
476482

483+
auto &api = detail::npy_api::get();
477484
auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
478-
api.PyArray_Type_, descr.release().ptr(), (int) ndim,
479-
reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(shape.data())),
480-
reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(strides.data())),
485+
api.PyArray_Type_, descr.release().ptr(), (int) ndim, shape->data(), strides->data(),
481486
const_cast<void *>(ptr), flags, nullptr));
482487
if (!tmp)
483488
pybind11_fail("NumPy: unable to create array!");
@@ -491,27 +496,24 @@ class array : public buffer {
491496
m_ptr = tmp.release().ptr();
492497
}
493498

494-
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
495-
const void *ptr = nullptr, handle base = handle())
496-
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
499+
array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle())
500+
: array(dt, std::move(shape), {}, ptr, base) { }
497501

498502
array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
499503
handle base = handle())
500-
: array(dt, std::vector<size_t>{ count }, ptr, base) { }
504+
: array(dt, ShapeContainer{{ count }}, ptr, base) { }
501505

502-
template<typename T> array(const std::vector<size_t>& shape,
503-
const std::vector<size_t>& strides,
504-
const T* ptr, handle base = handle())
505-
: array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { }
506+
template <typename T>
507+
array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
508+
: array(pybind11::dtype::of<T>(), std::move(shape), std::move(strides), ptr, base) { }
506509

507510
template <typename T>
508-
array(const std::vector<size_t> &shape, const T *ptr,
509-
handle base = handle())
510-
: array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
511+
array(ShapeContainer shape, const T *ptr, handle base = handle())
512+
: array(std::move(shape), {}, ptr, base) { }
511513

512514
template <typename T>
513515
array(size_t count, const T *ptr, handle base = handle())
514-
: array(std::vector<size_t>{ count }, ptr, base) { }
516+
: array({{ count }}, ptr, base) { }
515517

516518
explicit array(const buffer_info &info)
517519
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
@@ -673,9 +675,9 @@ class array : public buffer {
673675
throw std::domain_error("array is not writeable");
674676
}
675677

676-
static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
678+
static std::vector<Py_intptr_t> default_strides(const std::vector<Py_intptr_t>& shape, size_t itemsize) {
677679
auto ndim = shape.size();
678-
std::vector<size_t> strides(ndim);
680+
std::vector<Py_intptr_t> strides(ndim);
679681
if (ndim) {
680682
std::fill(strides.begin(), strides.end(), itemsize);
681683
for (size_t i = 0; i < ndim - 1; i++)
@@ -731,14 +733,11 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
731733

732734
explicit array_t(const buffer_info& info) : array(info) { }
733735

734-
array_t(const std::vector<size_t> &shape,
735-
const std::vector<size_t> &strides, const T *ptr = nullptr,
736-
handle base = handle())
737-
: array(shape, strides, ptr, base) { }
736+
array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle())
737+
: array(std::move(shape), std::move(strides), ptr, base) { }
738738

739-
explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
740-
handle base = handle())
741-
: array(shape, ptr, base) { }
739+
explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
740+
: array(std::move(shape), ptr, base) { }
742741

743742
explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
744743
: array(count, ptr, base) { }

tests/test_numpy_array.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include <pybind11/stl.h>
1414

1515
#include <cstdint>
16-
#include <vector>
1716

1817
using arr = py::array;
1918
using arr_t = py::array_t<uint16_t, 0>;
@@ -119,8 +118,8 @@ test_initializer numpy_array([](py::module &m) {
119118
sm.def("wrap", [](py::array a) {
120119
return py::array(
121120
a.dtype(),
122-
std::vector<size_t>(a.shape(), a.shape() + a.ndim()),
123-
std::vector<size_t>(a.strides(), a.strides() + a.ndim()),
121+
{a.shape(), a.shape() + a.ndim()},
122+
{a.strides(), a.strides() + a.ndim()},
124123
a.data(),
125124
a
126125
);

0 commit comments

Comments
 (0)