Skip to content

Commit 81662ae

Browse files
authored
Merge pull request #1019 from jturner314/improve-as_slice_memory_order_mut
Guarantee that `.as_slice_memory_order_mut()` preserves strides
2 parents 1daff26 + 37645bd commit 81662ae

File tree

3 files changed

+74
-13
lines changed

3 files changed

+74
-13
lines changed

src/data_traits.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,11 @@ pub unsafe trait RawData: Sized {
5050
pub unsafe trait RawDataMut: RawData {
5151
/// If possible, ensures that the array has unique access to its data.
5252
///
53-
/// If `Self` provides safe mutable access to array elements, then it
54-
/// **must** panic or ensure that the data is unique.
53+
/// The implementer must ensure that if the input is contiguous, then the
54+
/// output has the same strides as input.
55+
///
56+
/// Additionally, if `Self` provides safe mutable access to array elements,
57+
/// then this method **must** panic or ensure that the data is unique.
5558
#[doc(hidden)]
5659
fn try_ensure_unique<D>(_: &mut ArrayBase<Self, D>)
5760
where
@@ -230,14 +233,9 @@ where
230233
return;
231234
}
232235
if self_.dim.size() <= self_.data.0.len() / 2 {
233-
// Create a new vec if the current view is less than half of
234-
// backing data.
235-
unsafe {
236-
*self_ = ArrayBase::from_shape_vec_unchecked(
237-
self_.dim.clone(),
238-
self_.iter().cloned().collect(),
239-
);
240-
}
236+
// Clone only the visible elements if the current view is less than
237+
// half of backing data.
238+
*self_ = self_.to_owned().into_shared();
241239
return;
242240
}
243241
let rcvec = &mut self_.data.0;

src/impl_methods.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ where
217217
)
218218
}
219219
} else {
220-
self.map(|x| x.clone())
220+
self.map(A::clone)
221221
}
222222
}
223223

@@ -1548,6 +1548,10 @@ where
15481548

15491549
/// Return the array’s data as a slice if it is contiguous,
15501550
/// return `None` otherwise.
1551+
///
1552+
/// In the contiguous case, in order to return a unique reference, this
1553+
/// method unshares the data if necessary, but it preserves the existing
1554+
/// strides.
15511555
pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]>
15521556
where
15531557
S: DataMut,

tests/array.rs

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -990,8 +990,8 @@ fn map1() {
990990
}
991991

992992
#[test]
993-
fn as_slice_memory_order() {
994-
// test that mutation breaks sharing
993+
fn as_slice_memory_order_mut_arcarray() {
994+
// Test that mutation breaks sharing for `ArcArray`.
995995
let a = rcarr2(&[[1., 2.], [3., 4.0f32]]);
996996
let mut b = a.clone();
997997
for elt in b.as_slice_memory_order_mut().unwrap() {
@@ -1000,6 +1000,38 @@ fn as_slice_memory_order() {
10001000
assert!(a != b, "{:?} != {:?}", a, b);
10011001
}
10021002

1003+
#[test]
1004+
fn as_slice_memory_order_mut_cowarray() {
1005+
// Test that mutation breaks sharing for `CowArray`.
1006+
let a = arr2(&[[1., 2.], [3., 4.0f32]]);
1007+
let mut b = CowArray::from(a.view());
1008+
for elt in b.as_slice_memory_order_mut().unwrap() {
1009+
*elt = 0.;
1010+
}
1011+
assert!(a != b, "{:?} != {:?}", a, b);
1012+
}
1013+
1014+
#[test]
1015+
fn as_slice_memory_order_mut_contiguous_arcarray() {
1016+
// Test that unsharing preserves the strides in the contiguous case for `ArcArray`.
1017+
let a = rcarr2(&[[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]).reversed_axes();
1018+
let mut b = a.clone().slice_move(s![.., ..2]);
1019+
assert_eq!(b.strides(), &[1, 2]);
1020+
b.as_slice_memory_order_mut().unwrap();
1021+
assert_eq!(b.strides(), &[1, 2]);
1022+
}
1023+
1024+
#[test]
1025+
fn as_slice_memory_order_mut_contiguous_cowarray() {
1026+
// Test that unsharing preserves the strides in the contiguous case for `CowArray`.
1027+
let a = arr2(&[[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]).reversed_axes();
1028+
let mut b = CowArray::from(a.slice(s![.., ..2]));
1029+
assert!(b.is_view());
1030+
assert_eq!(b.strides(), &[1, 2]);
1031+
b.as_slice_memory_order_mut().unwrap();
1032+
assert_eq!(b.strides(), &[1, 2]);
1033+
}
1034+
10031035
#[test]
10041036
fn array0_into_scalar() {
10051037
// With this kind of setup, the `Array`'s pointer is not the same as the
@@ -1788,6 +1820,33 @@ fn map_memory_order() {
17881820
assert_eq!(amap.strides(), v.strides());
17891821
}
17901822

1823+
#[test]
1824+
fn map_mut_with_unsharing() {
1825+
// Fortran-layout `ArcArray`.
1826+
let a = rcarr2(&[[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]).reversed_axes();
1827+
assert_eq!(a.shape(), &[2, 5]);
1828+
assert_eq!(a.strides(), &[1, 2]);
1829+
assert_eq!(
1830+
a.as_slice_memory_order(),
1831+
Some(&[0, 5, 1, 6, 2, 7, 3, 8, 4, 9][..])
1832+
);
1833+
1834+
// Shared reference of a portion of `a`.
1835+
let mut b = a.clone().slice_move(s![.., ..2]);
1836+
assert_eq!(b.shape(), &[2, 2]);
1837+
assert_eq!(b.strides(), &[1, 2]);
1838+
assert_eq!(b.as_slice_memory_order(), Some(&[0, 5, 1, 6][..]));
1839+
assert_eq!(b, array![[0, 1], [5, 6]]);
1840+
1841+
// `.map_mut()` unshares the data. Earlier versions of `ndarray` failed
1842+
// this assertion. See #1018.
1843+
assert_eq!(b.map_mut(|&mut x| x + 10), array![[10, 11], [15, 16]]);
1844+
1845+
// The strides should be preserved.
1846+
assert_eq!(b.shape(), &[2, 2]);
1847+
assert_eq!(b.strides(), &[1, 2]);
1848+
}
1849+
17911850
#[test]
17921851
fn test_view_from_shape() {
17931852
let s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];

0 commit comments

Comments
 (0)