Skip to content

Commit 4deeff2

Browse files
committed
Introduce trait Weight
1 parent e5e5b1f commit 4deeff2

File tree

3 files changed

+64
-28
lines changed

3 files changed

+64
-28
lines changed

src/distributions/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ pub use self::slice::Slice;
126126
#[doc(inline)]
127127
pub use self::uniform::Uniform;
128128
#[cfg(feature = "alloc")]
129-
pub use self::weighted_index::{WeightedError, WeightedIndex};
129+
pub use self::weighted_index::{Weight, WeightedError, WeightedIndex};
130130

131131
#[allow(unused)]
132132
use crate::Rng;

src/distributions/weighted_index.rs

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,12 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
9999
where
100100
I: IntoIterator,
101101
I::Item: SampleBorrow<X>,
102-
X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
102+
X: Weight,
103103
{
104104
let mut iter = weights.into_iter();
105105
let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
106106

107-
let zero = <X as Default>::default();
107+
let zero = X::ZERO;
108108
if !(total_weight >= zero) {
109109
return Err(WeightedError::InvalidWeight);
110110
}
@@ -118,8 +118,7 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
118118
}
119119
weights.push(total_weight.clone());
120120

121-
total_weight += w.borrow();
122-
if total_weight < *w.borrow() {
121+
if let Err(()) = total_weight.checked_add_assign(w.borrow()) {
123122
return Err(WeightedError::Overflow);
124123
}
125124
}
@@ -240,6 +239,60 @@ where X: SampleUniform + PartialOrd
240239
}
241240
}
242241

242+
/// Bounds on a weight
243+
///
244+
/// See usage in [`WeightedIndex`].
245+
pub trait Weight: Clone {
246+
/// Representation of 0
247+
const ZERO: Self;
248+
249+
/// Checked addition
250+
///
251+
/// - `Result::Ok`: On success, `v` is added to `self`
252+
/// - `Result::Err`: Returns an error when `Self` cannot represent the
253+
/// result of `self + v` (i.e. overflow). The value of `self` should be
254+
/// discarded.
255+
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>;
256+
}
257+
258+
macro_rules! impl_weight_int {
259+
($t:ty) => {
260+
impl Weight for $t {
261+
const ZERO: Self = 0;
262+
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
263+
match self.checked_add(*v) {
264+
Some(sum) => {
265+
*self = sum;
266+
Ok(())
267+
}
268+
None => Err(()),
269+
}
270+
}
271+
}
272+
};
273+
($t:ty, $($tt:ty),*) => {
274+
impl_weight_int!($t);
275+
impl_weight_int!($($tt),*);
276+
}
277+
}
278+
impl_weight_int!(i8, i16, i32, i64, i128, isize);
279+
impl_weight_int!(u8, u16, u32, u64, u128, usize);
280+
281+
macro_rules! impl_weight_float {
282+
($t:ty) => {
283+
impl Weight for $t {
284+
const ZERO: Self = 0.0;
285+
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
286+
// Floats have an explicit representation for overflow
287+
*self += *v;
288+
Ok(())
289+
}
290+
}
291+
}
292+
}
293+
impl_weight_float!(f32);
294+
impl_weight_float!(f64);
295+
243296
#[cfg(test)]
244297
mod test {
245298
use super::*;
@@ -392,12 +445,11 @@ mod test {
392445

393446
#[test]
394447
fn value_stability() {
395-
fn test_samples<X: SampleUniform + PartialOrd, I>(
448+
fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(
396449
weights: I, buf: &mut [usize], expected: &[usize],
397450
) where
398451
I: IntoIterator,
399452
I::Item: SampleBorrow<X>,
400-
X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
401453
{
402454
assert_eq!(buf.len(), expected.len());
403455
let distr = WeightedIndex::new(weights).unwrap();

src/seq/mod.rs

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ use alloc::vec::Vec;
4040
#[cfg(feature = "alloc")]
4141
use crate::distributions::uniform::{SampleBorrow, SampleUniform};
4242
#[cfg(feature = "alloc")]
43-
use crate::distributions::WeightedError;
43+
use crate::distributions::{Weight, WeightedError};
4444
use crate::Rng;
4545

4646
use self::coin_flipper::CoinFlipper;
@@ -170,11 +170,7 @@ pub trait SliceRandom {
170170
R: Rng + ?Sized,
171171
F: Fn(&Self::Item) -> B,
172172
B: SampleBorrow<X>,
173-
X: SampleUniform
174-
+ for<'a> ::core::ops::AddAssign<&'a X>
175-
+ ::core::cmp::PartialOrd<X>
176-
+ Clone
177-
+ Default;
173+
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>;
178174

179175
/// Biased sampling for one element (mut)
180176
///
@@ -203,11 +199,7 @@ pub trait SliceRandom {
203199
R: Rng + ?Sized,
204200
F: Fn(&Self::Item) -> B,
205201
B: SampleBorrow<X>,
206-
X: SampleUniform
207-
+ for<'a> ::core::ops::AddAssign<&'a X>
208-
+ ::core::cmp::PartialOrd<X>
209-
+ Clone
210-
+ Default;
202+
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>;
211203

212204
/// Biased sampling of `amount` distinct elements
213205
///
@@ -585,11 +577,7 @@ impl<T> SliceRandom for [T] {
585577
R: Rng + ?Sized,
586578
F: Fn(&Self::Item) -> B,
587579
B: SampleBorrow<X>,
588-
X: SampleUniform
589-
+ for<'a> ::core::ops::AddAssign<&'a X>
590-
+ ::core::cmp::PartialOrd<X>
591-
+ Clone
592-
+ Default,
580+
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>,
593581
{
594582
use crate::distributions::{Distribution, WeightedIndex};
595583
let distr = WeightedIndex::new(self.iter().map(weight))?;
@@ -604,11 +592,7 @@ impl<T> SliceRandom for [T] {
604592
R: Rng + ?Sized,
605593
F: Fn(&Self::Item) -> B,
606594
B: SampleBorrow<X>,
607-
X: SampleUniform
608-
+ for<'a> ::core::ops::AddAssign<&'a X>
609-
+ ::core::cmp::PartialOrd<X>
610-
+ Clone
611-
+ Default,
595+
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>,
612596
{
613597
use crate::distributions::{Distribution, WeightedIndex};
614598
let distr = WeightedIndex::new(self.iter().map(weight))?;

0 commit comments

Comments
 (0)