Skip to content

Commit

Permalink
Merge pull request #1289 from dhardy/uniform-float
Browse files Browse the repository at this point in the history
Uniform float improvements
  • Loading branch information
vks authored May 1, 2023
2 parents 1464b88 + 026292d commit d4a2945
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 0 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,8 @@ harness = false
name = "shuffle"
path = "benches/shuffle.rs"
harness = false

[[bench]]
name = "uniform_float"
path = "benches/uniform_float.rs"
harness = false
103 changes: 103 additions & 0 deletions benches/uniform_float.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright 2023 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! Implement benchmarks for uniform distributions over FP types
//!
//! Sampling methods compared:
//!
//! - sample: current method: (x12 - 1.0) * (b - a) + a
use core::time::Duration;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use rand::distributions::uniform::{SampleUniform, Uniform, UniformSampler};
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use rand_pcg::{Pcg32, Pcg64};

const WARM_UP_TIME: Duration = Duration::from_millis(1000);
const MEASUREMENT_TIME: Duration = Duration::from_secs(3);
const SAMPLE_SIZE: usize = 100_000;
const N_RESAMPLES: usize = 10_000;

macro_rules! single_random {
($R:ty, $T:ty, $g:expr) => {
$g.bench_function(BenchmarkId::new(stringify!($T), stringify!($R)), |b| {
let mut rng = <$R>::from_entropy();
let (mut low, mut high);
loop {
low = <$T>::from_bits(rng.gen());
high = <$T>::from_bits(rng.gen());
if (low < high) && (high - low).is_normal() {
break;
}
}

b.iter(|| <$T as SampleUniform>::Sampler::sample_single_inclusive(low, high, &mut rng));
});
};

($c:expr, $T:ty) => {{
let mut g = $c.benchmark_group("uniform_single");
g.sample_size(SAMPLE_SIZE);
g.warm_up_time(WARM_UP_TIME);
g.measurement_time(MEASUREMENT_TIME);
g.nresamples(N_RESAMPLES);
single_random!(SmallRng, $T, g);
single_random!(ChaCha8Rng, $T, g);
single_random!(Pcg32, $T, g);
single_random!(Pcg64, $T, g);
g.finish();
}};
}

fn single_random(c: &mut Criterion) {
single_random!(c, f32);
single_random!(c, f64);
}

macro_rules! distr_random {
($R:ty, $T:ty, $g:expr) => {
$g.bench_function(BenchmarkId::new(stringify!($T), stringify!($R)), |b| {
let mut rng = <$R>::from_entropy();
let dist = loop {
let low = <$T>::from_bits(rng.gen());
let high = <$T>::from_bits(rng.gen());
if let Ok(dist) = Uniform::<$T>::new_inclusive(low, high) {
break dist;
}
};

b.iter(|| dist.sample(&mut rng));
});
};

($c:expr, $T:ty) => {{
let mut g = $c.benchmark_group("uniform_distribution");
g.sample_size(SAMPLE_SIZE);
g.warm_up_time(WARM_UP_TIME);
g.measurement_time(MEASUREMENT_TIME);
g.nresamples(N_RESAMPLES);
distr_random!(SmallRng, $T, g);
distr_random!(ChaCha8Rng, $T, g);
distr_random!(Pcg32, $T, g);
distr_random!(Pcg64, $T, g);
g.finish();
}};
}

fn distr_random(c: &mut Criterion) {
distr_random!(c, f32);
distr_random!(c, f64);
}

criterion_group! {
name = benches;
config = Criterion::default();
targets = single_random, distr_random
}
criterion_main!(benches);
46 changes: 46 additions & 0 deletions src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,38 @@ macro_rules! uniform_float_impl {
}
}
}

#[inline]
fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R) -> Result<Self::X, Error>
where
B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized,
{
let low = *low_b.borrow();
let high = *high_b.borrow();
#[cfg(debug_assertions)]
if !low.all_finite() || !high.all_finite() {
return Err(Error::NonFinite);
}
if !low.all_le(high) {
return Err(Error::EmptyRange);
}
let scale = high - low;
if !scale.all_finite() {
return Err(Error::NonFinite);
}

// Generate a value in the range [1, 2)
let value1_2 =
(rng.gen::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0);

// Get a value in the range [0, 1) to avoid overflow when multiplying by scale
let value0_1 = value1_2 - <$ty>::splat(1.0);

// Doing multiply before addition allows some architectures
// to use a single instruction.
Ok(value0_1 * scale + low)
}
}
};
}
Expand Down Expand Up @@ -1380,6 +1412,9 @@ mod tests {
let v = <$ty as SampleUniform>::Sampler
::sample_single(low, high, &mut rng).unwrap().extract(lane);
assert!(low_scalar <= v && v < high_scalar);
let v = <$ty as SampleUniform>::Sampler
::sample_single_inclusive(low, high, &mut rng).unwrap().extract(lane);
assert!(low_scalar <= v && v <= high_scalar);
}

assert_eq!(
Expand All @@ -1392,8 +1427,19 @@ mod tests {
assert_eq!(<$ty as SampleUniform>::Sampler
::sample_single(low, high, &mut zero_rng).unwrap()
.extract(lane), low_scalar);
assert_eq!(<$ty as SampleUniform>::Sampler
::sample_single_inclusive(low, high, &mut zero_rng).unwrap()
.extract(lane), low_scalar);

assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar);
assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar);
// sample_single cannot cope with max_rng:
// assert!(<$ty as SampleUniform>::Sampler
// ::sample_single(low, high, &mut max_rng).unwrap()
// .extract(lane) < high_scalar);
assert!(<$ty as SampleUniform>::Sampler
::sample_single_inclusive(low, high, &mut max_rng).unwrap()
.extract(lane) <= high_scalar);

// Don't run this test for really tiny differences between high and low
// since for those rounding might result in selecting high for a very
Expand Down

0 comments on commit d4a2945

Please sign in to comment.