Skip to content

Commit 1e88385

Browse files
Merge pull request #4 from jturner314/pairwise-summation
Improve pairwise summation
2 parents 4a63cb3 + 1ed1a63 commit 1e88385

File tree

8 files changed

+168
-151
lines changed

8 files changed

+168
-151
lines changed

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ serde = { version = "1.0", optional = true }
4545

4646
[dev-dependencies]
4747
defmac = "0.2"
48-
quickcheck = { version = "0.7.2", default-features = false }
48+
quickcheck = { version = "0.8.1", default-features = false }
49+
quickcheck_macros = "0.8"
4950
rawpointer = "0.1"
5051
rand = "0.5.5"
5152

src/dimension/dimension_trait.rs

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,8 @@ pub trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
291291
indices
292292
}
293293

294-
/// Compute the minimum stride axis (absolute value), under the constraint
295-
/// that the length of the axis is > 1;
294+
/// Compute the minimum stride axis (absolute value), preferring axes with
295+
/// length > 1.
296296
#[doc(hidden)]
297297
fn min_stride_axis(&self, strides: &Self) -> Axis {
298298
let n = match self.ndim() {
@@ -301,7 +301,7 @@ pub trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
301301
n => n,
302302
};
303303
axes_of(self, strides)
304-
.rev()
304+
.filter(|ax| ax.len() > 1)
305305
.min_by_key(|ax| ax.stride().abs())
306306
.map_or(Axis(n - 1), |ax| ax.axis())
307307
}
@@ -588,9 +588,9 @@ impl Dimension for Dim<[Ix; 2]> {
588588

589589
#[inline]
590590
fn min_stride_axis(&self, strides: &Self) -> Axis {
591-
let s = get!(strides, 0) as Ixs;
592-
let t = get!(strides, 1) as Ixs;
593-
if s.abs() < t.abs() {
591+
let s = (get!(strides, 0) as isize).abs();
592+
let t = (get!(strides, 1) as isize).abs();
593+
if s < t && get!(self, 0) > 1 {
594594
Axis(0)
595595
} else {
596596
Axis(1)
@@ -697,6 +697,23 @@ impl Dimension for Dim<[Ix; 3]> {
697697
Some(Ix3(i, j, k))
698698
}
699699

700+
#[inline]
701+
fn min_stride_axis(&self, strides: &Self) -> Axis {
702+
let s = (get!(strides, 0) as isize).abs();
703+
let t = (get!(strides, 1) as isize).abs();
704+
let u = (get!(strides, 2) as isize).abs();
705+
let (argmin, min) = if t < u && get!(self, 1) > 1 {
706+
(Axis(1), t)
707+
} else {
708+
(Axis(2), u)
709+
};
710+
if s < min && get!(self, 0) > 1 {
711+
Axis(0)
712+
} else {
713+
argmin
714+
}
715+
}
716+
700717
/// Self is an index, return the stride offset
701718
#[inline]
702719
fn stride_offset(index: &Self, strides: &Self) -> isize {

src/dimension/mod.rs

Lines changed: 78 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,8 @@ mod test {
629629
use crate::error::{from_kind, ErrorKind};
630630
use crate::slice::Slice;
631631
use num_integer::gcd;
632-
use quickcheck::{quickcheck, TestResult};
632+
use quickcheck::TestResult;
633+
use quickcheck_macros::quickcheck;
633634

634635
#[test]
635636
fn slice_indexing_uncommon_strides() {
@@ -738,30 +739,29 @@ mod test {
738739
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err();
739740
}
740741

741-
quickcheck! {
742-
fn can_index_slice_not_custom_same_as_can_index_slice(data: Vec<u8>, dim: Vec<usize>) -> bool {
743-
let dim = IxDyn(&dim);
744-
let result = can_index_slice_not_custom(&data, &dim);
745-
if dim.size_checked().is_none() {
746-
// Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`.
747-
result.is_err()
748-
} else {
749-
result == can_index_slice(&data, &dim, &dim.default_strides()) &&
750-
result == can_index_slice(&data, &dim, &dim.fortran_strides())
751-
}
742+
#[quickcheck]
743+
fn can_index_slice_not_custom_same_as_can_index_slice(data: Vec<u8>, dim: Vec<usize>) -> bool {
744+
let dim = IxDyn(&dim);
745+
let result = can_index_slice_not_custom(&data, &dim);
746+
if dim.size_checked().is_none() {
747+
// Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`.
748+
result.is_err()
749+
} else {
750+
result == can_index_slice(&data, &dim, &dim.default_strides()) &&
751+
result == can_index_slice(&data, &dim, &dim.fortran_strides())
752752
}
753753
}
754754

755-
quickcheck! {
756-
fn extended_gcd_solves_eq(a: isize, b: isize) -> bool {
757-
let (g, (x, y)) = extended_gcd(a, b);
758-
a * x + b * y == g
759-
}
755+
#[quickcheck]
756+
fn extended_gcd_solves_eq(a: isize, b: isize) -> bool {
757+
let (g, (x, y)) = extended_gcd(a, b);
758+
a * x + b * y == g
759+
}
760760

761-
fn extended_gcd_correct_gcd(a: isize, b: isize) -> bool {
762-
let (g, _) = extended_gcd(a, b);
763-
g == gcd(a, b)
764-
}
761+
#[quickcheck]
762+
fn extended_gcd_correct_gcd(a: isize, b: isize) -> bool {
763+
let (g, _) = extended_gcd(a, b);
764+
g == gcd(a, b)
765765
}
766766

767767
#[test]
@@ -773,73 +773,72 @@ mod test {
773773
assert_eq!(extended_gcd(-5, 0), (5, (-1, 0)));
774774
}
775775

776-
quickcheck! {
777-
fn solve_linear_diophantine_eq_solution_existence(
778-
a: isize, b: isize, c: isize
779-
) -> TestResult {
780-
if a == 0 || b == 0 {
781-
TestResult::discard()
782-
} else {
783-
TestResult::from_bool(
784-
(c % gcd(a, b) == 0) == solve_linear_diophantine_eq(a, b, c).is_some()
785-
)
786-
}
776+
#[quickcheck]
777+
fn solve_linear_diophantine_eq_solution_existence(
778+
a: isize, b: isize, c: isize
779+
) -> TestResult {
780+
if a == 0 || b == 0 {
781+
TestResult::discard()
782+
} else {
783+
TestResult::from_bool(
784+
(c % gcd(a, b) == 0) == solve_linear_diophantine_eq(a, b, c).is_some()
785+
)
787786
}
787+
}
788788

789-
fn solve_linear_diophantine_eq_correct_solution(
790-
a: isize, b: isize, c: isize, t: isize
791-
) -> TestResult {
792-
if a == 0 || b == 0 {
793-
TestResult::discard()
794-
} else {
795-
match solve_linear_diophantine_eq(a, b, c) {
796-
Some((x0, xd)) => {
797-
let x = x0 + xd * t;
798-
let y = (c - a * x) / b;
799-
TestResult::from_bool(a * x + b * y == c)
800-
}
801-
None => TestResult::discard(),
789+
#[quickcheck]
790+
fn solve_linear_diophantine_eq_correct_solution(
791+
a: isize, b: isize, c: isize, t: isize
792+
) -> TestResult {
793+
if a == 0 || b == 0 {
794+
TestResult::discard()
795+
} else {
796+
match solve_linear_diophantine_eq(a, b, c) {
797+
Some((x0, xd)) => {
798+
let x = x0 + xd * t;
799+
let y = (c - a * x) / b;
800+
TestResult::from_bool(a * x + b * y == c)
802801
}
802+
None => TestResult::discard(),
803803
}
804804
}
805805
}
806806

807-
quickcheck! {
808-
fn arith_seq_intersect_correct(
809-
first1: isize, len1: isize, step1: isize,
810-
first2: isize, len2: isize, step2: isize
811-
) -> TestResult {
812-
use std::cmp;
807+
#[quickcheck]
808+
fn arith_seq_intersect_correct(
809+
first1: isize, len1: isize, step1: isize,
810+
first2: isize, len2: isize, step2: isize
811+
) -> TestResult {
812+
use std::cmp;
813813

814-
if len1 == 0 || len2 == 0 {
815-
// This case is impossible to reach in `arith_seq_intersect()`
816-
// because the `min*` and `max*` arguments are inclusive.
817-
return TestResult::discard();
818-
}
819-
let len1 = len1.abs();
820-
let len2 = len2.abs();
821-
822-
// Convert to `min*` and `max*` arguments for `arith_seq_intersect()`.
823-
let last1 = first1 + step1 * (len1 - 1);
824-
let (min1, max1) = (cmp::min(first1, last1), cmp::max(first1, last1));
825-
let last2 = first2 + step2 * (len2 - 1);
826-
let (min2, max2) = (cmp::min(first2, last2), cmp::max(first2, last2));
827-
828-
// Naively determine if the sequences intersect.
829-
let seq1: Vec<_> = (0..len1)
830-
.map(|n| first1 + step1 * n)
831-
.collect();
832-
let intersects = (0..len2)
833-
.map(|n| first2 + step2 * n)
834-
.any(|elem2| seq1.contains(&elem2));
835-
836-
TestResult::from_bool(
837-
arith_seq_intersect(
838-
(min1, max1, if step1 == 0 { 1 } else { step1 }),
839-
(min2, max2, if step2 == 0 { 1 } else { step2 })
840-
) == intersects
841-
)
814+
if len1 == 0 || len2 == 0 {
815+
// This case is impossible to reach in `arith_seq_intersect()`
816+
// because the `min*` and `max*` arguments are inclusive.
817+
return TestResult::discard();
842818
}
819+
let len1 = len1.abs();
820+
let len2 = len2.abs();
821+
822+
// Convert to `min*` and `max*` arguments for `arith_seq_intersect()`.
823+
let last1 = first1 + step1 * (len1 - 1);
824+
let (min1, max1) = (cmp::min(first1, last1), cmp::max(first1, last1));
825+
let last2 = first2 + step2 * (len2 - 1);
826+
let (min2, max2) = (cmp::min(first2, last2), cmp::max(first2, last2));
827+
828+
// Naively determine if the sequences intersect.
829+
let seq1: Vec<_> = (0..len1)
830+
.map(|n| first1 + step1 * n)
831+
.collect();
832+
let intersects = (0..len2)
833+
.map(|n| first2 + step2 * n)
834+
.any(|elem2| seq1.contains(&elem2));
835+
836+
TestResult::from_bool(
837+
arith_seq_intersect(
838+
(min1, max1, if step1 == 0 { 1 } else { step1 }),
839+
(min2, max2, if step2 == 0 { 1 } else { step2 })
840+
) == intersects
841+
)
843842
}
844843

845844
#[test]

src/impl_methods.rs

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,12 +1619,11 @@ where
16191619
axes_of(&self.dim, &self.strides)
16201620
}
16211621

1622-
/*
1623-
/// Return the axis with the least stride (by absolute value)
1624-
pub fn min_stride_axis(&self) -> Axis {
1622+
/// Return the axis with the least stride (by absolute value),
1623+
/// preferring axes with len > 1.
1624+
pub(crate) fn min_stride_axis(&self) -> Axis {
16251625
self.dim.min_stride_axis(&self.strides)
16261626
}
1627-
*/
16281627

16291628
/// Return the axis with the greatest stride (by absolute value),
16301629
/// preferring axes with len > 1.
@@ -1854,25 +1853,11 @@ where
18541853
} else {
18551854
let mut v = self.view();
18561855
// put the narrowest axis at the last position
1857-
match v.ndim() {
1858-
0 | 1 => {}
1859-
2 => {
1860-
if self.len_of(Axis(1)) <= 1
1861-
|| self.len_of(Axis(0)) > 1
1862-
&& self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs()
1863-
{
1864-
v.swap_axes(0, 1);
1865-
}
1866-
}
1867-
n => {
1868-
let last = n - 1;
1869-
let narrow_axis = v
1870-
.axes()
1871-
.filter(|ax| ax.len() > 1)
1872-
.min_by_key(|ax| ax.stride().abs())
1873-
.map_or(last, |ax| ax.axis().index());
1874-
v.swap_axes(last, narrow_axis);
1875-
}
1856+
let n = v.ndim();
1857+
if n > 1 {
1858+
let last = n - 1;
1859+
let narrow_axis = self.min_stride_axis();
1860+
v.swap_axes(last, narrow_axis.index());
18761861
}
18771862
v.into_elements_base().fold(init, f)
18781863
}
@@ -2103,3 +2088,42 @@ where
21032088
})
21042089
}
21052090
}
2091+
2092+
#[cfg(test)]
2093+
mod tests {
2094+
use crate::prelude::*;
2095+
2096+
#[test]
2097+
fn min_stride_axis() {
2098+
let a = Array1::<u8>::zeros(10);
2099+
assert_eq!(a.min_stride_axis(), Axis(0));
2100+
2101+
let a = Array2::<u8>::zeros((3, 3));
2102+
assert_eq!(a.min_stride_axis(), Axis(1));
2103+
assert_eq!(a.t().min_stride_axis(), Axis(0));
2104+
2105+
let a = ArrayD::<u8>::zeros(vec![3, 3]);
2106+
assert_eq!(a.min_stride_axis(), Axis(1));
2107+
assert_eq!(a.t().min_stride_axis(), Axis(0));
2108+
2109+
let min_axis = a.axes().min_by_key(|t| t.2.abs()).unwrap().axis();
2110+
assert_eq!(min_axis, Axis(1));
2111+
2112+
let mut b = ArrayD::<u8>::zeros(vec![2, 3, 4, 5]);
2113+
assert_eq!(b.min_stride_axis(), Axis(3));
2114+
for ax in 0..3 {
2115+
b.swap_axes(3, ax);
2116+
assert_eq!(b.min_stride_axis(), Axis(ax));
2117+
b.swap_axes(3, ax);
2118+
}
2119+
let mut v = b.view();
2120+
v.collapse_axis(Axis(3), 0);
2121+
assert_eq!(v.min_stride_axis(), Axis(2));
2122+
2123+
let a = Array2::<u8>::zeros((3, 3));
2124+
let v = a.broadcast((8, 3, 3)).unwrap();
2125+
assert_eq!(v.min_stride_axis(), Axis(0));
2126+
let v2 = a.broadcast((1, 3, 3)).unwrap();
2127+
assert_eq!(v2.min_stride_axis(), Axis(2));
2128+
}
2129+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ extern crate num_integer;
105105
#[cfg(test)]
106106
extern crate quickcheck;
107107
#[cfg(test)]
108+
extern crate quickcheck_macros;
109+
#[cfg(test)]
108110
extern crate rand;
109111

110112
#[cfg(feature = "docs")]

src/numeric/impl_numeric.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,17 @@ impl<A, S, D> ArrayBase<S, D>
3232
where A: Clone + Add<Output=A> + num_traits::Zero,
3333
{
3434
if let Some(slc) = self.as_slice_memory_order() {
35-
numeric_util::pairwise_sum(&slc)
36-
} else {
37-
numeric_util::iterator_pairwise_sum(self.iter())
35+
return numeric_util::pairwise_sum(&slc);
36+
}
37+
if self.ndim() > 1 {
38+
let ax = self.dim.min_stride_axis(&self.strides);
39+
if self.len_of(ax) >= numeric_util::UNROLL_SIZE && self.stride_of(ax) == 1 {
40+
let partial_sums: Vec<_> =
41+
self.lanes(ax).into_iter().map(|lane| lane.sum()).collect();
42+
return numeric_util::pure_pairwise_sum(&partial_sums);
43+
}
3844
}
45+
numeric_util::iterator_pairwise_sum(self.iter())
3946
}
4047

4148
/// Return the sum of all elements in the array.

0 commit comments

Comments
 (0)