Skip to content

Commit b660e0f

Browse files
committed
Merge pull request #140 from bluss/no-order
Add .as_slice_memory_order(), improve scalar_sum, and fix bugs in from_vec_dim_stride
2 parents cc415cc + d363b0a commit b660e0f

File tree

9 files changed

+367
-87
lines changed

9 files changed

+367
-87
lines changed

benches/bench1.rs

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ fn sum_1d_raw(bench: &mut test::Bencher)
4848
let a = black_box(a);
4949
bench.iter(|| {
5050
let mut sum = 0;
51-
for &elt in a.raw_data() {
51+
for &elt in a.as_slice_memory_order().unwrap() {
5252
sum += elt;
5353
}
5454
sum
@@ -93,7 +93,7 @@ fn sum_2d_raw(bench: &mut test::Bencher)
9393
let a = black_box(a);
9494
bench.iter(|| {
9595
let mut sum = 0;
96-
for &elt in a.raw_data() {
96+
for &elt in a.as_slice_memory_order().unwrap() {
9797
sum += elt;
9898
}
9999
sum
@@ -373,12 +373,12 @@ fn muladd_2d_f32_blas(bench: &mut test::Bencher)
373373
}
374374

375375
#[bench]
376-
fn assign_scalar_2d_large(bench: &mut test::Bencher)
376+
fn assign_scalar_2d_corder(bench: &mut test::Bencher)
377377
{
378378
let a = OwnedArray::zeros((64, 64));
379379
let mut a = black_box(a);
380380
let s = 3.;
381-
bench.iter(|| a.assign_scalar(&s))
381+
bench.iter(move || a.assign_scalar(&s))
382382
}
383383

384384
#[bench]
@@ -388,26 +388,43 @@ fn assign_scalar_2d_cutout(bench: &mut test::Bencher)
388388
let a = a.slice_mut(s![1..-1, 1..-1]);
389389
let mut a = black_box(a);
390390
let s = 3.;
391-
bench.iter(|| a.assign_scalar(&s))
391+
bench.iter(move || a.assign_scalar(&s))
392392
}
393393

394394
#[bench]
395-
fn assign_scalar_2d_transposed_large(bench: &mut test::Bencher)
395+
fn assign_scalar_2d_forder(bench: &mut test::Bencher)
396396
{
397397
let mut a = OwnedArray::zeros((64, 64));
398398
a.swap_axes(0, 1);
399399
let mut a = black_box(a);
400400
let s = 3.;
401-
bench.iter(|| a.assign_scalar(&s))
401+
bench.iter(move || a.assign_scalar(&s))
402402
}
403403

404404
#[bench]
405-
fn assign_scalar_2d_raw_large(bench: &mut test::Bencher)
405+
fn assign_zero_2d_corder(bench: &mut test::Bencher)
406406
{
407407
let a = OwnedArray::zeros((64, 64));
408408
let mut a = black_box(a);
409-
let s = 3.;
410-
bench.iter(|| for elt in a.raw_data_mut() { *elt = s; });
409+
bench.iter(|| a.assign_scalar(&0.))
410+
}
411+
412+
#[bench]
413+
fn assign_zero_2d_cutout(bench: &mut test::Bencher)
414+
{
415+
let mut a = OwnedArray::zeros((66, 66));
416+
let a = a.slice_mut(s![1..-1, 1..-1]);
417+
let mut a = black_box(a);
418+
bench.iter(|| a.assign_scalar(&0.))
419+
}
420+
421+
#[bench]
422+
fn assign_zero_2d_forder(bench: &mut test::Bencher)
423+
{
424+
let mut a = OwnedArray::zeros((64, 64));
425+
a.swap_axes(0, 1);
426+
let mut a = black_box(a);
427+
bench.iter(|| a.assign_scalar(&0.))
411428
}
412429

413430
#[bench]

src/dimension.rs

Lines changed: 84 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
use std::cmp::Ordering;
99
use std::fmt::Debug;
1010
use std::slice;
11+
use itertools::free::enumerate;
1112

1213
use super::{Si, Ix, Ixs};
1314
use super::zipsl;
@@ -19,30 +20,6 @@ pub fn stride_offset(n: Ix, stride: Ix) -> isize {
1920
(n as isize) * ((stride as Ixs) as isize)
2021
}
2122

22-
/// Check whether `stride` is strictly positive
23-
#[inline]
24-
fn stride_is_positive(stride: Ix) -> bool {
25-
(stride as Ixs) > 0
26-
}
27-
28-
/// Return the axis ordering corresponding to the fastest variation
29-
///
30-
/// Assumes that no stride value appears twice. This cannot yield the correct
31-
/// result the strides are not positive.
32-
fn fastest_varying_order<D: Dimension>(strides: &D) -> D {
33-
let mut sorted = strides.clone();
34-
sorted.slice_mut().sort();
35-
let mut res = strides.clone();
36-
for (index, &val) in strides.slice().iter().enumerate() {
37-
let sorted_ind = sorted.slice()
38-
.iter()
39-
.position(|&x| x == val)
40-
.unwrap(); // cannot panic by construction
41-
res.slice_mut()[sorted_ind] = index;
42-
}
43-
res
44-
}
45-
4623
/// Check whether the given `dim` and `stride` lead to overlapping indices
4724
///
4825
/// There is overlap if, when iterating through the dimensions in the order
@@ -51,15 +28,19 @@ fn fastest_varying_order<D: Dimension>(strides: &D) -> D {
5128
///
5229
/// The current implementation assumes strides to be positive
5330
pub fn dim_stride_overlap<D: Dimension>(dim: &D, strides: &D) -> bool {
54-
let order = fastest_varying_order(strides);
31+
let order = strides._fastest_varying_stride_order();
5532

33+
let dim = dim.slice();
34+
let strides = strides.slice();
5635
let mut prev_offset = 1;
57-
for &index in order.slice().iter() {
58-
let s = strides.slice()[index];
59-
if (s as isize) < prev_offset {
36+
for &index in order.slice() {
37+
let d = dim[index];
38+
let s = strides[index];
39+
// any stride is ok if dimension is 1
40+
if d != 1 && (s as isize) < prev_offset {
6041
return true;
6142
}
62-
prev_offset = stride_offset(dim.slice()[index], s);
43+
prev_offset = stride_offset(d, s);
6344
}
6445
false
6546
}
@@ -74,33 +55,42 @@ pub fn dim_stride_overlap<D: Dimension>(dim: &D, strides: &D) -> bool {
7455
pub fn can_index_slice<A, D: Dimension>(data: &[A], dim: &D, strides: &D)
7556
-> Result<(), ShapeError>
7657
{
77-
if strides.slice().iter().cloned().all(stride_is_positive) {
78-
if dim.size_checked().is_none() {
79-
return Err(from_kind(ErrorKind::OutOfBounds));
58+
// check lengths of axes.
59+
let len = match dim.size_checked() {
60+
Some(l) => l,
61+
None => return Err(from_kind(ErrorKind::OutOfBounds)),
62+
};
63+
// check if strides are strictly positive (zero ok for len 0)
64+
for &s in strides.slice() {
65+
let s = s as Ixs;
66+
if s < 1 && (len != 0 || s < 0) {
67+
return Err(from_kind(ErrorKind::Unsupported));
8068
}
81-
let mut last_index = dim.clone();
82-
for mut index in last_index.slice_mut().iter_mut() {
83-
*index -= 1;
84-
}
85-
if let Some(offset) = stride_offset_checked_arithmetic(dim,
86-
strides,
87-
&last_index)
88-
{
89-
// offset is guaranteed to be positive so no issue converting
90-
// to usize here
91-
if (offset as usize) >= data.len() {
92-
return Err(from_kind(ErrorKind::OutOfBounds));
93-
}
94-
if dim_stride_overlap(dim, strides) {
95-
return Err(from_kind(ErrorKind::Unsupported));
96-
}
97-
} else {
69+
}
70+
if len == 0 {
71+
return Ok(());
72+
}
73+
// check that the maximum index is in bounds
74+
let mut last_index = dim.clone();
75+
for mut index in last_index.slice_mut().iter_mut() {
76+
*index -= 1;
77+
}
78+
if let Some(offset) = stride_offset_checked_arithmetic(dim,
79+
strides,
80+
&last_index)
81+
{
82+
// offset is guaranteed to be positive so no issue converting
83+
// to usize here
84+
if (offset as usize) >= data.len() {
9885
return Err(from_kind(ErrorKind::OutOfBounds));
9986
}
100-
Ok(())
87+
if dim_stride_overlap(dim, strides) {
88+
return Err(from_kind(ErrorKind::Unsupported));
89+
}
10190
} else {
102-
return Err(from_kind(ErrorKind::Unsupported));
91+
return Err(from_kind(ErrorKind::OutOfBounds));
10392
}
93+
Ok(())
10494
}
10595

10696
/// Return stride offset for this dimension and index.
@@ -335,6 +325,21 @@ pub unsafe trait Dimension : Clone + Eq + Debug + Send + Sync {
335325
offset
336326
}
337327

328+
/// Return the axis ordering corresponding to the fastest variation
329+
/// (in ascending order).
330+
///
331+
/// Assumes that no stride value appears twice. This cannot yield the correct
332+
/// result the strides are not positive.
333+
#[doc(hidden)]
334+
fn _fastest_varying_stride_order(&self) -> Self {
335+
let mut indices = self.clone();
336+
for (i, elt) in enumerate(indices.slice_mut()) {
337+
*elt = i;
338+
}
339+
let strides = self.slice();
340+
indices.slice_mut().sort_by_key(|&i| strides[i]);
341+
indices
342+
}
338343
}
339344

340345
/// Implementation-specific extensions to `Dimension`
@@ -484,6 +489,11 @@ unsafe impl Dimension for (Ix, Ix) {
484489
(self.1, 1)
485490
}
486491

492+
#[inline]
493+
fn _fastest_varying_stride_order(&self) -> Self {
494+
if self.0 as Ixs <= self.1 as Ixs { (0, 1) } else { (1, 0) }
495+
}
496+
487497
#[inline]
488498
fn first_index(&self) -> Option<(Ix, Ix)> {
489499
let (m, n) = *self;
@@ -563,6 +573,29 @@ unsafe impl Dimension for (Ix, Ix, Ix) {
563573
let (s, t, u) = *strides;
564574
stride_offset(i, s) + stride_offset(j, t) + stride_offset(k, u)
565575
}
576+
577+
#[inline]
578+
fn _fastest_varying_stride_order(&self) -> Self {
579+
let mut stride = *self;
580+
let mut order = (0, 1, 2);
581+
macro_rules! swap {
582+
($stride:expr, $order:expr, $x:expr, $y:expr) => {
583+
if $stride[$x] > $stride[$y] {
584+
$stride.swap($x, $y);
585+
$order.swap($x, $y);
586+
}
587+
}
588+
}
589+
{
590+
// stable sorting network for 3 elements
591+
let order = order.slice_mut();
592+
let strides = stride.slice_mut();
593+
swap![strides, order, 1, 2];
594+
swap![strides, order, 0, 1];
595+
swap![strides, order, 1, 2];
596+
}
597+
order
598+
}
566599
}
567600

568601
macro_rules! large_dim {
@@ -742,13 +775,6 @@ mod test {
742775
use super::Dimension;
743776
use error::StrideError;
744777

745-
#[test]
746-
fn fastest_varying_order() {
747-
let strides = (2, 8, 4, 1);
748-
let order = super::fastest_varying_order(&strides);
749-
assert_eq!(order.slice(), &[3, 0, 2, 1]);
750-
}
751-
752778
#[test]
753779
fn slice_indexing_uncommon_strides() {
754780
let v: Vec<_> = (0..12).collect();

src/error.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,13 @@ impl Error for ShapeError {
8787

8888
impl fmt::Display for ShapeError {
8989
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
90-
self.description().fmt(f)
90+
write!(f, "ShapeError/{:?}: {}", self.kind(), self.description())
9191
}
9292
}
9393

9494
impl fmt::Debug for ShapeError {
9595
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96-
write!(f, "ShapeError {:?}: {}", self.kind(), self.description())
96+
write!(f, "ShapeError/{:?}: {}", self.kind(), self.description())
9797
}
9898
}
9999

0 commit comments

Comments
 (0)