Skip to content

Rework subview methods and other related methods #537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 19, 2018
4 changes: 2 additions & 2 deletions examples/axis_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn main() {
}
a.swap_axes(0, 1);
a.swap_axes(0, 2);
a.slice_inplace(s![.., ..;-1, ..]);
a.slice_collapse(s![.., ..;-1, ..]);
regularize(&mut a).ok();

let mut b = Array::<u8, _>::zeros((2, 3, 4));
Expand All @@ -64,6 +64,6 @@ fn main() {
for (i, elt) in (0..).zip(&mut a) {
*elt = i;
}
a.slice_inplace(s![..;-1, ..;2, ..]);
a.slice_collapse(s![..;-1, ..;2, ..]);
regularize(&mut a).ok();
}
4 changes: 2 additions & 2 deletions examples/sort-axis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ impl<A, D> PermuteArray for Array<A, D>
result = Array::from_shape_vec_unchecked(self.dim(), v);
for i in 0..axis_len {
let perm_i = perm.indices[i];
Zip::from(result.subview_mut(axis, perm_i))
.and(self.subview(axis, i))
Zip::from(result.index_axis_mut(axis, perm_i))
.and(self.index_axis(axis, i))
.apply(|to, from| {
copy_nonoverlapping(from, to, 1)
});
Expand Down
8 changes: 4 additions & 4 deletions serialization-tests/tests/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ fn serial_many_dim()
{
// Test a sliced array.
let mut a = RcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4));
a.slice_inplace(s![..;-1, .., .., ..2]);
a.slice_collapse(s![..;-1, .., .., ..2]);
let serial = json::encode(&a).unwrap();
println!("Encode {:?} => {:?}", a, serial);
let res = json::decode::<RcArray<f32, _>>(&serial);
Expand Down Expand Up @@ -114,7 +114,7 @@ fn serial_many_dim_serde()
{
// Test a sliced array.
let mut a = RcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4));
a.slice_inplace(s![..;-1, .., .., ..2]);
a.slice_collapse(s![..;-1, .., .., ..2]);
let serial = serde_json::to_string(&a).unwrap();
println!("Encode {:?} => {:?}", a, serial);
let res = serde_json::from_str::<RcArray<f32, _>>(&serial);
Expand Down Expand Up @@ -221,7 +221,7 @@ fn serial_many_dim_serde_msgpack()
{
// Test a sliced array.
let mut a = RcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4));
a.slice_inplace(s![..;-1, .., .., ..2]);
a.slice_collapse(s![..;-1, .., .., ..2]);

let mut buf = Vec::new();
serde::Serialize::serialize(&a, &mut rmp_serde::Serializer::new(&mut buf)).ok().unwrap();
Expand Down Expand Up @@ -273,7 +273,7 @@ fn serial_many_dim_ron()
{
// Test a sliced array.
let mut a = RcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4));
a.slice_inplace(s![..;-1, .., .., ..2]);
a.slice_collapse(s![..;-1, .., .., ..2]);

let a_s = ron_serialize(&a).unwrap();

Expand Down
13 changes: 9 additions & 4 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,18 @@ impl<'a> DimensionExt for [Ix]
///
/// **Panics** if `index` is larger than the size of the axis
// FIXME: Move to Dimension trait
pub fn do_sub<A, D: Dimension>(dims: &mut D, ptr: &mut *mut A, strides: &D,
axis: usize, index: Ix) {
pub fn do_collapse_axis<A, D: Dimension>(
dims: &mut D,
ptr: &mut *mut A,
strides: &D,
axis: usize,
index: usize,
) {
let dim = dims.slice()[axis];
let stride = strides.slice()[axis];
ndassert!(index < dim,
concat!("subview: Index {} must be less than axis length {} ",
"for array with shape {:?}"),
"collapse_axis: Index {} must be less than axis length {} for \
array with shape {:?}",
index, dim, *dims);
dims.slice_mut()[axis] = 1;
let off = stride_offset(index, stride);
Expand Down
8 changes: 4 additions & 4 deletions src/doc/ndarray_for_numpy_users/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@
//! Only the non-mutable methods that take the array by reference are listed in
//! this table. For example, [`.slice()`][.slice()] also has corresponding
//! methods [`.slice_mut()`][.slice_mut()], [`.slice_move()`][.slice_move()], and
//! [`.slice_inplace()`][.slice_inplace()].
//! [`.slice_collapse()`][.slice_collapse()].
//!
//! * The behavior of slicing is slightly different from NumPy for slices with
//! `step < -1`. See the docs for the [`s![]` macro][s!] for more details.
Expand All @@ -238,7 +238,7 @@
//! ------|-----------|------
//! `a[-1]` | [`a[a.len() - 1]`][.index()] | access the last element in 1-D array `a`
//! `a[1, 4]` | [`a[[1, 4]]`][.index()] | access the element in row 1, column 4
//! `a[1]` or `a[1, :, :]` | [`a.slice(s![1, .., ..])`][.slice()] or [`a.subview(Axis(0), 1)`][.subview()] | get a 2-D subview of a 3-D array at index 1 of axis 0
//! `a[1]` or `a[1, :, :]` | [`a.slice(s![1, .., ..])`][.slice()] or [`a.index_axis(Axis(0), 1)`][.index_axis()] | get a 2-D subview of a 3-D array at index 1 of axis 0
//! `a[0:5]` or `a[:5]` or `a[0:5, :]` | [`a.slice(s![0..5, ..])`][.slice()] or [`a.slice(s![..5, ..])`][.slice()] or [`a.slice_axis(Axis(0), Slice::from(0..5))`][.slice_axis()] | get the first 5 rows of a 2-D array
//! `a[-5:]` or `a[-5:, :]` | [`a.slice(s![-5.., ..])`][.slice()] or [`a.slice_axis(Axis(0), Slice::from(-5..))`][.slice_axis()] | get the last 5 rows of a 2-D array
//! `a[:3, 4:9]` | [`a.slice(s![..3, 4..9])`][.slice()] | columns 4, 5, 6, 7, and 8 of the first 3 rows
Expand Down Expand Up @@ -618,14 +618,14 @@
//! [.sum()]: ../../struct.ArrayBase.html#method.sum
//! [.slice()]: ../../struct.ArrayBase.html#method.slice
//! [.slice_axis()]: ../../struct.ArrayBase.html#method.slice_axis
//! [.slice_inplace()]: ../../struct.ArrayBase.html#method.slice_inplace
//! [.slice_collapse()]: ../../struct.ArrayBase.html#method.slice_collapse
//! [.slice_move()]: ../../struct.ArrayBase.html#method.slice_move
//! [.slice_mut()]: ../../struct.ArrayBase.html#method.slice_mut
//! [.shape()]: ../../struct.ArrayBase.html#method.shape
//! [stack!]: ../../macro.stack.html
//! [stack()]: ../../fn.stack.html
//! [.strides()]: ../../struct.ArrayBase.html#method.strides
//! [.subview()]: ../../struct.ArrayBase.html#method.subview
//! [.index_axis()]: ../../struct.ArrayBase.html#method.index_axis
//! [.sum_axis()]: ../../struct.ArrayBase.html#method.sum_axis
//! [.t()]: ../../struct.ArrayBase.html#method.t
//! [::uninitialized()]: ../../struct.ArrayBase.html#method.uninitialized
Expand Down
8 changes: 4 additions & 4 deletions src/impl_2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl<A, S> ArrayBase<S, Ix2>
/// **Panics** if `index` is out of bounds.
pub fn row(&self, index: Ix) -> ArrayView1<A>
{
self.subview(Axis(0), index)
self.index_axis(Axis(0), index)
}

/// Return a mutable array view of row `index`.
Expand All @@ -28,7 +28,7 @@ impl<A, S> ArrayBase<S, Ix2>
pub fn row_mut(&mut self, index: Ix) -> ArrayViewMut1<A>
where S: DataMut
{
self.subview_mut(Axis(0), index)
self.index_axis_mut(Axis(0), index)
}

/// Return the number of rows (length of `Axis(0)`) in the two-dimensional array.
Expand All @@ -41,7 +41,7 @@ impl<A, S> ArrayBase<S, Ix2>
/// **Panics** if `index` is out of bounds.
pub fn column(&self, index: Ix) -> ArrayView1<A>
{
self.subview(Axis(1), index)
self.index_axis(Axis(1), index)
}

/// Return a mutable array view of column `index`.
Expand All @@ -50,7 +50,7 @@ impl<A, S> ArrayBase<S, Ix2>
pub fn column_mut(&mut self, index: Ix) -> ArrayViewMut1<A>
where S: DataMut
{
self.subview_mut(Axis(1), index)
self.index_axis_mut(Axis(1), index)
}

/// Return the number of columns (length of `Axis(1)`) in the two-dimensional array.
Expand Down
58 changes: 58 additions & 0 deletions src/impl_dyn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2018 bluss and ndarray developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! Methods for dynamic-dimensional arrays.
use imp_prelude::*;

/// # Methods for Dynamic-Dimensional Arrays
impl<A, S> ArrayBase<S, IxDyn>
where
S: Data<Elem = A>,
{
/// Insert new array axis of length 1 at `axis`, modifying the shape and
/// strides in-place.
///
/// **Panics** if the axis is out of bounds.
///
/// ```
/// use ndarray::{Axis, arr2, arr3};
///
/// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn();
/// assert_eq!(a.shape(), &[2, 3]);
///
/// a.insert_axis_inplace(Axis(1));
/// assert_eq!(a, arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn());
/// assert_eq!(a.shape(), &[2, 1, 3]);
/// ```
pub fn insert_axis_inplace(&mut self, axis: Axis) {
assert!(axis.index() <= self.ndim());
self.dim = self.dim.insert_axis(axis);
self.strides = self.strides.insert_axis(axis);
}

/// Collapses the array to `index` along the axis and removes the axis,
/// modifying the shape and strides in-place.
///
/// **Panics** if `axis` or `index` is out of bounds.
///
/// ```
/// use ndarray::{Axis, arr1, arr2};
///
/// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn();
/// assert_eq!(a.shape(), &[2, 3]);
///
/// a.index_axis_inplace(Axis(1), 1);
/// assert_eq!(a, arr1(&[2, 5]).into_dyn());
/// assert_eq!(a.shape(), &[2]);
/// ```
pub fn index_axis_inplace(&mut self, axis: Axis, index: usize) {
self.collapse_axis(axis, index);
self.dim = self.dim.remove_axis(axis);
self.strides = self.strides.remove_axis(axis);
}
}
Loading