Skip to content
13 changes: 9 additions & 4 deletions src/impl_1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@ impl<A, S> ArrayBase<S, Ix1>
}
}

/// Return the element that would occupy the `i`-th position if the array
/// were sorted in increasing order.
/// Return the element that would occupy the `i`-th position if
/// the array were sorted in increasing order.
///
/// The array is shuffled **in place** to retrieve the desired element:
/// no copy of the array is allocated.
/// No assumptions should be made on the ordering of elements
/// after this computation.
/// After the shuffling, all elements with an index smaller than `i`
/// are smaller than the desired element, while all elements with
/// an index greater or equal than `i` are greater than or equal
/// to the desired element.
///
/// No other assumptions should be made on the ordering of the
/// elements after this computation.
///
/// Complexity ([quickselect](https://en.wikipedia.org/wiki/Quickselect)):
/// - average case: O(`n`);
Expand Down
94 changes: 82 additions & 12 deletions src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,69 @@ use {
Zip,
};

/// Used to choose the interpolation strategy in [`percentile_axis_mut`].
/// Used to provide an interpolation strategy to [`percentile_axis_mut`].
///
/// [`percentile_axis_mut`]: struct.ArrayBase.html#method.percentile_axis_mut
pub enum InterpolationStrategy {
Lower,
Nearest,
Higher,
pub trait Interpolate<T> {
fn float_percentile_index(q: f64, len: usize) -> f64 {
((len - 1) as f64) * q
}

fn lower_index(q: f64, len: usize) -> usize {
Self::float_percentile_index(q, len).floor() as usize
}

fn upper_index(q: f64, len: usize) -> usize {
Self::float_percentile_index(q, len).ceil() as usize
}
fn needs_lower(q: f64, len: usize) -> bool;
fn needs_upper(q: f64, len: usize) -> bool;
fn interpolate(lower: Option<T>, upper: Option<T>, q: f64, len: usize) -> T;
}

pub struct Upper;
pub struct Lower;
pub struct Nearest;

impl<T> Interpolate<T> for Upper {
fn needs_lower(_q: f64, _len: usize) -> bool {
false
}
fn needs_upper(_q: f64, _len: usize) -> bool {
true
}
fn interpolate(_lower: Option<T>, upper: Option<T>, _q: f64, _len: usize) -> T {
upper.unwrap()
}
}

impl<T> Interpolate<T> for Lower {
fn needs_lower(_q: f64, _len: usize) -> bool {
true
}
fn needs_upper(_q: f64, _len: usize) -> bool {
false
}
fn interpolate(lower: Option<T>, _upper: Option<T>, _q: f64, _len: usize) -> T {
lower.unwrap()
}
}

impl<T> Interpolate<T> for Nearest {
fn needs_lower(q: f64, len: usize) -> bool {
let lower = <Self as Interpolate<T>>::lower_index(q, len);
((lower as f64) - <Self as Interpolate<T>>::float_percentile_index(q, len)) <= 0.
}
fn needs_upper(q: f64, len: usize) -> bool {
!<Self as Interpolate<T>>::needs_lower(q, len)
}
fn interpolate(lower: Option<T>, upper: Option<T>, q: f64, len: usize) -> T {
if <Self as Interpolate<T>>::needs_lower(q, len) {
lower.unwrap()
} else {
upper.unwrap()
}
}
}

/// Numerical methods for arrays.
Expand Down Expand Up @@ -154,19 +210,33 @@ impl<A, S, D> ArrayBase<S, D>
///
/// **Panics** if `axis` is out of bounds or if `q` is not between
/// `0.` and `1.` (inclusive).
pub fn percentile_axis_mut(&mut self, axis: Axis, q: f64, interpolation_strategy: InterpolationStrategy) -> Array<A, D::Smaller>
pub fn percentile_axis_mut<I>(&mut self, axis: Axis, q: f64) -> Array<A, D::Smaller>
where D: RemoveAxis,
A: Ord + Clone + Zero,
S: DataMut,
I: Interpolate<Array<A, D::Smaller>>,
{
assert!((0. <= q) && (q <= 1.));
let float_percentile_index = ((self.len_of(axis) - 1) as f64) * q;
let percentile_index = match interpolation_strategy {
InterpolationStrategy::Lower => float_percentile_index.floor() as usize,
InterpolationStrategy::Nearest => float_percentile_index.round() as usize,
InterpolationStrategy::Higher => float_percentile_index.ceil() as usize,
let mut lower = None;
let mut upper = None;
let axis_len = self.len_of(axis);
if I::needs_lower(q, axis_len) {
lower = Some(
self.map_axis_mut(
axis,
|mut x| x.sorted_get_mut(I::lower_index(q, axis_len))
)
);
};
if I::needs_upper(q, axis_len) {
upper = Some(
self.map_axis_mut(
axis,
|mut x| x.sorted_get_mut(I::upper_index(q, axis_len))
)
);
};
self.map_axis_mut(axis, |mut x| x.sorted_get_mut(percentile_index))
I::interpolate(lower, upper, q, axis_len)
}

/// Return variance along `axis`.
Expand Down
4 changes: 2 additions & 2 deletions src/numeric/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pub use self::impl_numeric::*;

pub mod impl_numeric;
pub use self::impl_numeric::InterpolationStrategy;
mod impl_numeric;
10 changes: 5 additions & 5 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ extern crate itertools;

use ndarray::{Slice, SliceInfo, SliceOrIndex};
use ndarray::prelude::*;
use ndarray::numeric::InterpolationStrategy;
use ndarray::numeric::{Lower};
use ndarray::{
rcarr2,
arr3,
Expand Down Expand Up @@ -1759,7 +1759,7 @@ fn test_percentile_axis_mut_with_odd_axis_length() {
[3, 5, 6, 12]
]
);
let p = a.percentile_axis_mut(Axis(0), 0.5, InterpolationStrategy::Lower);
let p = a.percentile_axis_mut::<Lower>(Axis(0), 0.5);
assert!(p == a.subview(Axis(0), 1));
}

Expand All @@ -1773,20 +1773,20 @@ fn test_percentile_axis_mut_with_even_axis_length() {
[4, 6, 7, 13]
]
);
let q = b.percentile_axis_mut(Axis(0), 0.5, InterpolationStrategy::Lower);
let q = b.percentile_axis_mut::<Lower>(Axis(0), 0.5);
assert!(q == b.subview(Axis(0), 1));
}

#[test]
fn test_percentile_axis_mut_to_get_minimum() {
let mut b = arr2(&[[1, 3, 22, 10]]);
let q = b.percentile_axis_mut(Axis(1), 0., InterpolationStrategy::Lower);
let q = b.percentile_axis_mut::<Lower>(Axis(1), 0.);
assert!(q == arr1(&[1]));
}

#[test]
fn test_percentile_axis_mut_to_get_maximum() {
let mut b = arr1(&[1, 3, 22, 10]);
let q = b.percentile_axis_mut(Axis(0), 1., InterpolationStrategy::Lower);
let q = b.percentile_axis_mut::<Lower>(Axis(0), 1.);
assert!(q == arr0(22));
}