Skip to content

Commit 6f16dd8

Browse files
munckymagikLukeMathWalker
authored andcommitted
Add deviation functions (#41)
* Port deviation functions from StatsBase.jl * SQUASH return type doc fixes * SQUASH try parenthesis to highlight the square root * SQUASH fix copy and paste error in docs * SQUASH add link from package summary to DeviationExt
1 parent 6f898f6 commit 6f16dd8

File tree

5 files changed

+635
-2
lines changed

5 files changed

+635
-2
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ quickcheck = { version = "0.8.1", default-features = false }
3030
ndarray-rand = "0.9"
3131
approx = "0.3"
3232
quickcheck_macros = "0.8"
33+
num-bigint = "0.2.2"
3334

3435
[[bench]]
3536
name = "sort"

src/deviation.rs

+376
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
use ndarray::{ArrayBase, Data, Dimension, Zip};
2+
use num_traits::{Signed, ToPrimitive};
3+
use std::convert::Into;
4+
use std::ops::AddAssign;
5+
6+
use crate::errors::{MultiInputError, ShapeMismatch};
7+
8+
/// An extension trait for `ArrayBase` providing functions
9+
/// to compute different deviation measures.
10+
pub trait DeviationExt<A, S, D>
11+
where
12+
S: Data<Elem = A>,
13+
D: Dimension,
14+
{
15+
/// Counts the number of indices at which the elements of the arrays `self`
16+
/// and `other` are equal.
17+
///
18+
/// The following **errors** may be returned:
19+
///
20+
/// * `MultiInputError::EmptyInput` if `self` is empty
21+
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
22+
fn count_eq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
23+
where
24+
A: PartialEq;
25+
26+
/// Counts the number of indices at which the elements of the arrays `self`
27+
/// and `other` are not equal.
28+
///
29+
/// The following **errors** may be returned:
30+
///
31+
/// * `MultiInputError::EmptyInput` if `self` is empty
32+
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
33+
fn count_neq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
34+
where
35+
A: PartialEq;
36+
37+
/// Computes the [squared L2 distance] between `self` and `other`.
38+
///
39+
/// ```text
40+
/// n
41+
/// ∑ |aᵢ - bᵢ|²
42+
/// i=1
43+
/// ```
44+
///
45+
/// where `self` is `a` and `other` is `b`.
46+
///
47+
/// The following **errors** may be returned:
48+
///
49+
/// * `MultiInputError::EmptyInput` if `self` is empty
50+
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
51+
///
52+
/// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance
53+
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
54+
where
55+
A: AddAssign + Clone + Signed;
56+
57+
/// Computes the [L2 distance] between `self` and `other`.
58+
///
59+
/// ```text
60+
/// n
61+
/// √ ( ∑ |aᵢ - bᵢ|² )
62+
/// i=1
63+
/// ```
64+
///
65+
/// where `self` is `a` and `other` is `b`.
66+
///
67+
/// The following **errors** may be returned:
68+
///
69+
/// * `MultiInputError::EmptyInput` if `self` is empty
70+
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
71+
///
72+
/// **Panics** if the type cast from `A` to `f64` fails.
73+
///
74+
/// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance
75+
fn l2_dist(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
76+
where
77+
A: AddAssign + Clone + Signed + ToPrimitive;
78+
79+
/// Computes the [L1 distance] between `self` and `other`.
80+
///
81+
/// ```text
82+
/// n
83+
/// ∑ |aᵢ - bᵢ|
84+
/// i=1
85+
/// ```
86+
///
87+
/// where `self` is `a` and `other` is `b`.
88+
///
89+
/// The following **errors** may be returned:
90+
///
91+
/// * `MultiInputError::EmptyInput` if `self` is empty
92+
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
93+
///
94+
/// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry
95+
fn l1_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
96+
where
97+
A: AddAssign + Clone + Signed;
98+
99+
/// Computes the [L∞ distance] between `self` and `other`.
100+
///
101+
/// ```text
102+
/// max(|aᵢ - bᵢ|)
103+
/// ᵢ
104+
/// ```
105+
///
106+
/// where `self` is `a` and `other` is `b`.
107+
///
108+
/// The following **errors** may be returned:
109+
///
110+
/// * `MultiInputError::EmptyInput` if `self` is empty
111+
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
112+
///
113+
/// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance
114+
fn linf_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
115+
where
116+
A: Clone + PartialOrd + Signed;
117+
118+
/// Computes the [mean absolute error] between `self` and `other`.
119+
///
120+
/// ```text
121+
/// n
122+
/// 1/n * ∑ |aᵢ - bᵢ|
123+
/// i=1
124+
/// ```
125+
///
126+
/// where `self` is `a` and `other` is `b`.
127+
///
128+
/// The following **errors** may be returned:
129+
///
130+
/// * `MultiInputError::EmptyInput` if `self` is empty
131+
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
132+
///
133+
/// **Panics** if the type cast from `A` to `f64` fails.
134+
///
135+
/// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error
136+
fn mean_abs_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
137+
where
138+
A: AddAssign + Clone + Signed + ToPrimitive;
139+
140+
/// Computes the [mean squared error] between `self` and `other`.
141+
///
142+
/// ```text
143+
/// n
144+
/// 1/n * ∑ |aᵢ - bᵢ|²
145+
/// i=1
146+
/// ```
147+
///
148+
/// where `self` is `a` and `other` is `b`.
149+
///
150+
/// The following **errors** may be returned:
151+
///
152+
/// * `MultiInputError::EmptyInput` if `self` is empty
153+
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
154+
///
155+
/// **Panics** if the type cast from `A` to `f64` fails.
156+
///
157+
/// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error
158+
fn mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
159+
where
160+
A: AddAssign + Clone + Signed + ToPrimitive;
161+
162+
/// Computes the unnormalized [root-mean-square error] between `self` and `other`.
163+
///
164+
/// ```text
165+
/// √ mse(a, b)
166+
/// ```
167+
///
168+
/// where `self` is `a`, `other` is `b` and `mse` is the mean-squared-error.
169+
///
170+
/// The following **errors** may be returned:
171+
///
172+
/// * `MultiInputError::EmptyInput` if `self` is empty
173+
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
174+
///
175+
/// **Panics** if the type cast from `A` to `f64` fails.
176+
///
177+
/// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation
178+
fn root_mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
179+
where
180+
A: AddAssign + Clone + Signed + ToPrimitive;
181+
182+
/// Computes the [peak signal-to-noise ratio] between `self` and `other`.
183+
///
184+
/// ```text
185+
/// 10 * log10(maxv^2 / mse(a, b))
186+
/// ```
187+
///
188+
/// where `self` is `a`, `other` is `b`, `mse` is the mean-squared-error
189+
/// and `maxv` is the maximum possible value either array can take.
190+
///
191+
/// The following **errors** may be returned:
192+
///
193+
/// * `MultiInputError::EmptyInput` if `self` is empty
194+
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
195+
///
196+
/// **Panics** if the type cast from `A` to `f64` fails.
197+
///
198+
/// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
199+
fn peak_signal_to_noise_ratio(
200+
&self,
201+
other: &ArrayBase<S, D>,
202+
maxv: A,
203+
) -> Result<f64, MultiInputError>
204+
where
205+
A: AddAssign + Clone + Signed + ToPrimitive;
206+
207+
private_decl! {}
208+
}
209+
210+
macro_rules! return_err_if_empty {
211+
($arr:expr) => {
212+
if $arr.len() == 0 {
213+
return Err(MultiInputError::EmptyInput);
214+
}
215+
};
216+
}
217+
macro_rules! return_err_unless_same_shape {
218+
($arr_a:expr, $arr_b:expr) => {
219+
if $arr_a.shape() != $arr_b.shape() {
220+
return Err(ShapeMismatch {
221+
first_shape: $arr_a.shape().to_vec(),
222+
second_shape: $arr_b.shape().to_vec(),
223+
}
224+
.into());
225+
}
226+
};
227+
}
228+
229+
impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D>
230+
where
231+
S: Data<Elem = A>,
232+
D: Dimension,
233+
{
234+
fn count_eq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
235+
where
236+
A: PartialEq,
237+
{
238+
return_err_if_empty!(self);
239+
return_err_unless_same_shape!(self, other);
240+
241+
let mut count = 0;
242+
243+
Zip::from(self).and(other).apply(|a, b| {
244+
if a == b {
245+
count += 1;
246+
}
247+
});
248+
249+
Ok(count)
250+
}
251+
252+
fn count_neq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
253+
where
254+
A: PartialEq,
255+
{
256+
self.count_eq(other).map(|n_eq| self.len() - n_eq)
257+
}
258+
259+
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
260+
where
261+
A: AddAssign + Clone + Signed,
262+
{
263+
return_err_if_empty!(self);
264+
return_err_unless_same_shape!(self, other);
265+
266+
let mut result = A::zero();
267+
268+
Zip::from(self).and(other).apply(|self_i, other_i| {
269+
let (a, b) = (self_i.clone(), other_i.clone());
270+
let abs_diff = (a - b).abs();
271+
result += abs_diff.clone() * abs_diff;
272+
});
273+
274+
Ok(result)
275+
}
276+
277+
fn l2_dist(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
278+
where
279+
A: AddAssign + Clone + Signed + ToPrimitive,
280+
{
281+
let sq_l2_dist = self
282+
.sq_l2_dist(other)?
283+
.to_f64()
284+
.expect("failed cast from type A to f64");
285+
286+
Ok(sq_l2_dist.sqrt())
287+
}
288+
289+
fn l1_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
290+
where
291+
A: AddAssign + Clone + Signed,
292+
{
293+
return_err_if_empty!(self);
294+
return_err_unless_same_shape!(self, other);
295+
296+
let mut result = A::zero();
297+
298+
Zip::from(self).and(other).apply(|self_i, other_i| {
299+
let (a, b) = (self_i.clone(), other_i.clone());
300+
result += (a - b).abs();
301+
});
302+
303+
Ok(result)
304+
}
305+
306+
fn linf_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
307+
where
308+
A: Clone + PartialOrd + Signed,
309+
{
310+
return_err_if_empty!(self);
311+
return_err_unless_same_shape!(self, other);
312+
313+
let mut max = A::zero();
314+
315+
Zip::from(self).and(other).apply(|self_i, other_i| {
316+
let (a, b) = (self_i.clone(), other_i.clone());
317+
let diff = (a - b).abs();
318+
if diff > max {
319+
max = diff;
320+
}
321+
});
322+
323+
Ok(max)
324+
}
325+
326+
fn mean_abs_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
327+
where
328+
A: AddAssign + Clone + Signed + ToPrimitive,
329+
{
330+
let l1_dist = self
331+
.l1_dist(other)?
332+
.to_f64()
333+
.expect("failed cast from type A to f64");
334+
let n = self.len() as f64;
335+
336+
Ok(l1_dist / n)
337+
}
338+
339+
fn mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
340+
where
341+
A: AddAssign + Clone + Signed + ToPrimitive,
342+
{
343+
let sq_l2_dist = self
344+
.sq_l2_dist(other)?
345+
.to_f64()
346+
.expect("failed cast from type A to f64");
347+
let n = self.len() as f64;
348+
349+
Ok(sq_l2_dist / n)
350+
}
351+
352+
fn root_mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
353+
where
354+
A: AddAssign + Clone + Signed + ToPrimitive,
355+
{
356+
let msd = self.mean_sq_err(other)?;
357+
Ok(msd.sqrt())
358+
}
359+
360+
fn peak_signal_to_noise_ratio(
361+
&self,
362+
other: &ArrayBase<S, D>,
363+
maxv: A,
364+
) -> Result<f64, MultiInputError>
365+
where
366+
A: AddAssign + Clone + Signed + ToPrimitive,
367+
{
368+
let maxv_f = maxv.to_f64().expect("failed cast from type A to f64");
369+
let msd = self.mean_sq_err(&other)?;
370+
let psnr = 10. * f64::log10(maxv_f * maxv_f / msd);
371+
372+
Ok(psnr)
373+
}
374+
375+
private_impl! {}
376+
}

0 commit comments

Comments
 (0)