Skip to content

Commit c9f2937

Browse files
committed
Simplify unrolled_dot
Simplify based on the new restriction (Copy).
1 parent 0e61653 commit c9f2937

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

src/numeric_util.rs

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ use libnum;
33
use std::cmp;
44
use std::ops::{
55
Add,
6-
Mul,
76
};
87

8+
use linalg::LinalgScalar;
9+
910
/// Compute the sum of the values in `xs`
1011
pub fn unrolled_sum<A>(mut xs: &[A]) -> A
1112
where A: Clone + Add<Output=A> + libnum::Zero,
@@ -44,7 +45,7 @@ pub fn unrolled_sum<A>(mut xs: &[A]) -> A
4445
///
4546
/// `xs` and `ys` must be the same length
4647
pub fn unrolled_dot<A>(xs: &[A], ys: &[A]) -> A
47-
where A: Clone + Add<Output=A> + Mul<Output=A> + libnum::Zero,
48+
where A: LinalgScalar,
4849
{
4950
debug_assert_eq!(xs.len(), ys.len());
5051
// eightfold unrolled so that floating point can be vectorized
@@ -58,24 +59,24 @@ pub fn unrolled_dot<A>(xs: &[A], ys: &[A]) -> A
5859
(A::zero(), A::zero(), A::zero(), A::zero(),
5960
A::zero(), A::zero(), A::zero(), A::zero());
6061
while xs.len() >= 8 {
61-
p0 = p0 + xs[0].clone() * ys[0].clone();
62-
p1 = p1 + xs[1].clone() * ys[1].clone();
63-
p2 = p2 + xs[2].clone() * ys[2].clone();
64-
p3 = p3 + xs[3].clone() * ys[3].clone();
65-
p4 = p4 + xs[4].clone() * ys[4].clone();
66-
p5 = p5 + xs[5].clone() * ys[5].clone();
67-
p6 = p6 + xs[6].clone() * ys[6].clone();
68-
p7 = p7 + xs[7].clone() * ys[7].clone();
62+
p0 = p0 + xs[0] * ys[0];
63+
p1 = p1 + xs[1] * ys[1];
64+
p2 = p2 + xs[2] * ys[2];
65+
p3 = p3 + xs[3] * ys[3];
66+
p4 = p4 + xs[4] * ys[4];
67+
p5 = p5 + xs[5] * ys[5];
68+
p6 = p6 + xs[6] * ys[6];
69+
p7 = p7 + xs[7] * ys[7];
6970

7071
xs = &xs[8..];
7172
ys = &ys[8..];
7273
}
73-
sum = sum.clone() + (p0 + p4);
74-
sum = sum.clone() + (p1 + p5);
75-
sum = sum.clone() + (p2 + p6);
76-
sum = sum.clone() + (p3 + p7);
74+
sum = sum + (p0 + p4);
75+
sum = sum + (p1 + p5);
76+
sum = sum + (p2 + p6);
77+
sum = sum + (p3 + p7);
7778
for i in 0..xs.len() {
78-
sum = sum.clone() + xs[i].clone() * ys[i].clone();
79+
sum = sum + xs[i] * ys[i];
7980
}
8081
sum
8182
}

0 commit comments

Comments
 (0)