Skip to content

Commit 53cd653

Browse files
committed
FEAT: Implement order optimizations for Baseiter
Implement axis merging - this preserves order of elements in the iteration but might simplify iteration. For example, in a contiguous matrix, a shape like [3, 4] can be merged into [1, 12]. Also allow arbitrary order optimization - we then try to iterate in memory order by sorting all axes, currently.
1 parent 6560e3b commit 53cd653

File tree

1 file changed

+217
-7
lines changed

1 file changed

+217
-7
lines changed

src/iterators/mod.rs

Lines changed: 217 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,58 @@ mod windows;
1717
use std::iter::FromIterator;
1818
use std::marker::PhantomData;
1919
use std::ptr;
20+
use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
2021
use alloc::vec::Vec;
2122

23+
use crate::imp_prelude::*;
2224
use crate::Ix1;
2325

24-
use super::{ArrayBase, ArrayView, ArrayViewMut, Axis, Data, NdProducer, RemoveAxis};
25-
use super::{Dimension, Ix, Ixs};
26+
use super::{NdProducer, RemoveAxis};
2627

2728
pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut};
2829
pub use self::lanes::{Lanes, LanesMut};
2930
pub use self::windows::Windows;
3031
pub use self::into_iter::IntoIter;
3132

32-
use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
33+
use crate::dimension;
34+
35+
/// No traversal optmizations that would change element order or axis dimensions are permitted.
36+
///
37+
/// This option is suitable for example for the indexed iterator.
38+
pub(crate) enum NoOptimization { }
39+
40+
/// Preserve element iteration order, but modify dimensions if profitable; for example we can
41+
/// change from shape [10, 1] to [1, 10], because that axis has len == 1, without consequence here.
42+
///
43+
/// This option is suitable for example for the default .iter() iterator.
44+
pub(crate) enum PreserveOrder { }
45+
46+
/// Allow use of arbitrary element iteration order
47+
///
48+
/// This option is suitable for example for an arbitrary order iterator.
49+
pub(crate) enum ArbitraryOrder { }
50+
51+
pub(crate) trait OrderOption {
52+
const ALLOW_REMOVE_REDUNDANT_AXES: bool = false;
53+
const ALLOW_ARBITRARY_ORDER: bool = false;
54+
}
55+
56+
impl OrderOption for NoOptimization { }
57+
58+
impl OrderOption for PreserveOrder {
59+
const ALLOW_REMOVE_REDUNDANT_AXES: bool = true;
60+
}
61+
62+
impl OrderOption for ArbitraryOrder {
63+
const ALLOW_REMOVE_REDUNDANT_AXES: bool = true;
64+
const ALLOW_ARBITRARY_ORDER: bool = true;
65+
}
3366

3467
/// Base for iterators over all axes.
3568
///
3669
/// Iterator element type is `*mut A`.
70+
///
71+
/// `F` is for layout/iteration order flags
3772
pub(crate) struct Baseiter<A, D> {
3873
ptr: *mut A,
3974
dim: D,
@@ -46,12 +81,43 @@ impl<A, D: Dimension> Baseiter<A, D> {
4681
/// to be correct to avoid performing an unsafe pointer offset while
4782
/// iterating.
4883
#[inline]
49-
pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter<A, D> {
84+
pub unsafe fn new(ptr: *mut A, dim: D, strides: D) -> Baseiter<A, D> {
85+
Self::new_with_order::<NoOptimization>(ptr, dim, strides)
86+
}
87+
}
88+
89+
impl<A, D: Dimension> Baseiter<A, D> {
90+
/// Creating a Baseiter is unsafe because shape and stride parameters need
91+
/// to be correct to avoid performing an unsafe pointer offset while
92+
/// iterating.
93+
#[inline]
94+
pub unsafe fn new_with_order<Flags: OrderOption>(mut ptr: *mut A, mut dim: D, mut strides: D)
95+
-> Baseiter<A, D>
96+
{
97+
debug_assert_eq!(dim.ndim(), strides.ndim());
98+
if Flags::ALLOW_ARBITRARY_ORDER {
99+
// iterate in memory order; merge axes if possible
100+
// make all axes positive and put the pointer back to the first element in memory
101+
let offset = dimension::offset_from_low_addr_ptr_to_logical_ptr(&dim, &strides);
102+
ptr = ptr.sub(offset);
103+
for i in 0..strides.ndim() {
104+
let s = strides.get_stride(Axis(i));
105+
if s < 0 {
106+
strides.set_stride(Axis(i), -s);
107+
}
108+
}
109+
dimension::sort_axes_to_standard(&mut dim, &mut strides);
110+
}
111+
if Flags::ALLOW_REMOVE_REDUNDANT_AXES {
112+
// preserve element order but shift dimensions
113+
dimension::merge_axes_from_the_back(&mut dim, &mut strides);
114+
dimension::squeeze(&mut dim, &mut strides);
115+
}
50116
Baseiter {
51117
ptr,
52-
index: len.first_index(),
53-
dim: len,
54-
strides: stride,
118+
index: dim.first_index(),
119+
dim,
120+
strides,
55121
}
56122
}
57123
}
@@ -1499,3 +1565,147 @@ where
14991565
debug_assert_eq!(size, result.len());
15001566
result
15011567
}
1568+
1569+
#[cfg(test)]
1570+
#[cfg(feature = "std")]
1571+
mod tests {
1572+
use crate::prelude::*;
1573+
use super::Baseiter;
1574+
use super::{ArbitraryOrder, PreserveOrder, NoOptimization};
1575+
use itertools::assert_equal;
1576+
use itertools::Itertools;
1577+
1578+
// 3-d axis swaps
1579+
fn swaps() -> impl Iterator<Item=Vec<(usize, usize)>> {
1580+
vec![
1581+
vec![],
1582+
vec![(0, 1)],
1583+
vec![(0, 2)],
1584+
vec![(1, 2)],
1585+
vec![(0, 1), (1, 2)],
1586+
vec![(0, 1), (0, 2)],
1587+
].into_iter()
1588+
}
1589+
1590+
// 3-d axis inverts
1591+
fn inverts() -> impl Iterator<Item=Vec<Axis>> {
1592+
vec![
1593+
vec![],
1594+
vec![Axis(0)],
1595+
vec![Axis(1)],
1596+
vec![Axis(2)],
1597+
vec![Axis(0), Axis(1)],
1598+
vec![Axis(0), Axis(2)],
1599+
vec![Axis(1), Axis(2)],
1600+
vec![Axis(0), Axis(1), Axis(2)],
1601+
].into_iter()
1602+
}
1603+
1604+
#[test]
1605+
fn test_arbitrary_order() {
1606+
for swap in swaps() {
1607+
for invert in inverts() {
1608+
for &slice in &[false, true] {
1609+
// pattern is 0, 1; 4, 5; 8, 9; etc..
1610+
let mut a = Array::from_iter(0..24).into_shape((3, 4, 2)).unwrap();
1611+
if slice {
1612+
a.slice_collapse(s![.., ..;2, ..]);
1613+
}
1614+
for &(i, j) in &swap {
1615+
a.swap_axes(i, j);
1616+
}
1617+
for &i in &invert {
1618+
a.invert_axis(i);
1619+
}
1620+
unsafe {
1621+
// Should have in-memory order for arbitrary order
1622+
let iter = Baseiter::new_with_order::<ArbitraryOrder>(a.as_mut_ptr(),
1623+
a.dim, a.strides);
1624+
if !slice {
1625+
assert_equal(iter.map(|ptr| *ptr), 0..a.len());
1626+
} else {
1627+
assert_eq!(iter.map(|ptr| *ptr).collect_vec(),
1628+
(0..a.len() * 2).filter(|&x| (x / 2) % 2 == 0).collect_vec());
1629+
}
1630+
}
1631+
}
1632+
}
1633+
}
1634+
}
1635+
1636+
#[test]
1637+
fn test_logical_order() {
1638+
for swap in swaps() {
1639+
for invert in inverts() {
1640+
for &slice in &[false, true] {
1641+
let mut a = Array::from_iter(0..24).into_shape((3, 4, 2)).unwrap();
1642+
for &(i, j) in &swap {
1643+
a.swap_axes(i, j);
1644+
}
1645+
for &i in &invert {
1646+
a.invert_axis(i);
1647+
}
1648+
if slice {
1649+
a.slice_collapse(s![.., ..;2, ..]);
1650+
}
1651+
1652+
unsafe {
1653+
let mut iter = Baseiter::new_with_order::<NoOptimization>(a.as_mut_ptr(),
1654+
a.dim, a.strides);
1655+
let mut index = Dim([0, 0, 0]);
1656+
let mut elts = 0;
1657+
while let Some(elt) = iter.next() {
1658+
assert_eq!(*elt, a[index]);
1659+
if let Some(index_) = a.raw_dim().next_for(index) {
1660+
index = index_;
1661+
}
1662+
elts += 1;
1663+
}
1664+
assert_eq!(elts, a.len());
1665+
}
1666+
}
1667+
}
1668+
}
1669+
}
1670+
1671+
#[test]
1672+
fn test_preserve_order() {
1673+
for swap in swaps() {
1674+
for invert in inverts() {
1675+
for &slice in &[false, true] {
1676+
let mut a = Array::from_iter(0..20).into_shape((2, 10, 1)).unwrap();
1677+
for &(i, j) in &swap {
1678+
a.swap_axes(i, j);
1679+
}
1680+
for &i in &invert {
1681+
a.invert_axis(i);
1682+
}
1683+
if slice {
1684+
a.slice_collapse(s![.., ..;2, ..]);
1685+
}
1686+
1687+
unsafe {
1688+
let mut iter = Baseiter::new_with_order::<PreserveOrder>(
1689+
a.as_mut_ptr(), a.dim, a.strides);
1690+
1691+
// check that axes have been merged (when it's easy to check)
1692+
if a.shape() == &[2, 10, 1] && invert.is_empty() {
1693+
assert_eq!(iter.dim, Dim([1, 1, 20]));
1694+
}
1695+
1696+
let mut index = Dim([0, 0, 0]);
1697+
let mut elts = 0;
1698+
while let Some(elt) = iter.next() {
1699+
assert_eq!(*elt, a[index]);
1700+
if let Some(index_) = a.raw_dim().next_for(index) {
1701+
index = index_;
1702+
}
1703+
elts += 1;
1704+
}
1705+
assert_eq!(elts, a.len());
1706+
}
1707+
}
1708+
}
1709+
}
1710+
}
1711+
}

0 commit comments

Comments
 (0)