Skip to content

Commit 7e67f4a

Browse files
committed
Add support for inserting new axes while slicing
1 parent 64aa02d commit 7e67f4a

File tree

7 files changed

+148
-67
lines changed

7 files changed

+148
-67
lines changed

src/dimension/mod.rs

+14-9
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,11 @@ pub fn slices_intersect<D: Dimension>(
534534
indices2: &impl CanSlice<D>,
535535
) -> bool {
536536
debug_assert_eq!(indices1.in_ndim(), indices2.in_ndim());
537-
for (&axis_len, &si1, &si2) in izip!(dim.slice(), indices1.as_ref(), indices2.as_ref()) {
537+
for (&axis_len, &si1, &si2) in izip!(
538+
dim.slice(),
539+
indices1.as_ref().iter().filter(|si| !si.is_new_axis()),
540+
indices2.as_ref().iter().filter(|si| !si.is_new_axis()),
541+
) {
538542
// The slices do not intersect iff any pair of `AxisSliceInfo` does not intersect.
539543
match (si1, si2) {
540544
(
@@ -582,6 +586,7 @@ pub fn slices_intersect<D: Dimension>(
582586
return false;
583587
}
584588
}
589+
(AxisSliceInfo::NewAxis, _) | (_, AxisSliceInfo::NewAxis) => unreachable!(),
585590
}
586591
}
587592
true
@@ -622,7 +627,7 @@ mod test {
622627
max_abs_offset_check_overflow, slice_min_max, slices_intersect,
623628
solve_linear_diophantine_eq, IntoDimension
624629
};
625-
use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn};
630+
use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn, NewAxis};
626631
use crate::error::{from_kind, ErrorKind};
627632
use crate::slice::Slice;
628633
use num_integer::gcd;
@@ -882,17 +887,17 @@ mod test {
882887

883888
#[test]
884889
fn slices_intersect_true() {
885-
assert!(slices_intersect(&Dim([4, 5]), s![.., ..], s![.., ..]));
886-
assert!(slices_intersect(&Dim([4, 5]), s![0, ..], s![0, ..]));
887-
assert!(slices_intersect(&Dim([4, 5]), s![..;2, ..], s![..;3, ..]));
888-
assert!(slices_intersect(&Dim([4, 5]), s![.., ..;2], s![.., 1..;3]));
890+
assert!(slices_intersect(&Dim([4, 5]), s![NewAxis, .., NewAxis, ..], s![.., NewAxis, .., NewAxis]));
891+
assert!(slices_intersect(&Dim([4, 5]), s![NewAxis, 0, ..], s![0, ..]));
892+
assert!(slices_intersect(&Dim([4, 5]), s![..;2, ..], s![..;3, NewAxis, ..]));
893+
assert!(slices_intersect(&Dim([4, 5]), s![.., ..;2], s![.., 1..;3, NewAxis]));
889894
assert!(slices_intersect(&Dim([4, 10]), s![.., ..;9], s![.., 3..;6]));
890895
}
891896

892897
#[test]
893898
fn slices_intersect_false() {
894-
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;2, ..]));
895-
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;3, ..]));
896-
assert!(!slices_intersect(&Dim([4, 5]), s![.., ..;9], s![.., 3..;6]));
899+
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![NewAxis, 1..;2, ..]));
900+
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, NewAxis, ..], s![1..;3, ..]));
901+
assert!(!slices_intersect(&Dim([4, 5]), s![.., ..;9], s![.., 3..;6, NewAxis]));
897902
}
898903
}

src/doc/ndarray_for_numpy_users/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@
519519
//! `a[:] = 3.` | [`a.fill(3.)`][.fill()] | set all array elements to the same scalar value
520520
//! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a`
521521
//! `np.concatenate((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), &[a.view(), b.view()])`][stack()] | concatenate arrays `a` and `b` along axis 1
522-
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1
522+
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.slice(s![.., NewAxis])`][.slice()] or [`a.insert_axis(Axis(1))`][.insert_axis()] | create an view of 1-D array `a`, inserting a new axis 1
523523
//! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`)
524524
//! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a`
525525
//! `a.flatten()` | [`Array::from_iter(a.iter())`][::from_iter()] | create a 1-D array by flattening `a`

src/impl_methods.rs

+16-6
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,12 @@ where
366366
// Skip the old axis since it should be removed.
367367
old_axis += 1;
368368
}
369+
&AxisSliceInfo::NewAxis => {
370+
// Set the dim and stride of the new axis.
371+
new_dim[new_axis] = 1;
372+
new_strides[new_axis] = 0;
373+
new_axis += 1;
374+
}
369375
});
370376
debug_assert_eq!(old_axis, self.ndim());
371377
debug_assert_eq!(new_axis, out_ndim);
@@ -381,6 +387,8 @@ where
381387

382388
/// Slice the array in place without changing the number of dimensions.
383389
///
390+
/// Note that `NewAxis` elements in `info` are ignored.
391+
///
384392
/// See [*Slicing*](#slicing) for full documentation.
385393
///
386394
/// **Panics** if an index is out of bounds or step size is zero.<br>
@@ -394,18 +402,20 @@ where
394402
self.ndim(),
395403
"The input dimension of `info` must match the array to be sliced.",
396404
);
397-
info.as_ref()
398-
.iter()
399-
.enumerate()
400-
.for_each(|(axis, ax_info)| match ax_info {
405+
let mut axis = 0;
406+
info.as_ref().iter().for_each(|ax_info| match ax_info {
401407
&AxisSliceInfo::Slice { start, end, step } => {
402-
self.slice_axis_inplace(Axis(axis), Slice { start, end, step })
408+
self.slice_axis_inplace(Axis(axis), Slice { start, end, step });
409+
axis += 1;
403410
}
404411
&AxisSliceInfo::Index(index) => {
405412
let i_usize = abs_index(self.len_of(Axis(axis)), index);
406-
self.collapse_axis(Axis(axis), i_usize)
413+
self.collapse_axis(Axis(axis), i_usize);
414+
axis += 1;
407415
}
416+
&AxisSliceInfo::NewAxis => {}
408417
});
418+
debug_assert_eq!(axis, self.ndim());
409419
}
410420

411421
/// Slice the array in place without changing the number of dimensions.

src/lib.rs

+13-10
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ pub use crate::indexes::{indices, indices_of};
127127
pub use crate::error::{ShapeError, ErrorKind};
128128
pub use crate::slice::{
129129
deref_raw_view_mut_into_view_with_life, deref_raw_view_mut_into_view_mut_with_life,
130-
life_of_view_mut, AxisSliceInfo, Slice, SliceInfo, SliceNextInDim, SliceNextOutDim,
130+
life_of_view_mut, AxisSliceInfo, NewAxis, Slice, SliceInfo, SliceNextInDim, SliceNextOutDim,
131131
};
132132

133133
use crate::iterators::Baseiter;
@@ -467,22 +467,24 @@ pub type Ixs = isize;
467467
///
468468
/// If a range is used, the axis is preserved. If an index is used, that index
469469
/// is selected and the axis is removed; this selects a subview. See
470-
/// [*Subviews*](#subviews) for more information about subviews. Note that
471-
/// [`.slice_collapse()`] behaves like [`.collapse_axis()`] by preserving
472-
/// the number of dimensions.
470+
/// [*Subviews*](#subviews) for more information about subviews. If a
471+
/// [`NewAxis`] instance is used, a new axis is inserted. Note that
472+
/// [`.slice_collapse()`] ignores `NewAxis` elements and behaves like
473+
/// [`.collapse_axis()`] by preserving the number of dimensions.
473474
///
474475
/// [`.slice()`]: #method.slice
475476
/// [`.slice_mut()`]: #method.slice_mut
476477
/// [`.slice_move()`]: #method.slice_move
477478
/// [`.slice_collapse()`]: #method.slice_collapse
479+
/// [`NewAxis`]: struct.NewAxis.html
478480
///
479481
/// It's possible to take multiple simultaneous *mutable* slices with the
480482
/// [`multislice!()`](macro.multislice!.html) macro.
481483
///
482484
/// ```
483485
/// extern crate ndarray;
484486
///
485-
/// use ndarray::{arr2, arr3, multislice, s};
487+
/// use ndarray::{arr2, arr3, multislice, s, NewAxis};
486488
///
487489
/// fn main() {
488490
///
@@ -519,16 +521,17 @@ pub type Ixs = isize;
519521
/// assert_eq!(d, e);
520522
/// assert_eq!(d.shape(), &[2, 1, 3]);
521523
///
522-
/// // Let’s create a slice while selecting a subview with
524+
/// // Let’s create a slice while selecting a subview and inserting a new axis with
523525
/// //
524526
/// // - Both submatrices of the greatest dimension: `..`
525527
/// // - The last row in each submatrix, removing that axis: `-1`
526528
/// // - Row elements in reverse order: `..;-1`
527-
/// let f = a.slice(s![.., -1, ..;-1]);
528-
/// let g = arr2(&[[ 6, 5, 4],
529-
/// [12, 11, 10]]);
529+
/// // - A new axis at the end.
530+
/// let f = a.slice(s![.., -1, ..;-1, NewAxis]);
531+
/// let g = arr3(&[[ [6], [5], [4]],
532+
/// [[12], [11], [10]]]);
530533
/// assert_eq!(f, g);
531-
/// assert_eq!(f.shape(), &[2, 3]);
534+
/// assert_eq!(f.shape(), &[2, 3, 1]);
532535
///
533536
/// // Let's take two disjoint, mutable slices of a matrix with
534537
/// //

src/prelude.rs

+5
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ pub use crate::{
6666
ShapeBuilder,
6767
};
6868

69+
#[doc(no_inline)]
70+
pub use crate::{
71+
NewAxis,
72+
};
73+
6974
#[doc(no_inline)]
7075
pub use crate::{
7176
NdFloat,

src/slice.rs

+70-22
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,12 @@ impl Slice {
6767
}
6868
}
6969

70-
/// A slice (range with step) or an index.
70+
/// Token to represent a new axis in a slice description.
71+
///
72+
/// See also the [`s![]`](macro.s!.html) macro.
73+
pub struct NewAxis;
74+
75+
/// A slice (range with step), an index, or a new axis token.
7176
///
7277
/// See also the [`s![]`](macro.s!.html) macro for a convenient way to create a
7378
/// `&SliceInfo<[AxisSliceInfo; n], Di, Do>`.
@@ -91,6 +96,10 @@ impl Slice {
9196
/// from `a` until the end, in reverse order. It can also be created with
9297
/// `AxisSliceInfo::from(a..).step_by(-1)`. The Python equivalent is `[a::-1]`.
9398
/// The macro equivalent is `s![a..;-1]`.
99+
///
100+
/// `AxisSliceInfo::NewAxis` is a new axis of length 1. It can also be created
101+
/// with `AxisSliceInfo::from(NewAxis)`. The Python equivalent is
102+
/// `[np.newaxis]`. The macro equivalent is `s![NewAxis]`.
94103
#[derive(Debug, PartialEq, Eq, Hash)]
95104
pub enum AxisSliceInfo {
96105
/// A range with step size. `end` is an exclusive index. Negative `begin`
@@ -103,6 +112,8 @@ pub enum AxisSliceInfo {
103112
},
104113
/// A single index.
105114
Index(isize),
115+
/// A new axis of length 1.
116+
NewAxis,
106117
}
107118

108119
copy_and_clone!{AxisSliceInfo}
@@ -124,6 +135,14 @@ impl AxisSliceInfo {
124135
}
125136
}
126137

138+
/// Returns `true` if `self` is a `NewAxis` value.
139+
pub fn is_new_axis(&self) -> bool {
140+
match self {
141+
&AxisSliceInfo::NewAxis => true,
142+
_ => false,
143+
}
144+
}
145+
127146
/// Returns a new `AxisSliceInfo` with the given step size (multiplied with
128147
/// the previous step size).
129148
///
@@ -143,6 +162,7 @@ impl AxisSliceInfo {
143162
step: orig_step * step,
144163
},
145164
AxisSliceInfo::Index(s) => AxisSliceInfo::Index(s),
165+
AxisSliceInfo::NewAxis => AxisSliceInfo::NewAxis,
146166
}
147167
}
148168
}
@@ -163,6 +183,7 @@ impl fmt::Display for AxisSliceInfo {
163183
write!(f, ";{}", step)?;
164184
}
165185
}
186+
AxisSliceInfo::NewAxis => write!(f, "NewAxis")?,
166187
}
167188
Ok(())
168189
}
@@ -282,6 +303,13 @@ impl_sliceorindex_from_index!(isize);
282303
impl_sliceorindex_from_index!(usize);
283304
impl_sliceorindex_from_index!(i32);
284305

306+
impl From<NewAxis> for AxisSliceInfo {
307+
#[inline]
308+
fn from(_: NewAxis) -> AxisSliceInfo {
309+
AxisSliceInfo::NewAxis
310+
}
311+
}
312+
285313
/// A type that can slice an array of dimension `D`.
286314
///
287315
/// This trait is unsafe to implement because the implementation must ensure
@@ -402,12 +430,12 @@ where
402430
/// Errors if `Di` or `Do` is not consistent with `indices`.
403431
pub fn new(indices: T) -> Result<SliceInfo<T, Di, Do>, ShapeError> {
404432
if let Some(ndim) = Di::NDIM {
405-
if ndim != indices.as_ref().len() {
433+
if ndim != indices.as_ref().iter().filter(|s| !s.is_new_axis()).count() {
406434
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
407435
}
408436
}
409437
if let Some(ndim) = Do::NDIM {
410-
if ndim != indices.as_ref().iter().filter(|s| s.is_slice()).count() {
438+
if ndim != indices.as_ref().iter().filter(|s| !s.is_index()).count() {
411439
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
412440
}
413441
}
@@ -427,8 +455,18 @@ where
427455
{
428456
/// Returns the number of dimensions of the input array for
429457
/// [`.slice()`](struct.ArrayBase.html#method.slice).
458+
///
459+
/// If `Di` is a fixed-size dimension type, then this is equivalent to
460+
/// `Di::NDIM.unwrap()`. Otherwise, the value is calculated by iterating
461+
/// over the `AxisSliceInfo` elements.
430462
pub fn in_ndim(&self) -> usize {
431-
Di::NDIM.unwrap_or_else(|| self.indices.as_ref().len())
463+
Di::NDIM.unwrap_or_else(|| {
464+
self.indices
465+
.as_ref()
466+
.iter()
467+
.filter(|s| !s.is_new_axis())
468+
.count()
469+
})
432470
}
433471

434472
/// Returns the number of dimensions after calling
@@ -443,7 +481,7 @@ where
443481
self.indices
444482
.as_ref()
445483
.iter()
446-
.filter(|s| s.is_slice())
484+
.filter(|s| !s.is_index())
447485
.count()
448486
})
449487
}
@@ -506,6 +544,12 @@ pub trait SliceNextInDim<D1, D2> {
506544
fn next_dim(&self, _: PhantomData<D1>) -> PhantomData<D2>;
507545
}
508546

547+
impl<D1: Dimension> SliceNextInDim<D1, D1> for NewAxis {
548+
fn next_dim(&self, _: PhantomData<D1>) -> PhantomData<D1> {
549+
PhantomData
550+
}
551+
}
552+
509553
macro_rules! impl_slicenextindim_larger {
510554
(($($generics:tt)*), $self:ty) => {
511555
impl<D1: Dimension, $($generics),*> SliceNextInDim<D1, D1::Larger> for $self {
@@ -560,12 +604,13 @@ impl_slicenextoutdim_larger!((T), RangeTo<T>);
560604
impl_slicenextoutdim_larger!((T), RangeToInclusive<T>);
561605
impl_slicenextoutdim_larger!((), RangeFull);
562606
impl_slicenextoutdim_larger!((), Slice);
607+
impl_slicenextoutdim_larger!((), NewAxis);
563608

564609
/// Slice argument constructor.
565610
///
566-
/// `s![]` takes a list of ranges/slices/indices, separated by comma, with
567-
/// optional step sizes that are separated from the range by a semicolon. It is
568-
/// converted into a [`&SliceInfo`] instance.
611+
/// `s![]` takes a list of ranges/slices/indices/new-axes, separated by comma,
612+
/// with optional step sizes that are separated from the range by a semicolon.
613+
/// It is converted into a [`&SliceInfo`] instance.
569614
///
570615
/// [`&SliceInfo`]: struct.SliceInfo.html
571616
///
@@ -584,22 +629,25 @@ impl_slicenextoutdim_larger!((), Slice);
584629
/// * *slice*: a [`Slice`] instance to use for slicing that axis.
585630
/// * *slice* `;` *step*: a range constructed from the start and end of a [`Slice`]
586631
/// instance, with new step size *step*, to use for slicing that axis.
632+
/// * *new-axis*: a [`NewAxis`] instance that represents the creation of a new axis.
587633
///
588634
/// [`Slice`]: struct.Slice.html
635+
/// [`NewAxis`]: struct.NewAxis.html
589636
///
590-
/// The number of *axis-slice-info* must match the number of axes in the array.
591-
/// *index*, *range*, *slice*, and *step* can be expressions. *index* must be
592-
/// of type `isize`, `usize`, or `i32`. *range* must be of type `Range<I>`,
593-
/// `RangeTo<I>`, `RangeFrom<I>`, or `RangeFull` where `I` is `isize`, `usize`,
594-
/// or `i32`. *step* must be a type that can be converted to `isize` with the
595-
/// `as` keyword.
637+
/// The number of *axis-slice-info*, not including *new-axis*, must match the
638+
/// number of axes in the array. *index*, *range*, *slice*, *step*, and
639+
/// *new-axis* can be expressions. *index* must be of type `isize`, `usize`, or
640+
/// `i32`. *range* must be of type `Range<I>`, `RangeTo<I>`, `RangeFrom<I>`, or
641+
/// `RangeFull` where `I` is `isize`, `usize`, or `i32`. *step* must be a type
642+
/// that can be converted to `isize` with the `as` keyword.
596643
///
597-
/// For example `s![0..4;2, 6, 1..5]` is a slice of the first axis for 0..4
598-
/// with step size 2, a subview of the second axis at index 6, and a slice of
599-
/// the third axis for 1..5 with default step size 1. The input array must have
600-
/// 3 dimensions. The resulting slice would have shape `[2, 4]` for
601-
/// [`.slice()`], [`.slice_mut()`], and [`.slice_move()`], and shape
602-
/// `[2, 1, 4]` for [`.slice_collapse()`].
644+
/// For example `s![0..4;2, 6, 1..5, NewAxis]` is a slice of the first axis for
645+
/// 0..4 with step size 2, a subview of the second axis at index 6, a slice of
646+
/// the third axis for 1..5 with default step size 1, and a new axis of length
647+
/// 1 at the end of the shape. The input array must have 3 dimensions. The
648+
/// resulting slice would have shape `[2, 4, 1]` for [`.slice()`],
649+
/// [`.slice_mut()`], and [`.slice_move()`], and shape `[2, 1, 4]` for
650+
/// [`.slice_collapse()`].
603651
///
604652
/// [`.slice()`]: struct.ArrayBase.html#method.slice
605653
/// [`.slice_mut()`]: struct.ArrayBase.html#method.slice_mut
@@ -726,11 +774,11 @@ macro_rules! s(
726774
}
727775
}
728776
};
729-
// convert range/index into AxisSliceInfo
777+
// convert range/index/new-axis into AxisSliceInfo
730778
(@convert $r:expr) => {
731779
<$crate::AxisSliceInfo as ::std::convert::From<_>>::from($r)
732780
};
733-
// convert range/index and step into AxisSliceInfo
781+
// convert range/index/new-axis and step into AxisSliceInfo
734782
(@convert $r:expr, $s:expr) => {
735783
<$crate::AxisSliceInfo as ::std::convert::From<_>>::from($r).step_by($s as isize)
736784
};

0 commit comments

Comments
 (0)