Skip to content

Commit 6560e3b

Browse files
committed
FEAT: Add dimension merge function to merge contiguous axes
1 parent 03675c2 commit 6560e3b

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

src/dimension/mod.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,33 @@ where
757757
}
758758
}
759759

760+
/// Attempt to merge axes if possible, starting from the back
761+
///
762+
/// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
763+
/// to merge all axes one by one into Axis(3); when/if this fails,
764+
/// it attempts to merge the rest of the axes together into the next
765+
/// axis in line, for example a result could be:
766+
///
767+
/// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
768+
/// mean axes were merged.
769+
pub(crate) fn merge_axes_from_the_back<D>(dim: &mut D, strides: &mut D)
770+
where
771+
D: Dimension,
772+
{
773+
debug_assert_eq!(dim.ndim(), strides.ndim());
774+
match dim.ndim() {
775+
0 | 1 => {}
776+
n => {
777+
let mut last = n - 1;
778+
for i in (0..last).rev() {
779+
if !merge_axes(dim, strides, Axis(i), Axis(last)) {
780+
last = i;
781+
}
782+
}
783+
}
784+
}
785+
}
786+
760787
/// Move the axis which has the smallest absolute stride and a length
761788
/// greater than one to be the last axis.
762789
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
@@ -821,12 +848,40 @@ where
821848
*strides = new_strides;
822849
}
823850

851+
852+
/// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
853+
/// stride
854+
///
855+
/// The axes are sorted according to the .abs() of their stride.
856+
pub(crate) fn sort_axes_to_standard<D>(dim: &mut D, strides: &mut D)
857+
where
858+
D: Dimension,
859+
{
860+
debug_assert!(dim.ndim() > 1);
861+
debug_assert_eq!(dim.ndim(), strides.ndim());
862+
// bubble sort axes
863+
let mut changed = true;
864+
while changed {
865+
changed = false;
866+
for i in 0..dim.ndim() - 1 {
867+
// make sure higher stride axes sort before.
868+
if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() {
869+
changed = true;
870+
dim.slice_mut().swap(i, i + 1);
871+
strides.slice_mut().swap(i, i + 1);
872+
}
873+
}
874+
}
875+
}
876+
877+
824878
#[cfg(test)]
825879
mod test {
826880
use super::{
827881
arith_seq_intersect, can_index_slice, can_index_slice_not_custom, extended_gcd,
828882
max_abs_offset_check_overflow, slice_min_max, slices_intersect,
829883
solve_linear_diophantine_eq, IntoDimension, squeeze,
884+
merge_axes_from_the_back,
830885
};
831886
use crate::error::{from_kind, ErrorKind};
832887
use crate::slice::Slice;
@@ -1191,4 +1246,26 @@ mod test {
11911246
assert_eq!(d, dans);
11921247
assert_eq!(s, sans);
11931248
}
1249+
1250+
#[test]
1251+
fn test_merge_axes_from_the_back() {
1252+
let dyndim = Dim::<&[usize]>;
1253+
1254+
let mut d = Dim([3, 4, 5]);
1255+
let mut s = Dim([20, 5, 1]);
1256+
merge_axes_from_the_back(&mut d, &mut s);
1257+
assert_eq!(d, Dim([1, 1, 60]));
1258+
assert_eq!(s, Dim([20, 5, 1]));
1259+
1260+
let mut d = Dim([3, 4, 5, 2]);
1261+
let mut s = Dim([80, 20, 2, 1]);
1262+
merge_axes_from_the_back(&mut d, &mut s);
1263+
assert_eq!(d, Dim([1, 12, 1, 10]));
1264+
assert_eq!(s, Dim([80, 20, 2, 1]));
1265+
let mut d = d.into_dyn();
1266+
let mut s = s.into_dyn();
1267+
squeeze(&mut d, &mut s);
1268+
assert_eq!(d, dyndim(&[12, 10]));
1269+
assert_eq!(s, dyndim(&[20, 1]));
1270+
}
11941271
}

0 commit comments

Comments
 (0)