diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md index 51bde39e86b..93756eb7055 100644 --- a/rand_distr/CHANGELOG.md +++ b/rand_distr/CHANGELOG.md @@ -6,6 +6,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased - The `serde1` feature has been renamed `serde` (#1477) +- Fix panic in Binomial (#1484) +- Move some of the computations in Binomial from `sample` to `new` (#1484) ### Added - Add plots for `rand_distr` distributions to documentation (#1434) diff --git a/rand_distr/src/binomial.rs b/rand_distr/src/binomial.rs index 885d8b21c3f..3ee0f447b42 100644 --- a/rand_distr/src/binomial.rs +++ b/rand_distr/src/binomial.rs @@ -26,10 +26,6 @@ use rand::Rng; /// /// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`. /// -/// # Known issues -/// -/// See documentation of [`Binomial::new`]. -/// /// # Plot /// /// The following plot of the binomial distribution illustrates the @@ -50,10 +46,34 @@ use rand::Rng; #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Binomial { - /// Number of trials. + method: Method, +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +enum Method { + Binv(Binv, bool), + Btpe(Btpe, bool), + Poisson(crate::poisson::KnuthMethod), + Constant(u64), +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +struct Binv { + r: f64, + s: f64, + a: f64, + n: u64, +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +struct Btpe { n: u64, - /// Probability of success. p: f64, + m: i64, + p1: f64, } /// Error type returned from [`Binomial::new`]. @@ -82,13 +102,6 @@ impl std::error::Error for Error {} impl Binomial { /// Construct a new `Binomial` with the given shape parameters `n` (number /// of trials) and `p` (probability of success). - /// - /// # Known issues - /// - /// Although this method should return an [`Error`] on invalid parameters, - /// some (extreme) parameter combinations are known to return a [`Binomial`] - /// object which panics when [sampled](Distribution::sample). - /// See [#1378](https://github.com/rust-random/rand/issues/1378). pub fn new(n: u64, p: f64) -> Result { if !(p >= 0.0) { return Err(Error::ProbabilityTooSmall); @@ -96,33 +109,22 @@ impl Binomial { if !(p <= 1.0) { return Err(Error::ProbabilityTooLarge); } - Ok(Binomial { n, p }) - } -} - -/// Convert a `f64` to an `i64`, panicking on overflow. -fn f64_to_i64(x: f64) -> i64 { - assert!(x < (i64::MAX as f64)); - x as i64 -} -impl Distribution for Binomial { - #[allow(clippy::many_single_char_names)] // Same names as in the reference. - fn sample(&self, rng: &mut R) -> u64 { - // Handle these values directly. - if self.p == 0.0 { - return 0; - } else if self.p == 1.0 { - return self.n; + if p == 0.0 { + return Ok(Binomial { + method: Method::Constant(0), + }); } - // The binomial distribution is symmetrical with respect to p -> 1-p, - // k -> n-k switch p so that it is less than 0.5 - this allows for lower - // expected values we will just invert the result at the end - let p = if self.p <= 0.5 { self.p } else { 1.0 - self.p }; + if p == 1.0 { + return Ok(Binomial { + method: Method::Constant(n), + }); + } - let result; - let q = 1. - p; + // The binomial distribution is symmetrical with respect to p -> 1-p + let flipped = p > 0.5; + let p = if flipped { 1.0 - p } else { p }; // For small n * min(p, 1 - p), the BINV algorithm based on the inverse // transformation of the binomial distribution is efficient. Otherwise, @@ -136,204 +138,253 @@ impl Distribution for Binomial { // Ranlib uses 30, and GSL uses 14. const BINV_THRESHOLD: f64 = 10.; - // Same value as in GSL. - // It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again. - // It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant. - // When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away. - const BINV_MAX_X: u64 = 110; - - if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (i32::MAX as u64) { - // Use the BINV algorithm. - let s = p / q; - let a = ((self.n + 1) as f64) * s; - - result = 'outer: loop { - let mut r = q.powi(self.n as i32); - let mut u: f64 = rng.random(); - let mut x = 0; - - while u > r { - u -= r; - x += 1; - if x > BINV_MAX_X { - continue 'outer; - } - r *= a / (x as f64) - s; - } - break x; + let np = n as f64 * p; + let method = if np < BINV_THRESHOLD { + let q = 1.0 - p; + if q == 1.0 { + // p is so small that this is extremely close to a Poisson distribution. + // The flipped case cannot occur here. + Method::Poisson(crate::poisson::KnuthMethod::new(np)) + } else { + let s = p / q; + Method::Binv( + Binv { + r: q.powf(n as f64), + s, + a: (n as f64 + 1.0) * s, + n, + }, + flipped, + ) } } else { - // Use the BTPE algorithm. - - // Threshold for using the squeeze algorithm. This can be freely - // chosen based on performance. Ranlib and GSL use 20. - const SQUEEZE_THRESHOLD: i64 = 20; - - // Step 0: Calculate constants as functions of `n` and `p`. - let n = self.n as f64; - let np = n * p; + let q = 1.0 - p; let npq = np * q; + let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5; let f_m = np + p; let m = f64_to_i64(f_m); - // radius of triangle region, since height=1 also area of region - let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5; - // tip of triangle - let x_m = (m as f64) + 0.5; - // left edge of triangle - let x_l = x_m - p1; - // right edge of triangle - let x_r = x_m + p1; - let c = 0.134 + 20.5 / (15.3 + (m as f64)); - // p1 + area of parallelogram region - let p2 = p1 * (1. + 2. * c); - - fn lambda(a: f64) -> f64 { - a * (1. + 0.5 * a) + Method::Btpe(Btpe { n, p, m, p1 }, flipped) + }; + Ok(Binomial { method }) + } +} + +/// Convert a `f64` to an `i64`, panicking on overflow. +fn f64_to_i64(x: f64) -> i64 { + assert!(x < (i64::MAX as f64)); + x as i64 +} + +fn binv(binv: Binv, flipped: bool, rng: &mut R) -> u64 { + // Same value as in GSL. + // It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again. + // It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant. + // When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away. + const BINV_MAX_X: u64 = 110; + + let sample = 'outer: loop { + let mut r = binv.r; + let mut u: f64 = rng.random(); + let mut x = 0; + + while u > r { + u -= r; + x += 1; + if x > BINV_MAX_X { + continue 'outer; } + r *= binv.a / (x as f64) - binv.s; + } + break x; + }; - let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p)); - let lambda_r = lambda((x_r - f_m) / (x_r * q)); - // p1 + area of left tail - let p3 = p2 + c / lambda_l; - // p1 + area of right tail - let p4 = p3 + c / lambda_r; - - // return value - let mut y: i64; - - let gen_u = Uniform::new(0., p4).unwrap(); - let gen_v = Uniform::new(0., 1.).unwrap(); - - loop { - // Step 1: Generate `u` for selecting the region. If region 1 is - // selected, generate a triangularly distributed variate. - let u = gen_u.sample(rng); - let mut v = gen_v.sample(rng); - if !(u > p1) { - y = f64_to_i64(x_m - p1 * v + u); - break; - } + if flipped { + binv.n - sample + } else { + sample + } +} - if !(u > p2) { - // Step 2: Region 2, parallelograms. Check if region 2 is - // used. If so, generate `y`. - let x = x_l + (u - p1) / c; - v = v * c + 1.0 - (x - x_m).abs() / p1; - if v > 1. { - continue; - } else { - y = f64_to_i64(x); - } - } else if !(u > p3) { - // Step 3: Region 3, left exponential tail. - y = f64_to_i64(x_l + v.ln() / lambda_l); - if y < 0 { - continue; - } else { - v *= (u - p2) * lambda_l; - } - } else { - // Step 4: Region 4, right exponential tail. - y = f64_to_i64(x_r - v.ln() / lambda_r); - if y > 0 && (y as u64) > self.n { - continue; - } else { - v *= (u - p3) * lambda_r; - } - } +#[allow(clippy::many_single_char_names)] // Same names as in the reference. +fn btpe(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 { + // Threshold for using the squeeze algorithm. This can be freely + // chosen based on performance. Ranlib and GSL use 20. + const SQUEEZE_THRESHOLD: i64 = 20; + + // Step 0: Calculate constants as functions of `n` and `p`. + let n = btpe.n as f64; + let np = n * btpe.p; + let q = 1. - btpe.p; + let npq = np * q; + let f_m = np + btpe.p; + let m = btpe.m; + // radius of triangle region, since height=1 also area of region + let p1 = btpe.p1; + // tip of triangle + let x_m = (m as f64) + 0.5; + // left edge of triangle + let x_l = x_m - p1; + // right edge of triangle + let x_r = x_m + p1; + let c = 0.134 + 20.5 / (15.3 + (m as f64)); + // p1 + area of parallelogram region + let p2 = p1 * (1. + 2. * c); + + fn lambda(a: f64) -> f64 { + a * (1. + 0.5 * a) + } - // Step 5: Acceptance/rejection comparison. - - // Step 5.0: Test for appropriate method of evaluating f(y). - let k = (y - m).abs(); - if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) { - // Step 5.1: Evaluate f(y) via the recursive relationship. Start the - // search from the mode. - let s = p / q; - let a = s * (n + 1.); - let mut f = 1.0; - match m.cmp(&y) { - Ordering::Less => { - let mut i = m; - loop { - i += 1; - f *= a / (i as f64) - s; - if i == y { - break; - } - } - } - Ordering::Greater => { - let mut i = y; - loop { - i += 1; - f /= a / (i as f64) - s; - if i == m { - break; - } - } + let lambda_l = lambda((f_m - x_l) / (f_m - x_l * btpe.p)); + let lambda_r = lambda((x_r - f_m) / (x_r * q)); + + let p3 = p2 + c / lambda_l; + + let p4 = p3 + c / lambda_r; + + // return value + let mut y: i64; + + let gen_u = Uniform::new(0., p4).unwrap(); + let gen_v = Uniform::new(0., 1.).unwrap(); + + loop { + // Step 1: Generate `u` for selecting the region. If region 1 is + // selected, generate a triangularly distributed variate. + let u = gen_u.sample(rng); + let mut v = gen_v.sample(rng); + if !(u > p1) { + y = f64_to_i64(x_m - p1 * v + u); + break; + } + + if !(u > p2) { + // Step 2: Region 2, parallelograms. Check if region 2 is + // used. If so, generate `y`. + let x = x_l + (u - p1) / c; + v = v * c + 1.0 - (x - x_m).abs() / p1; + if v > 1. { + continue; + } else { + y = f64_to_i64(x); + } + } else if !(u > p3) { + // Step 3: Region 3, left exponential tail. + y = f64_to_i64(x_l + v.ln() / lambda_l); + if y < 0 { + continue; + } else { + v *= (u - p2) * lambda_l; + } + } else { + // Step 4: Region 4, right exponential tail. + y = f64_to_i64(x_r - v.ln() / lambda_r); + if y > 0 && (y as u64) > btpe.n { + continue; + } else { + v *= (u - p3) * lambda_r; + } + } + + // Step 5: Acceptance/rejection comparison. + + // Step 5.0: Test for appropriate method of evaluating f(y). + let k = (y - m).abs(); + if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) { + // Step 5.1: Evaluate f(y) via the recursive relationship. Start the + // search from the mode. + let s = btpe.p / q; + let a = s * (n + 1.); + let mut f = 1.0; + match m.cmp(&y) { + Ordering::Less => { + let mut i = m; + loop { + i += 1; + f *= a / (i as f64) - s; + if i == y { + break; } - Ordering::Equal => {} } - if v > f { - continue; - } else { - break; - } - } - - // Step 5.2: Squeezing. Check the value of ln(v) against upper and - // lower bound of ln(f(y)). - let k = k as f64; - let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5); - let t = -0.5 * k * k / npq; - let alpha = v.ln(); - if alpha < t - rho { - break; } - if alpha > t + rho { - continue; + Ordering::Greater => { + let mut i = y; + loop { + i += 1; + f /= a / (i as f64) - s; + if i == m { + break; + } + } } + Ordering::Equal => {} + } + if v > f { + continue; + } else { + break; + } + } - // Step 5.3: Final acceptance/rejection test. - let x1 = (y + 1) as f64; - let f1 = (m + 1) as f64; - let z = (f64_to_i64(n) + 1 - m) as f64; - let w = (f64_to_i64(n) - y + 1) as f64; + // Step 5.2: Squeezing. Check the value of ln(v) against upper and + // lower bound of ln(f(y)). + let k = k as f64; + let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5); + let t = -0.5 * k * k / npq; + let alpha = v.ln(); + if alpha < t - rho { + break; + } + if alpha > t + rho { + continue; + } - fn stirling(a: f64) -> f64 { - let a2 = a * a; - (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320. - } + // Step 5.3: Final acceptance/rejection test. + let x1 = (y + 1) as f64; + let f1 = (m + 1) as f64; + let z = (f64_to_i64(n) + 1 - m) as f64; + let w = (f64_to_i64(n) - y + 1) as f64; - if alpha - > x_m * (f1 / x1).ln() - + (n - (m as f64) + 0.5) * (z / w).ln() - + ((y - m) as f64) * (w * p / (x1 * q)).ln() - // We use the signs from the GSL implementation, which are - // different than the ones in the reference. According to - // the GSL authors, the new signs were verified to be - // correct by one of the original designers of the - // algorithm. - + stirling(f1) - + stirling(z) - - stirling(x1) - - stirling(w) - { - continue; - } + fn stirling(a: f64) -> f64 { + let a2 = a * a; + (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320. + } - break; - } - assert!(y >= 0); - result = y as u64; + if alpha + > x_m * (f1 / x1).ln() + + (n - (m as f64) + 0.5) * (z / w).ln() + + ((y - m) as f64) * (w * btpe.p / (x1 * q)).ln() + // We use the signs from the GSL implementation, which are + // different than the ones in the reference. According to + // the GSL authors, the new signs were verified to be + // correct by one of the original designers of the + // algorithm. + + stirling(f1) + + stirling(z) + - stirling(x1) + - stirling(w) + { + continue; } - // Invert the result for p < 0.5. - if p != self.p { - self.n - result - } else { - result + break; + } + assert!(y >= 0); + let y = y as u64; + + if flipped { + btpe.n - y + } else { + y + } +} + +impl Distribution for Binomial { + fn sample(&self, rng: &mut R) -> u64 { + match self.method { + Method::Binv(binv_para, flipped) => binv(binv_para, flipped, rng), + Method::Btpe(btpe_para, flipped) => btpe(btpe_para, flipped, rng), + Method::Poisson(poisson) => poisson.sample(rng) as u64, + Method::Constant(c) => c, } } } @@ -371,6 +422,8 @@ mod test { test_binomial_mean_and_variance(40, 0.5, &mut rng); test_binomial_mean_and_variance(20, 0.7, &mut rng); test_binomial_mean_and_variance(20, 0.5, &mut rng); + test_binomial_mean_and_variance(1 << 61, 1e-17, &mut rng); + test_binomial_mean_and_variance(u64::MAX, 1e-19, &mut rng); } #[test]