diff --git a/Cargo.toml b/Cargo.toml index 10a9f43b9..9e0d93192 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ num-traits = "0.2" num-complex = "0.2" rustc-serialize = { version = "0.3.20", optional = true } itertools = { version = "0.7.0", default-features = false } +rand = { version = "0.5" } # Use via the `blas` crate feature! cblas-sys = { version = "0.1.4", optional = true, default-features = false } diff --git a/src/impl_1d.rs b/src/impl_1d.rs index 4d013e7d5..e69162259 100644 --- a/src/impl_1d.rs +++ b/src/impl_1d.rs @@ -10,6 +10,9 @@ //! Methods for one-dimensional arrays. use imp_prelude::*; +use rand::prelude::*; +use rand::thread_rng; + impl ArrayBase where S: Data, { @@ -23,5 +26,95 @@ impl ArrayBase ::iterators::to_vec(self.iter().map(|x| x.clone())) } } -} + /// Return the element that would occupy the `i`-th position if + /// the array were sorted in increasing order. + /// + /// The array is shuffled **in place** to retrieve the desired element: + /// no copy of the array is allocated. + /// After the shuffling, all elements with an index smaller than `i` + /// are smaller than the desired element, while all elements with + /// an index greater or equal than `i` are greater than or equal + /// to the desired element. + /// + /// No other assumptions should be made on the ordering of the + /// elements after this computation. + /// + /// Complexity ([quickselect](https://en.wikipedia.org/wiki/Quickselect)): + /// - average case: O(`n`); + /// - worst case: O(`n`^2); + /// where n is the number of elements in the array. + /// + /// **Panics** if `i` is greater than or equal to `n`. + pub fn sorted_get_mut(&mut self, i: usize) -> A + where A: Ord + Clone, + S: DataMut, + { + let n = self.len(); + if n == 1 { + self[0].clone() + } else { + let mut rng = thread_rng(); + let pivot_index = rng.gen_range(0, n); + let partition_index = self.partition_mut(pivot_index); + if i < partition_index { + self.slice_mut(s![..partition_index]).sorted_get_mut(i) + } else if i == partition_index { + self[i].clone() + } else { + self.slice_mut(s![partition_index+1..]).sorted_get_mut(i - (partition_index+1)) + } + } + } + + /// Return the index of `self[partition_index]` if `self` were to be sorted + /// in increasing order. + /// + /// `self` elements are rearranged in such a way that `self[partition_index]` + /// is in the position it would be in an array sorted in increasing order. + /// All elements smaller than `self[partition_index]` are moved to its + /// left and all elements equal or greater than `self[partition_index]` + /// are moved to its right. + /// The ordering of the elements in the two partitions is undefined. + /// + /// `self` is shuffled **in place** to operate the desired partition: + /// no copy of the array is allocated. + /// + /// The method uses Hoare's partition algorithm. + /// Complexity: O(`n`), where `n` is the number of elements in the array. + /// Average number of element swaps: n/6 - 1/3 (see + /// (link)[https://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto/11550]) + /// + /// **Panics** if `partition_index` is greater than or equal to `n`. + pub fn partition_mut(&mut self, pivot_index: usize) -> usize + where A: Ord + Clone, + S: DataMut + { + let pivot_value = self[pivot_index].clone(); + self.swap(pivot_index, 0); + + let n = self.len(); + let mut i = 1; + let mut j = n - 1; + loop { + loop { + if i > j { break } + if self[i] >= pivot_value { break } + i += 1; + } + while pivot_value <= self[j] { + if j == 1 { break } + j -= 1; + } + if i >= j { + break + } else { + self.swap(i, j); + i += 1; + j -= 1; + } + } + self.swap(0, i-1); + i-1 + } +} diff --git a/src/lib.rs b/src/lib.rs index 9868b0c53..bdf2cb352 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,6 +100,8 @@ extern crate matrixmultiply; extern crate num_traits as libnum; extern crate num_complex; +extern crate rand; + #[cfg(feature = "docs")] pub mod doc; @@ -158,6 +160,7 @@ mod free_functions; pub use free_functions::*; pub use iterators::iter; +#[macro_use] mod slice; mod layout; mod indexes; @@ -893,7 +896,7 @@ impl ArrayBase mod impl_1d; mod impl_2d; -mod numeric; +pub mod numeric; pub mod linalg; diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 2f0b1b054..8ffdc354c 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -6,8 +6,8 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::ops::Add; -use libnum::{self, Zero, Float}; +use std::ops::{Add, Sub, Div, Mul}; +use libnum::{self, Zero, Float, FromPrimitive}; use itertools::free::enumerate; use imp_prelude::*; @@ -19,6 +19,133 @@ use { Zip, }; +/// Used to provide an interpolation strategy to [`percentile_axis_mut`]. +/// +/// [`percentile_axis_mut`]: struct.ArrayBase.html#method.percentile_axis_mut +pub trait Interpolate { + fn float_percentile_index(q: f64, len: usize) -> f64 { + ((len - 1) as f64) * q + } + + fn lower_index(q: f64, len: usize) -> usize { + Self::float_percentile_index(q, len).floor() as usize + } + + fn upper_index(q: f64, len: usize) -> usize { + Self::float_percentile_index(q, len).ceil() as usize + } + + fn float_percentile_index_fraction(q: f64, len: usize) -> f64 { + Self::float_percentile_index(q, len) - (Self::lower_index(q, len) as f64) + } + + fn needs_lower(q: f64, len: usize) -> bool; + fn needs_upper(q: f64, len: usize) -> bool; + fn interpolate(lower: Option>, + upper: Option>, + q: f64, + len: usize) -> Array + where D: Dimension; +} + +pub struct Upper; +pub struct Lower; +pub struct Nearest; +pub struct Midpoint; +pub struct Linear; + +impl Interpolate for Upper { + fn needs_lower(_q: f64, _len: usize) -> bool { + false + } + fn needs_upper(_q: f64, _len: usize) -> bool { + true + } + fn interpolate(_lower: Option>, + upper: Option>, + _q: f64, + _len: usize) -> Array { + upper.unwrap() + } +} + +impl Interpolate for Lower { + fn needs_lower(_q: f64, _len: usize) -> bool { + true + } + fn needs_upper(_q: f64, _len: usize) -> bool { + false + } + fn interpolate(lower: Option>, + _upper: Option>, + _q: f64, + _len: usize) -> Array { + lower.unwrap() + } +} + +impl Interpolate for Nearest { + fn needs_lower(q: f64, len: usize) -> bool { + let lower = >::lower_index(q, len); + ((lower as f64) - >::float_percentile_index(q, len)) <= 0. + } + fn needs_upper(q: f64, len: usize) -> bool { + !>::needs_lower(q, len) + } + fn interpolate(lower: Option>, + upper: Option>, + q: f64, + len: usize) -> Array { + if >::needs_lower(q, len) { + lower.unwrap() + } else { + upper.unwrap() + } + } +} + +impl Interpolate for Midpoint + where T: Add + Div + Clone + FromPrimitive +{ + fn needs_lower(_q: f64, _len: usize) -> bool { + true + } + fn needs_upper(_q: f64, _len: usize) -> bool { + true + } + fn interpolate(lower: Option>, + upper: Option>, + _q: f64, _len: usize) -> Array + where D: Dimension + { + let denom = T::from_u8(2).unwrap(); + (lower.unwrap() + upper.unwrap()).mapv_into(|x| x / denom.clone()) + } +} + +impl Interpolate for Linear + where T: Add + Sub + Mul + Clone + FromPrimitive +{ + fn needs_lower(_q: f64, _len: usize) -> bool { + true + } + fn needs_upper(_q: f64, _len: usize) -> bool { + true + } + fn interpolate(lower: Option>, + upper: Option>, + q: f64, len: usize) -> Array + where D: Dimension + { + let fraction = T::from_f64( + >::float_percentile_index_fraction(q, len) + ).unwrap(); + let a = lower.unwrap().mapv_into(|x| x * fraction.clone()); + let b = upper.unwrap().mapv_into(|x| x * (T::from_u8(1).unwrap() - fraction.clone())); + a + b + } +} + /// Numerical methods for arrays. impl ArrayBase where S: Data, @@ -115,6 +242,75 @@ impl ArrayBase sum / &aview0(&cnt) } + /// Return the qth percentile of the data along the specified axis. + /// + /// `q` needs to be a float between 0 and 1, bounds included. + /// The qth percentile for a 1-dimensional lane of length `N` is defined + /// as the element that would be indexed as `(N-1)q` if the lane were to be sorted + /// in increasing order. + /// If `(N-1)q` is not an integer the desired percentile lies between + /// two data points: we return the lower, nearest, higher or interpolated + /// value depending on the type `Interpolate` bound `I`. + /// + /// Some examples: + /// - `q=0.` returns the minimum along each 1-dimensional lane; + /// - `q=0.5` returns the median along each 1-dimensional lane; + /// - `q=1.` returns the maximum along each 1-dimensional lane. + /// (`q=0` and `q=1` are considered improper percentiles) + /// + /// The array is shuffled **in place** along each 1-dimensional lane in + /// order to produce the required percentile without allocating a copy + /// of the original array. Each 1-dimensional lane is shuffled independently + /// from the others. + /// No assumptions should be made on the ordering of the array elements + /// after this computation. + /// + /// Complexity ([quickselect](https://en.wikipedia.org/wiki/Quickselect)): + /// - average case: O(`m`); + /// - worst case: O(`m`^2); + /// where `m` is the number of elements in the array. + /// + /// **Panics** if `axis` is out of bounds or if `q` is not between + /// `0.` and `1.` (inclusive). + pub fn percentile_axis_mut(&mut self, axis: Axis, q: f64) -> Array + where D: RemoveAxis, + A: Ord + Clone, + S: DataMut, + I: Interpolate, + { + assert!((0. <= q) && (q <= 1.)); + let mut lower = None; + let mut upper = None; + let axis_len = self.len_of(axis); + if I::needs_lower(q, axis_len) { + let lower_index = I::lower_index(q, axis_len); + lower = Some( + self.map_axis_mut( + axis, + |mut x| x.sorted_get_mut(lower_index) + ) + ); + if I::needs_upper(q, axis_len) { + let upper_index = I::upper_index(q, axis_len); + let relative_upper_index = upper_index - lower_index; + upper = Some( + self.map_axis_mut( + axis, + |mut x| x.slice_mut(s![lower_index..]).sorted_get_mut(relative_upper_index) + ) + ); + }; + } else { + upper = Some( + self.map_axis_mut( + axis, + |mut x| x.sorted_get_mut(I::upper_index(q, axis_len)) + ) + ); + }; + I::interpolate(lower, upper, q, axis_len) + } + /// Return variance along `axis`. /// /// The variance is computed using the [Welford one-pass @@ -201,4 +397,3 @@ impl ArrayBase }).is_done() } } - diff --git a/src/numeric/mod.rs b/src/numeric/mod.rs index 60c5e39f6..30024d902 100644 --- a/src/numeric/mod.rs +++ b/src/numeric/mod.rs @@ -1,3 +1,3 @@ +pub use self::impl_numeric::*; mod impl_numeric; - diff --git a/tests/array.rs b/tests/array.rs index 01f9041ec..75a2610fa 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -8,6 +8,7 @@ extern crate itertools; use ndarray::{Slice, SliceInfo, SliceOrIndex}; use ndarray::prelude::*; +use ndarray::numeric::{Lower}; use ndarray::{ rcarr2, arr3, @@ -1710,3 +1711,82 @@ fn array_macros() { let empty2: Array = array![[]]; assert_eq!(empty2, array![[]]); } + +#[test] +fn test_partition_mut() { + let mut l = vec!( + arr1(&[1, 1, 1, 1, 1]), + arr1(&[1, 3, 2, 10, 10]), + arr1(&[2, 3, 4, 1]), + arr1(&[355, 453, 452, 391, 289, 343, 44, 154, 271, 44, 314, 276, 160, + 469, 191, 138, 163, 308, 395, 3, 416, 391, 210, 354, 200]), + arr1(&[ 84, 192, 216, 159, 89, 296, 35, 213, 456, 278, 98, 52, 308, + 418, 329, 173, 286, 106, 366, 129, 125, 450, 23, 463, 151]), + ); + + for a in l.iter_mut() { + let n = a.len(); + let pivot_index = n-1; + let pivot_value = a[pivot_index].clone(); + let partition_index = a.partition_mut(pivot_index); + for i in 0..partition_index { + assert!(a[i] < pivot_value); + } + assert!(a[partition_index] == pivot_value); + for j in (partition_index+1)..n { + assert!(pivot_value <= a[j]); + } + } +} + +#[test] +fn test_sorted_get_mut() { + let a = arr1(&[1, 3, 2, 10]); + let j = a.clone().view_mut().sorted_get_mut(2); + assert_eq!(j, 3); + let j = a.clone().view_mut().sorted_get_mut(1); + assert_eq!(j, 2); + let j = a.clone().view_mut().sorted_get_mut(3); + assert_eq!(j, 10); +} + +#[test] +fn test_percentile_axis_mut_with_odd_axis_length() { + let mut a = arr2( + &[ + [1, 3, 2, 10], + [2, 4, 3, 11], + [3, 5, 6, 12] + ] + ); + let p = a.percentile_axis_mut::(Axis(0), 0.5); + assert!(p == a.subview(Axis(0), 1)); +} + +#[test] +fn test_percentile_axis_mut_with_even_axis_length() { + let mut b = arr2( + &[ + [1, 3, 2, 10], + [2, 4, 3, 11], + [3, 5, 6, 12], + [4, 6, 7, 13] + ] + ); + let q = b.percentile_axis_mut::(Axis(0), 0.5); + assert!(q == b.subview(Axis(0), 1)); +} + +#[test] +fn test_percentile_axis_mut_to_get_minimum() { + let mut b = arr2(&[[1, 3, 22, 10]]); + let q = b.percentile_axis_mut::(Axis(1), 0.); + assert!(q == arr1(&[1])); +} + +#[test] +fn test_percentile_axis_mut_to_get_maximum() { + let mut b = arr1(&[1, 3, 22, 10]); + let q = b.percentile_axis_mut::(Axis(0), 1.); + assert!(q == arr0(22)); +}