@@ -67,7 +67,12 @@ impl Slice {
67
67
}
68
68
}
69
69
70
- /// A slice (range with step) or an index.
70
+ /// Token to represent a new axis in a slice description.
71
+ ///
72
+ /// See also the [`s![]`](macro.s!.html) macro.
73
+ pub struct NewAxis ;
74
+
75
+ /// A slice (range with step), an index, or a new axis token.
71
76
///
72
77
/// See also the [`s![]`](macro.s!.html) macro for a convenient way to create a
73
78
/// `&SliceInfo<[AxisSliceInfo; n], Di, Do>`.
@@ -91,6 +96,10 @@ impl Slice {
91
96
/// from `a` until the end, in reverse order. It can also be created with
92
97
/// `AxisSliceInfo::from(a..).step_by(-1)`. The Python equivalent is `[a::-1]`.
93
98
/// The macro equivalent is `s![a..;-1]`.
99
+ ///
100
+ /// `AxisSliceInfo::NewAxis` is a new axis of length 1. It can also be created
101
+ /// with `AxisSliceInfo::from(NewAxis)`. The Python equivalent is
102
+ /// `[np.newaxis]`. The macro equivalent is `s![NewAxis]`.
94
103
#[ derive( Debug , PartialEq , Eq , Hash ) ]
95
104
pub enum AxisSliceInfo {
96
105
/// A range with step size. `end` is an exclusive index. Negative `begin`
@@ -103,6 +112,8 @@ pub enum AxisSliceInfo {
103
112
} ,
104
113
/// A single index.
105
114
Index ( isize ) ,
115
+ /// A new axis of length 1.
116
+ NewAxis ,
106
117
}
107
118
108
119
copy_and_clone ! { AxisSliceInfo }
@@ -124,6 +135,14 @@ impl AxisSliceInfo {
124
135
}
125
136
}
126
137
138
+ /// Returns `true` if `self` is a `NewAxis` value.
139
+ pub fn is_new_axis ( & self ) -> bool {
140
+ match self {
141
+ & AxisSliceInfo :: NewAxis => true ,
142
+ _ => false ,
143
+ }
144
+ }
145
+
127
146
/// Returns a new `AxisSliceInfo` with the given step size (multiplied with
128
147
/// the previous step size).
129
148
///
@@ -143,6 +162,7 @@ impl AxisSliceInfo {
143
162
step : orig_step * step,
144
163
} ,
145
164
AxisSliceInfo :: Index ( s) => AxisSliceInfo :: Index ( s) ,
165
+ AxisSliceInfo :: NewAxis => AxisSliceInfo :: NewAxis ,
146
166
}
147
167
}
148
168
}
@@ -163,6 +183,7 @@ impl fmt::Display for AxisSliceInfo {
163
183
write ! ( f, ";{}" , step) ?;
164
184
}
165
185
}
186
+ AxisSliceInfo :: NewAxis => write ! ( f, "NewAxis" ) ?,
166
187
}
167
188
Ok ( ( ) )
168
189
}
@@ -282,6 +303,13 @@ impl_sliceorindex_from_index!(isize);
282
303
impl_sliceorindex_from_index ! ( usize ) ;
283
304
impl_sliceorindex_from_index ! ( i32 ) ;
284
305
306
+ impl From < NewAxis > for AxisSliceInfo {
307
+ #[ inline]
308
+ fn from ( _: NewAxis ) -> AxisSliceInfo {
309
+ AxisSliceInfo :: NewAxis
310
+ }
311
+ }
312
+
285
313
/// A type that can slice an array of dimension `D`.
286
314
///
287
315
/// This trait is unsafe to implement because the implementation must ensure
@@ -402,12 +430,12 @@ where
402
430
/// Errors if `Di` or `Do` is not consistent with `indices`.
403
431
pub fn new ( indices : T ) -> Result < SliceInfo < T , Di , Do > , ShapeError > {
404
432
if let Some ( ndim) = Di :: NDIM {
405
- if ndim != indices. as_ref ( ) . len ( ) {
433
+ if ndim != indices. as_ref ( ) . iter ( ) . filter ( |s| !s . is_new_axis ( ) ) . count ( ) {
406
434
return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
407
435
}
408
436
}
409
437
if let Some ( ndim) = Do :: NDIM {
410
- if ndim != indices. as_ref ( ) . iter ( ) . filter ( |s| s . is_slice ( ) ) . count ( ) {
438
+ if ndim != indices. as_ref ( ) . iter ( ) . filter ( |s| !s . is_index ( ) ) . count ( ) {
411
439
return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
412
440
}
413
441
}
@@ -427,8 +455,18 @@ where
427
455
{
428
456
/// Returns the number of dimensions of the input array for
429
457
/// [`.slice()`](struct.ArrayBase.html#method.slice).
458
+ ///
459
+ /// If `Di` is a fixed-size dimension type, then this is equivalent to
460
+ /// `Di::NDIM.unwrap()`. Otherwise, the value is calculated by iterating
461
+ /// over the `AxisSliceInfo` elements.
430
462
pub fn in_ndim ( & self ) -> usize {
431
- Di :: NDIM . unwrap_or_else ( || self . indices . as_ref ( ) . len ( ) )
463
+ Di :: NDIM . unwrap_or_else ( || {
464
+ self . indices
465
+ . as_ref ( )
466
+ . iter ( )
467
+ . filter ( |s| !s. is_new_axis ( ) )
468
+ . count ( )
469
+ } )
432
470
}
433
471
434
472
/// Returns the number of dimensions after calling
@@ -443,7 +481,7 @@ where
443
481
self . indices
444
482
. as_ref ( )
445
483
. iter ( )
446
- . filter ( |s| s . is_slice ( ) )
484
+ . filter ( |s| !s . is_index ( ) )
447
485
. count ( )
448
486
} )
449
487
}
@@ -506,6 +544,12 @@ pub trait SliceNextInDim<D1, D2> {
506
544
fn next_dim ( & self , PhantomData < D1 > ) -> PhantomData < D2 > ;
507
545
}
508
546
547
+ impl < D1 : Dimension > SliceNextInDim < D1 , D1 > for NewAxis {
548
+ fn next_dim ( & self , _: PhantomData < D1 > ) -> PhantomData < D1 > {
549
+ PhantomData
550
+ }
551
+ }
552
+
509
553
macro_rules! impl_slicenextindim_larger {
510
554
( ( $( $generics: tt) * ) , $self: ty) => {
511
555
impl <D1 : Dimension , $( $generics) ,* > SliceNextInDim <D1 , D1 :: Larger > for $self {
@@ -560,12 +604,13 @@ impl_slicenextoutdim_larger!((T), RangeTo<T>);
560
604
impl_slicenextoutdim_larger ! ( ( T ) , RangeToInclusive <T >) ;
561
605
impl_slicenextoutdim_larger ! ( ( ) , RangeFull ) ;
562
606
impl_slicenextoutdim_larger ! ( ( ) , Slice ) ;
607
+ impl_slicenextoutdim_larger ! ( ( ) , NewAxis ) ;
563
608
564
609
/// Slice argument constructor.
565
610
///
566
- /// `s![]` takes a list of ranges/slices/indices, separated by comma, with
567
- /// optional step sizes that are separated from the range by a semicolon. It is
568
- /// converted into a [`&SliceInfo`] instance.
611
+ /// `s![]` takes a list of ranges/slices/indices/new-axes , separated by comma,
612
+ /// with optional step sizes that are separated from the range by a semicolon.
613
+ /// It is converted into a [`&SliceInfo`] instance.
569
614
///
570
615
/// [`&SliceInfo`]: struct.SliceInfo.html
571
616
///
@@ -584,22 +629,25 @@ impl_slicenextoutdim_larger!((), Slice);
584
629
/// * *slice*: a [`Slice`] instance to use for slicing that axis.
585
630
/// * *slice* `;` *step*: a range constructed from the start and end of a [`Slice`]
586
631
/// instance, with new step size *step*, to use for slicing that axis.
632
+ /// * *new-axis*: a [`NewAxis`] instance that represents the creation of a new axis.
587
633
///
588
634
/// [`Slice`]: struct.Slice.html
635
+ /// [`NewAxis`]: struct.NewAxis.html
589
636
///
590
- /// The number of *axis-slice-info* must match the number of axes in the array.
591
- /// *index*, *range*, *slice*, and *step* can be expressions. *index* must be
592
- /// of type `isize`, `usize`, or `i32` . *range * must be of type `Range<I>`,
593
- /// `RangeTo <I>`, `RangeFrom <I>`, or `RangeFull` where `I` is `isize`, `usize`,
594
- /// or `i32`. *step* must be a type that can be converted to `isize` with the
595
- /// `as` keyword.
637
+ /// The number of *axis-slice-info*, not including *new-axis*, must match the
638
+ /// number of axes in the array. *index*, *range*, *slice*, *step*, and
639
+ /// *new-axis* can be expressions . *index * must be of type `isize`, `usize`, or
640
+ /// `i32`. *range* must be of type `Range <I>`, `RangeTo <I>`, `RangeFrom<I>`, or
641
+ /// `RangeFull` where `I` is `isize`, `usize`, or `i32`. *step* must be a type
642
+ /// that can be converted to `isize` with the `as` keyword.
596
643
///
597
- /// For example `s![0..4;2, 6, 1..5]` is a slice of the first axis for 0..4
598
- /// with step size 2, a subview of the second axis at index 6, and a slice of
599
- /// the third axis for 1..5 with default step size 1. The input array must have
600
- /// 3 dimensions. The resulting slice would have shape `[2, 4]` for
601
- /// [`.slice()`], [`.slice_mut()`], and [`.slice_move()`], and shape
602
- /// `[2, 1, 4]` for [`.slice_collapse()`].
644
+ /// For example `s![0..4;2, 6, 1..5, NewAxis]` is a slice of the first axis for
645
+ /// 0..4 with step size 2, a subview of the second axis at index 6, a slice of
646
+ /// the third axis for 1..5 with default step size 1, and a new axis of length
647
+ /// 1 at the end of the shape. The input array must have 3 dimensions. The
648
+ /// resulting slice would have shape `[2, 4, 1]` for [`.slice()`],
649
+ /// [`.slice_mut()`], and [`.slice_move()`], and shape `[2, 1, 4]` for
650
+ /// [`.slice_collapse()`].
603
651
///
604
652
/// [`.slice()`]: struct.ArrayBase.html#method.slice
605
653
/// [`.slice_mut()`]: struct.ArrayBase.html#method.slice_mut
@@ -726,11 +774,11 @@ macro_rules! s(
726
774
}
727
775
}
728
776
} ;
729
- // convert range/index into AxisSliceInfo
777
+ // convert range/index/new-axis into AxisSliceInfo
730
778
( @convert $r: expr) => {
731
779
<$crate:: AxisSliceInfo as :: std:: convert:: From <_>>:: from( $r)
732
780
} ;
733
- // convert range/index and step into AxisSliceInfo
781
+ // convert range/index/new-axis and step into AxisSliceInfo
734
782
( @convert $r: expr, $s: expr) => {
735
783
<$crate:: AxisSliceInfo as :: std:: convert:: From <_>>:: from( $r) . step_by( $s as isize )
736
784
} ;
0 commit comments