Skip to content

array: implement array resize #795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ struct npy_api {
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
};

typedef struct {
Py_intptr_t *ptr;
int len;
} PyArray_Dims;

static npy_api& get() {
static npy_api api = lookup();
return api;
Expand Down Expand Up @@ -159,6 +164,7 @@ struct npy_api {
Py_ssize_t *, PyObject **, PyObject *);
PyObject *(*PyArray_Squeeze_)(PyObject *);
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
private:
enum functions {
API_PyArray_GetNDArrayCFeatureVersion = 211,
Expand All @@ -168,6 +174,7 @@ struct npy_api {
API_PyArray_DescrFromType = 45,
API_PyArray_DescrFromScalar = 57,
API_PyArray_FromAny = 69,
API_PyArray_Resize = 80,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
API_PyArray_DescrNewFromType = 9,
Expand Down Expand Up @@ -197,6 +204,7 @@ struct npy_api {
DECL_NPY_API(PyArray_DescrFromType);
DECL_NPY_API(PyArray_DescrFromScalar);
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_Resize);
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_DescrNewFromType);
Expand Down Expand Up @@ -652,6 +660,21 @@ class array : public buffer {
return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
}

/// Resize array to given shape
/// If refcheck is true and more that one reference exist to this array
/// then resize will succeed only if it makes a reshape, i.e. original size doesn't change
void resize(ShapeContainer new_shape, bool refcheck = true) {
detail::npy_api::PyArray_Dims d = {
new_shape->data(), int(new_shape->size())
};
// try to resize, set ordering param to -1 cause it's not used anyway
object new_array = reinterpret_steal<object>(
detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1)
);
if (!new_array) throw error_already_set();
if (isinstance<array>(new_array)) { *this = std::move(new_array); }
}

/// Ensure that the argument is a NumPy array
/// In case of an error, nullptr is returned and the Python error is cleared.
static array ensure(handle h, int ExtraFlags = 0) {
Expand Down
23 changes: 22 additions & 1 deletion tests/test_numpy_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,4 +273,25 @@ test_initializer numpy_array([](py::module &m) {
sm.def("array_initializer_list", []() { return py::array_t<float>({ 1, 2 }); });
sm.def("array_initializer_list", []() { return py::array_t<float>({ 1, 2, 3 }); });
sm.def("array_initializer_list", []() { return py::array_t<float>({ 1, 2, 3, 4 }); });
});

// reshape array to 2D without changing size
sm.def("array_reshape2", [](py::array_t<double> a) {
const size_t dim_sz = (size_t)std::sqrt(a.size());
if (dim_sz * dim_sz != a.size())
throw std::domain_error("array_reshape2: input array total size is not a squared integer");
a.resize({dim_sz, dim_sz});
});

// resize to 3D array with each dimension = N
sm.def("array_resize3", [](py::array_t<double> a, size_t N, bool refcheck) {
a.resize({N, N, N}, refcheck);
});

// return 2D array with Nrows = Ncols = N
sm.def("create_and_resize", [](size_t N) {
py::array_t<double> a;
a.resize({N, N});
std::fill(a.mutable_data(), a.mutable_data() + a.size(), 42.);
return a;
});
});
35 changes: 35 additions & 0 deletions tests/test_numpy_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,38 @@ def test_array_failure():
with pytest.raises(ValueError) as excinfo:
array_t_fail_test()
assert str(excinfo.value) == 'cannot create a pybind11::array_t from a nullptr'


def test_array_resize(msg):
from pybind11_tests.array import (array_reshape2, array_resize3)

a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='float64')
array_reshape2(a)
assert(a.size == 9)
assert(np.all(a == [[1, 2, 3], [4, 5, 6], [7, 8, 9]]))

# total size change should succced with refcheck off
array_resize3(a, 4, False)
assert(a.size == 64)
# ... and fail with refcheck on
try:
array_resize3(a, 3, True)
except ValueError as e:
assert(str(e).startswith("cannot resize an array"))
# transposed array doesn't own data
b = a.transpose()
try:
array_resize3(b, 3, False)
except ValueError as e:
assert(str(e).startswith("cannot resize this array: it does not own its data"))
# ... but reshape should be fine
array_reshape2(b)
assert(b.shape == (8, 8))


@pytest.unsupported_on_pypy
def test_array_create_and_resize(msg):
from pybind11_tests.array import create_and_resize
a = create_and_resize(2)
assert(a.size == 4)
assert(np.all(a == 42.))