Skip to content

Commit 9bae4eb

Browse files
committed
Merge pull request #92 from bluss/specialize-dot
Use BLAS acceleration in .dot() when possible
2 parents a08b1fd + 86b1673 commit 9bae4eb

File tree

6 files changed

+193
-44
lines changed

6 files changed

+193
-44
lines changed

benches/bench1.rs

+21-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use rblas::matrix::Matrix;
1212

1313
use ndarray::{
1414
OwnedArray,
15-
zeros,
1615
};
1716
use ndarray::{arr0, arr1, arr2};
1817

@@ -398,7 +397,6 @@ fn muladd_2d_f32_blas(bench: &mut test::Bencher)
398397
});
399398
}
400399

401-
402400
#[bench]
403401
fn assign_scalar_2d_large(bench: &mut test::Bencher)
404402
{
@@ -506,14 +504,14 @@ fn create_iter_4d(bench: &mut test::Bencher)
506504
#[bench]
507505
fn bench_to_owned_n(bench: &mut test::Bencher)
508506
{
509-
let a = zeros::<f32, _>((32, 32));
507+
let a = OwnedArray::<f32, _>::zeros((32, 32));
510508
bench.iter(|| a.to_owned());
511509
}
512510

513511
#[bench]
514512
fn bench_to_owned_t(bench: &mut test::Bencher)
515513
{
516-
let mut a = zeros::<f32, _>((32, 32));
514+
let mut a = OwnedArray::<f32, _>::zeros((32, 32));
517515
a.swap_axes(0, 1);
518516
bench.iter(|| a.to_owned());
519517
}
@@ -535,13 +533,31 @@ fn equality_f32(bench: &mut test::Bencher)
535533
}
536534

537535
#[bench]
538-
fn dot(bench: &mut test::Bencher)
536+
fn dot_f32_16(bench: &mut test::Bencher)
537+
{
538+
let a = OwnedArray::<f32, _>::zeros(16);
539+
let b = OwnedArray::<f32, _>::zeros(16);
540+
bench.iter(|| a.dot(&b));
541+
}
542+
543+
#[bench]
544+
fn dot_f32_256(bench: &mut test::Bencher)
539545
{
540546
let a = OwnedArray::<f32, _>::zeros(256);
541547
let b = OwnedArray::<f32, _>::zeros(256);
542548
bench.iter(|| a.dot(&b));
543549
}
544550

551+
#[bench]
552+
fn dot_f32_1024(bench: &mut test::Bencher)
553+
{
554+
let av = OwnedArray::<f32, _>::zeros(1024);
555+
let bv = OwnedArray::<f32, _>::zeros(1024);
556+
bench.iter(|| {
557+
av.dot(&bv)
558+
});
559+
}
560+
545561
#[bench]
546562
fn means(bench: &mut test::Bencher) {
547563
let a = OwnedArray::from_iter(0..100_000i64);

src/blas.rs

+13-20
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,18 @@
4848
//! I know), instead output its own error conditions, for example on dimension
4949
//! mismatch in a matrix multiplication.
5050
//!
51-
extern crate rblas;
5251
5352
use std::os::raw::{c_int};
5453

55-
use self::rblas::{
54+
use rblas::{
5655
Matrix,
5756
Vector,
5857
};
5958
use super::{
60-
ArrayBase,
61-
ArrayView,
62-
ArrayViewMut,
63-
Ix, Ixs,
6459
ShapeError,
65-
Data,
66-
DataMut,
67-
DataOwned,
68-
Dimension,
6960
zipsl,
7061
};
62+
use imp_prelude::*;
7163

7264

7365
/// ***Requires crate feature `"rblas"`***
@@ -108,22 +100,23 @@ impl<S, D> ArrayBase<S, D>
108100
}
109101
}
110102

111-
impl<'a, A, D> ArrayView<'a, A, D>
103+
impl<'a, A, D> Priv<ArrayView<'a, A, D>>
112104
where D: Dimension
113105
{
114-
fn into_matrix(self) -> Result<BlasArrayView<'a, A, D>, ShapeError> {
115-
if self.dim.ndim() > 1 {
116-
try!(self.contiguous_check());
106+
pub fn into_blas_view(self) -> Result<BlasArrayView<'a, A, D>, ShapeError> {
107+
let self_ = self.0;
108+
if self_.dim.ndim() > 1 {
109+
try!(self_.contiguous_check());
117110
}
118-
try!(self.size_check());
119-
Ok(BlasArrayView(self))
111+
try!(self_.size_check());
112+
Ok(BlasArrayView(self_))
120113
}
121114
}
122115

123116
impl<'a, A, D> ArrayViewMut<'a, A, D>
124117
where D: Dimension
125118
{
126-
fn into_matrix_mut(self) -> Result<BlasArrayViewMut<'a, A, D>, ShapeError> {
119+
fn into_blas_view_mut(self) -> Result<BlasArrayViewMut<'a, A, D>, ShapeError> {
127120
if self.dim.ndim() > 1 {
128121
try!(self.contiguous_check());
129122
}
@@ -241,19 +234,19 @@ impl<A, S, D> AsBlas<A, S, D> for ArrayBase<S, D>
241234
}
242235
_n => self.ensure_standard_layout(),
243236
}
244-
self.view_mut().into_matrix_mut()
237+
self.view_mut().into_blas_view_mut()
245238
}
246239

247240
fn blas_view_checked(&self) -> Result<BlasArrayView<A, D>, ShapeError>
248241
where S: Data
249242
{
250-
self.view().into_matrix()
243+
Priv(self.view()).into_blas_view()
251244
}
252245

253246
fn blas_view_mut_checked(&mut self) -> Result<BlasArrayViewMut<A, D>, ShapeError>
254247
where S: DataMut,
255248
{
256-
self.view_mut().into_matrix_mut()
249+
self.view_mut().into_blas_view_mut()
257250
}
258251

259252
/*

src/lib.rs

+79-3
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ extern crate serde;
5959
#[cfg(feature = "rustc-serialize")]
6060
extern crate rustc_serialize as serialize;
6161

62+
#[cfg(feature = "rblas")]
63+
extern crate rblas;
64+
6265
extern crate itertools;
6366
extern crate num as libnum;
6467

@@ -118,6 +121,25 @@ mod si;
118121
mod shape_error;
119122
mod stride_error;
120123

124+
/// Implementation's prelude. Common types used everywhere.
125+
mod imp_prelude {
126+
pub use {
127+
ArrayBase,
128+
ArrayView,
129+
ArrayViewMut,
130+
OwnedArray,
131+
RcArray,
132+
Ix, Ixs,
133+
Dimension,
134+
Data,
135+
DataMut,
136+
DataOwned,
137+
};
138+
/// Wrapper type for private methods
139+
#[derive(Copy, Clone, Debug)]
140+
pub struct Priv<T>(pub T);
141+
}
142+
121143
// NOTE: In theory, the whole library should compile
122144
// and pass tests even if you change Ix and Ixs.
123145
/// Array index type
@@ -2004,7 +2026,10 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
20042026
}
20052027
}
20062028

2029+
/// ***Deprecated: Use `ArrayBase::zeros` instead.***
2030+
///
20072031
/// Return an array filled with zeros
2032+
#[cfg_attr(has_deprecated, deprecated(note="Use `ArrayBase::zeros` instead."))]
20082033
pub fn zeros<A, D>(dim: D) -> OwnedArray<A, D>
20092034
where A: Clone + libnum::Zero, D: Dimension,
20102035
{
@@ -2264,7 +2289,6 @@ impl<A, S, D> ArrayBase<S, D>
22642289
///
22652290
///
22662291
/// **Panics** if `axis` is out of bounds.
2267-
#[allow(deprecated)]
22682292
pub fn mean(&self, axis: usize) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
22692293
where A: LinalgScalar,
22702294
D: RemoveAxis,
@@ -2304,6 +2328,13 @@ impl<A, S> ArrayBase<S, Ix>
23042328
pub fn dot<S2>(&self, rhs: &ArrayBase<S2, Ix>) -> A
23052329
where S2: Data<Elem=A>,
23062330
A: LinalgScalar,
2331+
{
2332+
self.dot_impl(rhs)
2333+
}
2334+
2335+
fn dot_generic<S2>(&self, rhs: &ArrayBase<S2, Ix>) -> A
2336+
where S2: Data<Elem=A>,
2337+
A: LinalgScalar,
23072338
{
23082339
assert_eq!(self.len(), rhs.len());
23092340
if let Some(self_s) = self.as_slice() {
@@ -2319,8 +2350,54 @@ impl<A, S> ArrayBase<S, Ix>
23192350
}
23202351
sum
23212352
}
2353+
2354+
#[cfg(not(feature="rblas"))]
2355+
fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix>) -> A
2356+
where S2: Data<Elem=A>,
2357+
A: LinalgScalar,
2358+
{
2359+
self.dot_generic(rhs)
2360+
}
2361+
2362+
#[cfg(feature="rblas")]
2363+
fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix>) -> A
2364+
where S2: Data<Elem=A>,
2365+
A: LinalgScalar,
2366+
{
2367+
use std::any::{Any, TypeId};
2368+
use rblas::vector::ops::Dot;
2369+
use linalg::AsBlasAny;
2370+
2371+
// Read pointer to type `A` as type `B`.
2372+
//
2373+
// **Panics** if `A` and `B` are not the same type
2374+
fn cast_as<A: Any + Copy, B: Any + Copy>(a: &A) -> B {
2375+
assert_eq!(TypeId::of::<A>(), TypeId::of::<B>());
2376+
unsafe {
2377+
::std::ptr::read(a as *const _ as *const B)
2378+
}
2379+
}
2380+
// Use only if the vector is large enough to be worth it
2381+
if self.len() >= 32 {
2382+
assert_eq!(self.len(), rhs.len());
2383+
if let Ok(self_v) = self.blas_view_as_type::<f32>() {
2384+
if let Ok(rhs_v) = rhs.blas_view_as_type::<f32>() {
2385+
let f_ret = f32::dot(&self_v, &rhs_v);
2386+
return cast_as::<f32, A>(&f_ret);
2387+
}
2388+
}
2389+
if let Ok(self_v) = self.blas_view_as_type::<f64>() {
2390+
if let Ok(rhs_v) = rhs.blas_view_as_type::<f64>() {
2391+
let f_ret = f64::dot(&self_v, &rhs_v);
2392+
return cast_as::<f64, A>(&f_ret);
2393+
}
2394+
}
2395+
}
2396+
self.dot_generic(rhs)
2397+
}
23222398
}
23232399

2400+
23242401
impl<A, S> ArrayBase<S, (Ix, Ix)>
23252402
where S: Data<Elem=A>,
23262403
{
@@ -2381,7 +2458,6 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
23812458
/// );
23822459
/// ```
23832460
///
2384-
#[allow(deprecated)]
23852461
pub fn mat_mul(&self, rhs: &ArrayBase<S, (Ix, Ix)>) -> OwnedArray<A, (Ix, Ix)>
23862462
where A: LinalgScalar,
23872463
{
@@ -2426,7 +2502,6 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
24262502
/// Return a result array with shape *M*.
24272503
///
24282504
/// **Panics** if shapes are incompatible.
2429-
#[allow(deprecated)]
24302505
pub fn mat_mul_col(&self, rhs: &ArrayBase<S, Ix>) -> OwnedArray<A, Ix>
24312506
where A: LinalgScalar,
24322507
{
@@ -2913,3 +2988,4 @@ enum ElementsRepr<S, C> {
29132988
Slice(S),
29142989
Counted(C),
29152990
}
2991+

src/linalg.rs

+41-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
1-
use libnum::{Zero, One, Float};
1+
use libnum::{Zero, One};
22
use std::ops::{Add, Sub, Mul, Div};
33
use std::any::Any;
44

5+
#[cfg(feature="rblas")]
6+
use std::any::TypeId;
7+
8+
#[cfg(feature="rblas")]
9+
use ShapeError;
10+
11+
#[cfg(feature="rblas")]
12+
use blas::{AsBlas, BlasArrayView};
13+
14+
#[cfg(feature="rblas")]
15+
use imp_prelude::*;
16+
517
/// Trait union for scalars (array elements) that support linear algebra operations.
618
///
719
/// `Any` for type-based specialization, `Copy` so that they don't need move
@@ -26,3 +38,31 @@ impl<T> LinalgScalar for T
2638
Mul<Output=T> +
2739
Div<Output=T>
2840
{ }
41+
42+
#[cfg(feature = "rblas")]
43+
pub trait AsBlasAny<A, S, D> : AsBlas<A, S, D> {
44+
fn blas_view_as_type<T: Any>(&self) -> Result<BlasArrayView<T, D>, ShapeError>
45+
where S: Data;
46+
}
47+
48+
#[cfg(feature = "rblas")]
49+
/// ***Requires `features = "rblas"`***
50+
impl<A, S, D> AsBlasAny<A, S, D> for ArrayBase<S, D>
51+
where S: Data<Elem=A>,
52+
D: Dimension,
53+
A: Any,
54+
{
55+
fn blas_view_as_type<T: Any>(&self) -> Result<BlasArrayView<T, D>, ShapeError>
56+
where S: Data
57+
{
58+
if TypeId::of::<A>() == TypeId::of::<T>() {
59+
unsafe {
60+
let v = self.view();
61+
let u = ArrayView::new_(v.ptr as *const T, v.dim, v.strides);
62+
Priv(u).into_blas_view()
63+
}
64+
} else {
65+
Err(ShapeError::IncompatibleLayout)
66+
}
67+
}
68+
}

0 commit comments

Comments
 (0)