@@ -757,6 +757,33 @@ where
757
757
}
758
758
}
759
759
760
+ /// Attempt to merge axes if possible, starting from the back
761
+ ///
762
+ /// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
763
+ /// to merge all axes one by one into Axis(3); when/if this fails,
764
+ /// it attempts to merge the rest of the axes together into the next
765
+ /// axis in line, for example a result could be:
766
+ ///
767
+ /// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
768
+ /// mean axes were merged.
769
+ pub ( crate ) fn merge_axes_from_the_back < D > ( dim : & mut D , strides : & mut D )
770
+ where
771
+ D : Dimension ,
772
+ {
773
+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
774
+ match dim. ndim ( ) {
775
+ 0 | 1 => { }
776
+ n => {
777
+ let mut last = n - 1 ;
778
+ for i in ( 0 ..last) . rev ( ) {
779
+ if !merge_axes ( dim, strides, Axis ( i) , Axis ( last) ) {
780
+ last = i;
781
+ }
782
+ }
783
+ }
784
+ }
785
+ }
786
+
760
787
/// Move the axis which has the smallest absolute stride and a length
761
788
/// greater than one to be the last axis.
762
789
pub fn move_min_stride_axis_to_last < D > ( dim : & mut D , strides : & mut D )
@@ -821,12 +848,40 @@ where
821
848
* strides = new_strides;
822
849
}
823
850
851
+
852
+ /// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
853
+ /// stride
854
+ ///
855
+ /// The axes are sorted according to the .abs() of their stride.
856
+ pub ( crate ) fn sort_axes_to_standard < D > ( dim : & mut D , strides : & mut D )
857
+ where
858
+ D : Dimension ,
859
+ {
860
+ debug_assert ! ( dim. ndim( ) > 1 ) ;
861
+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
862
+ // bubble sort axes
863
+ let mut changed = true ;
864
+ while changed {
865
+ changed = false ;
866
+ for i in 0 ..dim. ndim ( ) - 1 {
867
+ // make sure higher stride axes sort before.
868
+ if strides. get_stride ( Axis ( i) ) . abs ( ) < strides. get_stride ( Axis ( i + 1 ) ) . abs ( ) {
869
+ changed = true ;
870
+ dim. slice_mut ( ) . swap ( i, i + 1 ) ;
871
+ strides. slice_mut ( ) . swap ( i, i + 1 ) ;
872
+ }
873
+ }
874
+ }
875
+ }
876
+
877
+
824
878
#[ cfg( test) ]
825
879
mod test {
826
880
use super :: {
827
881
arith_seq_intersect, can_index_slice, can_index_slice_not_custom, extended_gcd,
828
882
max_abs_offset_check_overflow, slice_min_max, slices_intersect,
829
883
solve_linear_diophantine_eq, IntoDimension , squeeze,
884
+ merge_axes_from_the_back,
830
885
} ;
831
886
use crate :: error:: { from_kind, ErrorKind } ;
832
887
use crate :: slice:: Slice ;
@@ -1191,4 +1246,26 @@ mod test {
1191
1246
assert_eq ! ( d, dans) ;
1192
1247
assert_eq ! ( s, sans) ;
1193
1248
}
1249
+
1250
+ #[ test]
1251
+ fn test_merge_axes_from_the_back ( ) {
1252
+ let dyndim = Dim :: < & [ usize ] > ;
1253
+
1254
+ let mut d = Dim ( [ 3 , 4 , 5 ] ) ;
1255
+ let mut s = Dim ( [ 20 , 5 , 1 ] ) ;
1256
+ merge_axes_from_the_back ( & mut d, & mut s) ;
1257
+ assert_eq ! ( d, Dim ( [ 1 , 1 , 60 ] ) ) ;
1258
+ assert_eq ! ( s, Dim ( [ 20 , 5 , 1 ] ) ) ;
1259
+
1260
+ let mut d = Dim ( [ 3 , 4 , 5 , 2 ] ) ;
1261
+ let mut s = Dim ( [ 80 , 20 , 2 , 1 ] ) ;
1262
+ merge_axes_from_the_back ( & mut d, & mut s) ;
1263
+ assert_eq ! ( d, Dim ( [ 1 , 12 , 1 , 10 ] ) ) ;
1264
+ assert_eq ! ( s, Dim ( [ 80 , 20 , 2 , 1 ] ) ) ;
1265
+ let mut d = d. into_dyn ( ) ;
1266
+ let mut s = s. into_dyn ( ) ;
1267
+ squeeze ( & mut d, & mut s) ;
1268
+ assert_eq ! ( d, dyndim( & [ 12 , 10 ] ) ) ;
1269
+ assert_eq ! ( s, dyndim( & [ 20 , 1 ] ) ) ;
1270
+ }
1194
1271
}
0 commit comments