Skip to content

Commit 984cf22

Browse files
committed
Implement approx traits for ArrayBase
1 parent 03552e2 commit 984cf22

File tree

4 files changed

+147
-0
lines changed

4 files changed

+147
-0
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ bench = false
2828
test = true
2929

3030
[dependencies]
31+
approx = "0.3"
3132
num-integer = "0.1.39"
3233
num-traits = "0.2"
3334
num-complex = "0.2"

src/arraytraits.rs

+86
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,92 @@ impl<S, D> Eq for ArrayBase<S, D>
113113
S::Elem: Eq,
114114
{ }
115115

116+
impl<A, S, D> approx::AbsDiffEq for ArrayBase<S, D>
117+
where
118+
A: approx::AbsDiffEq,
119+
A::Epsilon: Clone,
120+
S: Data<Elem = A>,
121+
D: Dimension,
122+
{
123+
type Epsilon = A::Epsilon;
124+
125+
fn default_epsilon() -> A::Epsilon {
126+
A::default_epsilon()
127+
}
128+
129+
fn abs_diff_eq(&self, other: &ArrayBase<S, D>, epsilon: A::Epsilon) -> bool {
130+
if self.shape() != other.shape() {
131+
return false;
132+
}
133+
Zip::from(self)
134+
.and(other)
135+
.fold_while(true, |_, a, b| {
136+
if A::abs_diff_ne(a, b, epsilon.clone()) {
137+
FoldWhile::Done(false)
138+
} else {
139+
FoldWhile::Continue(true)
140+
}
141+
})
142+
.into_inner()
143+
}
144+
}
145+
146+
impl<A, S, D> approx::RelativeEq for ArrayBase<S, D>
147+
where
148+
A: approx::RelativeEq,
149+
A::Epsilon: Clone,
150+
S: Data<Elem = A>,
151+
D: Dimension,
152+
{
153+
fn default_max_relative() -> A::Epsilon {
154+
A::default_max_relative()
155+
}
156+
157+
fn relative_eq(&self, other: &ArrayBase<S, D>, epsilon: A::Epsilon, max_relative: A::Epsilon) -> bool {
158+
if self.shape() != other.shape() {
159+
return false;
160+
}
161+
Zip::from(self)
162+
.and(other)
163+
.fold_while(true, |_, a, b| {
164+
if A::relative_ne(a, b, epsilon.clone(), max_relative.clone()) {
165+
FoldWhile::Done(false)
166+
} else {
167+
FoldWhile::Continue(true)
168+
}
169+
})
170+
.into_inner()
171+
}
172+
}
173+
174+
impl<A, S, D> approx::UlpsEq for ArrayBase<S, D>
175+
where
176+
A: approx::UlpsEq,
177+
A::Epsilon: Clone,
178+
S: Data<Elem = A>,
179+
D: Dimension,
180+
{
181+
fn default_max_ulps() -> u32 {
182+
A::default_max_ulps()
183+
}
184+
185+
fn ulps_eq(&self, other: &ArrayBase<S, D>, epsilon: A::Epsilon, max_ulps: u32) -> bool {
186+
if self.shape() != other.shape() {
187+
return false;
188+
}
189+
Zip::from(self)
190+
.and(other)
191+
.fold_while(true, |_, a, b| {
192+
if A::ulps_ne(a, b, epsilon.clone(), max_ulps) {
193+
FoldWhile::Done(false)
194+
} else {
195+
FoldWhile::Continue(true)
196+
}
197+
})
198+
.into_inner()
199+
}
200+
}
201+
116202
impl<A, S> FromIterator<A> for ArrayBase<S, Ix1>
117203
where S: DataOwned<Elem=A>
118204
{

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ extern crate blas_src;
9797

9898
extern crate matrixmultiply;
9999

100+
extern crate approx;
100101
extern crate itertools;
101102
extern crate num_traits;
102103
extern crate num_complex;

tests/array.rs

+59
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![allow(non_snake_case)]
22

3+
extern crate approx;
34
extern crate ndarray;
45
extern crate defmac;
56
extern crate itertools;
@@ -12,6 +13,10 @@ use ndarray::{
1213
multislice,
1314
};
1415
use ndarray::indices;
16+
use approx::{
17+
assert_abs_diff_eq, assert_abs_diff_ne, assert_relative_eq, assert_relative_ne, assert_ulps_eq,
18+
assert_ulps_ne,
19+
};
1520
use defmac::defmac;
1621
use itertools::{enumerate, zip};
1722

@@ -1163,6 +1168,60 @@ fn equality()
11631168
assert!(a != c);
11641169
}
11651170

1171+
#[test]
1172+
fn abs_diff_eq()
1173+
{
1174+
let a: Array2<f32> = array![[0., 2.], [-0.000010001, 100000000.]];
1175+
let mut b: Array2<f32> = array![[0., 1.], [-0.000010002, 100000001.]];
1176+
assert_abs_diff_ne!(a, b);
1177+
b[(0, 1)] = 2.;
1178+
assert_abs_diff_eq!(a, b);
1179+
1180+
// Check epsilon.
1181+
assert_abs_diff_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32);
1182+
assert_abs_diff_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32);
1183+
1184+
// Make sure we can compare different shapes without failure.
1185+
let c = array![[1., 2.]];
1186+
assert_abs_diff_ne!(a, c);
1187+
}
1188+
1189+
#[test]
1190+
fn relative_eq()
1191+
{
1192+
let a: Array2<f32> = array![[1., 2.], [-0.000010001, 100000000.]];
1193+
let mut b: Array2<f32> = array![[1., 1.], [-0.000010002, 100000001.]];
1194+
assert_relative_ne!(a, b);
1195+
b[(0, 1)] = 2.;
1196+
assert_relative_eq!(a, b);
1197+
1198+
// Check epsilon.
1199+
assert_relative_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32);
1200+
assert_relative_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32);
1201+
1202+
// Make sure we can compare different shapes without failure.
1203+
let c = array![[1., 2.]];
1204+
assert_relative_ne!(a, c);
1205+
}
1206+
1207+
#[test]
1208+
fn ulps_eq()
1209+
{
1210+
let a: Array2<f32> = array![[1., 2.], [-0.000010001, 100000000.]];
1211+
let mut b: Array2<f32> = array![[1., 1.], [-0.000010002, 100000001.]];
1212+
assert_ulps_ne!(a, b);
1213+
b[(0, 1)] = 2.;
1214+
assert_ulps_eq!(a, b);
1215+
1216+
// Check epsilon.
1217+
assert_ulps_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32);
1218+
assert_ulps_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32);
1219+
1220+
// Make sure we can compare different shapes without failure.
1221+
let c = array![[1., 2.]];
1222+
assert_ulps_ne!(a, c);
1223+
}
1224+
11661225
#[test]
11671226
fn map1()
11681227
{

0 commit comments

Comments
 (0)