Skip to content

Commit 15aa626

Browse files
Merge pull request #2 from LukeMathWalker/add-more-interpolation-strategies
Add more interpolation strategies
2 parents 00ddf37 + ff51479 commit 15aa626

File tree

1 file changed

+71
-9
lines changed

1 file changed

+71
-9
lines changed

src/numeric/impl_numeric.rs

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9-
use std::ops::Add;
10-
use libnum::{self, Zero, Float};
9+
use std::ops::{Add, Sub, Div, Mul};
10+
use libnum::{self, Zero, Float, FromPrimitive};
1111
use itertools::free::enumerate;
1212

1313
use imp_prelude::*;
@@ -34,14 +34,25 @@ pub trait Interpolate<T> {
3434
fn upper_index(q: f64, len: usize) -> usize {
3535
Self::float_percentile_index(q, len).ceil() as usize
3636
}
37+
38+
fn float_percentile_index_fraction(q: f64, len: usize) -> f64 {
39+
Self::float_percentile_index(q, len) - (Self::lower_index(q, len) as f64)
40+
}
41+
3742
fn needs_lower(q: f64, len: usize) -> bool;
3843
fn needs_upper(q: f64, len: usize) -> bool;
39-
fn interpolate(lower: Option<T>, upper: Option<T>, q: f64, len: usize) -> T;
44+
fn interpolate<D>(lower: Option<Array<T, D>>,
45+
upper: Option<Array<T, D>>,
46+
q: f64,
47+
len: usize) -> Array<T, D>
48+
where D: Dimension;
4049
}
4150

4251
pub struct Upper;
4352
pub struct Lower;
4453
pub struct Nearest;
54+
pub struct Midpoint;
55+
pub struct Linear;
4556

4657
impl<T> Interpolate<T> for Upper {
4758
fn needs_lower(_q: f64, _len: usize) -> bool {
@@ -50,7 +61,10 @@ impl<T> Interpolate<T> for Upper {
5061
fn needs_upper(_q: f64, _len: usize) -> bool {
5162
true
5263
}
53-
fn interpolate(_lower: Option<T>, upper: Option<T>, _q: f64, _len: usize) -> T {
64+
fn interpolate<D>(_lower: Option<Array<T, D>>,
65+
upper: Option<Array<T, D>>,
66+
_q: f64,
67+
_len: usize) -> Array<T, D> {
5468
upper.unwrap()
5569
}
5670
}
@@ -62,7 +76,10 @@ impl<T> Interpolate<T> for Lower {
6276
fn needs_upper(_q: f64, _len: usize) -> bool {
6377
false
6478
}
65-
fn interpolate(lower: Option<T>, _upper: Option<T>, _q: f64, _len: usize) -> T {
79+
fn interpolate<D>(lower: Option<Array<T, D>>,
80+
_upper: Option<Array<T, D>>,
81+
_q: f64,
82+
_len: usize) -> Array<T, D> {
6683
lower.unwrap()
6784
}
6885
}
@@ -75,7 +92,10 @@ impl<T> Interpolate<T> for Nearest {
7592
fn needs_upper(q: f64, len: usize) -> bool {
7693
!<Self as Interpolate<T>>::needs_lower(q, len)
7794
}
78-
fn interpolate(lower: Option<T>, upper: Option<T>, q: f64, len: usize) -> T {
95+
fn interpolate<D>(lower: Option<Array<T, D>>,
96+
upper: Option<Array<T, D>>,
97+
q: f64,
98+
len: usize) -> Array<T, D> {
7999
if <Self as Interpolate<T>>::needs_lower(q, len) {
80100
lower.unwrap()
81101
} else {
@@ -84,6 +104,48 @@ impl<T> Interpolate<T> for Nearest {
84104
}
85105
}
86106

107+
impl<T> Interpolate<T> for Midpoint
108+
where T: Add<T, Output = T> + Div<T, Output = T> + Clone + FromPrimitive
109+
{
110+
fn needs_lower(_q: f64, _len: usize) -> bool {
111+
true
112+
}
113+
fn needs_upper(_q: f64, _len: usize) -> bool {
114+
true
115+
}
116+
fn interpolate<D>(lower: Option<Array<T, D>>,
117+
upper: Option<Array<T, D>>,
118+
_q: f64, _len: usize) -> Array<T, D>
119+
where D: Dimension
120+
{
121+
let denom = T::from_u8(2).unwrap();
122+
(lower.unwrap() + upper.unwrap()).mapv_into(|x| x / denom.clone())
123+
}
124+
}
125+
126+
impl<T> Interpolate<T> for Linear
127+
where T: Add<T, Output = T> + Sub<T, Output = T> + Mul<T, Output = T> + Clone + FromPrimitive
128+
{
129+
fn needs_lower(_q: f64, _len: usize) -> bool {
130+
true
131+
}
132+
fn needs_upper(_q: f64, _len: usize) -> bool {
133+
true
134+
}
135+
fn interpolate<D>(lower: Option<Array<T, D>>,
136+
upper: Option<Array<T, D>>,
137+
q: f64, len: usize) -> Array<T, D>
138+
where D: Dimension
139+
{
140+
let fraction = T::from_f64(
141+
<Self as Interpolate<T>>::float_percentile_index_fraction(q, len)
142+
).unwrap();
143+
let a = lower.unwrap().mapv_into(|x| x * fraction.clone());
144+
let b = upper.unwrap().mapv_into(|x| x * (T::from_u8(1).unwrap() - fraction.clone()));
145+
a + b
146+
}
147+
}
148+
87149
/// Numerical methods for arrays.
88150
impl<A, S, D> ArrayBase<S, D>
89151
where S: Data<Elem=A>,
@@ -187,8 +249,8 @@ impl<A, S, D> ArrayBase<S, D>
187249
/// as the element that would be indexed as `(N-1)q` if the lane were to be sorted
188250
/// in increasing order.
189251
/// If `(N-1)q` is not an integer the desired percentile lies between
190-
/// two data points: we return the lower, nearest or higher datapoint depending
191-
/// on `interpolation_strategy`.
252+
/// two data points: we return the lower, nearest, higher or interpolated
253+
/// value depending on the type `Interpolate` bound `I`.
192254
///
193255
/// Some examples:
194256
/// - `q=0.` returns the minimum along each 1-dimensional lane;
@@ -214,7 +276,7 @@ impl<A, S, D> ArrayBase<S, D>
214276
where D: RemoveAxis,
215277
A: Ord + Clone + Zero,
216278
S: DataMut,
217-
I: Interpolate<Array<A, D::Smaller>>,
279+
I: Interpolate<A>,
218280
{
219281
assert!((0. <= q) && (q <= 1.));
220282
let mut lower = None;

0 commit comments

Comments
 (0)