From a8555478b4664cfc51af45712ca5549ed39ac56f Mon Sep 17 00:00:00 2001 From: Jambo Date: Tue, 28 Jan 2025 02:36:02 +0800 Subject: [PATCH] Replaces the Poisson rejection method implementation (#1560) - [x] Added a `CHANGELOG.md` entry # Summary As discussed in #1515, this PR replaces the implementation of `poisson::RejectionMethod` with a new algorithm based on the [paper ](https://dl.acm.org/doi/10.1145/355993.355997). # Motivation The new implementation offers improved performance and maintains better sampling distribution, especially for extreme values of lambda (> 1e9). # Details In terms of performance, here are the benchmarks I ran, with the current implementation as the baseline: ```text poisson/100 time: [45.5242 cycles 45.6734 cycles 45.8337 cycles] change: [-86.572% -86.507% -86.438%] (p = 0.00 < 0.05) Performance has improved. Found 5 outliers among 100 measurements (5.00%) 2 (2.00%) low mild 2 (2.00%) high mild 1 (1.00%) high severe poisson/variable time: [5494.6626 cycles 5508.2882 cycles 5523.2298 cycles] thrpt: [5523.2298 cycles/100 5508.2882 cycles/100 5494.6626 cycles/100] change: time: [-76.728% -76.573% -76.430%] (p = 0.00 < 0.05) thrpt: [+324.27% +326.85% +329.69%] Performance has improved. Found 5 outliers among 100 measurements (5.00%) 1 (1.00%) low mild 3 (3.00%) high mild 1 (1.00%) high severe ``` --- distr_test/tests/cdf.rs | 6 +- rand_distr/CHANGELOG.md | 1 + rand_distr/src/poisson.rs | 193 +++++++++++++++++----------- rand_distr/src/utils.rs | 43 ------- rand_distr/tests/value_stability.rs | 2 +- 5 files changed, 125 insertions(+), 120 deletions(-) diff --git a/distr_test/tests/cdf.rs b/distr_test/tests/cdf.rs index f417c63..9704c44 100644 --- a/distr_test/tests/cdf.rs +++ b/distr_test/tests/cdf.rs @@ -427,9 +427,9 @@ fn hypergeometric() { fn poisson() { use rand_distr::Poisson; let parameters = [ - 0.1, 1.0, 7.5, - 45.0, // 1e9, passed case but too slow - // 1.844E+19, // fail case + 0.1, 1.0, 7.5, 15.0, 45.0, 98.0, 230.0, 4567.5, + 4.4541e7, // 1e10, //passed case but too slow + // 1.844E+19, // fail case ]; for (seed, lambda) in parameters.into_iter().enumerate() { diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md index 81fa3a3..a75e125 100644 --- a/rand_distr/CHANGELOG.md +++ b/rand_distr/CHANGELOG.md @@ -48,6 +48,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 This breaks serialization compatibility with older versions. - Add plots for `rand_distr` distributions to documentation (#1434) - Move some of the computations in Binomial from `sample` to `new` (#1484) +- Reimplement `Poisson`'s rejection method to improve performance and correct sampling inaccuracies for large lambda values, this is a Value-breaking change (#1560) ## [0.4.3] - 2021-12-30 - Fix `no_std` build (#1208) diff --git a/rand_distr/src/poisson.rs b/rand_distr/src/poisson.rs index 3e44212..424f32f 100644 --- a/rand_distr/src/poisson.rs +++ b/rand_distr/src/poisson.rs @@ -9,7 +9,7 @@ //! The Poisson distribution `Poisson(λ)`. -use crate::{Cauchy, Distribution, StandardUniform}; +use crate::{Distribution, Exp1, Normal, StandardNormal, StandardUniform}; use core::fmt; use num_traits::{Float, FloatConst}; use rand::Rng; @@ -101,21 +101,37 @@ impl KnuthMethod { #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] struct RejectionMethod { lambda: F, - log_lambda: F, - sqrt_2lambda: F, - magic_val: F, + s: F, + d: F, + l: F, + c: F, + c0: F, + c1: F, + c2: F, + c3: F, + omega: F, } -impl RejectionMethod { +impl RejectionMethod { pub(crate) fn new(lambda: F) -> Self { - let log_lambda = lambda.ln(); - let sqrt_2lambda = (F::from(2.0).unwrap() * lambda).sqrt(); - let magic_val = lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda); + let b1 = F::from(1.0 / 24.0).unwrap() / lambda; + let b2 = F::from(0.3).unwrap() * b1 * b1; + let c3 = F::from(1.0 / 7.0).unwrap() * b1 * b2; + let c2 = b2 - F::from(15).unwrap() * c3; + let c1 = b1 - F::from(6).unwrap() * b2 + F::from(45).unwrap() * c3; + let c0 = F::one() - b1 + F::from(3).unwrap() * b2 - F::from(15).unwrap() * c3; + RejectionMethod { lambda, - log_lambda, - sqrt_2lambda, - magic_val, + s: lambda.sqrt(), + d: F::from(6.0).unwrap() * lambda.powi(2), + l: (lambda - F::from(1.1484).unwrap()).floor(), + c: F::from(0.1069).unwrap() / lambda, + c0, + c1, + c2, + c3, + omega: F::one() / (F::from(2).unwrap() * F::PI()).sqrt() / lambda.sqrt(), } } } @@ -189,49 +205,105 @@ impl Distribution for RejectionMethod where F: Float + FloatConst, StandardUniform: Distribution, + StandardNormal: Distribution, + Exp1: Distribution, { fn sample(&self, rng: &mut R) -> F { - // The algorithm from Numerical Recipes in C + // The algorithm is based on: + // J. H. Ahrens and U. Dieter. 1982. + // Computer Generation of Poisson Deviates from Modified Normal Distributions. + // ACM Trans. Math. Softw. 8, 2 (June 1982), 163–179. https://doi.org/10.1145/355993.355997 + + // Step F + let f = |k: F| { + const FACT: [f64; 10] = [ + 1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0, 362880.0, + ]; // factorial of 0..10 + const A: [f64; 10] = [ + -0.5000000002, + 0.3333333343, + -0.2499998565, + 0.1999997049, + -0.1666848753, + 0.1428833286, + -0.1241963125, + 0.1101687109, + -0.1142650302, + 0.1055093006, + ]; // coefficients from Table 1 + let (px, py) = if k < F::from(10.0).unwrap() { + let px = -self.lambda; + let py = self.lambda.powf(k) / F::from(FACT[k.to_usize().unwrap()]).unwrap(); + + (px, py) + } else { + let delta = (F::from(12.0).unwrap() * k).recip(); + let delta = delta - F::from(4.8).unwrap() * delta.powi(3); + let v = (self.lambda - k) / k; + + let px = if v.abs() <= F::from(0.25).unwrap() { + k * v.powi(2) + * A.iter() + .rev() + .fold(F::zero(), |acc, &a| { + acc * v + F::from(a).unwrap() + }) // Σ a_i * v^i + - delta + } else { + k * (F::one() + v).ln() - (self.lambda - k) - delta + }; + + let py = F::one() / (F::from(2.0).unwrap() * F::PI()).sqrt() / k.sqrt(); + + (px, py) + }; + + let x = (k - self.lambda + F::from(0.5).unwrap()) / self.s; + let fx = -F::from(0.5).unwrap() * x * x; + let fy = + self.omega * (((self.c3 * x * x + self.c2) * x * x + self.c1) * x * x + self.c0); + + (px, py, fx, fy) + }; + + // Step N + let normal = Normal::new(self.lambda, self.s).unwrap(); + let g = normal.sample(rng); + if g >= F::zero() { + let k1 = g.floor(); + + // Step I + if k1 >= self.l { + return k1; + } - // we use the Cauchy distribution as the comparison distribution - // f(x) ~ 1/(1+x^2) - let cauchy = Cauchy::new(F::zero(), F::one()).unwrap(); - let mut result; + // Step S + let u: F = rng.random(); + if self.d * u >= (self.lambda - k1).powi(3) { + return k1; + } + + let (px, py, fx, fy) = f(k1); + + if fy * (F::one() - u) <= py * (px - fx).exp() { + return k1; + } + } loop { - let mut comp_dev; - - loop { - // draw from the Cauchy distribution - comp_dev = rng.sample(cauchy); - // shift the peak of the comparison distribution - result = self.sqrt_2lambda * comp_dev + self.lambda; - // repeat the drawing until we are in the range of possible values - if result >= F::zero() { - break; + // Step E + let e = Exp1.sample(rng); + let u: F = rng.random() * F::from(2.0).unwrap() - F::one(); + let t = F::from(1.8).unwrap() + e * u.signum(); + if t > F::from(-0.6744).unwrap() { + let k2 = (self.lambda + self.s * t).floor(); + let (px, py, fx, fy) = f(k2); + // Step H + if self.c * u.abs() <= py * (px + e).exp() - fy * (fx + e).exp() { + return k2; } } - // now the result is a random variable greater than 0 with Cauchy distribution - // the result should be an integer value - result = result.floor(); - - // this is the ratio of the Poisson distribution to the comparison distribution - // the magic value scales the distribution function to a range of approximately 0-1 - // since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1 - // this doesn't change the resulting distribution, only increases the rate of failed drawings - let check = F::from(0.9).unwrap() - * (F::one() + comp_dev * comp_dev) - * (result * self.log_lambda - - crate::utils::log_gamma(F::one() + result) - - self.magic_val) - .exp(); - - // check with uniform random value - if below the threshold, we are within the target distribution - if rng.random::() <= check { - break; - } } - result } } @@ -239,6 +311,8 @@ impl Distribution for Poisson where F: Float + FloatConst, StandardUniform: Distribution, + StandardNormal: Distribution, + Exp1: Distribution, { #[inline] fn sample(&self, rng: &mut R) -> F { @@ -253,33 +327,6 @@ where mod test { use super::*; - fn test_poisson_avg_gen(lambda: F, tol: F) - where - StandardUniform: Distribution, - { - let poisson = Poisson::new(lambda).unwrap(); - let mut rng = crate::test::rng(123); - let mut sum = F::zero(); - for _ in 0..1000 { - sum = sum + poisson.sample(&mut rng); - } - let avg = sum / F::from(1000.0).unwrap(); - assert!((avg - lambda).abs() < tol); - } - - #[test] - fn test_poisson_avg() { - test_poisson_avg_gen::(10.0, 0.1); - test_poisson_avg_gen::(15.0, 0.1); - - test_poisson_avg_gen::(10.0, 0.1); - test_poisson_avg_gen::(15.0, 0.1); - - // Small lambda will use Knuth's method with exp_lambda == 1.0 - test_poisson_avg_gen::(0.00000000000000005, 0.1); - test_poisson_avg_gen::(0.00000000000000005, 0.1); - } - #[test] #[should_panic] fn test_poisson_invalid_lambda_zero() { diff --git a/rand_distr/src/utils.rs b/rand_distr/src/utils.rs index f0cf2a1..ebc2fb5 100644 --- a/rand_distr/src/utils.rs +++ b/rand_distr/src/utils.rs @@ -9,52 +9,9 @@ //! Math helper functions use crate::ziggurat_tables; -use num_traits::Float; use rand::distr::hidden_export::IntoFloat; use rand::Rng; -/// Calculates ln(gamma(x)) (natural logarithm of the gamma -/// function) using the Lanczos approximation. -/// -/// The approximation expresses the gamma function as: -/// `gamma(z+1) = sqrt(2*pi)*(z+g+0.5)^(z+0.5)*exp(-z-g-0.5)*Ag(z)` -/// `g` is an arbitrary constant; we use the approximation with `g=5`. -/// -/// Noting that `gamma(z+1) = z*gamma(z)` and applying `ln` to both sides: -/// `ln(gamma(z)) = (z+0.5)*ln(z+g+0.5)-(z+g+0.5) + ln(sqrt(2*pi)*Ag(z)/z)` -/// -/// `Ag(z)` is an infinite series with coefficients that can be calculated -/// ahead of time - we use just the first 6 terms, which is good enough -/// for most purposes. -pub(crate) fn log_gamma(x: F) -> F { - // precalculated 6 coefficients for the first 6 terms of the series - let coefficients: [F; 6] = [ - F::from(76.18009172947146).unwrap(), - F::from(-86.50532032941677).unwrap(), - F::from(24.01409824083091).unwrap(), - F::from(-1.231739572450155).unwrap(), - F::from(0.1208650973866179e-2).unwrap(), - F::from(-0.5395239384953e-5).unwrap(), - ]; - - // (x+0.5)*ln(x+g+0.5)-(x+g+0.5) - let tmp = x + F::from(5.5).unwrap(); - let log = (x + F::from(0.5).unwrap()) * tmp.ln() - tmp; - - // the first few terms of the series for Ag(x) - let mut a = F::from(1.000000000190015).unwrap(); - let mut denom = x; - for &coeff in &coefficients { - denom = denom + F::one(); - a = a + (coeff / denom); - } - - // get everything together - // a is Ag(x) - // 2.5066... is sqrt(2pi) - log + (F::from(2.5066282746310005).unwrap() * a / x).ln() -} - /// Sample a random number using the Ziggurat method (specifically the /// ZIGNOR variant from Doornik 2005). Most of the arguments are /// directly from the paper: diff --git a/rand_distr/tests/value_stability.rs b/rand_distr/tests/value_stability.rs index 330119b..2eb263e 100644 --- a/rand_distr/tests/value_stability.rs +++ b/rand_distr/tests/value_stability.rs @@ -207,7 +207,7 @@ fn poisson_stability() { test_samples( 223, Poisson::new(27.0).unwrap(), - &[28.0f32, 32.0, 36.0, 36.0], + &[30.0f32, 33.0, 23.0, 25.0], ); }