@@ -19,13 +19,69 @@ use {
19
19
Zip ,
20
20
} ;
21
21
22
- /// Used to choose the interpolation strategy in [`percentile_axis_mut`].
22
+ /// Used to provide an interpolation strategy to [`percentile_axis_mut`].
23
23
///
24
24
/// [`percentile_axis_mut`]: struct.ArrayBase.html#method.percentile_axis_mut
25
- pub enum InterpolationStrategy {
26
- Lower ,
27
- Nearest ,
28
- Higher ,
25
+ pub trait Interpolate < T > {
26
+ fn float_percentile_index ( q : f64 , len : usize ) -> f64 {
27
+ ( ( len - 1 ) as f64 ) * q
28
+ }
29
+
30
+ fn lower_index ( q : f64 , len : usize ) -> usize {
31
+ Self :: float_percentile_index ( q, len) . floor ( ) as usize
32
+ }
33
+
34
+ fn upper_index ( q : f64 , len : usize ) -> usize {
35
+ Self :: float_percentile_index ( q, len) . ceil ( ) as usize
36
+ }
37
+ fn needs_lower ( q : f64 , len : usize ) -> bool ;
38
+ fn needs_upper ( q : f64 , len : usize ) -> bool ;
39
+ fn interpolate ( lower : Option < T > , upper : Option < T > , q : f64 , len : usize ) -> T ;
40
+ }
41
+
42
+ pub struct Upper ;
43
+ pub struct Lower ;
44
+ pub struct Nearest ;
45
+
46
+ impl < T > Interpolate < T > for Upper {
47
+ fn needs_lower ( _q : f64 , _len : usize ) -> bool {
48
+ false
49
+ }
50
+ fn needs_upper ( _q : f64 , _len : usize ) -> bool {
51
+ true
52
+ }
53
+ fn interpolate ( _lower : Option < T > , upper : Option < T > , _q : f64 , _len : usize ) -> T {
54
+ upper. unwrap ( )
55
+ }
56
+ }
57
+
58
+ impl < T > Interpolate < T > for Lower {
59
+ fn needs_lower ( _q : f64 , _len : usize ) -> bool {
60
+ true
61
+ }
62
+ fn needs_upper ( _q : f64 , _len : usize ) -> bool {
63
+ false
64
+ }
65
+ fn interpolate ( lower : Option < T > , _upper : Option < T > , _q : f64 , _len : usize ) -> T {
66
+ lower. unwrap ( )
67
+ }
68
+ }
69
+
70
+ impl < T > Interpolate < T > for Nearest {
71
+ fn needs_lower ( q : f64 , len : usize ) -> bool {
72
+ let lower = <Self as Interpolate < T > >:: lower_index ( q, len) ;
73
+ ( ( lower as f64 ) - <Self as Interpolate < T > >:: float_percentile_index ( q, len) ) <= 0.
74
+ }
75
+ fn needs_upper ( q : f64 , len : usize ) -> bool {
76
+ !<Self as Interpolate < T > >:: needs_lower ( q, len)
77
+ }
78
+ fn interpolate ( lower : Option < T > , upper : Option < T > , q : f64 , len : usize ) -> T {
79
+ if <Self as Interpolate < T > >:: needs_lower ( q, len) {
80
+ lower. unwrap ( )
81
+ } else {
82
+ upper. unwrap ( )
83
+ }
84
+ }
29
85
}
30
86
31
87
/// Numerical methods for arrays.
@@ -154,19 +210,33 @@ impl<A, S, D> ArrayBase<S, D>
154
210
///
155
211
/// **Panics** if `axis` is out of bounds or if `q` is not between
156
212
/// `0.` and `1.` (inclusive).
157
- pub fn percentile_axis_mut ( & mut self , axis : Axis , q : f64 , interpolation_strategy : InterpolationStrategy ) -> Array < A , D :: Smaller >
213
+ pub fn percentile_axis_mut < I > ( & mut self , axis : Axis , q : f64 ) -> Array < A , D :: Smaller >
158
214
where D : RemoveAxis ,
159
215
A : Ord + Clone + Zero ,
160
216
S : DataMut ,
217
+ I : Interpolate < Array < A , D :: Smaller > > ,
161
218
{
162
219
assert ! ( ( 0. <= q) && ( q <= 1. ) ) ;
163
- let float_percentile_index = ( ( self . len_of ( axis) - 1 ) as f64 ) * q;
164
- let percentile_index = match interpolation_strategy {
165
- InterpolationStrategy :: Lower => float_percentile_index. floor ( ) as usize ,
166
- InterpolationStrategy :: Nearest => float_percentile_index. round ( ) as usize ,
167
- InterpolationStrategy :: Higher => float_percentile_index. ceil ( ) as usize ,
220
+ let mut lower = None ;
221
+ let mut upper = None ;
222
+ let axis_len = self . len_of ( axis) ;
223
+ if I :: needs_lower ( q, axis_len) {
224
+ lower = Some (
225
+ self . map_axis_mut (
226
+ axis,
227
+ |mut x| x. sorted_get_mut ( I :: lower_index ( q, axis_len) )
228
+ )
229
+ ) ;
230
+ } ;
231
+ if I :: needs_upper ( q, axis_len) {
232
+ upper = Some (
233
+ self . map_axis_mut (
234
+ axis,
235
+ |mut x| x. sorted_get_mut ( I :: upper_index ( q, axis_len) )
236
+ )
237
+ ) ;
168
238
} ;
169
- self . map_axis_mut ( axis , | mut x| x . sorted_get_mut ( percentile_index ) )
239
+ I :: interpolate ( lower , upper , q , axis_len )
170
240
}
171
241
172
242
/// Return variance along `axis`.
0 commit comments