diff --git a/benches/bench1.rs b/benches/bench1.rs index 38d93c22e..7ccab0395 100644 --- a/benches/bench1.rs +++ b/benches/bench1.rs @@ -12,6 +12,7 @@ use rblas::matrix::Matrix; use ndarray::{ OwnedArray, + Axis, }; use ndarray::{arr0, arr1, arr2}; @@ -562,5 +563,5 @@ fn dot_f32_1024(bench: &mut test::Bencher) fn means(bench: &mut test::Bencher) { let a = OwnedArray::from_iter(0..100_000i64); let a = a.into_shape((100, 1000)).unwrap(); - bench.iter(|| a.mean(0)); + bench.iter(|| a.mean(Axis(0))); } diff --git a/examples/axis.rs b/examples/axis.rs new file mode 100644 index 000000000..0812f95e9 --- /dev/null +++ b/examples/axis.rs @@ -0,0 +1,15 @@ +extern crate ndarray; + +use ndarray::{ + OwnedArray, + Axis, +}; + +fn main() { + let a = OwnedArray::::linspace(0., 24., 25).into_shape((5, 5)).unwrap(); + println!("{:?}", a.subview(Axis(0), 0)); + println!("{:?}", a.subview(Axis(0), 1)); + println!("{:?}", a.subview(Axis(1), 1)); + println!("{:?}", a.subview(Axis(0), 1)); + println!("{:?}", a.subview(Axis(2), 1)); // PANIC +} diff --git a/src/dimension.rs b/src/dimension.rs index 1440ffdf5..99e61943c 100644 --- a/src/dimension.rs +++ b/src/dimension.rs @@ -715,3 +715,13 @@ mod test { assert!(super::dim_stride_overlap(&dim, &strides)); } } + +/// An axis index. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub struct Axis(pub usize); + +impl Axis { + #[inline(always)] + pub fn axis(&self) -> usize { self.0 } +} + diff --git a/src/lib.rs b/src/lib.rs index f840aaa5f..b29180fe9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -92,6 +92,7 @@ use itertools::free::enumerate; pub use dimension::{ Dimension, RemoveAxis, + Axis, }; pub use dimension::NdIndex; @@ -292,7 +293,7 @@ pub type Ixs = isize; /// Subview takes two arguments: `axis` and `index`. /// /// ``` -/// use ndarray::{arr3, aview2}; +/// use ndarray::{arr3, aview2, Axis}; /// /// // 2 submatrices of 2 rows with 3 elements per row, means a shape of `[2, 2, 3]`. /// @@ -308,8 +309,8 @@ pub type Ixs = isize; /// // Let’s take a subview along the greatest dimension (axis 0), /// // taking submatrix 0, then submatrix 1 /// -/// let sub_0 = a.subview(0, 0); -/// let sub_1 = a.subview(0, 1); +/// let sub_0 = a.subview(Axis(0), 0); +/// let sub_1 = a.subview(Axis(0), 1); /// /// assert_eq!(sub_0, aview2(&[[ 1, 2, 3], /// [ 4, 5, 6]])); @@ -318,7 +319,7 @@ pub type Ixs = isize; /// assert_eq!(sub_0.shape(), &[2, 3]); /// /// // This is the subview picking only axis 2, column 0 -/// let sub_col = a.subview(2, 0); +/// let sub_col = a.subview(Axis(2), 0); /// /// assert_eq!(sub_col, aview2(&[[ 1, 4], /// [ 7, 10]])); @@ -1265,7 +1266,7 @@ impl ArrayBase where S: Data, D: Dimension /// **Panics** if `axis` or `index` is out of bounds. /// /// ``` - /// use ndarray::{arr1, arr2}; + /// use ndarray::{arr1, arr2, Axis}; /// /// let a = arr2(&[[1., 2.], // -- axis 0, row 0 /// [3., 4.], // -- axis 0, row 1 @@ -1274,13 +1275,13 @@ impl ArrayBase where S: Data, D: Dimension /// // \ axis 1, column 1 /// // axis 1, column 0 /// assert!( - /// a.subview(0, 1) == arr1(&[3., 4.]) && - /// a.subview(1, 1) == arr1(&[2., 4., 6.]) + /// a.subview(Axis(0), 1) == arr1(&[3., 4.]) && + /// a.subview(Axis(1), 1) == arr1(&[2., 4., 6.]) /// ); /// ``` - pub fn subview(&self, axis: usize, index: Ix) + pub fn subview(&self, axis: Axis, index: Ix) -> ArrayView::Smaller> - where D: RemoveAxis + where D: RemoveAxis, { self.view().into_subview(axis, index) } @@ -1291,19 +1292,19 @@ impl ArrayBase where S: Data, D: Dimension /// **Panics** if `axis` or `index` is out of bounds. /// /// ``` - /// use ndarray::{arr2, aview2}; + /// use ndarray::{arr2, aview2, Axis}; /// /// let mut a = arr2(&[[1., 2.], /// [3., 4.]]); /// - /// a.subview_mut(1, 1).iadd_scalar(&10.); + /// a.subview_mut(Axis(1), 1).iadd_scalar(&10.); /// /// assert!( /// a == aview2(&[[1., 12.], /// [3., 14.]]) /// ); /// ``` - pub fn subview_mut(&mut self, axis: usize, index: Ix) + pub fn subview_mut(&mut self, axis: Axis, index: Ix) -> ArrayViewMut where S: DataMut, D: RemoveAxis, @@ -1315,19 +1316,21 @@ impl ArrayBase where S: Data, D: Dimension /// and select the subview of `index` along that axis. /// /// **Panics** if `index` is past the length of the axis. - pub fn isubview(&mut self, axis: usize, index: Ix) { - dimension::do_sub(&mut self.dim, &mut self.ptr, &self.strides, axis, index) + pub fn isubview(&mut self, axis: Axis, index: Ix) { + dimension::do_sub(&mut self.dim, &mut self.ptr, &self.strides, + axis.axis(), index) } /// Along `axis`, select the subview `index` and return `self` /// with that axis removed. /// /// See [`.subview()`](#method.subview) and [*Subviews*](#subviews) for full documentation. - pub fn into_subview(mut self, axis: usize, index: Ix) + pub fn into_subview(mut self, axis: Axis, index: Ix) -> ArrayBase::Smaller> - where D: RemoveAxis + where D: RemoveAxis, { self.isubview(axis, index); + let axis = axis.axis(); // don't use reshape -- we always know it will fit the size, // and we can use remove_axis on the strides as well ArrayBase { @@ -1379,15 +1382,16 @@ impl ArrayBase where S: Data, D: Dimension /// Iterator element is `ArrayView` (read-only array view). /// /// ``` - /// use ndarray::arr3; + /// use ndarray::{arr3, Axis}; + /// /// let a = arr3(&[[[ 0, 1, 2], // \ axis 0, submatrix 0 /// [ 3, 4, 5]], // / /// [[ 6, 7, 8], // \ axis 0, submatrix 1 /// [ 9, 10, 11]]]); // / /// // `outer_iter` yields the two submatrices along axis 0. /// let mut iter = a.outer_iter(); - /// assert_eq!(iter.next().unwrap(), a.subview(0, 0)); - /// assert_eq!(iter.next().unwrap(), a.subview(0, 1)); + /// assert_eq!(iter.next().unwrap(), a.subview(Axis(0), 0)); + /// assert_eq!(iter.next().unwrap(), a.subview(Axis(0), 1)); /// ``` pub fn outer_iter(&self) -> OuterIter where D: RemoveAxis, @@ -1418,10 +1422,10 @@ impl ArrayBase where S: Data, D: Dimension /// See [*Subviews*](#subviews) for full documentation. /// /// **Panics** if `axis` is out of bounds. - pub fn axis_iter(&self, axis: usize) -> OuterIter - where D: RemoveAxis + pub fn axis_iter(&self, axis: Axis) -> OuterIter + where D: RemoveAxis, { - iterators::new_axis_iter(self.view(), axis) + iterators::new_axis_iter(self.view(), axis.axis()) } @@ -1432,11 +1436,11 @@ impl ArrayBase where S: Data, D: Dimension /// (read-write array view). /// /// **Panics** if `axis` is out of bounds. - pub fn axis_iter_mut(&mut self, axis: usize) -> OuterIterMut + pub fn axis_iter_mut(&mut self, axis: Axis) -> OuterIterMut where S: DataMut, D: RemoveAxis, { - iterators::new_axis_iter_mut(self.view_mut(), axis) + iterators::new_axis_iter_mut(self.view_mut(), axis.axis()) } /// Return an iterator that traverses over `axis` by chunks of `size`, @@ -1451,20 +1455,22 @@ impl ArrayBase where S: Data, D: Dimension /// /// ``` /// use ndarray::OwnedArray; - /// use ndarray::arr3; + /// use ndarray::{arr3, Axis}; /// /// let a = OwnedArray::from_iter(0..28).into_shape((2, 7, 2)).unwrap(); - /// let mut iter = a.axis_chunks_iter(1, 2); + /// let mut iter = a.axis_chunks_iter(Axis(1), 2); /// /// // first iteration yields a 2 × 2 × 2 view /// assert_eq!(iter.next().unwrap(), - /// arr3(&[[[0, 1], [2, 3]], [[14, 15], [16, 17]]])); + /// arr3(&[[[ 0, 1], [ 2, 3]], + /// [[14, 15], [16, 17]]])); /// /// // however the last element is a 2 × 1 × 2 view since 7 % 2 == 1 - /// assert_eq!(iter.next_back().unwrap(), arr3(&[[[12, 13]], [[26, 27]]])); + /// assert_eq!(iter.next_back().unwrap(), arr3(&[[[12, 13]], + /// [[26, 27]]])); /// ``` - pub fn axis_chunks_iter(&self, axis: usize, size: usize) -> AxisChunksIter { - iterators::new_chunk_iter(self.view(), axis, size) + pub fn axis_chunks_iter(&self, axis: Axis, size: usize) -> AxisChunksIter { + iterators::new_chunk_iter(self.view(), axis.axis(), size) } /// Return an iterator that traverses over `axis` by chunks of `size`, @@ -1473,11 +1479,11 @@ impl ArrayBase where S: Data, D: Dimension /// Iterator element is `ArrayViewMut` /// /// **Panics** if `axis` is out of bounds. - pub fn axis_chunks_iter_mut(&mut self, axis: usize, size: usize) + pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize) -> AxisChunksIterMut where S: DataMut { - iterators::new_chunk_iter_mut(self.view_mut(), axis, size) + iterators::new_chunk_iter_mut(self.view_mut(), axis.axis(), size) } // Return (length, stride) for diagonal @@ -2229,24 +2235,24 @@ impl ArrayBase /// Return sum along `axis`. /// /// ``` - /// use ndarray::{aview0, aview1, arr2}; + /// use ndarray::{aview0, aview1, arr2, Axis}; /// /// let a = arr2(&[[1., 2.], /// [3., 4.]]); /// assert!( - /// a.sum(0) == aview1(&[4., 6.]) && - /// a.sum(1) == aview1(&[3., 7.]) && + /// a.sum(Axis(0)) == aview1(&[4., 6.]) && + /// a.sum(Axis(1)) == aview1(&[3., 7.]) && /// - /// a.sum(0).sum(0) == aview0(&10.) + /// a.sum(Axis(0)).sum(Axis(0)) == aview0(&10.) /// ); /// ``` /// /// **Panics** if `axis` is out of bounds. - pub fn sum(&self, axis: usize) -> OwnedArray::Smaller> + pub fn sum(&self, axis: Axis) -> OwnedArray::Smaller> where A: Clone + Add, D: RemoveAxis, { - let n = self.shape()[axis]; + let n = self.shape()[axis.axis()]; let mut res = self.subview(axis, 0).to_owned(); for i in 1..n { let view = self.subview(axis, i); @@ -2283,24 +2289,23 @@ impl ArrayBase /// Return mean along `axis`. /// + /// **Panics** if `axis` is out of bounds. + /// /// ``` - /// use ndarray::{aview1, arr2}; + /// use ndarray::{aview1, arr2, Axis}; /// /// let a = arr2(&[[1., 2.], /// [3., 4.]]); /// assert!( - /// a.mean(0) == aview1(&[2.0, 3.0]) && - /// a.mean(1) == aview1(&[1.5, 3.5]) + /// a.mean(Axis(0)) == aview1(&[2.0, 3.0]) && + /// a.mean(Axis(1)) == aview1(&[1.5, 3.5]) /// ); /// ``` - /// - /// - /// **Panics** if `axis` is out of bounds. - pub fn mean(&self, axis: usize) -> OwnedArray::Smaller> + pub fn mean(&self, axis: Axis) -> OwnedArray::Smaller> where A: LinalgScalar, D: RemoveAxis, { - let n = self.shape()[axis]; + let n = self.shape()[axis.axis()]; let mut sum = self.sum(axis); let one = libnum::one::(); let mut cnt = one; @@ -2413,7 +2418,7 @@ impl ArrayBase /// **Panics** if `index` is out of bounds. pub fn row(&self, index: Ix) -> ArrayView { - self.subview(0, index) + self.subview(Axis(0), index) } /// Return a mutable array view of row `index`. @@ -2422,7 +2427,7 @@ impl ArrayBase pub fn row_mut(&mut self, index: Ix) -> ArrayViewMut where S: DataMut { - self.subview_mut(0, index) + self.subview_mut(Axis(0), index) } /// Return an array view of column `index`. @@ -2430,7 +2435,7 @@ impl ArrayBase /// **Panics** if `index` is out of bounds. pub fn column(&self, index: Ix) -> ArrayView { - self.subview(1, index) + self.subview(Axis(1), index) } /// Return a mutable array view of column `index`. @@ -2439,7 +2444,7 @@ impl ArrayBase pub fn column_mut(&mut self, index: Ix) -> ArrayViewMut where S: DataMut { - self.subview_mut(1, index) + self.subview_mut(Axis(1), index) } /// Perform matrix multiplication of rectangular arrays `self` and `rhs`. diff --git a/tests/array.rs b/tests/array.rs index 0274dd428..3f0280a3b 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -14,6 +14,7 @@ use ndarray::{arr0, arr1, arr2, aview_mut1, }; use ndarray::Indexes; +use ndarray::Axis; #[test] fn test_matmul_rcarray() @@ -201,14 +202,14 @@ fn test_cow() fn test_sub() { let mat = RcArray::linspace(0., 15., 16).reshape((2, 4, 2)); - let s1 = mat.subview(0,0); - let s2 = mat.subview(0,1); + let s1 = mat.subview(Axis(0),0); + let s2 = mat.subview(Axis(0),1); assert_eq!(s1.dim(), (4, 2)); assert_eq!(s2.dim(), (4, 2)); let n = RcArray::linspace(8., 15., 8).reshape((4,2)); assert_eq!(n, s2); let m = RcArray::from_vec(vec![2., 3., 10., 11.]).reshape((2, 2)); - assert_eq!(m, mat.subview(1, 1)); + assert_eq!(m, mat.subview(Axis(1), 1)); } #[test] @@ -248,9 +249,9 @@ fn standard_layout() assert!(!a.is_standard_layout()); a.swap_axes(0, 1); assert!(a.is_standard_layout()); - let x1 = a.subview(0, 0); + let x1 = a.subview(Axis(0), 0); assert!(x1.is_standard_layout()); - let x2 = a.subview(1, 0); + let x2 = a.subview(Axis(1), 0); assert!(!x2.is_standard_layout()); } @@ -284,12 +285,12 @@ fn assign() fn sum_mean() { let a = arr2(&[[1., 2.], [3., 4.]]); - assert_eq!(a.sum(0), arr1(&[4., 6.])); - assert_eq!(a.sum(1), arr1(&[3., 7.])); - assert_eq!(a.mean(0), arr1(&[2., 3.])); - assert_eq!(a.mean(1), arr1(&[1.5, 3.5])); - assert_eq!(a.sum(1).sum(0), arr0(10.)); - assert_eq!(a.view().mean(1), aview1(&[1.5, 3.5])); + assert_eq!(a.sum(Axis(0)), arr1(&[4., 6.])); + assert_eq!(a.sum(Axis(1)), arr1(&[3., 7.])); + assert_eq!(a.mean(Axis(0)), arr1(&[2., 3.])); + assert_eq!(a.mean(Axis(1)), arr1(&[1.5, 3.5])); + assert_eq!(a.sum(Axis(1)).sum(Axis(0)), arr0(10.)); + assert_eq!(a.view().mean(Axis(1)), aview1(&[1.5, 3.5])); assert_eq!(a.scalar_sum(), 10.); } @@ -341,7 +342,7 @@ fn zero_axes() println!("{:?}\n{:?}", b.shape(), b); // we can even get a subarray of b - let bsub = b.subview(0, 2); + let bsub = b.subview(Axis(0), 2); assert_eq!(bsub.dim(), 0); } @@ -595,7 +596,7 @@ fn char_array() { // test compilation & basics of non-numerical array let cc = RcArray::from_iter("alphabet".chars()).reshape((4, 2)); - assert!(cc.subview(1, 0) == RcArray::from_iter("apae".chars())); + assert!(cc.subview(Axis(1), 0) == RcArray::from_iter("apae".chars())); } #[test] diff --git a/tests/complex.rs b/tests/complex.rs index 96d79c562..049b5c8c4 100644 --- a/tests/complex.rs +++ b/tests/complex.rs @@ -2,7 +2,7 @@ extern crate num; extern crate ndarray; -use ndarray::{arr1, arr2}; +use ndarray::{arr1, arr2, Axis}; use ndarray::OwnedArray; use num::{Num, Complex}; @@ -20,5 +20,5 @@ fn complex_mat_mul() let r = a.mat_mul(&e); println!("{}", a); assert_eq!(r, a); - assert_eq!(a.mean(0), arr1(&[c(1.5, 1.), c(2.5, 0.)])); + assert_eq!(a.mean(Axis(0)), arr1(&[c(1.5, 1.), c(2.5, 0.)])); } diff --git a/tests/dimension.rs b/tests/dimension.rs index d590ce36e..84f72a421 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -5,6 +5,7 @@ use ndarray::{ OwnedArray, RemoveAxis, arr2, + Axis, }; #[test] @@ -17,8 +18,11 @@ fn remove_axis() assert_eq!(vec![1,2].remove_axis(0), vec![2]); assert_eq!(vec![4, 5, 6].remove_axis(1), vec![4, 6]); + let a = RcArray::::zeros((4,5)); + a.subview(Axis(1), 0); + let a = RcArray::::zeros(vec![4,5,6]); - let _b = a.into_subview(1, 0).reshape((4, 6)).reshape(vec![2, 3, 4]); + let _b = a.into_subview(Axis(1), 0).reshape((4, 6)).reshape(vec![2, 3, 4]); } diff --git a/tests/iterators.rs b/tests/iterators.rs index 0ed035cb4..df917e684 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -11,6 +11,7 @@ use ndarray::{ aview1, arr2, arr3, + Axis, }; use itertools::assert_equal; @@ -89,12 +90,12 @@ fn as_slice() { let a = a.reshape((2, 4)); assert_slice_correct(&a); - assert!(a.view().subview(1, 0).as_slice().is_none()); + assert!(a.view().subview(Axis(1), 0).as_slice().is_none()); let v = a.view(); assert_slice_correct(&v); - assert_slice_correct(&v.subview(0, 0)); - assert_slice_correct(&v.subview(0, 1)); + assert_slice_correct(&v.subview(Axis(0), 0)); + assert_slice_correct(&v.subview(Axis(0), 1)); assert!(v.slice(&[S, Si(0, Some(1), 1)]).as_slice().is_none()); println!("{:?}", v.slice(&[Si(0, Some(1), 2), S])); @@ -176,12 +177,12 @@ fn outer_iter() { // [8, 9], // ... assert_equal(a.outer_iter(), - vec![a.subview(0, 0), a.subview(0, 1)]); + vec![a.subview(Axis(0), 0), a.subview(Axis(0), 1)]); let mut b = RcArray::zeros((2, 3, 2)); b.swap_axes(0, 2); b.assign(&a); assert_equal(b.outer_iter(), - vec![a.subview(0, 0), a.subview(0, 1)]); + vec![a.subview(Axis(0), 0), a.subview(Axis(0), 1)]); let mut found_rows = Vec::new(); for sub in b.outer_iter() { @@ -206,7 +207,7 @@ fn outer_iter() { cv.assign(&a); assert_eq!(&a, &cv); assert_equal(cv.outer_iter(), - vec![a.subview(0, 0), a.subview(0, 1)]); + vec![a.subview(Axis(0), 0), a.subview(Axis(0), 1)]); let mut found_rows = Vec::new(); for sub in cv.outer_iter() { @@ -228,10 +229,10 @@ fn axis_iter() { // [[6, 7], // [8, 9], // ... - assert_equal(a.axis_iter(1), - vec![a.subview(1, 0), - a.subview(1, 1), - a.subview(1, 2)]); + assert_equal(a.axis_iter(Axis(1)), + vec![a.subview(Axis(1), 0), + a.subview(Axis(1), 1), + a.subview(Axis(1), 2)]); } #[test] @@ -259,7 +260,7 @@ fn outer_iter_mut() { b.swap_axes(0, 2); b.assign(&a); assert_equal(b.outer_iter_mut(), - vec![a.subview(0, 0), a.subview(0, 1)]); + vec![a.subview(Axis(0), 0), a.subview(Axis(0), 1)]); let mut found_rows = Vec::new(); for sub in b.outer_iter_mut() { @@ -282,7 +283,7 @@ fn axis_iter_mut() { // ... let mut a = a.to_owned(); - for mut subview in a.axis_iter_mut(1) { + for mut subview in a.axis_iter_mut(Axis(1)) { subview[[0, 0]] = 42; } @@ -300,7 +301,7 @@ fn axis_chunks_iter() { let a = RcArray::from_iter(0..24); let a = a.reshape((2, 6, 2)); - let it = a.axis_chunks_iter(1, 2); + let it = a.axis_chunks_iter(Axis(1), 2); assert_equal(it, vec![arr3(&[[[0, 1], [2, 3]], [[12, 13], [14, 15]]]), arr3(&[[[4, 5], [6, 7]], [[16, 17], [18, 19]]]), @@ -309,24 +310,24 @@ fn axis_chunks_iter() { let a = RcArray::from_iter(0..28); let a = a.reshape((2, 7, 2)); - let it = a.axis_chunks_iter(1, 2); + let it = a.axis_chunks_iter(Axis(1), 2); assert_equal(it, vec![arr3(&[[[0, 1], [2, 3]], [[14, 15], [16, 17]]]), arr3(&[[[4, 5], [6, 7]], [[18, 19], [20, 21]]]), arr3(&[[[8, 9], [10, 11]], [[22, 23], [24, 25]]]), arr3(&[[[12, 13]], [[26, 27]]])]); - let it = a.axis_chunks_iter(1, 2).rev(); + let it = a.axis_chunks_iter(Axis(1), 2).rev(); assert_equal(it, vec![arr3(&[[[12, 13]], [[26, 27]]]), arr3(&[[[8, 9], [10, 11]], [[22, 23], [24, 25]]]), arr3(&[[[4, 5], [6, 7]], [[18, 19], [20, 21]]]), arr3(&[[[0, 1], [2, 3]], [[14, 15], [16, 17]]])]); - let it = a.axis_chunks_iter(1, 7); + let it = a.axis_chunks_iter(Axis(1), 7); assert_equal(it, vec![a.view()]); - let it = a.axis_chunks_iter(1, 9); + let it = a.axis_chunks_iter(Axis(1), 9); assert_equal(it, vec![a.view()]); } @@ -338,12 +339,12 @@ fn axis_chunks_iter_corner_cases() { // checking the absence of of out of bounds offseting cannot (?) be // done automatically, so one has to launch this test in a debugger. let a = RcArray::::linspace(0., 7., 8).reshape((8, 1)); - let it = a.axis_chunks_iter(0, 4); + let it = a.axis_chunks_iter(Axis(0), 4); assert_equal(it, vec![a.slice(s![..4, ..]), a.slice(s![4.., ..])]); let a = a.slice(s![..;-1,..]); - let it = a.axis_chunks_iter(0, 8); + let it = a.axis_chunks_iter(Axis(0), 8); assert_equal(it, vec![a.view()]); - let it = a.axis_chunks_iter(0, 3); + let it = a.axis_chunks_iter(Axis(0), 3); assert_equal(it, vec![arr2(&[[7.], [6.], [5.]]), arr2(&[[4.], [3.], [2.]]), @@ -351,10 +352,10 @@ fn axis_chunks_iter_corner_cases() { let b = RcArray::::zeros((8, 2)); let a = b.slice(s![1..;2,..]); - let it = a.axis_chunks_iter(0, 8); + let it = a.axis_chunks_iter(Axis(0), 8); assert_equal(it, vec![a.view()]); - let it = a.axis_chunks_iter(0, 1); + let it = a.axis_chunks_iter(Axis(0), 1); assert_equal(it, vec![RcArray::zeros((1, 2)); 4]); } @@ -363,7 +364,7 @@ fn axis_chunks_iter_mut() { let a = RcArray::from_iter(0..24); let mut a = a.reshape((2, 6, 2)); - let mut it = a.axis_chunks_iter_mut(1, 2); + let mut it = a.axis_chunks_iter_mut(Axis(1), 2); let mut col0 = it.next().unwrap(); col0[[0, 0, 0]] = 42; assert_eq!(col0, arr3(&[[[42, 1], [2, 3]], [[12, 13], [14, 15]]]));