@@ -10,6 +10,7 @@ use libnum;
10
10
use std:: cmp;
11
11
use std:: ops:: {
12
12
Add ,
13
+ Mul ,
13
14
} ;
14
15
15
16
use LinalgScalar ;
@@ -51,6 +52,43 @@ pub fn unrolled_sum<A>(mut xs: &[A]) -> A
51
52
sum
52
53
}
53
54
55
+ /// Compute the product of the values in `xs`
56
+ pub fn unrolled_prod < A > ( mut xs : & [ A ] ) -> A
57
+ where A : Clone + Mul < Output =A > + libnum:: One ,
58
+ {
59
+ // eightfold unrolled so that floating point can be vectorized
60
+ // (even with strict floating point accuracy semantics)
61
+ let mut prod = A :: one ( ) ;
62
+ let ( mut p0, mut p1, mut p2, mut p3,
63
+ mut p4, mut p5, mut p6, mut p7) =
64
+ ( A :: one ( ) , A :: one ( ) , A :: one ( ) , A :: one ( ) ,
65
+ A :: one ( ) , A :: one ( ) , A :: one ( ) , A :: one ( ) ) ;
66
+ while xs. len ( ) >= 8 {
67
+ p0 = p0 * xs[ 0 ] . clone ( ) ;
68
+ p1 = p1 * xs[ 1 ] . clone ( ) ;
69
+ p2 = p2 * xs[ 2 ] . clone ( ) ;
70
+ p3 = p3 * xs[ 3 ] . clone ( ) ;
71
+ p4 = p4 * xs[ 4 ] . clone ( ) ;
72
+ p5 = p5 * xs[ 5 ] . clone ( ) ;
73
+ p6 = p6 * xs[ 6 ] . clone ( ) ;
74
+ p7 = p7 * xs[ 7 ] . clone ( ) ;
75
+
76
+ xs = & xs[ 8 ..] ;
77
+ }
78
+ prod = prod. clone ( ) * ( p0 * p4) ;
79
+ prod = prod. clone ( ) * ( p1 * p5) ;
80
+ prod = prod. clone ( ) * ( p2 * p6) ;
81
+ prod = prod. clone ( ) * ( p3 * p7) ;
82
+
83
+ // make it clear to the optimizer that this loop is short
84
+ // and can not be autovectorized.
85
+ for i in 0 ..xs. len ( ) {
86
+ if i >= 7 { break ; }
87
+ prod = prod. clone ( ) * xs[ i] . clone ( )
88
+ }
89
+ prod
90
+ }
91
+
54
92
/// Compute the dot product.
55
93
///
56
94
/// `xs` and `ys` must be the same length
0 commit comments