Skip to content

Commit 74b2bca

Browse files
committed
Use BLAS in dot only for "long" vectors.
For some measure of long, seems like the smallest vectors benefit from using the plain generic dot product (32 elements or smaller).
1 parent c9f2937 commit 74b2bca

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

benches/bench1.rs

+19-12
Original file line numberDiff line numberDiff line change
@@ -397,17 +397,6 @@ fn muladd_2d_f32_blas(bench: &mut test::Bencher)
397397
});
398398
}
399399

400-
#[bench]
401-
fn dot_f32_regular(bench: &mut test::Bencher)
402-
{
403-
let av = OwnedArray::<f32, _>::zeros(1024);
404-
let bv = OwnedArray::<f32, _>::zeros(1024);
405-
bench.iter(|| {
406-
av.dot(&bv)
407-
});
408-
}
409-
410-
411400
#[bench]
412401
fn assign_scalar_2d_large(bench: &mut test::Bencher)
413402
{
@@ -544,13 +533,31 @@ fn equality_f32(bench: &mut test::Bencher)
544533
}
545534

546535
#[bench]
547-
fn dot(bench: &mut test::Bencher)
536+
fn dot_f32_16(bench: &mut test::Bencher)
537+
{
538+
let a = OwnedArray::<f32, _>::zeros(16);
539+
let b = OwnedArray::<f32, _>::zeros(16);
540+
bench.iter(|| a.dot(&b));
541+
}
542+
543+
#[bench]
544+
fn dot_f32_256(bench: &mut test::Bencher)
548545
{
549546
let a = OwnedArray::<f32, _>::zeros(256);
550547
let b = OwnedArray::<f32, _>::zeros(256);
551548
bench.iter(|| a.dot(&b));
552549
}
553550

551+
#[bench]
552+
fn dot_f32_1024(bench: &mut test::Bencher)
553+
{
554+
let av = OwnedArray::<f32, _>::zeros(1024);
555+
let bv = OwnedArray::<f32, _>::zeros(1024);
556+
bench.iter(|| {
557+
av.dot(&bv)
558+
});
559+
}
560+
554561
#[bench]
555562
fn means(bench: &mut test::Bencher) {
556563
let a = OwnedArray::from_iter(0..100_000i64);

src/lib.rs

+13-10
Original file line numberDiff line numberDiff line change
@@ -2361,17 +2361,20 @@ impl<A, S> ArrayBase<S, Ix>
23612361
::std::ptr::read(a as *const _ as *const B)
23622362
}
23632363
}
2364-
assert_eq!(self.len(), rhs.len());
2365-
if let Ok(self_v) = self.blas_view_as_type::<f32>() {
2366-
if let Ok(rhs_v) = rhs.blas_view_as_type::<f32>() {
2367-
let f_ret = f32::dot(&self_v, &rhs_v);
2368-
return cast_as::<f32, A>(&f_ret);
2364+
// Use only if the vector is large enough to be worth it
2365+
if self.len() >= 32 {
2366+
assert_eq!(self.len(), rhs.len());
2367+
if let Ok(self_v) = self.blas_view_as_type::<f32>() {
2368+
if let Ok(rhs_v) = rhs.blas_view_as_type::<f32>() {
2369+
let f_ret = f32::dot(&self_v, &rhs_v);
2370+
return cast_as::<f32, A>(&f_ret);
2371+
}
23692372
}
2370-
}
2371-
if let Ok(self_v) = self.blas_view_as_type::<f64>() {
2372-
if let Ok(rhs_v) = rhs.blas_view_as_type::<f64>() {
2373-
let f_ret = f64::dot(&self_v, &rhs_v);
2374-
return cast_as::<f64, A>(&f_ret);
2373+
if let Ok(self_v) = self.blas_view_as_type::<f64>() {
2374+
if let Ok(rhs_v) = rhs.blas_view_as_type::<f64>() {
2375+
let f_ret = f64::dot(&self_v, &rhs_v);
2376+
return cast_as::<f64, A>(&f_ret);
2377+
}
23752378
}
23762379
}
23772380
self.dot_generic(rhs)

0 commit comments

Comments
 (0)