Skip to content

Commit c73a0b1

Browse files
committed
Use rayon-git to implement IntoParallelIterator for array views
1 parent 300c157 commit c73a0b1

File tree

3 files changed

+67
-2
lines changed

3 files changed

+67
-2
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ optional = true
3939
blas-sys = { version = "0.6.3", optional = true, default-features = false }
4040
matrixmultiply = { version = "0.1.11" }
4141

42-
#rayon = { git = "https://github.com/nikomatsakis/rayon" }
43-
rayon = { version = "*" }
42+
rayon = { git = "https://github.com/nikomatsakis/rayon" }
43+
#rayon = { version = "*" }
4444

4545
[dependencies.serde]
4646
version = "0.8"

src/iterators/par.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ use rayon::par_iter::ExactParallelIterator;
77
use rayon::par_iter::BoundedParallelIterator;
88
use rayon::par_iter::internal::{Consumer, UnindexedConsumer};
99
use rayon::par_iter::internal::bridge;
10+
use rayon::par_iter::internal::bridge_unindexed;
1011
use rayon::par_iter::internal::ProducerCallback;
1112
use rayon::par_iter::internal::Producer;
13+
use rayon::par_iter::internal::UnindexedProducer;
1214

1315
use super::AxisIter;
1416
use super::AxisIterMut;
@@ -112,3 +114,55 @@ macro_rules! par_iter_wrapper {
112114

113115
par_iter_wrapper!(AxisIter, [Sync]);
114116
par_iter_wrapper!(AxisIterMut, [Send + Sync]);
117+
118+
macro_rules! par_iter_view_wrapper {
119+
// thread_bounds are either Sync or Send + Sync
120+
($view_name:ident, [$($thread_bounds:tt)*]) => {
121+
impl<'a, A, D> IntoParallelIterator for $view_name<'a, A, D>
122+
where D: Dimension,
123+
A: $($thread_bounds)*,
124+
{
125+
type Item = <Self as IntoIterator>::Item;
126+
type Iter = Parallel<Self>;
127+
fn into_par_iter(self) -> Self::Iter {
128+
Parallel {
129+
iter: self,
130+
}
131+
}
132+
}
133+
134+
135+
impl<'a, A, D> ParallelIterator for Parallel<$view_name<'a, A, D>>
136+
where D: Dimension,
137+
A: $($thread_bounds)*,
138+
{
139+
type Item = <$view_name<'a, A, D> as IntoIterator>::Item;
140+
fn drive_unindexed<C>(self, consumer: C) -> C::Result
141+
where C: UnindexedConsumer<Self::Item>
142+
{
143+
bridge_unindexed(self.iter, consumer)
144+
}
145+
}
146+
147+
impl<'a, A, D> UnindexedProducer for $view_name<'a, A, D>
148+
where D: Dimension,
149+
A: $($thread_bounds)*,
150+
{
151+
fn can_split(&self) -> bool {
152+
self.len() > 1
153+
}
154+
155+
fn split(self) -> (Self, Self) {
156+
let max_axis = self.max_stride_axis();
157+
let mid = self.len_of(max_axis);
158+
self.split_at(max_axis, mid)
159+
}
160+
}
161+
162+
}
163+
}
164+
165+
use super::Iter;
166+
167+
par_iter_view_wrapper!(ArrayView, [Sync]);
168+
par_iter_view_wrapper!(ArrayViewMut, [Sync + Send]);

tests/rayon.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,14 @@ fn test_axis_iter_mut() {
3030
println!("{:?}", a.slice(s![..10, ..5]));
3131
assert!(a.all_close(&b, 0.001));
3232
}
33+
34+
#[test]
35+
fn test_regular_iter() {
36+
let mut a = Array2::<f64>::zeros((M, N));
37+
for (i, mut v) in a.axis_iter_mut(Axis(0)).enumerate() {
38+
v.fill(i as _);
39+
}
40+
let s = a.view().into_par_iter().map(|&x| x).sum();
41+
println!("{:?}", a.slice(s![..10, ..5]));
42+
assert_eq!(s, a.scalar_sum());
43+
}

0 commit comments

Comments
 (0)