@@ -3,9 +3,10 @@ use libnum;
3
3
use std:: cmp;
4
4
use std:: ops:: {
5
5
Add ,
6
- Mul ,
7
6
} ;
8
7
8
+ use linalg:: LinalgScalar ;
9
+
9
10
/// Compute the sum of the values in `xs`
10
11
pub fn unrolled_sum < A > ( mut xs : & [ A ] ) -> A
11
12
where A : Clone + Add < Output =A > + libnum:: Zero ,
@@ -44,7 +45,7 @@ pub fn unrolled_sum<A>(mut xs: &[A]) -> A
44
45
///
45
46
/// `xs` and `ys` must be the same length
46
47
pub fn unrolled_dot < A > ( xs : & [ A ] , ys : & [ A ] ) -> A
47
- where A : Clone + Add < Output = A > + Mul < Output = A > + libnum :: Zero ,
48
+ where A : LinalgScalar ,
48
49
{
49
50
debug_assert_eq ! ( xs. len( ) , ys. len( ) ) ;
50
51
// eightfold unrolled so that floating point can be vectorized
@@ -58,24 +59,24 @@ pub fn unrolled_dot<A>(xs: &[A], ys: &[A]) -> A
58
59
( A :: zero ( ) , A :: zero ( ) , A :: zero ( ) , A :: zero ( ) ,
59
60
A :: zero ( ) , A :: zero ( ) , A :: zero ( ) , A :: zero ( ) ) ;
60
61
while xs. len ( ) >= 8 {
61
- p0 = p0 + xs[ 0 ] . clone ( ) * ys[ 0 ] . clone ( ) ;
62
- p1 = p1 + xs[ 1 ] . clone ( ) * ys[ 1 ] . clone ( ) ;
63
- p2 = p2 + xs[ 2 ] . clone ( ) * ys[ 2 ] . clone ( ) ;
64
- p3 = p3 + xs[ 3 ] . clone ( ) * ys[ 3 ] . clone ( ) ;
65
- p4 = p4 + xs[ 4 ] . clone ( ) * ys[ 4 ] . clone ( ) ;
66
- p5 = p5 + xs[ 5 ] . clone ( ) * ys[ 5 ] . clone ( ) ;
67
- p6 = p6 + xs[ 6 ] . clone ( ) * ys[ 6 ] . clone ( ) ;
68
- p7 = p7 + xs[ 7 ] . clone ( ) * ys[ 7 ] . clone ( ) ;
62
+ p0 = p0 + xs[ 0 ] * ys[ 0 ] ;
63
+ p1 = p1 + xs[ 1 ] * ys[ 1 ] ;
64
+ p2 = p2 + xs[ 2 ] * ys[ 2 ] ;
65
+ p3 = p3 + xs[ 3 ] * ys[ 3 ] ;
66
+ p4 = p4 + xs[ 4 ] * ys[ 4 ] ;
67
+ p5 = p5 + xs[ 5 ] * ys[ 5 ] ;
68
+ p6 = p6 + xs[ 6 ] * ys[ 6 ] ;
69
+ p7 = p7 + xs[ 7 ] * ys[ 7 ] ;
69
70
70
71
xs = & xs[ 8 ..] ;
71
72
ys = & ys[ 8 ..] ;
72
73
}
73
- sum = sum. clone ( ) + ( p0 + p4) ;
74
- sum = sum. clone ( ) + ( p1 + p5) ;
75
- sum = sum. clone ( ) + ( p2 + p6) ;
76
- sum = sum. clone ( ) + ( p3 + p7) ;
74
+ sum = sum + ( p0 + p4) ;
75
+ sum = sum + ( p1 + p5) ;
76
+ sum = sum + ( p2 + p6) ;
77
+ sum = sum + ( p3 + p7) ;
77
78
for i in 0 ..xs. len ( ) {
78
- sum = sum. clone ( ) + xs[ i] . clone ( ) * ys[ i] . clone ( ) ;
79
+ sum = sum + xs[ i] * ys[ i] ;
79
80
}
80
81
sum
81
82
}
0 commit comments