Skip to content

Use BLAS acceleration in .dot() when possible #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 28, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions benches/bench1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use rblas::matrix::Matrix;

use ndarray::{
OwnedArray,
zeros,
};
use ndarray::{arr0, arr1, arr2};

Expand Down Expand Up @@ -398,7 +397,6 @@ fn muladd_2d_f32_blas(bench: &mut test::Bencher)
});
}


#[bench]
fn assign_scalar_2d_large(bench: &mut test::Bencher)
{
Expand Down Expand Up @@ -506,14 +504,14 @@ fn create_iter_4d(bench: &mut test::Bencher)
#[bench]
fn bench_to_owned_n(bench: &mut test::Bencher)
{
let a = zeros::<f32, _>((32, 32));
let a = OwnedArray::<f32, _>::zeros((32, 32));
bench.iter(|| a.to_owned());
}

#[bench]
fn bench_to_owned_t(bench: &mut test::Bencher)
{
let mut a = zeros::<f32, _>((32, 32));
let mut a = OwnedArray::<f32, _>::zeros((32, 32));
a.swap_axes(0, 1);
bench.iter(|| a.to_owned());
}
Expand All @@ -535,13 +533,31 @@ fn equality_f32(bench: &mut test::Bencher)
}

#[bench]
fn dot(bench: &mut test::Bencher)
fn dot_f32_16(bench: &mut test::Bencher)
{
let a = OwnedArray::<f32, _>::zeros(16);
let b = OwnedArray::<f32, _>::zeros(16);
bench.iter(|| a.dot(&b));
}

#[bench]
fn dot_f32_256(bench: &mut test::Bencher)
{
let a = OwnedArray::<f32, _>::zeros(256);
let b = OwnedArray::<f32, _>::zeros(256);
bench.iter(|| a.dot(&b));
}

#[bench]
fn dot_f32_1024(bench: &mut test::Bencher)
{
let av = OwnedArray::<f32, _>::zeros(1024);
let bv = OwnedArray::<f32, _>::zeros(1024);
bench.iter(|| {
av.dot(&bv)
});
}

#[bench]
fn means(bench: &mut test::Bencher) {
let a = OwnedArray::from_iter(0..100_000i64);
Expand Down
33 changes: 13 additions & 20 deletions src/blas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,18 @@
//! I know), instead output its own error conditions, for example on dimension
//! mismatch in a matrix multiplication.
//!
extern crate rblas;

use std::os::raw::{c_int};

use self::rblas::{
use rblas::{
Matrix,
Vector,
};
use super::{
ArrayBase,
ArrayView,
ArrayViewMut,
Ix, Ixs,
ShapeError,
Data,
DataMut,
DataOwned,
Dimension,
zipsl,
};
use imp_prelude::*;


/// ***Requires `features = "rblas"`***
Expand Down Expand Up @@ -108,22 +100,23 @@ impl<S, D> ArrayBase<S, D>
}
}

impl<'a, A, D> ArrayView<'a, A, D>
impl<'a, A, D> Priv<ArrayView<'a, A, D>>
where D: Dimension
{
fn into_matrix(self) -> Result<BlasArrayView<'a, A, D>, ShapeError> {
if self.dim.ndim() > 1 {
try!(self.contiguous_check());
pub fn into_blas_view(self) -> Result<BlasArrayView<'a, A, D>, ShapeError> {
let self_ = self.0;
if self_.dim.ndim() > 1 {
try!(self_.contiguous_check());
}
try!(self.size_check());
Ok(BlasArrayView(self))
try!(self_.size_check());
Ok(BlasArrayView(self_))
}
}

impl<'a, A, D> ArrayViewMut<'a, A, D>
where D: Dimension
{
fn into_matrix_mut(self) -> Result<BlasArrayViewMut<'a, A, D>, ShapeError> {
fn into_blas_view_mut(self) -> Result<BlasArrayViewMut<'a, A, D>, ShapeError> {
if self.dim.ndim() > 1 {
try!(self.contiguous_check());
}
Expand Down Expand Up @@ -241,19 +234,19 @@ impl<A, S, D> AsBlas<A, S, D> for ArrayBase<S, D>
}
_n => self.ensure_standard_layout(),
}
self.view_mut().into_matrix_mut()
self.view_mut().into_blas_view_mut()
}

fn blas_view_checked(&self) -> Result<BlasArrayView<A, D>, ShapeError>
where S: Data
{
self.view().into_matrix()
Priv(self.view()).into_blas_view()
}

fn blas_view_mut_checked(&mut self) -> Result<BlasArrayViewMut<A, D>, ShapeError>
where S: DataMut,
{
self.view_mut().into_matrix_mut()
self.view_mut().into_blas_view_mut()
}

/*
Expand Down
82 changes: 79 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ extern crate serde;
#[cfg(feature = "rustc-serialize")]
extern crate rustc_serialize as serialize;

#[cfg(feature = "rblas")]
extern crate rblas;

extern crate itertools;
extern crate num as libnum;

Expand Down Expand Up @@ -115,6 +118,25 @@ mod si;
mod shape_error;
mod stride_error;

/// Implementation's prelude. Common types used everywhere.
mod imp_prelude {
pub use {
ArrayBase,
ArrayView,
ArrayViewMut,
OwnedArray,
RcArray,
Ix, Ixs,
Dimension,
Data,
DataMut,
DataOwned,
};
/// Wrapper type for private methods
#[derive(Copy, Clone, Debug)]
pub struct Priv<T>(pub T);
}

// NOTE: In theory, the whole library should compile
// and pass tests even if you change Ix and Ixs.
/// Array index type
Expand Down Expand Up @@ -1989,7 +2011,10 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
}
}

/// ***Deprecated: Use `ArrayBase::zeros` instead.***
///
/// Return an array filled with zeros
#[cfg_attr(has_deprecated, deprecated(note="Use `ArrayBase::zeros` instead."))]
pub fn zeros<A, D>(dim: D) -> OwnedArray<A, D>
where A: Clone + libnum::Zero, D: Dimension,
{
Expand Down Expand Up @@ -2249,7 +2274,6 @@ impl<A, S, D> ArrayBase<S, D>
///
///
/// **Panics** if `axis` is out of bounds.
#[allow(deprecated)]
pub fn mean(&self, axis: usize) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
where A: LinalgScalar,
D: RemoveAxis,
Expand Down Expand Up @@ -2289,6 +2313,13 @@ impl<A, S> ArrayBase<S, Ix>
pub fn dot<S2>(&self, rhs: &ArrayBase<S2, Ix>) -> A
where S2: Data<Elem=A>,
A: LinalgScalar,
{
self.dot_impl(rhs)
}

fn dot_generic<S2>(&self, rhs: &ArrayBase<S2, Ix>) -> A
where S2: Data<Elem=A>,
A: LinalgScalar,
{
assert_eq!(self.len(), rhs.len());
if let Some(self_s) = self.as_slice() {
Expand All @@ -2304,8 +2335,54 @@ impl<A, S> ArrayBase<S, Ix>
}
sum
}

#[cfg(not(feature="rblas"))]
fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix>) -> A
where S2: Data<Elem=A>,
A: LinalgScalar,
{
self.dot_generic(rhs)
}

#[cfg(feature="rblas")]
fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix>) -> A
where S2: Data<Elem=A>,
A: LinalgScalar,
{
use std::any::{Any, TypeId};
use rblas::vector::ops::Dot;
use linalg::AsBlasAny;

// Read pointer to type `A` as type `B`.
//
// **Panics** if `A` and `B` are not the same type
fn cast_as<A: Any + Copy, B: Any + Copy>(a: &A) -> B {
assert_eq!(TypeId::of::<A>(), TypeId::of::<B>());
unsafe {
::std::ptr::read(a as *const _ as *const B)
}
}
// Use only if the vector is large enough to be worth it
if self.len() >= 32 {
assert_eq!(self.len(), rhs.len());
if let Ok(self_v) = self.blas_view_as_type::<f32>() {
if let Ok(rhs_v) = rhs.blas_view_as_type::<f32>() {
let f_ret = f32::dot(&self_v, &rhs_v);
return cast_as::<f32, A>(&f_ret);
}
}
if let Ok(self_v) = self.blas_view_as_type::<f64>() {
if let Ok(rhs_v) = rhs.blas_view_as_type::<f64>() {
let f_ret = f64::dot(&self_v, &rhs_v);
return cast_as::<f64, A>(&f_ret);
}
}
}
self.dot_generic(rhs)
}
}


impl<A, S> ArrayBase<S, (Ix, Ix)>
where S: Data<Elem=A>,
{
Expand Down Expand Up @@ -2366,7 +2443,6 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
/// );
/// ```
///
#[allow(deprecated)]
pub fn mat_mul(&self, rhs: &ArrayBase<S, (Ix, Ix)>) -> OwnedArray<A, (Ix, Ix)>
where A: LinalgScalar,
{
Expand Down Expand Up @@ -2411,7 +2487,6 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
/// Return a result array with shape *M*.
///
/// **Panics** if shapes are incompatible.
#[allow(deprecated)]
pub fn mat_mul_col(&self, rhs: &ArrayBase<S, Ix>) -> OwnedArray<A, Ix>
where A: LinalgScalar,
{
Expand Down Expand Up @@ -2894,3 +2969,4 @@ enum ElementsRepr<S, C> {
Slice(S),
Counted(C),
}

42 changes: 41 additions & 1 deletion src/linalg.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
use libnum::{Zero, One, Float};
use libnum::{Zero, One};
use std::ops::{Add, Sub, Mul, Div};
use std::any::Any;

#[cfg(feature="rblas")]
use std::any::TypeId;

#[cfg(feature="rblas")]
use ShapeError;

#[cfg(feature="rblas")]
use blas::{AsBlas, BlasArrayView};

#[cfg(feature="rblas")]
use imp_prelude::*;

/// Trait union for scalars (array elements) that support linear algebra operations.
///
/// `Any` for type-based specialization, `Copy` so that they don't need move
Expand All @@ -26,3 +38,31 @@ impl<T> LinalgScalar for T
Mul<Output=T> +
Div<Output=T>
{ }

#[cfg(feature = "rblas")]
pub trait AsBlasAny<A, S, D> : AsBlas<A, S, D> {
fn blas_view_as_type<T: Any>(&self) -> Result<BlasArrayView<T, D>, ShapeError>
where S: Data;
}

#[cfg(feature = "rblas")]
/// ***Requires `features = "rblas"`***
impl<A, S, D> AsBlasAny<A, S, D> for ArrayBase<S, D>
where S: Data<Elem=A>,
D: Dimension,
A: Any,
{
fn blas_view_as_type<T: Any>(&self) -> Result<BlasArrayView<T, D>, ShapeError>
where S: Data
{
if TypeId::of::<A>() == TypeId::of::<T>() {
unsafe {
let v = self.view();
let u = ArrayView::new_(v.ptr as *const T, v.dim, v.strides);
Priv(u).into_blas_view()
}
} else {
Err(ShapeError::IncompatibleLayout)
}
}
}
Loading