diff --git a/.travis.yml b/.travis.yml index bc9a7c5e7..eee94920e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -23,8 +23,7 @@ addons: - libopenblas-dev script: - | - cargo build --verbose && - cargo test --verbose && - ([ -z "$FEATURES" ] || cargo build --verbose --features "$FEATURES") && - ([ -z "$FEATURES" ] || cargo test --verbose --features "$FEATURES") && + cargo build --verbose --features "$FEATURES" && + cargo test --verbose --features "$FEATURES" && + cargo test --release --verbose --features "" && ([ "$BENCH" != 1 ] || cargo bench --no-run --verbose --features "$FEATURES") diff --git a/ndarray-tests/tests/accuracy.rs b/ndarray-tests/tests/accuracy.rs index 5bd8bc3b5..2a55989c2 100644 --- a/ndarray-tests/tests/accuracy.rs +++ b/ndarray-tests/tests/accuracy.rs @@ -47,7 +47,7 @@ fn reference_mat_mul(lhs: &ArrayBase, rhs: &ArrayBase, at your // option. This file may not be copied, modified, or distributed // except according to those terms. -#[cfg(feature = "rustc-serialize")] -use serialize::{Encodable, Encoder, Decodable, Decoder}; use std::hash; use std::iter::FromIterator; diff --git a/src/dimension.rs b/src/dimension.rs index 09ad21b62..c39dd8fa4 100644 --- a/src/dimension.rs +++ b/src/dimension.rs @@ -639,6 +639,8 @@ pub trait RemoveAxis : Dimension { } macro_rules! impl_shrink( + ($_a:ident, ) => {}; // implement this case manually below + ($_a:ident, $_b:ident, ) => {}; // implement this case manually below ($from:ident, $($more:ident,)*) => ( impl RemoveAxis for ($from $(,$more)*) { @@ -665,6 +667,22 @@ impl RemoveAxis for ($from $(,$more)*) ) ); +impl RemoveAxis for Ix { + type Smaller = (); + #[inline] + fn remove_axis(&self, _: Axis) { } +} + +impl RemoveAxis for (Ix, Ix) { + type Smaller = Ix; + #[inline] + fn remove_axis(&self, axis: Axis) -> Ix { + let axis = axis.axis(); + debug_assert!(axis < self.ndim()); + if axis == 0 { self.1 } else { self.0 } + } +} + macro_rules! impl_shrink_recursive( ($ix:ident, ) => (impl_shrink!($ix,);); ($ix1:ident, $($ix:ident,)*) => ( diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 266d1115b..983a1b3ad 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -896,10 +896,9 @@ impl ArrayBase where S: Data, D: Dimension } { - let mut new_stride_iter = new_stride.slice_mut().iter_mut().rev(); for ((er, es), dr) in from.slice().iter().rev() .zip(stride.slice().iter().rev()) - .zip(new_stride_iter.by_ref()) + .zip(new_stride.slice_mut().iter_mut().rev()) { /* update strides */ if *dr == *er { @@ -914,7 +913,8 @@ impl ArrayBase where S: Data, D: Dimension } /* set remaining strides to zero */ - for dr in new_stride_iter { + let tail_len = to.ndim() - from.ndim(); + for dr in &mut new_stride.slice_mut()[..tail_len] { *dr = 0; } } diff --git a/tests/broadcast.rs b/tests/broadcast.rs index fa66bd192..daaa3cd9d 100644 --- a/tests/broadcast.rs +++ b/tests/broadcast.rs @@ -1,7 +1,7 @@ extern crate ndarray; -use ndarray::{RcArray, Dimension}; +use ndarray::prelude::*; #[test] fn broadcast_1() @@ -49,3 +49,37 @@ fn test_add_incompat() let incompat = RcArray::from_elem(3, 1.0f32); a += &incompat; } + +#[test] +fn test_broadcast() { + let (_, n, k) = (16, 16, 16); + let x1 = 1.; + // b0 broadcast 1 -> n, k + let x = Array::from_vec(vec![x1]); + let b0 = x.broadcast((n, k)).unwrap(); + // b1 broadcast n -> n, k + let b1 = Array::from_elem(n, x1); + let b1 = b1.broadcast((n, k)).unwrap(); + // b2 is n, k + let b2 = Array::from_elem((n, k), x1); + + println!("b0=\n{:?}", b0); + println!("b1=\n{:?}", b1); + println!("b2=\n{:?}", b2); + assert_eq!(b0, b1); + assert_eq!(b0, b2); +} + +#[test] +fn test_broadcast_1d() { + let n = 16; + let x1 = 1.; + // b0 broadcast 1 -> n + let x = Array::from_vec(vec![x1]); + let b0 = x.broadcast(n).unwrap(); + let b2 = Array::from_elem(n, x1); + + println!("b0=\n{:?}", b0); + println!("b2=\n{:?}", b2); + assert_eq!(b0, b2); +}