Skip to content

Commit 85c694d

Browse files
authored
Merge pull request #537 from jturner314/indexing-methods
Rework subview methods and other related methods
2 parents 1725928 + 8835b67 commit 85c694d

File tree

14 files changed

+275
-143
lines changed

14 files changed

+275
-143
lines changed

examples/axis_ops.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ fn main() {
4747
}
4848
a.swap_axes(0, 1);
4949
a.swap_axes(0, 2);
50-
a.slice_inplace(s![.., ..;-1, ..]);
50+
a.slice_collapse(s![.., ..;-1, ..]);
5151
regularize(&mut a).ok();
5252

5353
let mut b = Array::<u8, _>::zeros((2, 3, 4));
@@ -64,6 +64,6 @@ fn main() {
6464
for (i, elt) in (0..).zip(&mut a) {
6565
*elt = i;
6666
}
67-
a.slice_inplace(s![..;-1, ..;2, ..]);
67+
a.slice_collapse(s![..;-1, ..;2, ..]);
6868
regularize(&mut a).ok();
6969
}

examples/sort-axis.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ impl<A, D> PermuteArray for Array<A, D>
109109
result = Array::from_shape_vec_unchecked(self.dim(), v);
110110
for i in 0..axis_len {
111111
let perm_i = perm.indices[i];
112-
Zip::from(result.subview_mut(axis, perm_i))
113-
.and(self.subview(axis, i))
112+
Zip::from(result.index_axis_mut(axis, perm_i))
113+
.and(self.index_axis(axis, i))
114114
.apply(|to, from| {
115115
copy_nonoverlapping(from, to, 1)
116116
});

serialization-tests/tests/serialize.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ fn serial_many_dim()
5353
{
5454
// Test a sliced array.
5555
let mut a = RcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4));
56-
a.slice_inplace(s![..;-1, .., .., ..2]);
56+
a.slice_collapse(s![..;-1, .., .., ..2]);
5757
let serial = json::encode(&a).unwrap();
5858
println!("Encode {:?} => {:?}", a, serial);
5959
let res = json::decode::<RcArray<f32, _>>(&serial);
@@ -114,7 +114,7 @@ fn serial_many_dim_serde()
114114
{
115115
// Test a sliced array.
116116
let mut a = RcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4));
117-
a.slice_inplace(s![..;-1, .., .., ..2]);
117+
a.slice_collapse(s![..;-1, .., .., ..2]);
118118
let serial = serde_json::to_string(&a).unwrap();
119119
println!("Encode {:?} => {:?}", a, serial);
120120
let res = serde_json::from_str::<RcArray<f32, _>>(&serial);
@@ -221,7 +221,7 @@ fn serial_many_dim_serde_msgpack()
221221
{
222222
// Test a sliced array.
223223
let mut a = RcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4));
224-
a.slice_inplace(s![..;-1, .., .., ..2]);
224+
a.slice_collapse(s![..;-1, .., .., ..2]);
225225

226226
let mut buf = Vec::new();
227227
serde::Serialize::serialize(&a, &mut rmp_serde::Serializer::new(&mut buf)).ok().unwrap();
@@ -273,7 +273,7 @@ fn serial_many_dim_ron()
273273
{
274274
// Test a sliced array.
275275
let mut a = RcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4));
276-
a.slice_inplace(s![..;-1, .., .., ..2]);
276+
a.slice_collapse(s![..;-1, .., .., ..2]);
277277

278278
let a_s = ron_serialize(&a).unwrap();
279279

src/dimension/mod.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,18 @@ impl<'a> DimensionExt for [Ix]
302302
///
303303
/// **Panics** if `index` is larger than the size of the axis
304304
// FIXME: Move to Dimension trait
305-
pub fn do_sub<A, D: Dimension>(dims: &mut D, ptr: &mut *mut A, strides: &D,
306-
axis: usize, index: Ix) {
305+
pub fn do_collapse_axis<A, D: Dimension>(
306+
dims: &mut D,
307+
ptr: &mut *mut A,
308+
strides: &D,
309+
axis: usize,
310+
index: usize,
311+
) {
307312
let dim = dims.slice()[axis];
308313
let stride = strides.slice()[axis];
309314
ndassert!(index < dim,
310-
concat!("subview: Index {} must be less than axis length {} ",
311-
"for array with shape {:?}"),
315+
"collapse_axis: Index {} must be less than axis length {} for \
316+
array with shape {:?}",
312317
index, dim, *dims);
313318
dims.slice_mut()[axis] = 1;
314319
let off = stride_offset(index, stride);

src/doc/ndarray_for_numpy_users/mod.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@
229229
//! Only the non-mutable methods that take the array by reference are listed in
230230
//! this table. For example, [`.slice()`][.slice()] also has corresponding
231231
//! methods [`.slice_mut()`][.slice_mut()], [`.slice_move()`][.slice_move()], and
232-
//! [`.slice_inplace()`][.slice_inplace()].
232+
//! [`.slice_collapse()`][.slice_collapse()].
233233
//!
234234
//! * The behavior of slicing is slightly different from NumPy for slices with
235235
//! `step < -1`. See the docs for the [`s![]` macro][s!] for more details.
@@ -238,7 +238,7 @@
238238
//! ------|-----------|------
239239
//! `a[-1]` | [`a[a.len() - 1]`][.index()] | access the last element in 1-D array `a`
240240
//! `a[1, 4]` | [`a[[1, 4]]`][.index()] | access the element in row 1, column 4
241-
//! `a[1]` or `a[1, :, :]` | [`a.slice(s![1, .., ..])`][.slice()] or [`a.subview(Axis(0), 1)`][.subview()] | get a 2-D subview of a 3-D array at index 1 of axis 0
241+
//! `a[1]` or `a[1, :, :]` | [`a.slice(s![1, .., ..])`][.slice()] or [`a.index_axis(Axis(0), 1)`][.index_axis()] | get a 2-D subview of a 3-D array at index 1 of axis 0
242242
//! `a[0:5]` or `a[:5]` or `a[0:5, :]` | [`a.slice(s![0..5, ..])`][.slice()] or [`a.slice(s![..5, ..])`][.slice()] or [`a.slice_axis(Axis(0), Slice::from(0..5))`][.slice_axis()] | get the first 5 rows of a 2-D array
243243
//! `a[-5:]` or `a[-5:, :]` | [`a.slice(s![-5.., ..])`][.slice()] or [`a.slice_axis(Axis(0), Slice::from(-5..))`][.slice_axis()] | get the last 5 rows of a 2-D array
244244
//! `a[:3, 4:9]` | [`a.slice(s![..3, 4..9])`][.slice()] | columns 4, 5, 6, 7, and 8 of the first 3 rows
@@ -618,14 +618,14 @@
618618
//! [.sum()]: ../../struct.ArrayBase.html#method.sum
619619
//! [.slice()]: ../../struct.ArrayBase.html#method.slice
620620
//! [.slice_axis()]: ../../struct.ArrayBase.html#method.slice_axis
621-
//! [.slice_inplace()]: ../../struct.ArrayBase.html#method.slice_inplace
621+
//! [.slice_collapse()]: ../../struct.ArrayBase.html#method.slice_collapse
622622
//! [.slice_move()]: ../../struct.ArrayBase.html#method.slice_move
623623
//! [.slice_mut()]: ../../struct.ArrayBase.html#method.slice_mut
624624
//! [.shape()]: ../../struct.ArrayBase.html#method.shape
625625
//! [stack!]: ../../macro.stack.html
626626
//! [stack()]: ../../fn.stack.html
627627
//! [.strides()]: ../../struct.ArrayBase.html#method.strides
628-
//! [.subview()]: ../../struct.ArrayBase.html#method.subview
628+
//! [.index_axis()]: ../../struct.ArrayBase.html#method.index_axis
629629
//! [.sum_axis()]: ../../struct.ArrayBase.html#method.sum_axis
630630
//! [.t()]: ../../struct.ArrayBase.html#method.t
631631
//! [::uninitialized()]: ../../struct.ArrayBase.html#method.uninitialized

src/impl_2d.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ impl<A, S> ArrayBase<S, Ix2>
1919
/// **Panics** if `index` is out of bounds.
2020
pub fn row(&self, index: Ix) -> ArrayView1<A>
2121
{
22-
self.subview(Axis(0), index)
22+
self.index_axis(Axis(0), index)
2323
}
2424

2525
/// Return a mutable array view of row `index`.
@@ -28,7 +28,7 @@ impl<A, S> ArrayBase<S, Ix2>
2828
pub fn row_mut(&mut self, index: Ix) -> ArrayViewMut1<A>
2929
where S: DataMut
3030
{
31-
self.subview_mut(Axis(0), index)
31+
self.index_axis_mut(Axis(0), index)
3232
}
3333

3434
/// Return the number of rows (length of `Axis(0)`) in the two-dimensional array.
@@ -41,7 +41,7 @@ impl<A, S> ArrayBase<S, Ix2>
4141
/// **Panics** if `index` is out of bounds.
4242
pub fn column(&self, index: Ix) -> ArrayView1<A>
4343
{
44-
self.subview(Axis(1), index)
44+
self.index_axis(Axis(1), index)
4545
}
4646

4747
/// Return a mutable array view of column `index`.
@@ -50,7 +50,7 @@ impl<A, S> ArrayBase<S, Ix2>
5050
pub fn column_mut(&mut self, index: Ix) -> ArrayViewMut1<A>
5151
where S: DataMut
5252
{
53-
self.subview_mut(Axis(1), index)
53+
self.index_axis_mut(Axis(1), index)
5454
}
5555

5656
/// Return the number of columns (length of `Axis(1)`) in the two-dimensional array.

src/impl_dyn.rs

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright 2018 bluss and ndarray developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
//! Methods for dynamic-dimensional arrays.
10+
use imp_prelude::*;
11+
12+
/// # Methods for Dynamic-Dimensional Arrays
13+
impl<A, S> ArrayBase<S, IxDyn>
14+
where
15+
S: Data<Elem = A>,
16+
{
17+
/// Insert new array axis of length 1 at `axis`, modifying the shape and
18+
/// strides in-place.
19+
///
20+
/// **Panics** if the axis is out of bounds.
21+
///
22+
/// ```
23+
/// use ndarray::{Axis, arr2, arr3};
24+
///
25+
/// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn();
26+
/// assert_eq!(a.shape(), &[2, 3]);
27+
///
28+
/// a.insert_axis_inplace(Axis(1));
29+
/// assert_eq!(a, arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn());
30+
/// assert_eq!(a.shape(), &[2, 1, 3]);
31+
/// ```
32+
pub fn insert_axis_inplace(&mut self, axis: Axis) {
33+
assert!(axis.index() <= self.ndim());
34+
self.dim = self.dim.insert_axis(axis);
35+
self.strides = self.strides.insert_axis(axis);
36+
}
37+
38+
/// Collapses the array to `index` along the axis and removes the axis,
39+
/// modifying the shape and strides in-place.
40+
///
41+
/// **Panics** if `axis` or `index` is out of bounds.
42+
///
43+
/// ```
44+
/// use ndarray::{Axis, arr1, arr2};
45+
///
46+
/// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn();
47+
/// assert_eq!(a.shape(), &[2, 3]);
48+
///
49+
/// a.index_axis_inplace(Axis(1), 1);
50+
/// assert_eq!(a, arr1(&[2, 5]).into_dyn());
51+
/// assert_eq!(a.shape(), &[2]);
52+
/// ```
53+
pub fn index_axis_inplace(&mut self, axis: Axis, index: usize) {
54+
self.collapse_axis(axis, index);
55+
self.dim = self.dim.remove_axis(axis);
56+
self.strides = self.strides.remove_axis(axis);
57+
}
58+
}

0 commit comments

Comments
 (0)