@@ -99,12 +99,12 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
9999 where
100100 I : IntoIterator ,
101101 I :: Item : SampleBorrow < X > ,
102- X : for < ' a > :: core :: ops :: AddAssign < & ' a X > + Clone + Default ,
102+ X : Weight ,
103103 {
104104 let mut iter = weights. into_iter ( ) ;
105105 let mut total_weight: X = iter. next ( ) . ok_or ( WeightedError :: NoItem ) ?. borrow ( ) . clone ( ) ;
106106
107- let zero = < X as Default > :: default ( ) ;
107+ let zero = X :: ZERO ;
108108 if !( total_weight >= zero) {
109109 return Err ( WeightedError :: InvalidWeight ) ;
110110 }
@@ -118,8 +118,7 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
118118 }
119119 weights. push ( total_weight. clone ( ) ) ;
120120
121- total_weight += w. borrow ( ) ;
122- if total_weight < * w. borrow ( ) {
121+ if let Err ( ( ) ) = total_weight. checked_add_assign ( w. borrow ( ) ) {
123122 return Err ( WeightedError :: Overflow ) ;
124123 }
125124 }
@@ -240,6 +239,60 @@ where X: SampleUniform + PartialOrd
240239 }
241240}
242241
242+ /// Bounds on a weight
243+ ///
244+ /// See usage in [`WeightedIndex`].
245+ pub trait Weight : Clone {
246+ /// Representation of 0
247+ const ZERO : Self ;
248+
249+ /// Checked addition
250+ ///
251+ /// - `Result::Ok`: On success, `v` is added to `self`
252+ /// - `Result::Err`: Returns an error when `Self` cannot represent the
253+ /// result of `self + v` (i.e. overflow). The value of `self` should be
254+ /// discarded.
255+ fn checked_add_assign ( & mut self , v : & Self ) -> Result < ( ) , ( ) > ;
256+ }
257+
258+ macro_rules! impl_weight_int {
259+ ( $t: ty) => {
260+ impl Weight for $t {
261+ const ZERO : Self = 0 ;
262+ fn checked_add_assign( & mut self , v: & Self ) -> Result <( ) , ( ) > {
263+ match self . checked_add( * v) {
264+ Some ( sum) => {
265+ * self = sum;
266+ Ok ( ( ) )
267+ }
268+ None => Err ( ( ) ) ,
269+ }
270+ }
271+ }
272+ } ;
273+ ( $t: ty, $( $tt: ty) ,* ) => {
274+ impl_weight_int!( $t) ;
275+ impl_weight_int!( $( $tt) ,* ) ;
276+ }
277+ }
278+ impl_weight_int ! ( i8 , i16 , i32 , i64 , i128 , isize ) ;
279+ impl_weight_int ! ( u8 , u16 , u32 , u64 , u128 , usize ) ;
280+
281+ macro_rules! impl_weight_float {
282+ ( $t: ty) => {
283+ impl Weight for $t {
284+ const ZERO : Self = 0.0 ;
285+ fn checked_add_assign( & mut self , v: & Self ) -> Result <( ) , ( ) > {
286+ // Floats have an explicit representation for overflow
287+ * self += * v;
288+ Ok ( ( ) )
289+ }
290+ }
291+ }
292+ }
293+ impl_weight_float ! ( f32 ) ;
294+ impl_weight_float ! ( f64 ) ;
295+
243296#[ cfg( test) ]
244297mod test {
245298 use super :: * ;
@@ -392,12 +445,11 @@ mod test {
392445
393446 #[ test]
394447 fn value_stability ( ) {
395- fn test_samples < X : SampleUniform + PartialOrd , I > (
448+ fn test_samples < X : Weight + SampleUniform + PartialOrd , I > (
396449 weights : I , buf : & mut [ usize ] , expected : & [ usize ] ,
397450 ) where
398451 I : IntoIterator ,
399452 I :: Item : SampleBorrow < X > ,
400- X : for < ' a > :: core:: ops:: AddAssign < & ' a X > + Clone + Default ,
401453 {
402454 assert_eq ! ( buf. len( ) , expected. len( ) ) ;
403455 let distr = WeightedIndex :: new ( weights) . unwrap ( ) ;
0 commit comments