diff --git a/src/impl_raw_views.rs b/src/impl_raw_views.rs index 90cfa6376..425af96c2 100644 --- a/src/impl_raw_views.rs +++ b/src/impl_raw_views.rs @@ -1,3 +1,4 @@ +use num_complex::Complex; use std::mem; use std::ptr::NonNull; @@ -149,6 +150,73 @@ where } } +impl RawArrayView, D> +where + D: Dimension, +{ + /// Splits the view into views of the real and imaginary components of the + /// elements. + pub fn split_re_im(self) -> Complex> { + // Check that the size and alignment of `Complex` are as expected. + // These assertions should always pass, for arbitrary `T`. + assert_eq!( + mem::size_of::>(), + mem::size_of::().checked_mul(2).unwrap() + ); + assert_eq!(mem::align_of::>(), mem::align_of::()); + + let dim = self.dim.clone(); + + // Double the strides. In the zero-sized element case and for axes of + // length <= 1, we leave the strides as-is to avoid possible overflow. + let mut strides = self.strides.clone(); + if mem::size_of::() != 0 { + for ax in 0..strides.ndim() { + if dim[ax] > 1 { + strides[ax] = (strides[ax] as isize * 2) as usize; + } + } + } + + let ptr_re: *mut T = self.ptr.as_ptr().cast(); + let ptr_im: *mut T = if self.is_empty() { + // In the empty case, we can just reuse the existing pointer since + // it won't be dereferenced anyway. It is not safe to offset by + // one, since the allocation may be empty. + ptr_re + } else { + // In the nonempty case, we can safely offset into the first + // (complex) element. + unsafe { ptr_re.add(1) } + }; + + // `Complex` is `repr(C)` with only fields `re: T` and `im: T`. So, the + // real components of the elements start at the same pointer, and the + // imaginary components start at the pointer offset by one, with + // exactly double the strides. The new, doubled strides still meet the + // overflow constraints: + // + // - For the zero-sized element case, the strides are unchanged in + // units of bytes and in units of the element type. + // + // - For the nonzero-sized element case: + // + // - In units of bytes, the strides are unchanged. The only exception + // is axes of length <= 1, but those strides are irrelevant anyway. + // + // - Since `Complex` for nonzero `T` is always at least 2 bytes, + // and the original strides did not overflow in units of bytes, we + // know that the new, doubled strides will not overflow in units of + // `T`. + unsafe { + Complex { + re: RawArrayView::new_(ptr_re, dim.clone(), strides.clone()), + im: RawArrayView::new_(ptr_im, dim, strides), + } + } + } +} + impl RawArrayViewMut where D: Dimension, @@ -300,3 +368,20 @@ where unsafe { RawArrayViewMut::new(ptr, self.dim, self.strides) } } } + +impl RawArrayViewMut, D> +where + D: Dimension, +{ + /// Splits the view into views of the real and imaginary components of the + /// elements. + pub fn split_re_im(self) -> Complex> { + let Complex { re, im } = self.into_raw_view().split_re_im(); + unsafe { + Complex { + re: RawArrayViewMut::new(re.ptr, re.dim, re.strides), + im: RawArrayViewMut::new(im.ptr, im.dim, im.strides), + } + } + } +} diff --git a/src/impl_views/splitting.rs b/src/impl_views/splitting.rs index f2c0dc41a..84d24038a 100644 --- a/src/impl_views/splitting.rs +++ b/src/impl_views/splitting.rs @@ -8,6 +8,7 @@ use crate::imp_prelude::*; use crate::slice::MultiSliceArg; +use num_complex::Complex; /// Methods for read-only array views. impl<'a, A, D> ArrayView<'a, A, D> @@ -95,6 +96,37 @@ where } } +impl<'a, T, D> ArrayView<'a, Complex, D> +where + D: Dimension, +{ + /// Splits the view into views of the real and imaginary components of the + /// elements. + /// + /// ``` + /// use ndarray::prelude::*; + /// use num_complex::{Complex, Complex64}; + /// + /// let arr = array![ + /// [Complex64::new(1., 2.), Complex64::new(3., 4.)], + /// [Complex64::new(5., 6.), Complex64::new(7., 8.)], + /// [Complex64::new(9., 10.), Complex64::new(11., 12.)], + /// ]; + /// let Complex { re, im } = arr.view().split_re_im(); + /// assert_eq!(re, array![[1., 3.], [5., 7.], [9., 11.]]); + /// assert_eq!(im, array![[2., 4.], [6., 8.], [10., 12.]]); + /// ``` + pub fn split_re_im(self) -> Complex> { + unsafe { + let Complex { re, im } = self.into_raw_view().split_re_im(); + Complex { + re: re.deref_into_view(), + im: im.deref_into_view(), + } + } + } +} + /// Methods for read-write array views. impl<'a, A, D> ArrayViewMut<'a, A, D> where @@ -135,3 +167,41 @@ where info.multi_slice_move(self) } } + +impl<'a, T, D> ArrayViewMut<'a, Complex, D> +where + D: Dimension, +{ + /// Splits the view into views of the real and imaginary components of the + /// elements. + /// + /// ``` + /// use ndarray::prelude::*; + /// use num_complex::{Complex, Complex64}; + /// + /// let mut arr = array![ + /// [Complex64::new(1., 2.), Complex64::new(3., 4.)], + /// [Complex64::new(5., 6.), Complex64::new(7., 8.)], + /// [Complex64::new(9., 10.), Complex64::new(11., 12.)], + /// ]; + /// + /// let Complex { mut re, mut im } = arr.view_mut().split_re_im(); + /// assert_eq!(re, array![[1., 3.], [5., 7.], [9., 11.]]); + /// assert_eq!(im, array![[2., 4.], [6., 8.], [10., 12.]]); + /// + /// re[[0, 1]] = 13.; + /// im[[2, 0]] = 14.; + /// + /// assert_eq!(arr[[0, 1]], Complex64::new(13., 4.)); + /// assert_eq!(arr[[2, 0]], Complex64::new(9., 14.)); + /// ``` + pub fn split_re_im(self) -> Complex> { + unsafe { + let Complex { re, im } = self.into_raw_view_mut().split_re_im(); + Complex { + re: re.deref_into_view_mut(), + im: im.deref_into_view_mut(), + } + } + } +} diff --git a/tests/array.rs b/tests/array.rs index d0fc67def..16d901568 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -7,12 +7,14 @@ clippy::float_cmp )] +use approx::assert_relative_eq; use defmac::defmac; use itertools::{zip, Itertools}; use ndarray::prelude::*; use ndarray::{arr3, rcarr2}; use ndarray::indices; use ndarray::{Slice, SliceInfo, SliceInfoElem}; +use num_complex::Complex; use std::convert::TryFrom; macro_rules! assert_panics { @@ -2501,3 +2503,71 @@ fn test_remove_index_oob3() { let mut a = array![[10], [4], [1]]; a.remove_index(Axis(2), 0); } + +#[test] +fn test_split_re_im_view() { + let a = Array3::from_shape_fn((3, 4, 5), |(i, j, k)| { + Complex::::new(i as f32 * j as f32, k as f32) + }); + let Complex { re, im } = a.view().split_re_im(); + assert_relative_eq!(re.sum(), 90.); + assert_relative_eq!(im.sum(), 120.); +} + +#[test] +fn test_split_re_im_view_roundtrip() { + let a_re = Array3::from_shape_fn((3,1,5), |(i, j, _k)| { + i * j + }); + let a_im = Array3::from_shape_fn((3,1,5), |(_i, _j, k)| { + k + }); + let a = Array3::from_shape_fn((3,1,5), |(i,j,k)| { + Complex::new(a_re[[i,j,k]], a_im[[i,j,k]]) + }); + let Complex { re, im } = a.view().split_re_im(); + assert_eq!(a_re, re); + assert_eq!(a_im, im); +} + +#[test] +fn test_split_re_im_view_mut() { + let eye_scalar = Array2::::eye(4); + let eye_complex = Array2::>::eye(4); + let mut a = Array2::>::zeros((4, 4)); + let Complex { mut re, im } = a.view_mut().split_re_im(); + re.assign(&eye_scalar); + assert_eq!(im.sum(), 0); + assert_eq!(a, eye_complex); +} + +#[test] +fn test_split_re_im_zerod() { + let mut a = Array0::from_elem((), Complex::new(42, 32)); + let Complex { re, im } = a.view().split_re_im(); + assert_eq!(re.get(()), Some(&42)); + assert_eq!(im.get(()), Some(&32)); + let cmplx = a.view_mut().split_re_im(); + cmplx.re.assign_to(cmplx.im); + assert_eq!(a.get(()).unwrap().im, 42); +} + +#[test] +fn test_split_re_im_permuted() { + let a = Array3::from_shape_fn((3, 4, 5), |(i, j, k)| { + Complex::new(i * k + j, k) + }); + let permuted = a.view().permuted_axes([1,0,2]); + let Complex { re, im } = permuted.split_re_im(); + assert_eq!(re.get((3,2,4)).unwrap(), &11); + assert_eq!(im.get((3,2,4)).unwrap(), &4); +} + +#[test] +fn test_split_re_im_invert_axis() { + let mut a = Array::from_shape_fn((2, 3, 2), |(i, j, k)| Complex::new(i as f64 + j as f64, i as f64 + k as f64)); + a.invert_axis(Axis(1)); + let cmplx = a.view().split_re_im(); + assert_eq!(cmplx.re, a.mapv(|z| z.re)); + assert_eq!(cmplx.im, a.mapv(|z| z.im)); +}