diff --git a/benches/bench1.rs b/benches/bench1.rs index 3b1e7d7cb..c37be0b74 100644 --- a/benches/bench1.rs +++ b/benches/bench1.rs @@ -12,7 +12,6 @@ use rblas::matrix::Matrix; use ndarray::{ OwnedArray, - zeros, }; use ndarray::{arr0, arr1, arr2}; @@ -398,7 +397,6 @@ fn muladd_2d_f32_blas(bench: &mut test::Bencher) }); } - #[bench] fn assign_scalar_2d_large(bench: &mut test::Bencher) { @@ -506,14 +504,14 @@ fn create_iter_4d(bench: &mut test::Bencher) #[bench] fn bench_to_owned_n(bench: &mut test::Bencher) { - let a = zeros::((32, 32)); + let a = OwnedArray::::zeros((32, 32)); bench.iter(|| a.to_owned()); } #[bench] fn bench_to_owned_t(bench: &mut test::Bencher) { - let mut a = zeros::((32, 32)); + let mut a = OwnedArray::::zeros((32, 32)); a.swap_axes(0, 1); bench.iter(|| a.to_owned()); } @@ -535,13 +533,31 @@ fn equality_f32(bench: &mut test::Bencher) } #[bench] -fn dot(bench: &mut test::Bencher) +fn dot_f32_16(bench: &mut test::Bencher) +{ + let a = OwnedArray::::zeros(16); + let b = OwnedArray::::zeros(16); + bench.iter(|| a.dot(&b)); +} + +#[bench] +fn dot_f32_256(bench: &mut test::Bencher) { let a = OwnedArray::::zeros(256); let b = OwnedArray::::zeros(256); bench.iter(|| a.dot(&b)); } +#[bench] +fn dot_f32_1024(bench: &mut test::Bencher) +{ + let av = OwnedArray::::zeros(1024); + let bv = OwnedArray::::zeros(1024); + bench.iter(|| { + av.dot(&bv) + }); +} + #[bench] fn means(bench: &mut test::Bencher) { let a = OwnedArray::from_iter(0..100_000i64); diff --git a/src/blas.rs b/src/blas.rs index 4fe20d2c2..ef8db37e7 100644 --- a/src/blas.rs +++ b/src/blas.rs @@ -48,26 +48,18 @@ //! I know), instead output its own error conditions, for example on dimension //! mismatch in a matrix multiplication. //! -extern crate rblas; use std::os::raw::{c_int}; -use self::rblas::{ +use rblas::{ Matrix, Vector, }; use super::{ - ArrayBase, - ArrayView, - ArrayViewMut, - Ix, Ixs, ShapeError, - Data, - DataMut, - DataOwned, - Dimension, zipsl, }; +use imp_prelude::*; /// ***Requires `features = "rblas"`*** @@ -108,22 +100,23 @@ impl ArrayBase } } -impl<'a, A, D> ArrayView<'a, A, D> +impl<'a, A, D> Priv> where D: Dimension { - fn into_matrix(self) -> Result, ShapeError> { - if self.dim.ndim() > 1 { - try!(self.contiguous_check()); + pub fn into_blas_view(self) -> Result, ShapeError> { + let self_ = self.0; + if self_.dim.ndim() > 1 { + try!(self_.contiguous_check()); } - try!(self.size_check()); - Ok(BlasArrayView(self)) + try!(self_.size_check()); + Ok(BlasArrayView(self_)) } } impl<'a, A, D> ArrayViewMut<'a, A, D> where D: Dimension { - fn into_matrix_mut(self) -> Result, ShapeError> { + fn into_blas_view_mut(self) -> Result, ShapeError> { if self.dim.ndim() > 1 { try!(self.contiguous_check()); } @@ -241,19 +234,19 @@ impl AsBlas for ArrayBase } _n => self.ensure_standard_layout(), } - self.view_mut().into_matrix_mut() + self.view_mut().into_blas_view_mut() } fn blas_view_checked(&self) -> Result, ShapeError> where S: Data { - self.view().into_matrix() + Priv(self.view()).into_blas_view() } fn blas_view_mut_checked(&mut self) -> Result, ShapeError> where S: DataMut, { - self.view_mut().into_matrix_mut() + self.view_mut().into_blas_view_mut() } /* diff --git a/src/lib.rs b/src/lib.rs index bde34026b..abeb7caad 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -56,6 +56,9 @@ extern crate serde; #[cfg(feature = "rustc-serialize")] extern crate rustc_serialize as serialize; +#[cfg(feature = "rblas")] +extern crate rblas; + extern crate itertools; extern crate num as libnum; @@ -115,6 +118,25 @@ mod si; mod shape_error; mod stride_error; +/// Implementation's prelude. Common types used everywhere. +mod imp_prelude { + pub use { + ArrayBase, + ArrayView, + ArrayViewMut, + OwnedArray, + RcArray, + Ix, Ixs, + Dimension, + Data, + DataMut, + DataOwned, + }; + /// Wrapper type for private methods + #[derive(Copy, Clone, Debug)] + pub struct Priv(pub T); +} + // NOTE: In theory, the whole library should compile // and pass tests even if you change Ix and Ixs. /// Array index type @@ -1989,7 +2011,10 @@ impl ArrayBase where S: Data, D: Dimension } } +/// ***Deprecated: Use `ArrayBase::zeros` instead.*** +/// /// Return an array filled with zeros +#[cfg_attr(has_deprecated, deprecated(note="Use `ArrayBase::zeros` instead."))] pub fn zeros(dim: D) -> OwnedArray where A: Clone + libnum::Zero, D: Dimension, { @@ -2249,7 +2274,6 @@ impl ArrayBase /// /// /// **Panics** if `axis` is out of bounds. - #[allow(deprecated)] pub fn mean(&self, axis: usize) -> OwnedArray::Smaller> where A: LinalgScalar, D: RemoveAxis, @@ -2289,6 +2313,13 @@ impl ArrayBase pub fn dot(&self, rhs: &ArrayBase) -> A where S2: Data, A: LinalgScalar, + { + self.dot_impl(rhs) + } + + fn dot_generic(&self, rhs: &ArrayBase) -> A + where S2: Data, + A: LinalgScalar, { assert_eq!(self.len(), rhs.len()); if let Some(self_s) = self.as_slice() { @@ -2304,8 +2335,54 @@ impl ArrayBase } sum } + + #[cfg(not(feature="rblas"))] + fn dot_impl(&self, rhs: &ArrayBase) -> A + where S2: Data, + A: LinalgScalar, + { + self.dot_generic(rhs) + } + + #[cfg(feature="rblas")] + fn dot_impl(&self, rhs: &ArrayBase) -> A + where S2: Data, + A: LinalgScalar, + { + use std::any::{Any, TypeId}; + use rblas::vector::ops::Dot; + use linalg::AsBlasAny; + + // Read pointer to type `A` as type `B`. + // + // **Panics** if `A` and `B` are not the same type + fn cast_as(a: &A) -> B { + assert_eq!(TypeId::of::(), TypeId::of::()); + unsafe { + ::std::ptr::read(a as *const _ as *const B) + } + } + // Use only if the vector is large enough to be worth it + if self.len() >= 32 { + assert_eq!(self.len(), rhs.len()); + if let Ok(self_v) = self.blas_view_as_type::() { + if let Ok(rhs_v) = rhs.blas_view_as_type::() { + let f_ret = f32::dot(&self_v, &rhs_v); + return cast_as::(&f_ret); + } + } + if let Ok(self_v) = self.blas_view_as_type::() { + if let Ok(rhs_v) = rhs.blas_view_as_type::() { + let f_ret = f64::dot(&self_v, &rhs_v); + return cast_as::(&f_ret); + } + } + } + self.dot_generic(rhs) + } } + impl ArrayBase where S: Data, { @@ -2366,7 +2443,6 @@ impl ArrayBase /// ); /// ``` /// - #[allow(deprecated)] pub fn mat_mul(&self, rhs: &ArrayBase) -> OwnedArray where A: LinalgScalar, { @@ -2411,7 +2487,6 @@ impl ArrayBase /// Return a result array with shape *M*. /// /// **Panics** if shapes are incompatible. - #[allow(deprecated)] pub fn mat_mul_col(&self, rhs: &ArrayBase) -> OwnedArray where A: LinalgScalar, { @@ -2894,3 +2969,4 @@ enum ElementsRepr { Slice(S), Counted(C), } + diff --git a/src/linalg.rs b/src/linalg.rs index a183f833a..ae979c29b 100644 --- a/src/linalg.rs +++ b/src/linalg.rs @@ -1,7 +1,19 @@ -use libnum::{Zero, One, Float}; +use libnum::{Zero, One}; use std::ops::{Add, Sub, Mul, Div}; use std::any::Any; +#[cfg(feature="rblas")] +use std::any::TypeId; + +#[cfg(feature="rblas")] +use ShapeError; + +#[cfg(feature="rblas")] +use blas::{AsBlas, BlasArrayView}; + +#[cfg(feature="rblas")] +use imp_prelude::*; + /// Trait union for scalars (array elements) that support linear algebra operations. /// /// `Any` for type-based specialization, `Copy` so that they don't need move @@ -26,3 +38,31 @@ impl LinalgScalar for T Mul + Div { } + +#[cfg(feature = "rblas")] +pub trait AsBlasAny : AsBlas { + fn blas_view_as_type(&self) -> Result, ShapeError> + where S: Data; +} + +#[cfg(feature = "rblas")] +/// ***Requires `features = "rblas"`*** +impl AsBlasAny for ArrayBase + where S: Data, + D: Dimension, + A: Any, +{ + fn blas_view_as_type(&self) -> Result, ShapeError> + where S: Data + { + if TypeId::of::() == TypeId::of::() { + unsafe { + let v = self.view(); + let u = ArrayView::new_(v.ptr as *const T, v.dim, v.strides); + Priv(u).into_blas_view() + } + } else { + Err(ShapeError::IncompatibleLayout) + } + } +} diff --git a/src/numeric_util.rs b/src/numeric_util.rs index 142bfa225..7752fa1cb 100644 --- a/src/numeric_util.rs +++ b/src/numeric_util.rs @@ -3,9 +3,10 @@ use libnum; use std::cmp; use std::ops::{ Add, - Mul, }; +use linalg::LinalgScalar; + /// Compute the sum of the values in `xs` pub fn unrolled_sum(mut xs: &[A]) -> A where A: Clone + Add + libnum::Zero, @@ -44,7 +45,7 @@ pub fn unrolled_sum(mut xs: &[A]) -> A /// /// `xs` and `ys` must be the same length pub fn unrolled_dot(xs: &[A], ys: &[A]) -> A - where A: Clone + Add + Mul + libnum::Zero, + where A: LinalgScalar, { debug_assert_eq!(xs.len(), ys.len()); // eightfold unrolled so that floating point can be vectorized @@ -58,24 +59,24 @@ pub fn unrolled_dot(xs: &[A], ys: &[A]) -> A (A::zero(), A::zero(), A::zero(), A::zero(), A::zero(), A::zero(), A::zero(), A::zero()); while xs.len() >= 8 { - p0 = p0 + xs[0].clone() * ys[0].clone(); - p1 = p1 + xs[1].clone() * ys[1].clone(); - p2 = p2 + xs[2].clone() * ys[2].clone(); - p3 = p3 + xs[3].clone() * ys[3].clone(); - p4 = p4 + xs[4].clone() * ys[4].clone(); - p5 = p5 + xs[5].clone() * ys[5].clone(); - p6 = p6 + xs[6].clone() * ys[6].clone(); - p7 = p7 + xs[7].clone() * ys[7].clone(); + p0 = p0 + xs[0] * ys[0]; + p1 = p1 + xs[1] * ys[1]; + p2 = p2 + xs[2] * ys[2]; + p3 = p3 + xs[3] * ys[3]; + p4 = p4 + xs[4] * ys[4]; + p5 = p5 + xs[5] * ys[5]; + p6 = p6 + xs[6] * ys[6]; + p7 = p7 + xs[7] * ys[7]; xs = &xs[8..]; ys = &ys[8..]; } - sum = sum.clone() + (p0 + p4); - sum = sum.clone() + (p1 + p5); - sum = sum.clone() + (p2 + p6); - sum = sum.clone() + (p3 + p7); + sum = sum + (p0 + p4); + sum = sum + (p1 + p5); + sum = sum + (p2 + p6); + sum = sum + (p3 + p7); for i in 0..xs.len() { - sum = sum.clone() + xs[i].clone() * ys[i].clone(); + sum = sum + xs[i] * ys[i]; } sum } diff --git a/tests/oper.rs b/tests/oper.rs index 214c1399b..76fc3e648 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -3,6 +3,9 @@ extern crate num as libnum; use ndarray::RcArray; use ndarray::{arr0, rcarr1, rcarr2}; +use ndarray::{ + OwnedArray, +}; use std::fmt; use libnum::Float; @@ -108,3 +111,23 @@ fn scalar_operations() assert_eq!(x, y); } } + +fn assert_approx_eq(f: F, g: F, tol: F) -> bool { + assert!((f - g).abs() <= tol, "{:?} approx== {:?} (tol={:?})", + f, g, tol); + true +} + +#[test] +fn dot_product() { + let a = OwnedArray::linspace(0., 63., 64); + let b = OwnedArray::linspace(0., 63., 64); + let dot = 85344.; + assert_approx_eq(a.dot(&b), dot, 1e-5); + let a = a.map(|f| *f as f32); + let b = a.map(|f| *f as f32); + assert_approx_eq(a.dot(&b), dot as f32, 1e-5); + let a = a.map(|f| *f as i32); + let b = a.map(|f| *f as i32); + assert_eq!(a.dot(&b), dot as i32); +}