Skip to content

Implement real/imag splitting of arrays #1107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions src/impl_raw_views.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use num_complex::Complex;
use std::mem;
use std::ptr::NonNull;

Expand Down Expand Up @@ -149,6 +150,73 @@ where
}
}

impl<T, D> RawArrayView<Complex<T>, D>
where
D: Dimension,
{
/// Splits the view into views of the real and imaginary components of the
/// elements.
pub fn split_re_im(self) -> Complex<RawArrayView<T, D>> {
// Check that the size and alignment of `Complex<T>` are as expected.
// These assertions should always pass, for arbitrary `T`.
assert_eq!(
mem::size_of::<Complex<T>>(),
mem::size_of::<T>().checked_mul(2).unwrap()
);
assert_eq!(mem::align_of::<Complex<T>>(), mem::align_of::<T>());

let dim = self.dim.clone();

// Double the strides. In the zero-sized element case and for axes of
// length <= 1, we leave the strides as-is to avoid possible overflow.
let mut strides = self.strides.clone();
if mem::size_of::<T>() != 0 {
for ax in 0..strides.ndim() {
if dim[ax] > 1 {
strides[ax] = (strides[ax] as isize * 2) as usize;
}
}
}

let ptr_re: *mut T = self.ptr.as_ptr().cast();
let ptr_im: *mut T = if self.is_empty() {
// In the empty case, we can just reuse the existing pointer since
// it won't be dereferenced anyway. It is not safe to offset by
// one, since the allocation may be empty.
ptr_re
} else {
// In the nonempty case, we can safely offset into the first
// (complex) element.
unsafe { ptr_re.add(1) }
};

// `Complex` is `repr(C)` with only fields `re: T` and `im: T`. So, the
// real components of the elements start at the same pointer, and the
// imaginary components start at the pointer offset by one, with
// exactly double the strides. The new, doubled strides still meet the
// overflow constraints:
//
// - For the zero-sized element case, the strides are unchanged in
// units of bytes and in units of the element type.
//
// - For the nonzero-sized element case:
//
// - In units of bytes, the strides are unchanged. The only exception
// is axes of length <= 1, but those strides are irrelevant anyway.
//
// - Since `Complex<T>` for nonzero `T` is always at least 2 bytes,
// and the original strides did not overflow in units of bytes, we
// know that the new, doubled strides will not overflow in units of
// `T`.
unsafe {
Complex {
re: RawArrayView::new_(ptr_re, dim.clone(), strides.clone()),
im: RawArrayView::new_(ptr_im, dim, strides),
}
}
}
}

impl<A, D> RawArrayViewMut<A, D>
where
D: Dimension,
Expand Down Expand Up @@ -300,3 +368,20 @@ where
unsafe { RawArrayViewMut::new(ptr, self.dim, self.strides) }
}
}

impl<T, D> RawArrayViewMut<Complex<T>, D>
where
D: Dimension,
{
/// Splits the view into views of the real and imaginary components of the
/// elements.
pub fn split_re_im(self) -> Complex<RawArrayViewMut<T, D>> {
let Complex { re, im } = self.into_raw_view().split_re_im();
unsafe {
Complex {
re: RawArrayViewMut::new(re.ptr, re.dim, re.strides),
im: RawArrayViewMut::new(im.ptr, im.dim, im.strides),
}
}
}
}
70 changes: 70 additions & 0 deletions src/impl_views/splitting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

use crate::imp_prelude::*;
use crate::slice::MultiSliceArg;
use num_complex::Complex;

/// Methods for read-only array views.
impl<'a, A, D> ArrayView<'a, A, D>
Expand Down Expand Up @@ -95,6 +96,37 @@ where
}
}

impl<'a, T, D> ArrayView<'a, Complex<T>, D>
where
D: Dimension,
{
/// Splits the view into views of the real and imaginary components of the
/// elements.
///
/// ```
/// use ndarray::prelude::*;
/// use num_complex::{Complex, Complex64};
///
/// let arr = array![
/// [Complex64::new(1., 2.), Complex64::new(3., 4.)],
/// [Complex64::new(5., 6.), Complex64::new(7., 8.)],
/// [Complex64::new(9., 10.), Complex64::new(11., 12.)],
/// ];
/// let Complex { re, im } = arr.view().split_re_im();
/// assert_eq!(re, array![[1., 3.], [5., 7.], [9., 11.]]);
/// assert_eq!(im, array![[2., 4.], [6., 8.], [10., 12.]]);
/// ```
pub fn split_re_im(self) -> Complex<ArrayView<'a, T, D>> {
unsafe {
let Complex { re, im } = self.into_raw_view().split_re_im();
Complex {
re: re.deref_into_view(),
im: im.deref_into_view(),
}
}
}
}

/// Methods for read-write array views.
impl<'a, A, D> ArrayViewMut<'a, A, D>
where
Expand Down Expand Up @@ -135,3 +167,41 @@ where
info.multi_slice_move(self)
}
}

impl<'a, T, D> ArrayViewMut<'a, Complex<T>, D>
where
D: Dimension,
{
/// Splits the view into views of the real and imaginary components of the
/// elements.
///
/// ```
/// use ndarray::prelude::*;
/// use num_complex::{Complex, Complex64};
///
/// let mut arr = array![
/// [Complex64::new(1., 2.), Complex64::new(3., 4.)],
/// [Complex64::new(5., 6.), Complex64::new(7., 8.)],
/// [Complex64::new(9., 10.), Complex64::new(11., 12.)],
/// ];
///
/// let Complex { mut re, mut im } = arr.view_mut().split_re_im();
/// assert_eq!(re, array![[1., 3.], [5., 7.], [9., 11.]]);
/// assert_eq!(im, array![[2., 4.], [6., 8.], [10., 12.]]);
///
/// re[[0, 1]] = 13.;
/// im[[2, 0]] = 14.;
///
/// assert_eq!(arr[[0, 1]], Complex64::new(13., 4.));
/// assert_eq!(arr[[2, 0]], Complex64::new(9., 14.));
/// ```
pub fn split_re_im(self) -> Complex<ArrayViewMut<'a, T, D>> {
unsafe {
let Complex { re, im } = self.into_raw_view_mut().split_re_im();
Complex {
re: re.deref_into_view_mut(),
im: im.deref_into_view_mut(),
}
}
}
}
70 changes: 70 additions & 0 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
clippy::float_cmp
)]

use approx::assert_relative_eq;
use defmac::defmac;
use itertools::{zip, Itertools};
use ndarray::prelude::*;
use ndarray::{arr3, rcarr2};
use ndarray::indices;
use ndarray::{Slice, SliceInfo, SliceInfoElem};
use num_complex::Complex;
use std::convert::TryFrom;

macro_rules! assert_panics {
Expand Down Expand Up @@ -2501,3 +2503,71 @@ fn test_remove_index_oob3() {
let mut a = array![[10], [4], [1]];
a.remove_index(Axis(2), 0);
}

#[test]
fn test_split_re_im_view() {
let a = Array3::from_shape_fn((3, 4, 5), |(i, j, k)| {
Complex::<f32>::new(i as f32 * j as f32, k as f32)
});
let Complex { re, im } = a.view().split_re_im();
assert_relative_eq!(re.sum(), 90.);
assert_relative_eq!(im.sum(), 120.);
}

#[test]
fn test_split_re_im_view_roundtrip() {
let a_re = Array3::from_shape_fn((3,1,5), |(i, j, _k)| {
i * j
});
let a_im = Array3::from_shape_fn((3,1,5), |(_i, _j, k)| {
k
});
let a = Array3::from_shape_fn((3,1,5), |(i,j,k)| {
Complex::new(a_re[[i,j,k]], a_im[[i,j,k]])
});
let Complex { re, im } = a.view().split_re_im();
assert_eq!(a_re, re);
assert_eq!(a_im, im);
}

#[test]
fn test_split_re_im_view_mut() {
let eye_scalar = Array2::<u32>::eye(4);
let eye_complex = Array2::<Complex<u32>>::eye(4);
let mut a = Array2::<Complex<u32>>::zeros((4, 4));
let Complex { mut re, im } = a.view_mut().split_re_im();
re.assign(&eye_scalar);
assert_eq!(im.sum(), 0);
assert_eq!(a, eye_complex);
}

#[test]
fn test_split_re_im_zerod() {
let mut a = Array0::from_elem((), Complex::new(42, 32));
let Complex { re, im } = a.view().split_re_im();
assert_eq!(re.get(()), Some(&42));
assert_eq!(im.get(()), Some(&32));
let cmplx = a.view_mut().split_re_im();
cmplx.re.assign_to(cmplx.im);
assert_eq!(a.get(()).unwrap().im, 42);
}

#[test]
fn test_split_re_im_permuted() {
let a = Array3::from_shape_fn((3, 4, 5), |(i, j, k)| {
Complex::new(i * k + j, k)
});
let permuted = a.view().permuted_axes([1,0,2]);
let Complex { re, im } = permuted.split_re_im();
assert_eq!(re.get((3,2,4)).unwrap(), &11);
assert_eq!(im.get((3,2,4)).unwrap(), &4);
}

#[test]
fn test_split_re_im_invert_axis() {
let mut a = Array::from_shape_fn((2, 3, 2), |(i, j, k)| Complex::new(i as f64 + j as f64, i as f64 + k as f64));
a.invert_axis(Axis(1));
let cmplx = a.view().split_re_im();
assert_eq!(cmplx.re, a.mapv(|z| z.re));
assert_eq!(cmplx.im, a.mapv(|z| z.im));
}