Skip to content

Commit 1bb8849

Browse files
authored
Merge pull request #216 from bluss/broadcast
Work around compiler issues with .broadcast() and .remove_axis()
2 parents 1d7082d + 398215a commit 1bb8849

File tree

6 files changed

+60
-11
lines changed

6 files changed

+60
-11
lines changed

.travis.yml

+3-4
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ addons:
2323
- libopenblas-dev
2424
script:
2525
- |
26-
cargo build --verbose &&
27-
cargo test --verbose &&
28-
([ -z "$FEATURES" ] || cargo build --verbose --features "$FEATURES") &&
29-
([ -z "$FEATURES" ] || cargo test --verbose --features "$FEATURES") &&
26+
cargo build --verbose --features "$FEATURES" &&
27+
cargo test --verbose --features "$FEATURES" &&
28+
cargo test --release --verbose --features "" &&
3029
([ "$BENCH" != 1 ] || cargo bench --no-run --verbose --features "$FEATURES")

ndarray-tests/tests/accuracy.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ fn reference_mat_mul<A, S, S2>(lhs: &ArrayBase<S, (Ix, Ix)>, rhs: &ArrayBase<S2,
4747
}
4848
}
4949
unsafe {
50-
ArrayBase::from_vec_dim_unchecked((m, n), res_elems)
50+
ArrayBase::from_shape_vec_unchecked((m, n), res_elems)
5151
}
5252
}
5353

src/arraytraits.rs

-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
8-
#[cfg(feature = "rustc-serialize")]
9-
use serialize::{Encodable, Encoder, Decodable, Decoder};
108

119
use std::hash;
1210
use std::iter::FromIterator;

src/dimension.rs

+18
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,8 @@ pub trait RemoveAxis : Dimension {
639639
}
640640

641641
macro_rules! impl_shrink(
642+
($_a:ident, ) => {}; // implement this case manually below
643+
($_a:ident, $_b:ident, ) => {}; // implement this case manually below
642644
($from:ident, $($more:ident,)*) => (
643645
impl RemoveAxis for ($from $(,$more)*)
644646
{
@@ -665,6 +667,22 @@ impl RemoveAxis for ($from $(,$more)*)
665667
)
666668
);
667669

670+
impl RemoveAxis for Ix {
671+
type Smaller = ();
672+
#[inline]
673+
fn remove_axis(&self, _: Axis) { }
674+
}
675+
676+
impl RemoveAxis for (Ix, Ix) {
677+
type Smaller = Ix;
678+
#[inline]
679+
fn remove_axis(&self, axis: Axis) -> Ix {
680+
let axis = axis.axis();
681+
debug_assert!(axis < self.ndim());
682+
if axis == 0 { self.1 } else { self.0 }
683+
}
684+
}
685+
668686
macro_rules! impl_shrink_recursive(
669687
($ix:ident, ) => (impl_shrink!($ix,););
670688
($ix1:ident, $($ix:ident,)*) => (

src/impl_methods.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -896,10 +896,9 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
896896
}
897897

898898
{
899-
let mut new_stride_iter = new_stride.slice_mut().iter_mut().rev();
900899
for ((er, es), dr) in from.slice().iter().rev()
901900
.zip(stride.slice().iter().rev())
902-
.zip(new_stride_iter.by_ref())
901+
.zip(new_stride.slice_mut().iter_mut().rev())
903902
{
904903
/* update strides */
905904
if *dr == *er {
@@ -914,7 +913,8 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
914913
}
915914

916915
/* set remaining strides to zero */
917-
for dr in new_stride_iter {
916+
let tail_len = to.ndim() - from.ndim();
917+
for dr in &mut new_stride.slice_mut()[..tail_len] {
918918
*dr = 0;
919919
}
920920
}

tests/broadcast.rs

+35-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
extern crate ndarray;
33

4-
use ndarray::{RcArray, Dimension};
4+
use ndarray::prelude::*;
55

66
#[test]
77
fn broadcast_1()
@@ -49,3 +49,37 @@ fn test_add_incompat()
4949
let incompat = RcArray::from_elem(3, 1.0f32);
5050
a += &incompat;
5151
}
52+
53+
#[test]
54+
fn test_broadcast() {
55+
let (_, n, k) = (16, 16, 16);
56+
let x1 = 1.;
57+
// b0 broadcast 1 -> n, k
58+
let x = Array::from_vec(vec![x1]);
59+
let b0 = x.broadcast((n, k)).unwrap();
60+
// b1 broadcast n -> n, k
61+
let b1 = Array::from_elem(n, x1);
62+
let b1 = b1.broadcast((n, k)).unwrap();
63+
// b2 is n, k
64+
let b2 = Array::from_elem((n, k), x1);
65+
66+
println!("b0=\n{:?}", b0);
67+
println!("b1=\n{:?}", b1);
68+
println!("b2=\n{:?}", b2);
69+
assert_eq!(b0, b1);
70+
assert_eq!(b0, b2);
71+
}
72+
73+
#[test]
74+
fn test_broadcast_1d() {
75+
let n = 16;
76+
let x1 = 1.;
77+
// b0 broadcast 1 -> n
78+
let x = Array::from_vec(vec![x1]);
79+
let b0 = x.broadcast(n).unwrap();
80+
let b2 = Array::from_elem(n, x1);
81+
82+
println!("b0=\n{:?}", b0);
83+
println!("b2=\n{:?}", b2);
84+
assert_eq!(b0, b2);
85+
}

0 commit comments

Comments
 (0)