Skip to content

Commit 37645bd

Browse files
committed
Guarantee that as_slice_memory_order_mut preserves strides
This fixes bugs in `.map_mut()` and `.zip_mut_with_same_shape()`. Before this commit, strides obtained before calling `.as_slice_memory_order_mut()` could not be used to correctly interpret the data in the returned slice. Now, the strides are preserved, so the implementations of `.map_mut()` and `.zip_mut_with_same_shape()` work correctly. This also makes it much easier for users of the crate to use `.as_slice_memory_order_mut()` correctly in generic code. Fixes #1018.
1 parent 15b0808 commit 37645bd

File tree

3 files changed

+50
-12
lines changed

3 files changed

+50
-12
lines changed

src/data_traits.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,11 @@ pub unsafe trait RawData: Sized {
5050
pub unsafe trait RawDataMut: RawData {
5151
/// If possible, ensures that the array has unique access to its data.
5252
///
53-
/// If `Self` provides safe mutable access to array elements, then it
54-
/// **must** panic or ensure that the data is unique.
53+
/// The implementer must ensure that if the input is contiguous, then the
54+
/// output has the same strides as input.
55+
///
56+
/// Additionally, if `Self` provides safe mutable access to array elements,
57+
/// then this method **must** panic or ensure that the data is unique.
5558
#[doc(hidden)]
5659
fn try_ensure_unique<D>(_: &mut ArrayBase<Self, D>)
5760
where
@@ -230,14 +233,9 @@ where
230233
return;
231234
}
232235
if self_.dim.size() <= self_.data.0.len() / 2 {
233-
// Create a new vec if the current view is less than half of
234-
// backing data.
235-
unsafe {
236-
*self_ = ArrayBase::from_shape_vec_unchecked(
237-
self_.dim.clone(),
238-
self_.iter().cloned().collect(),
239-
);
240-
}
236+
// Clone only the visible elements if the current view is less than
237+
// half of backing data.
238+
*self_ = self_.to_owned().into_shared();
241239
return;
242240
}
243241
let rcvec = &mut self_.data.0;

src/impl_methods.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,6 +1536,10 @@ where
15361536

15371537
/// Return the array’s data as a slice if it is contiguous,
15381538
/// return `None` otherwise.
1539+
///
1540+
/// In the contiguous case, in order to return a unique reference, this
1541+
/// method unshares the data if necessary, but it preserves the existing
1542+
/// strides.
15391543
pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]>
15401544
where
15411545
S: DataMut,

tests/array.rs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -990,8 +990,8 @@ fn map1() {
990990
}
991991

992992
#[test]
993-
fn as_slice_memory_order() {
994-
// test that mutation breaks sharing
993+
fn as_slice_memory_order_mut_arcarray() {
994+
// Test that mutation breaks sharing for `ArcArray`.
995995
let a = rcarr2(&[[1., 2.], [3., 4.0f32]]);
996996
let mut b = a.clone();
997997
for elt in b.as_slice_memory_order_mut().unwrap() {
@@ -1000,6 +1000,38 @@ fn as_slice_memory_order() {
10001000
assert!(a != b, "{:?} != {:?}", a, b);
10011001
}
10021002

1003+
#[test]
1004+
fn as_slice_memory_order_mut_cowarray() {
1005+
// Test that mutation breaks sharing for `CowArray`.
1006+
let a = arr2(&[[1., 2.], [3., 4.0f32]]);
1007+
let mut b = CowArray::from(a.view());
1008+
for elt in b.as_slice_memory_order_mut().unwrap() {
1009+
*elt = 0.;
1010+
}
1011+
assert!(a != b, "{:?} != {:?}", a, b);
1012+
}
1013+
1014+
#[test]
1015+
fn as_slice_memory_order_mut_contiguous_arcarray() {
1016+
// Test that unsharing preserves the strides in the contiguous case for `ArcArray`.
1017+
let a = rcarr2(&[[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]).reversed_axes();
1018+
let mut b = a.clone().slice_move(s![.., ..2]);
1019+
assert_eq!(b.strides(), &[1, 2]);
1020+
b.as_slice_memory_order_mut().unwrap();
1021+
assert_eq!(b.strides(), &[1, 2]);
1022+
}
1023+
1024+
#[test]
1025+
fn as_slice_memory_order_mut_contiguous_cowarray() {
1026+
// Test that unsharing preserves the strides in the contiguous case for `CowArray`.
1027+
let a = arr2(&[[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]).reversed_axes();
1028+
let mut b = CowArray::from(a.slice(s![.., ..2]));
1029+
assert!(b.is_view());
1030+
assert_eq!(b.strides(), &[1, 2]);
1031+
b.as_slice_memory_order_mut().unwrap();
1032+
assert_eq!(b.strides(), &[1, 2]);
1033+
}
1034+
10031035
#[test]
10041036
fn array0_into_scalar() {
10051037
// With this kind of setup, the `Array`'s pointer is not the same as the
@@ -1809,6 +1841,10 @@ fn map_mut_with_unsharing() {
18091841
// `.map_mut()` unshares the data. Earlier versions of `ndarray` failed
18101842
// this assertion. See #1018.
18111843
assert_eq!(b.map_mut(|&mut x| x + 10), array![[10, 11], [15, 16]]);
1844+
1845+
// The strides should be preserved.
1846+
assert_eq!(b.shape(), &[2, 2]);
1847+
assert_eq!(b.strides(), &[1, 2]);
18121848
}
18131849

18141850
#[test]

0 commit comments

Comments
 (0)