diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 42666a894..6e3eeaf58 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1694,6 +1694,34 @@ impl ArrayBase where S: Data, D: Dimension } } + /// Call `f` on a mutable reference of each element and create a new array + /// with the new values. + /// + /// Elements are visited in arbitrary order. + /// + /// Return an array with the same shape as `self`. + pub fn map_mut<'a, B, F>(&'a mut self, f: F) -> Array + where F: FnMut(&'a mut A) -> B, + A: 'a, + S: DataMut + { + let dim = self.dim.clone(); + if self.is_contiguous() { + let strides = self.strides.clone(); + let slc = self.as_slice_memory_order_mut().unwrap(); + let v = ::iterators::to_vec_mapped(slc.iter_mut(), f); + unsafe { + ArrayBase::from_shape_vec_unchecked( + dim.strides(strides), v) + } + } else { + let v = ::iterators::to_vec_mapped(self.iter_mut(), f); + unsafe { + ArrayBase::from_shape_vec_unchecked(dim, v) + } + } + } + /// Call `f` by **v**alue on each element and create a new array /// with the new values. /// @@ -1819,4 +1847,32 @@ impl ArrayBase where S: Data, D: Dimension } }) } + + /// Reduce the values along an axis into just one value, producing a new + /// array with one less dimension. + /// 1-dimensional lanes are passed as mutable references to the reducer, + /// allowing for side-effects. + /// + /// Elements are visited in arbitrary order. + /// + /// Return the result as an `Array`. + /// + /// **Panics** if `axis` is out of bounds. + pub fn map_axis_mut<'a, B, F>(&'a mut self, axis: Axis, mut mapping: F) + -> Array + where D: RemoveAxis, + F: FnMut(ArrayViewMut1<'a, A>) -> B, + A: 'a, + S: DataMut, + { + let view_len = self.len_of(axis); + let view_stride = self.strides.axis(axis); + // use the 0th subview as a map to each 1d array view extended from + // the 0th element. + self.subview_mut(axis, 0).map_mut(|first_elt: &mut A| { + unsafe { + mapping(ArrayViewMut::new_(first_elt, Ix1(view_len), Ix1(view_stride))) + } + }) + } } diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 4ed5d99be..984cc2179 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -1180,9 +1180,11 @@ use indexes::IndicesIterF; unsafe impl TrustedIterator for Linspace { } unsafe impl<'a, A, D> TrustedIterator for Iter<'a, A, D> { } +unsafe impl<'a, A, D> TrustedIterator for IterMut<'a, A, D> { } unsafe impl TrustedIterator for std::iter::Map where I: TrustedIterator { } unsafe impl<'a, A> TrustedIterator for slice::Iter<'a, A> { } +unsafe impl<'a, A> TrustedIterator for slice::IterMut<'a, A> { } unsafe impl TrustedIterator for ::std::ops::Range { } // FIXME: These indices iter are dubious -- size needs to be checked up front. unsafe impl TrustedIterator for IndicesIter where D: Dimension { }