Skip to content

Commit 00ddf37

Browse files
Merge pull request #1 from LukeMathWalker/interpolation-strategy
Interpolation strategy
2 parents c6c459d + 1b39c61 commit 00ddf37

File tree

4 files changed

+98
-23
lines changed

4 files changed

+98
-23
lines changed

src/impl_1d.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,18 @@ impl<A, S> ArrayBase<S, Ix1>
2727
}
2828
}
2929

30-
/// Return the element that would occupy the `i`-th position if the array
31-
/// were sorted in increasing order.
30+
/// Return the element that would occupy the `i`-th position if
31+
/// the array were sorted in increasing order.
3232
///
3333
/// The array is shuffled **in place** to retrieve the desired element:
3434
/// no copy of the array is allocated.
35-
/// No assumptions should be made on the ordering of elements
36-
/// after this computation.
35+
/// After the shuffling, all elements with an index smaller than `i`
36+
/// are smaller than the desired element, while all elements with
37+
/// an index greater or equal than `i` are greater than or equal
38+
/// to the desired element.
39+
///
40+
/// No other assumptions should be made on the ordering of the
41+
/// elements after this computation.
3742
///
3843
/// Complexity ([quickselect](https://en.wikipedia.org/wiki/Quickselect)):
3944
/// - average case: O(`n`);

src/numeric/impl_numeric.rs

+82-12
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,69 @@ use {
1919
Zip,
2020
};
2121

22-
/// Used to choose the interpolation strategy in [`percentile_axis_mut`].
22+
/// Used to provide an interpolation strategy to [`percentile_axis_mut`].
2323
///
2424
/// [`percentile_axis_mut`]: struct.ArrayBase.html#method.percentile_axis_mut
25-
pub enum InterpolationStrategy {
26-
Lower,
27-
Nearest,
28-
Higher,
25+
pub trait Interpolate<T> {
26+
fn float_percentile_index(q: f64, len: usize) -> f64 {
27+
((len - 1) as f64) * q
28+
}
29+
30+
fn lower_index(q: f64, len: usize) -> usize {
31+
Self::float_percentile_index(q, len).floor() as usize
32+
}
33+
34+
fn upper_index(q: f64, len: usize) -> usize {
35+
Self::float_percentile_index(q, len).ceil() as usize
36+
}
37+
fn needs_lower(q: f64, len: usize) -> bool;
38+
fn needs_upper(q: f64, len: usize) -> bool;
39+
fn interpolate(lower: Option<T>, upper: Option<T>, q: f64, len: usize) -> T;
40+
}
41+
42+
pub struct Upper;
43+
pub struct Lower;
44+
pub struct Nearest;
45+
46+
impl<T> Interpolate<T> for Upper {
47+
fn needs_lower(_q: f64, _len: usize) -> bool {
48+
false
49+
}
50+
fn needs_upper(_q: f64, _len: usize) -> bool {
51+
true
52+
}
53+
fn interpolate(_lower: Option<T>, upper: Option<T>, _q: f64, _len: usize) -> T {
54+
upper.unwrap()
55+
}
56+
}
57+
58+
impl<T> Interpolate<T> for Lower {
59+
fn needs_lower(_q: f64, _len: usize) -> bool {
60+
true
61+
}
62+
fn needs_upper(_q: f64, _len: usize) -> bool {
63+
false
64+
}
65+
fn interpolate(lower: Option<T>, _upper: Option<T>, _q: f64, _len: usize) -> T {
66+
lower.unwrap()
67+
}
68+
}
69+
70+
impl<T> Interpolate<T> for Nearest {
71+
fn needs_lower(q: f64, len: usize) -> bool {
72+
let lower = <Self as Interpolate<T>>::lower_index(q, len);
73+
((lower as f64) - <Self as Interpolate<T>>::float_percentile_index(q, len)) <= 0.
74+
}
75+
fn needs_upper(q: f64, len: usize) -> bool {
76+
!<Self as Interpolate<T>>::needs_lower(q, len)
77+
}
78+
fn interpolate(lower: Option<T>, upper: Option<T>, q: f64, len: usize) -> T {
79+
if <Self as Interpolate<T>>::needs_lower(q, len) {
80+
lower.unwrap()
81+
} else {
82+
upper.unwrap()
83+
}
84+
}
2985
}
3086

3187
/// Numerical methods for arrays.
@@ -154,19 +210,33 @@ impl<A, S, D> ArrayBase<S, D>
154210
///
155211
/// **Panics** if `axis` is out of bounds or if `q` is not between
156212
/// `0.` and `1.` (inclusive).
157-
pub fn percentile_axis_mut(&mut self, axis: Axis, q: f64, interpolation_strategy: InterpolationStrategy) -> Array<A, D::Smaller>
213+
pub fn percentile_axis_mut<I>(&mut self, axis: Axis, q: f64) -> Array<A, D::Smaller>
158214
where D: RemoveAxis,
159215
A: Ord + Clone + Zero,
160216
S: DataMut,
217+
I: Interpolate<Array<A, D::Smaller>>,
161218
{
162219
assert!((0. <= q) && (q <= 1.));
163-
let float_percentile_index = ((self.len_of(axis) - 1) as f64) * q;
164-
let percentile_index = match interpolation_strategy {
165-
InterpolationStrategy::Lower => float_percentile_index.floor() as usize,
166-
InterpolationStrategy::Nearest => float_percentile_index.round() as usize,
167-
InterpolationStrategy::Higher => float_percentile_index.ceil() as usize,
220+
let mut lower = None;
221+
let mut upper = None;
222+
let axis_len = self.len_of(axis);
223+
if I::needs_lower(q, axis_len) {
224+
lower = Some(
225+
self.map_axis_mut(
226+
axis,
227+
|mut x| x.sorted_get_mut(I::lower_index(q, axis_len))
228+
)
229+
);
230+
};
231+
if I::needs_upper(q, axis_len) {
232+
upper = Some(
233+
self.map_axis_mut(
234+
axis,
235+
|mut x| x.sorted_get_mut(I::upper_index(q, axis_len))
236+
)
237+
);
168238
};
169-
self.map_axis_mut(axis, |mut x| x.sorted_get_mut(percentile_index))
239+
I::interpolate(lower, upper, q, axis_len)
170240
}
171241

172242
/// Return variance along `axis`.

src/numeric/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1+
pub use self::impl_numeric::*;
12

2-
pub mod impl_numeric;
3-
pub use self::impl_numeric::InterpolationStrategy;
3+
mod impl_numeric;

tests/array.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ extern crate itertools;
88

99
use ndarray::{Slice, SliceInfo, SliceOrIndex};
1010
use ndarray::prelude::*;
11-
use ndarray::numeric::InterpolationStrategy;
11+
use ndarray::numeric::{Lower};
1212
use ndarray::{
1313
rcarr2,
1414
arr3,
@@ -1759,7 +1759,7 @@ fn test_percentile_axis_mut_with_odd_axis_length() {
17591759
[3, 5, 6, 12]
17601760
]
17611761
);
1762-
let p = a.percentile_axis_mut(Axis(0), 0.5, InterpolationStrategy::Lower);
1762+
let p = a.percentile_axis_mut::<Lower>(Axis(0), 0.5);
17631763
assert!(p == a.subview(Axis(0), 1));
17641764
}
17651765

@@ -1773,20 +1773,20 @@ fn test_percentile_axis_mut_with_even_axis_length() {
17731773
[4, 6, 7, 13]
17741774
]
17751775
);
1776-
let q = b.percentile_axis_mut(Axis(0), 0.5, InterpolationStrategy::Lower);
1776+
let q = b.percentile_axis_mut::<Lower>(Axis(0), 0.5);
17771777
assert!(q == b.subview(Axis(0), 1));
17781778
}
17791779

17801780
#[test]
17811781
fn test_percentile_axis_mut_to_get_minimum() {
17821782
let mut b = arr2(&[[1, 3, 22, 10]]);
1783-
let q = b.percentile_axis_mut(Axis(1), 0., InterpolationStrategy::Lower);
1783+
let q = b.percentile_axis_mut::<Lower>(Axis(1), 0.);
17841784
assert!(q == arr1(&[1]));
17851785
}
17861786

17871787
#[test]
17881788
fn test_percentile_axis_mut_to_get_maximum() {
17891789
let mut b = arr1(&[1, 3, 22, 10]);
1790-
let q = b.percentile_axis_mut(Axis(0), 1., InterpolationStrategy::Lower);
1790+
let q = b.percentile_axis_mut::<Lower>(Axis(0), 1.);
17911791
assert!(q == arr0(22));
17921792
}

0 commit comments

Comments
 (0)