Skip to content

Commit ca61280

Browse files
committed
Merge pull request #94 from vbarrielle/splitAt
split_at for Array
2 parents 0aab138 + 53a7dc5 commit ca61280

File tree

2 files changed

+120
-1
lines changed

2 files changed

+120
-1
lines changed

src/lib.rs

+72
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ pub use dimension::{
9494
RemoveAxis,
9595
};
9696

97+
use dimension::stride_offset;
98+
9799
pub use dimension::NdIndex;
98100
pub use indexes::Indexes;
99101
pub use shape_error::ShapeError;
@@ -923,6 +925,40 @@ impl<'a, A, D> ArrayView<'a, A, D>
923925
{
924926
iterators::new_outer_iter(self)
925927
}
928+
929+
/// Split the array along `axis` and return one view strictly before the
930+
/// split and one view after the split.
931+
///
932+
/// **Panics** if `axis` is out of bounds.
933+
pub fn axis_split_at(self, axis: usize, index: Ix)
934+
-> (Self, Self)
935+
{
936+
assert!(index <= self.shape()[axis]);
937+
let left_ptr = self.ptr;
938+
let right_ptr = if index == self.shape()[axis] {
939+
self.ptr
940+
} else {
941+
let offset = stride_offset(index, self.strides.slice()[axis]);
942+
unsafe {
943+
self.ptr.offset(offset)
944+
}
945+
};
946+
947+
let mut dim_left = self.dim.clone();
948+
dim_left.slice_mut()[axis] = index;
949+
let left = unsafe {
950+
Self::new_(left_ptr, dim_left, self.strides.clone())
951+
};
952+
953+
let mut dim_right = self.dim.clone();
954+
dim_right.slice_mut()[axis] = self.dim.slice()[axis] - index;
955+
let right = unsafe {
956+
Self::new_(right_ptr, dim_right, self.strides.clone())
957+
};
958+
959+
(left, right)
960+
}
961+
926962
}
927963

928964
impl<'a, A, D> ArrayViewMut<'a, A, D>
@@ -1018,6 +1054,41 @@ impl<'a, A, D> ArrayViewMut<'a, A, D>
10181054
{
10191055
iterators::new_outer_iter_mut(self)
10201056
}
1057+
1058+
/// Split the array along `axis` and return one mutable view strictly
1059+
/// before the split and one mutable view after the split.
1060+
///
1061+
/// **Panics** if `axis` is out of bounds.
1062+
pub fn axis_split_at(self, axis: usize, index: Ix)
1063+
-> (Self, Self)
1064+
{
1065+
assert!(index <= self.shape()[axis]);
1066+
let left_ptr = self.ptr;
1067+
let right_ptr = if index == self.shape()[axis] {
1068+
self.ptr
1069+
}
1070+
else {
1071+
let offset = stride_offset(index, self.strides.slice()[axis]);
1072+
unsafe {
1073+
self.ptr.offset(offset)
1074+
}
1075+
};
1076+
1077+
let mut dim_left = self.dim.clone();
1078+
dim_left.slice_mut()[axis] = index;
1079+
let left = unsafe {
1080+
Self::new_(left_ptr, dim_left, self.strides.clone())
1081+
};
1082+
1083+
let mut dim_right = self.dim.clone();
1084+
dim_right.slice_mut()[axis] = self.dim.slice()[axis] - index;
1085+
let right = unsafe {
1086+
Self::new_(right_ptr, dim_right, self.strides.clone())
1087+
};
1088+
1089+
(left, right)
1090+
}
1091+
10211092
}
10221093

10231094
impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
@@ -1439,6 +1510,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
14391510
iterators::new_axis_iter_mut(self.view_mut(), axis)
14401511
}
14411512

1513+
14421514
/// Return an iterator that traverses over `axis` by chunks of `size`,
14431515
/// yielding non-overlapping views along that axis.
14441516
///

tests/array.rs

+48-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ extern crate ndarray;
77
use ndarray::{RcArray, S, Si,
88
OwnedArray,
99
};
10-
use ndarray::{arr0, arr1, arr2,
10+
use ndarray::{arr0, arr1, arr2, arr3,
1111
aview0,
1212
aview1,
1313
aview2,
@@ -658,3 +658,50 @@ fn deny_wraparound_reshape() {
658658
let five = OwnedArray::<f32, _>::zeros(5);
659659
let _five_large = five.into_shape((3, 7, 29, 36760123, 823996703)).unwrap();
660660
}
661+
662+
#[test]
663+
fn split_at() {
664+
let mut a = arr2(&[[1., 2.], [3., 4.]]);
665+
666+
{
667+
let (c0, c1) = a.view().axis_split_at(1, 1);
668+
669+
assert_eq!(c0, arr2(&[[1.], [3.]]));
670+
assert_eq!(c1, arr2(&[[2.], [4.]]));
671+
}
672+
673+
{
674+
let (mut r0, mut r1) = a.view_mut().axis_split_at(0, 1);
675+
r0[[0, 1]] = 5.;
676+
r1[[0, 0]] = 8.;
677+
}
678+
assert_eq!(a, arr2(&[[1., 5.], [8., 4.]]));
679+
680+
681+
let b = RcArray::linspace(0., 59., 60).reshape((3, 4, 5));
682+
683+
let (left, right) = b.view().axis_split_at(2, 2);
684+
assert_eq!(left.shape(), [3, 4, 2]);
685+
assert_eq!(right.shape(), [3, 4, 3]);
686+
assert_eq!(left, arr3(&[[[0., 1.], [5., 6.], [10., 11.], [15., 16.]],
687+
[[20., 21.], [25., 26.], [30., 31.], [35., 36.]],
688+
[[40., 41.], [45., 46.], [50., 51.], [55., 56.]]]));
689+
690+
// we allow for an empty right view when index == dim[axis]
691+
let (_, right) = b.view().axis_split_at(1, 4);
692+
assert_eq!(right.shape(), [3, 0, 5]);
693+
}
694+
695+
#[test]
696+
#[should_panic]
697+
fn deny_split_at_axis_out_of_bounds() {
698+
let a = arr2(&[[1., 2.], [3., 4.]]);
699+
a.view().axis_split_at(2, 0);
700+
}
701+
702+
#[test]
703+
#[should_panic]
704+
fn deny_split_at_index_out_of_bounds() {
705+
let a = arr2(&[[1., 2.], [3., 4.]]);
706+
a.view().axis_split_at(1, 3);
707+
}

0 commit comments

Comments
 (0)