8
8
use std:: cmp:: Ordering ;
9
9
use std:: fmt:: Debug ;
10
10
use std:: slice;
11
+ use itertools:: free:: enumerate;
11
12
12
13
use super :: { Si , Ix , Ixs } ;
13
14
use super :: zipsl;
@@ -19,30 +20,6 @@ pub fn stride_offset(n: Ix, stride: Ix) -> isize {
19
20
( n as isize ) * ( ( stride as Ixs ) as isize )
20
21
}
21
22
22
- /// Check whether `stride` is strictly positive
23
- #[ inline]
24
- fn stride_is_positive ( stride : Ix ) -> bool {
25
- ( stride as Ixs ) > 0
26
- }
27
-
28
- /// Return the axis ordering corresponding to the fastest variation
29
- ///
30
- /// Assumes that no stride value appears twice. This cannot yield the correct
31
- /// result the strides are not positive.
32
- fn fastest_varying_order < D : Dimension > ( strides : & D ) -> D {
33
- let mut sorted = strides. clone ( ) ;
34
- sorted. slice_mut ( ) . sort ( ) ;
35
- let mut res = strides. clone ( ) ;
36
- for ( index, & val) in strides. slice ( ) . iter ( ) . enumerate ( ) {
37
- let sorted_ind = sorted. slice ( )
38
- . iter ( )
39
- . position ( |& x| x == val)
40
- . unwrap ( ) ; // cannot panic by construction
41
- res. slice_mut ( ) [ sorted_ind] = index;
42
- }
43
- res
44
- }
45
-
46
23
/// Check whether the given `dim` and `stride` lead to overlapping indices
47
24
///
48
25
/// There is overlap if, when iterating through the dimensions in the order
@@ -51,15 +28,19 @@ fn fastest_varying_order<D: Dimension>(strides: &D) -> D {
51
28
///
52
29
/// The current implementation assumes strides to be positive
53
30
pub fn dim_stride_overlap < D : Dimension > ( dim : & D , strides : & D ) -> bool {
54
- let order = fastest_varying_order ( strides) ;
31
+ let order = strides. _fastest_varying_stride_order ( ) ;
55
32
33
+ let dim = dim. slice ( ) ;
34
+ let strides = strides. slice ( ) ;
56
35
let mut prev_offset = 1 ;
57
- for & index in order. slice ( ) . iter ( ) {
58
- let s = strides. slice ( ) [ index] ;
59
- if ( s as isize ) < prev_offset {
36
+ for & index in order. slice ( ) {
37
+ let d = dim[ index] ;
38
+ let s = strides[ index] ;
39
+ // any stride is ok if dimension is 1
40
+ if d != 1 && ( s as isize ) < prev_offset {
60
41
return true ;
61
42
}
62
- prev_offset = stride_offset ( dim . slice ( ) [ index ] , s) ;
43
+ prev_offset = stride_offset ( d , s) ;
63
44
}
64
45
false
65
46
}
@@ -74,33 +55,42 @@ pub fn dim_stride_overlap<D: Dimension>(dim: &D, strides: &D) -> bool {
74
55
pub fn can_index_slice < A , D : Dimension > ( data : & [ A ] , dim : & D , strides : & D )
75
56
-> Result < ( ) , ShapeError >
76
57
{
77
- if strides. slice ( ) . iter ( ) . cloned ( ) . all ( stride_is_positive) {
78
- if dim. size_checked ( ) . is_none ( ) {
79
- return Err ( from_kind ( ErrorKind :: OutOfBounds ) ) ;
58
+ // check lengths of axes.
59
+ let len = match dim. size_checked ( ) {
60
+ Some ( l) => l,
61
+ None => return Err ( from_kind ( ErrorKind :: OutOfBounds ) ) ,
62
+ } ;
63
+ // check if strides are strictly positive (zero ok for len 0)
64
+ for & s in strides. slice ( ) {
65
+ let s = s as Ixs ;
66
+ if s < 1 && ( len != 0 || s < 0 ) {
67
+ return Err ( from_kind ( ErrorKind :: Unsupported ) ) ;
80
68
}
81
- let mut last_index = dim. clone ( ) ;
82
- for mut index in last_index. slice_mut ( ) . iter_mut ( ) {
83
- * index -= 1 ;
84
- }
85
- if let Some ( offset) = stride_offset_checked_arithmetic ( dim,
86
- strides,
87
- & last_index)
88
- {
89
- // offset is guaranteed to be positive so no issue converting
90
- // to usize here
91
- if ( offset as usize ) >= data. len ( ) {
92
- return Err ( from_kind ( ErrorKind :: OutOfBounds ) ) ;
93
- }
94
- if dim_stride_overlap ( dim, strides) {
95
- return Err ( from_kind ( ErrorKind :: Unsupported ) ) ;
96
- }
97
- } else {
69
+ }
70
+ if len == 0 {
71
+ return Ok ( ( ) ) ;
72
+ }
73
+ // check that the maximum index is in bounds
74
+ let mut last_index = dim. clone ( ) ;
75
+ for mut index in last_index. slice_mut ( ) . iter_mut ( ) {
76
+ * index -= 1 ;
77
+ }
78
+ if let Some ( offset) = stride_offset_checked_arithmetic ( dim,
79
+ strides,
80
+ & last_index)
81
+ {
82
+ // offset is guaranteed to be positive so no issue converting
83
+ // to usize here
84
+ if ( offset as usize ) >= data. len ( ) {
98
85
return Err ( from_kind ( ErrorKind :: OutOfBounds ) ) ;
99
86
}
100
- Ok ( ( ) )
87
+ if dim_stride_overlap ( dim, strides) {
88
+ return Err ( from_kind ( ErrorKind :: Unsupported ) ) ;
89
+ }
101
90
} else {
102
- return Err ( from_kind ( ErrorKind :: Unsupported ) ) ;
91
+ return Err ( from_kind ( ErrorKind :: OutOfBounds ) ) ;
103
92
}
93
+ Ok ( ( ) )
104
94
}
105
95
106
96
/// Return stride offset for this dimension and index.
@@ -335,6 +325,21 @@ pub unsafe trait Dimension : Clone + Eq + Debug + Send + Sync {
335
325
offset
336
326
}
337
327
328
+ /// Return the axis ordering corresponding to the fastest variation
329
+ /// (in ascending order).
330
+ ///
331
+ /// Assumes that no stride value appears twice. This cannot yield the correct
332
+ /// result the strides are not positive.
333
+ #[ doc( hidden) ]
334
+ fn _fastest_varying_stride_order ( & self ) -> Self {
335
+ let mut indices = self . clone ( ) ;
336
+ for ( i, elt) in enumerate ( indices. slice_mut ( ) ) {
337
+ * elt = i;
338
+ }
339
+ let strides = self . slice ( ) ;
340
+ indices. slice_mut ( ) . sort_by_key ( |& i| strides[ i] ) ;
341
+ indices
342
+ }
338
343
}
339
344
340
345
/// Implementation-specific extensions to `Dimension`
@@ -484,6 +489,11 @@ unsafe impl Dimension for (Ix, Ix) {
484
489
( self . 1 , 1 )
485
490
}
486
491
492
+ #[ inline]
493
+ fn _fastest_varying_stride_order ( & self ) -> Self {
494
+ if self . 0 as Ixs <= self . 1 as Ixs { ( 0 , 1 ) } else { ( 1 , 0 ) }
495
+ }
496
+
487
497
#[ inline]
488
498
fn first_index ( & self ) -> Option < ( Ix , Ix ) > {
489
499
let ( m, n) = * self ;
@@ -563,6 +573,29 @@ unsafe impl Dimension for (Ix, Ix, Ix) {
563
573
let ( s, t, u) = * strides;
564
574
stride_offset ( i, s) + stride_offset ( j, t) + stride_offset ( k, u)
565
575
}
576
+
577
+ #[ inline]
578
+ fn _fastest_varying_stride_order ( & self ) -> Self {
579
+ let mut stride = * self ;
580
+ let mut order = ( 0 , 1 , 2 ) ;
581
+ macro_rules! swap {
582
+ ( $stride: expr, $order: expr, $x: expr, $y: expr) => {
583
+ if $stride[ $x] > $stride[ $y] {
584
+ $stride. swap( $x, $y) ;
585
+ $order. swap( $x, $y) ;
586
+ }
587
+ }
588
+ }
589
+ {
590
+ // stable sorting network for 3 elements
591
+ let order = order. slice_mut ( ) ;
592
+ let strides = stride. slice_mut ( ) ;
593
+ swap ! [ strides, order, 1 , 2 ] ;
594
+ swap ! [ strides, order, 0 , 1 ] ;
595
+ swap ! [ strides, order, 1 , 2 ] ;
596
+ }
597
+ order
598
+ }
566
599
}
567
600
568
601
macro_rules! large_dim {
@@ -742,13 +775,6 @@ mod test {
742
775
use super :: Dimension ;
743
776
use error:: StrideError ;
744
777
745
- #[ test]
746
- fn fastest_varying_order ( ) {
747
- let strides = ( 2 , 8 , 4 , 1 ) ;
748
- let order = super :: fastest_varying_order ( & strides) ;
749
- assert_eq ! ( order. slice( ) , & [ 3 , 0 , 2 , 1 ] ) ;
750
- }
751
-
752
778
#[ test]
753
779
fn slice_indexing_uncommon_strides ( ) {
754
780
let v: Vec < _ > = ( 0 ..12 ) . collect ( ) ;
0 commit comments