Skip to content

Commit

Permalink
Replaces the Poisson rejection method implementation (#1560)
Browse files Browse the repository at this point in the history
- [x] Added a `CHANGELOG.md` entry

# Summary

As discussed in #1515, this PR replaces the implementation of
`poisson::RejectionMethod` with a new algorithm based on the [paper
](https://dl.acm.org/doi/10.1145/355993.355997).

# Motivation

The new implementation offers improved performance and maintains better
sampling distribution, especially for extreme values of lambda (> 1e9).

# Details

In terms of performance, here are the benchmarks I ran, with the current
implementation as the baseline:

```text
poisson/100             time:   [45.5242 cycles 45.6734 cycles 45.8337 cycles]
                        change: [-86.572% -86.507% -86.438%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 5 outliers among 100 measurements (5.00%)
  2 (2.00%) low mild
  2 (2.00%) high mild
  1 (1.00%) high severe
poisson/variable        time:   [5494.6626 cycles 5508.2882 cycles 5523.2298 cycles]
                        thrpt:  [5523.2298 cycles/100 5508.2882 cycles/100 5494.6626 cycles/100]
                 change:
                        time:   [-76.728% -76.573% -76.430%] (p = 0.00 < 0.05)
                        thrpt:  [+324.27% +326.85% +329.69%]
                        Performance has improved.
Found 5 outliers among 100 measurements (5.00%)
  1 (1.00%) low mild
  3 (3.00%) high mild
  1 (1.00%) high severe
```
  • Loading branch information
JamboChen authored Jan 27, 2025
1 parent 67fd92e commit a855547
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 120 deletions.
6 changes: 3 additions & 3 deletions distr_test/tests/cdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,9 @@ fn hypergeometric() {
fn poisson() {
use rand_distr::Poisson;
let parameters = [
0.1, 1.0, 7.5,
45.0, // 1e9, passed case but too slow
// 1.844E+19, // fail case
0.1, 1.0, 7.5, 15.0, 45.0, 98.0, 230.0, 4567.5,
4.4541e7, // 1e10, //passed case but too slow
// 1.844E+19, // fail case
];

for (seed, lambda) in parameters.into_iter().enumerate() {
Expand Down
1 change: 1 addition & 0 deletions rand_distr/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
This breaks serialization compatibility with older versions.
- Add plots for `rand_distr` distributions to documentation (#1434)
- Move some of the computations in Binomial from `sample` to `new` (#1484)
- Reimplement `Poisson`'s rejection method to improve performance and correct sampling inaccuracies for large lambda values, this is a Value-breaking change (#1560)

## [0.4.3] - 2021-12-30
- Fix `no_std` build (#1208)
Expand Down
193 changes: 120 additions & 73 deletions rand_distr/src/poisson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

//! The Poisson distribution `Poisson(λ)`.
use crate::{Cauchy, Distribution, StandardUniform};
use crate::{Distribution, Exp1, Normal, StandardNormal, StandardUniform};
use core::fmt;
use num_traits::{Float, FloatConst};
use rand::Rng;
Expand Down Expand Up @@ -101,21 +101,37 @@ impl<F: Float> KnuthMethod<F> {
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct RejectionMethod<F> {
lambda: F,
log_lambda: F,
sqrt_2lambda: F,
magic_val: F,
s: F,
d: F,
l: F,
c: F,
c0: F,
c1: F,
c2: F,
c3: F,
omega: F,
}

impl<F: Float> RejectionMethod<F> {
impl<F: Float + FloatConst> RejectionMethod<F> {
pub(crate) fn new(lambda: F) -> Self {
let log_lambda = lambda.ln();
let sqrt_2lambda = (F::from(2.0).unwrap() * lambda).sqrt();
let magic_val = lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda);
let b1 = F::from(1.0 / 24.0).unwrap() / lambda;
let b2 = F::from(0.3).unwrap() * b1 * b1;
let c3 = F::from(1.0 / 7.0).unwrap() * b1 * b2;
let c2 = b2 - F::from(15).unwrap() * c3;
let c1 = b1 - F::from(6).unwrap() * b2 + F::from(45).unwrap() * c3;
let c0 = F::one() - b1 + F::from(3).unwrap() * b2 - F::from(15).unwrap() * c3;

RejectionMethod {
lambda,
log_lambda,
sqrt_2lambda,
magic_val,
s: lambda.sqrt(),
d: F::from(6.0).unwrap() * lambda.powi(2),
l: (lambda - F::from(1.1484).unwrap()).floor(),
c: F::from(0.1069).unwrap() / lambda,
c0,
c1,
c2,
c3,
omega: F::one() / (F::from(2).unwrap() * F::PI()).sqrt() / lambda.sqrt(),
}
}
}
Expand Down Expand Up @@ -189,56 +205,114 @@ impl<F> Distribution<F> for RejectionMethod<F>
where
F: Float + FloatConst,
StandardUniform: Distribution<F>,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
// The algorithm from Numerical Recipes in C
// The algorithm is based on:
// J. H. Ahrens and U. Dieter. 1982.
// Computer Generation of Poisson Deviates from Modified Normal Distributions.
// ACM Trans. Math. Softw. 8, 2 (June 1982), 163–179. https://doi.org/10.1145/355993.355997

// Step F
let f = |k: F| {
const FACT: [f64; 10] = [
1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0, 362880.0,
]; // factorial of 0..10
const A: [f64; 10] = [
-0.5000000002,
0.3333333343,
-0.2499998565,
0.1999997049,
-0.1666848753,
0.1428833286,
-0.1241963125,
0.1101687109,
-0.1142650302,
0.1055093006,
]; // coefficients from Table 1
let (px, py) = if k < F::from(10.0).unwrap() {
let px = -self.lambda;
let py = self.lambda.powf(k) / F::from(FACT[k.to_usize().unwrap()]).unwrap();

(px, py)
} else {
let delta = (F::from(12.0).unwrap() * k).recip();
let delta = delta - F::from(4.8).unwrap() * delta.powi(3);
let v = (self.lambda - k) / k;

let px = if v.abs() <= F::from(0.25).unwrap() {
k * v.powi(2)
* A.iter()
.rev()
.fold(F::zero(), |acc, &a| {
acc * v + F::from(a).unwrap()
}) // Σ a_i * v^i
- delta
} else {
k * (F::one() + v).ln() - (self.lambda - k) - delta
};

let py = F::one() / (F::from(2.0).unwrap() * F::PI()).sqrt() / k.sqrt();

(px, py)
};

let x = (k - self.lambda + F::from(0.5).unwrap()) / self.s;
let fx = -F::from(0.5).unwrap() * x * x;
let fy =
self.omega * (((self.c3 * x * x + self.c2) * x * x + self.c1) * x * x + self.c0);

(px, py, fx, fy)
};

// Step N
let normal = Normal::new(self.lambda, self.s).unwrap();
let g = normal.sample(rng);
if g >= F::zero() {
let k1 = g.floor();

// Step I
if k1 >= self.l {
return k1;
}

// we use the Cauchy distribution as the comparison distribution
// f(x) ~ 1/(1+x^2)
let cauchy = Cauchy::new(F::zero(), F::one()).unwrap();
let mut result;
// Step S
let u: F = rng.random();
if self.d * u >= (self.lambda - k1).powi(3) {
return k1;
}

let (px, py, fx, fy) = f(k1);

if fy * (F::one() - u) <= py * (px - fx).exp() {
return k1;
}
}

loop {
let mut comp_dev;

loop {
// draw from the Cauchy distribution
comp_dev = rng.sample(cauchy);
// shift the peak of the comparison distribution
result = self.sqrt_2lambda * comp_dev + self.lambda;
// repeat the drawing until we are in the range of possible values
if result >= F::zero() {
break;
// Step E
let e = Exp1.sample(rng);
let u: F = rng.random() * F::from(2.0).unwrap() - F::one();
let t = F::from(1.8).unwrap() + e * u.signum();
if t > F::from(-0.6744).unwrap() {
let k2 = (self.lambda + self.s * t).floor();
let (px, py, fx, fy) = f(k2);
// Step H
if self.c * u.abs() <= py * (px + e).exp() - fy * (fx + e).exp() {
return k2;
}
}
// now the result is a random variable greater than 0 with Cauchy distribution
// the result should be an integer value
result = result.floor();

// this is the ratio of the Poisson distribution to the comparison distribution
// the magic value scales the distribution function to a range of approximately 0-1
// since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1
// this doesn't change the resulting distribution, only increases the rate of failed drawings
let check = F::from(0.9).unwrap()
* (F::one() + comp_dev * comp_dev)
* (result * self.log_lambda
- crate::utils::log_gamma(F::one() + result)
- self.magic_val)
.exp();

// check with uniform random value - if below the threshold, we are within the target distribution
if rng.random::<F>() <= check {
break;
}
}
result
}
}

impl<F> Distribution<F> for Poisson<F>
where
F: Float + FloatConst,
StandardUniform: Distribution<F>,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
{
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
Expand All @@ -253,33 +327,6 @@ where
mod test {
use super::*;

fn test_poisson_avg_gen<F: Float + FloatConst>(lambda: F, tol: F)
where
StandardUniform: Distribution<F>,
{
let poisson = Poisson::new(lambda).unwrap();
let mut rng = crate::test::rng(123);
let mut sum = F::zero();
for _ in 0..1000 {
sum = sum + poisson.sample(&mut rng);
}
let avg = sum / F::from(1000.0).unwrap();
assert!((avg - lambda).abs() < tol);
}

#[test]
fn test_poisson_avg() {
test_poisson_avg_gen::<f64>(10.0, 0.1);
test_poisson_avg_gen::<f64>(15.0, 0.1);

test_poisson_avg_gen::<f32>(10.0, 0.1);
test_poisson_avg_gen::<f32>(15.0, 0.1);

// Small lambda will use Knuth's method with exp_lambda == 1.0
test_poisson_avg_gen::<f32>(0.00000000000000005, 0.1);
test_poisson_avg_gen::<f64>(0.00000000000000005, 0.1);
}

#[test]
#[should_panic]
fn test_poisson_invalid_lambda_zero() {
Expand Down
43 changes: 0 additions & 43 deletions rand_distr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,9 @@
//! Math helper functions
use crate::ziggurat_tables;
use num_traits::Float;
use rand::distr::hidden_export::IntoFloat;
use rand::Rng;

/// Calculates ln(gamma(x)) (natural logarithm of the gamma
/// function) using the Lanczos approximation.
///
/// The approximation expresses the gamma function as:
/// `gamma(z+1) = sqrt(2*pi)*(z+g+0.5)^(z+0.5)*exp(-z-g-0.5)*Ag(z)`
/// `g` is an arbitrary constant; we use the approximation with `g=5`.
///
/// Noting that `gamma(z+1) = z*gamma(z)` and applying `ln` to both sides:
/// `ln(gamma(z)) = (z+0.5)*ln(z+g+0.5)-(z+g+0.5) + ln(sqrt(2*pi)*Ag(z)/z)`
///
/// `Ag(z)` is an infinite series with coefficients that can be calculated
/// ahead of time - we use just the first 6 terms, which is good enough
/// for most purposes.
pub(crate) fn log_gamma<F: Float>(x: F) -> F {
// precalculated 6 coefficients for the first 6 terms of the series
let coefficients: [F; 6] = [
F::from(76.18009172947146).unwrap(),
F::from(-86.50532032941677).unwrap(),
F::from(24.01409824083091).unwrap(),
F::from(-1.231739572450155).unwrap(),
F::from(0.1208650973866179e-2).unwrap(),
F::from(-0.5395239384953e-5).unwrap(),
];

// (x+0.5)*ln(x+g+0.5)-(x+g+0.5)
let tmp = x + F::from(5.5).unwrap();
let log = (x + F::from(0.5).unwrap()) * tmp.ln() - tmp;

// the first few terms of the series for Ag(x)
let mut a = F::from(1.000000000190015).unwrap();
let mut denom = x;
for &coeff in &coefficients {
denom = denom + F::one();
a = a + (coeff / denom);
}

// get everything together
// a is Ag(x)
// 2.5066... is sqrt(2pi)
log + (F::from(2.5066282746310005).unwrap() * a / x).ln()
}

/// Sample a random number using the Ziggurat method (specifically the
/// ZIGNOR variant from Doornik 2005). Most of the arguments are
/// directly from the paper:
Expand Down
2 changes: 1 addition & 1 deletion rand_distr/tests/value_stability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ fn poisson_stability() {
test_samples(
223,
Poisson::new(27.0).unwrap(),
&[28.0f32, 32.0, 36.0, 36.0],
&[30.0f32, 33.0, 23.0, 25.0],
);
}

Expand Down

0 comments on commit a855547

Please sign in to comment.