From b4b1eb7579c0a47c1d71560ada0acffd647c9370 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Tue, 14 Jan 2025 07:37:46 +0000 Subject: [PATCH] Re-org with distr::slice, distr::weighted modules (#1548) - Move `Slice` -> `slice::Choose`, `EmptySlice` -> `slice::Empty` - Rename trait `DistString` -> `SampleString` - Rename `DistIter` -> `Iter`, `DistMap` -> `Map` - Move `{Weight, WeightError, WeightedIndex}` -> `weighted::{Weight, Error, WeightedIndex}` - Move `weighted_alias::{AliasableWeight, WeightedAliasIndex}` -> `weighted::{..}` - Move `weighted_tree::WeightedTreeIndex` -> `weighted::WeightedTreeIndex` --- .github/workflows/benches.yml | 4 +- .github/workflows/distr_test.yml | 4 +- .github/workflows/test.yml | 2 +- CHANGELOG.md | 4 + benches/benches/distr.rs | 1 + benches/benches/weighted.rs | 2 +- distr_test/tests/weighted.rs | 4 +- rand_distr/CHANGELOG.md | 6 + rand_distr/src/lib.rs | 24 +-- rand_distr/src/weighted/mod.rs | 28 +++ .../src/{ => weighted}/weighted_alias.rs | 42 ++-- .../src/{ => weighted}/weighted_tree.rs | 61 +++--- src/distr/distribution.rs | 58 +++--- src/distr/mod.rs | 16 +- src/distr/other.rs | 14 +- src/distr/slice.rs | 82 ++++---- src/distr/uniform_other.rs | 7 +- src/distr/weighted/mod.rs | 115 +++++++++++ src/distr/{ => weighted}/weighted_index.rs | 190 ++++-------------- src/lib.rs | 2 +- src/rng.rs | 4 +- src/seq/mod.rs | 4 +- src/seq/slice.rs | 14 +- 23 files changed, 354 insertions(+), 334 deletions(-) create mode 100644 rand_distr/src/weighted/mod.rs rename rand_distr/src/{ => weighted}/weighted_alias.rs (94%) rename rand_distr/src/{ => weighted}/weighted_tree.rs (87%) create mode 100644 src/distr/weighted/mod.rs rename src/distr/{ => weighted}/weighted_index.rs (77%) diff --git a/.github/workflows/benches.yml b/.github/workflows/benches.yml index 4be504fb67..22b4baa8dc 100644 --- a/.github/workflows/benches.yml +++ b/.github/workflows/benches.yml @@ -20,7 +20,7 @@ defaults: jobs: clippy-fmt: - name: Check Clippy and rustfmt + name: "Benches: Check Clippy and rustfmt" runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -33,7 +33,7 @@ jobs: - name: Clippy run: cargo clippy --all-targets -- -D warnings benches: - name: Test benchmarks + name: "Benches: Test" runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/distr_test.yml b/.github/workflows/distr_test.yml index ad0c0ba1e8..f2b7f814c9 100644 --- a/.github/workflows/distr_test.yml +++ b/.github/workflows/distr_test.yml @@ -20,7 +20,7 @@ defaults: jobs: clippy-fmt: - name: Check Clippy and rustfmt + name: "distr_test: Check Clippy and rustfmt" runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -33,7 +33,7 @@ jobs: - name: Clippy run: cargo clippy --all-targets -- -D warnings ks-tests: - name: Run Komogorov Smirnov tests + name: "distr_test: Run Komogorov Smirnov tests" runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9858b0f41a..293d5f4942 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,7 +28,7 @@ jobs: toolchain: stable components: clippy, rustfmt - name: Check Clippy - run: cargo clippy --all --all-targets -- -D warnings + run: cargo clippy --workspace -- -D warnings - name: Check rustfmt run: cargo fmt --all -- --check diff --git a/CHANGELOG.md b/CHANGELOG.md index bc9ecdf603..8f2e62e9bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update. ## [0.9.0-beta.3] - 2025-01-03 - Add feature `thread_rng` (#1547) +- Move `distr::Slice` -> `distr::slice::Choose`, `distr::EmptySlice` -> `distr::slice::Empty` (#1548) +- Rename trait `distr::DistString` -> `distr::SampleString` (#1548) +- Rename `distr::DistIter` -> `distr::Iter`, `distr::DistMap` -> `distr::Map` (#1548) +- Move `distr::{Weight, WeightError, WeightedIndex}` -> `distr::weighted::{Weight, Error, WeightedIndex}` (#1548) ## [0.9.0-beta.1] - 2024-11-30 - Bump `rand_core` version diff --git a/benches/benches/distr.rs b/benches/benches/distr.rs index fccfb1e0e9..3a76211972 100644 --- a/benches/benches/distr.rs +++ b/benches/benches/distr.rs @@ -10,6 +10,7 @@ use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use criterion_cycles_per_byte::CyclesPerByte; use rand::prelude::*; +use rand_distr::weighted::*; use rand_distr::*; // At this time, distributions are optimised for 64-bit platforms. diff --git a/benches/benches/weighted.rs b/benches/benches/weighted.rs index d7af914736..69576b3608 100644 --- a/benches/benches/weighted.rs +++ b/benches/benches/weighted.rs @@ -7,7 +7,7 @@ // except according to those terms. use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use rand::distr::WeightedIndex; +use rand::distr::weighted::WeightedIndex; use rand::prelude::*; use rand::seq::index::sample_weighted; diff --git a/distr_test/tests/weighted.rs b/distr_test/tests/weighted.rs index cf87b3ee63..73df7beb9b 100644 --- a/distr_test/tests/weighted.rs +++ b/distr_test/tests/weighted.rs @@ -8,9 +8,9 @@ mod ks; use ks::test_discrete; -use rand::distr::{Distribution, WeightedIndex}; +use rand::distr::Distribution; use rand::seq::{IndexedRandom, IteratorRandom}; -use rand_distr::{WeightedAliasIndex, WeightedTreeIndex}; +use rand_distr::weighted::*; /// Takes the unnormalized pdf and creates the cdf of a discrete distribution fn make_cdf(num: usize, f: impl Fn(i64) -> f64) -> impl Fn(i64) -> f64 { diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md index 155b5ce845..ee3490ca30 100644 --- a/rand_distr/CHANGELOG.md +++ b/rand_distr/CHANGELOG.md @@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.5.0-beta.3] - 2025-01-03 - Bump `rand` version (#1547) +- Move `Slice` -> `slice::Choose`, `EmptySlice` -> `slice::Empty` (#1548) +- Rename trait `DistString` -> `SampleString` (#1548) +- Rename `DistIter` -> `Iter`, `DistMap` -> `Map` (#1548) +- Move `{Weight, WeightError, WeightedIndex}` -> `weighted::{Weight, Error, WeightedIndex}` (#1548) +- Move `weighted_alias::{AliasableWeight, WeightedAliasIndex}` -> `weighted::{..}` (#1548) +- Move `weighted_tree::WeightedTreeIndex` -> `weighted::WeightedTreeIndex` (#1548) ## [0.5.0-beta.2] - 2024-11-30 - Bump `rand` version diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index efd316b09c..ef1109b7d6 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -33,9 +33,10 @@ //! //! The following are re-exported: //! -//! - The [`Distribution`] trait and [`DistIter`] helper type +//! - The [`Distribution`] trait and [`Iter`] helper type //! - The [`StandardUniform`], [`Alphanumeric`], [`Uniform`], [`OpenClosed01`], -//! [`Open01`], [`Bernoulli`], and [`WeightedIndex`] distributions +//! [`Open01`], [`Bernoulli`] distributions +//! - The [`weighted`] module //! //! ## Distributions //! @@ -76,9 +77,6 @@ //! - [`UnitBall`] distribution //! - [`UnitCircle`] distribution //! - [`UnitDisc`] distribution -//! - Alternative implementations for weighted index sampling -//! - [`WeightedAliasIndex`] distribution -//! - [`WeightedTreeIndex`] distribution //! - Misc. distributions //! - [`InverseGaussian`] distribution //! - [`NormalInverseGaussian`] distribution @@ -94,7 +92,7 @@ extern crate std; use rand::Rng; pub use rand::distr::{ - uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01, + uniform, Alphanumeric, Bernoulli, BernoulliError, Distribution, Iter, Open01, OpenClosed01, StandardUniform, Uniform, }; @@ -128,16 +126,13 @@ pub use self::unit_sphere::UnitSphere; pub use self::weibull::{Error as WeibullError, Weibull}; pub use self::zeta::{Error as ZetaError, Zeta}; pub use self::zipf::{Error as ZipfError, Zipf}; -#[cfg(feature = "alloc")] -pub use rand::distr::{WeightError, WeightedIndex}; pub use student_t::StudentT; -#[cfg(feature = "alloc")] -pub use weighted_alias::WeightedAliasIndex; -#[cfg(feature = "alloc")] -pub use weighted_tree::WeightedTreeIndex; pub use num_traits; +#[cfg(feature = "alloc")] +pub mod weighted; + #[cfg(test)] #[macro_use] mod test { @@ -189,11 +184,6 @@ mod test { } } -#[cfg(feature = "alloc")] -pub mod weighted_alias; -#[cfg(feature = "alloc")] -pub mod weighted_tree; - mod beta; mod binomial; mod cauchy; diff --git a/rand_distr/src/weighted/mod.rs b/rand_distr/src/weighted/mod.rs new file mode 100644 index 0000000000..1c54e48e69 --- /dev/null +++ b/rand_distr/src/weighted/mod.rs @@ -0,0 +1,28 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Weighted (index) sampling +//! +//! This module is a superset of [`rand::distr::weighted`]. +//! +//! Multiple implementations of weighted index sampling are provided: +//! +//! - [`WeightedIndex`] (a re-export from [`rand`]) supports fast construction +//! and `O(log N)` sampling over `N` weights. +//! It also supports updating weights with `O(N)` time. +//! - [`WeightedAliasIndex`] supports `O(1)` sampling, but due to high +//! construction time many samples are required to outperform [`WeightedIndex`]. +//! - [`WeightedTreeIndex`] supports `O(log N)` sampling and +//! update/insertion/removal of weights with `O(log N)` time. + +mod weighted_alias; +mod weighted_tree; + +pub use rand::distr::weighted::*; +pub use weighted_alias::*; +pub use weighted_tree::*; diff --git a/rand_distr/src/weighted_alias.rs b/rand_distr/src/weighted/weighted_alias.rs similarity index 94% rename from rand_distr/src/weighted_alias.rs rename to rand_distr/src/weighted/weighted_alias.rs index 676689f2ad..862f2b70b3 100644 --- a/rand_distr/src/weighted_alias.rs +++ b/rand_distr/src/weighted/weighted_alias.rs @@ -9,7 +9,7 @@ //! This module contains an implementation of alias method for sampling random //! indices with probabilities proportional to a collection of weights. -use super::WeightError; +use super::Error; use crate::{uniform::SampleUniform, Distribution, Uniform}; use alloc::{boxed::Box, vec, vec::Vec}; use core::fmt; @@ -41,7 +41,7 @@ use serde::{Deserialize, Serialize}; /// # Example /// /// ``` -/// use rand_distr::WeightedAliasIndex; +/// use rand_distr::weighted::WeightedAliasIndex; /// use rand::prelude::*; /// /// let choices = vec!['a', 'b', 'c']; @@ -85,14 +85,14 @@ impl WeightedAliasIndex { /// Creates a new [`WeightedAliasIndex`]. /// /// Error cases: - /// - [`WeightError::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`. - /// - [`WeightError::InvalidWeight`] when a weight is not-a-number, + /// - [`Error::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`. + /// - [`Error::InvalidWeight`] when a weight is not-a-number, /// negative or greater than `max = W::MAX / weights.len()`. - /// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero. - pub fn new(weights: Vec) -> Result { + /// - [`Error::InsufficientNonZero`] when the sum of all weights is zero. + pub fn new(weights: Vec) -> Result { let n = weights.len(); if n == 0 || n > u32::MAX as usize { - return Err(WeightError::InvalidInput); + return Err(Error::InvalidInput); } let n = n as u32; @@ -103,7 +103,7 @@ impl WeightedAliasIndex { .iter() .all(|&w| W::ZERO <= w && w <= max_weight_size) { - return Err(WeightError::InvalidWeight); + return Err(Error::InvalidWeight); } // The sum of weights will represent 100% of no alias odds. @@ -115,7 +115,7 @@ impl WeightedAliasIndex { weight_sum }; if weight_sum == W::ZERO { - return Err(WeightError::InsufficientNonZero); + return Err(Error::InsufficientNonZero); } // `weight_sum` would have been zero if `try_from_lossy` causes an error here. @@ -384,23 +384,23 @@ mod test { // Floating point special cases assert_eq!( WeightedAliasIndex::new(vec![f32::INFINITY]).unwrap_err(), - WeightError::InvalidWeight + Error::InvalidWeight ); assert_eq!( WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(), - WeightError::InsufficientNonZero + Error::InsufficientNonZero ); assert_eq!( WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(), - WeightError::InvalidWeight + Error::InvalidWeight ); assert_eq!( WeightedAliasIndex::new(vec![f32::NEG_INFINITY]).unwrap_err(), - WeightError::InvalidWeight + Error::InvalidWeight ); assert_eq!( WeightedAliasIndex::new(vec![f32::NAN]).unwrap_err(), - WeightError::InvalidWeight + Error::InvalidWeight ); } @@ -418,11 +418,11 @@ mod test { // Signed integer special cases assert_eq!( WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(), - WeightError::InvalidWeight + Error::InvalidWeight ); assert_eq!( WeightedAliasIndex::new(vec![i128::MIN]).unwrap_err(), - WeightError::InvalidWeight + Error::InvalidWeight ); } @@ -440,11 +440,11 @@ mod test { // Signed integer special cases assert_eq!( WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(), - WeightError::InvalidWeight + Error::InvalidWeight ); assert_eq!( WeightedAliasIndex::new(vec![i8::MIN]).unwrap_err(), - WeightError::InvalidWeight + Error::InvalidWeight ); } @@ -491,15 +491,15 @@ mod test { assert_eq!( WeightedAliasIndex::::new(vec![]).unwrap_err(), - WeightError::InvalidInput + Error::InvalidInput ); assert_eq!( WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(), - WeightError::InsufficientNonZero + Error::InsufficientNonZero ); assert_eq!( WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(), - WeightError::InvalidWeight + Error::InvalidWeight ); } diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted/weighted_tree.rs similarity index 87% rename from rand_distr/src/weighted_tree.rs rename to rand_distr/src/weighted/weighted_tree.rs index 355373a1b5..dd315aa5f8 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted/weighted_tree.rs @@ -11,11 +11,10 @@ use core::ops::SubAssign; -use super::WeightError; +use super::{Error, Weight}; use crate::Distribution; use alloc::vec::Vec; use rand::distr::uniform::{SampleBorrow, SampleUniform}; -use rand::distr::Weight; use rand::Rng; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -30,7 +29,7 @@ use serde::{Deserialize, Serialize}; /// /// # Key differences /// -/// The main distinction between [`WeightedTreeIndex`] and [`rand::distr::WeightedIndex`] +/// The main distinction between [`WeightedTreeIndex`] and [`WeightedIndex`] /// lies in the internal representation of weights. In [`WeightedTreeIndex`], /// weights are structured as a tree, which is optimized for frequent updates of the weights. /// @@ -58,7 +57,7 @@ use serde::{Deserialize, Serialize}; /// # Example /// /// ``` -/// use rand_distr::WeightedTreeIndex; +/// use rand_distr::weighted::WeightedTreeIndex; /// use rand::prelude::*; /// /// let choices = vec!['a', 'b', 'c']; @@ -77,6 +76,7 @@ use serde::{Deserialize, Serialize}; /// ``` /// /// [`WeightedTreeIndex`]: WeightedTreeIndex +/// [`WeightedIndex`]: super::WeightedIndex #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr( feature = "serde", @@ -99,9 +99,9 @@ impl + Weight> /// Creates a new [`WeightedTreeIndex`] from a slice of weights. /// /// Error cases: - /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. - /// - [`WeightError::Overflow`] when the sum of all weights overflows. - pub fn new(weights: I) -> Result + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::Overflow`] when the sum of all weights overflows. + pub fn new(weights: I) -> Result where I: IntoIterator, I::Item: SampleBorrow, @@ -109,7 +109,7 @@ impl + Weight> let mut subtotals: Vec = weights.into_iter().map(|x| x.borrow().clone()).collect(); for weight in subtotals.iter() { if !(*weight >= W::ZERO) { - return Err(WeightError::InvalidWeight); + return Err(Error::InvalidWeight); } } let n = subtotals.len(); @@ -118,7 +118,7 @@ impl + Weight> let parent = (i - 1) / 2; subtotals[parent] .checked_add_assign(&w) - .map_err(|()| WeightError::Overflow)?; + .map_err(|()| Error::Overflow)?; } Ok(Self { subtotals }) } @@ -169,16 +169,16 @@ impl + Weight> /// Appends a new weight at the end. /// /// Error cases: - /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. - /// - [`WeightError::Overflow`] when the sum of all weights overflows. - pub fn push(&mut self, weight: W) -> Result<(), WeightError> { + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::Overflow`] when the sum of all weights overflows. + pub fn push(&mut self, weight: W) -> Result<(), Error> { if !(weight >= W::ZERO) { - return Err(WeightError::InvalidWeight); + return Err(Error::InvalidWeight); } if let Some(total) = self.subtotals.first() { let mut total = total.clone(); if total.checked_add_assign(&weight).is_err() { - return Err(WeightError::Overflow); + return Err(Error::Overflow); } } let mut index = self.len(); @@ -193,11 +193,11 @@ impl + Weight> /// Updates the weight at an index. /// /// Error cases: - /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. - /// - [`WeightError::Overflow`] when the sum of all weights overflows. - pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightError> { + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::Overflow`] when the sum of all weights overflows. + pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), Error> { if !(weight >= W::ZERO) { - return Err(WeightError::InvalidWeight); + return Err(Error::InvalidWeight); } let old_weight = self.get(index); if weight > old_weight { @@ -206,7 +206,7 @@ impl + Weight> if let Some(total) = self.subtotals.first() { let mut total = total.clone(); if total.checked_add_assign(&difference).is_err() { - return Err(WeightError::Overflow); + return Err(Error::Overflow); } } self.subtotals[index] @@ -246,10 +246,10 @@ impl + Weight> /// /// Returns an error if there are no elements or all weights are zero. This /// is unlike [`Distribution::sample`], which panics in those cases. - pub fn try_sample(&self, rng: &mut R) -> Result { + pub fn try_sample(&self, rng: &mut R) -> Result { let total_weight = self.subtotals.first().cloned().unwrap_or(W::ZERO); if total_weight == W::ZERO { - return Err(WeightError::InsufficientNonZero); + return Err(Error::InsufficientNonZero); } let mut target_weight = rng.random_range(W::ZERO..total_weight); let mut index = 0; @@ -306,19 +306,16 @@ mod test { let tree = WeightedTreeIndex::::new(&[]).unwrap(); assert_eq!( tree.try_sample(&mut rng).unwrap_err(), - WeightError::InsufficientNonZero + Error::InsufficientNonZero ); } #[test] fn test_overflow_error() { - assert_eq!( - WeightedTreeIndex::new([i32::MAX, 2]), - Err(WeightError::Overflow) - ); + assert_eq!(WeightedTreeIndex::new([i32::MAX, 2]), Err(Error::Overflow)); let mut tree = WeightedTreeIndex::new([i32::MAX - 2, 1]).unwrap(); - assert_eq!(tree.push(3), Err(WeightError::Overflow)); - assert_eq!(tree.update(1, 4), Err(WeightError::Overflow)); + assert_eq!(tree.push(3), Err(Error::Overflow)); + assert_eq!(tree.update(1, 4), Err(Error::Overflow)); tree.update(1, 2).unwrap(); } @@ -328,7 +325,7 @@ mod test { let mut rng = crate::test::rng(0x9c9fa0b0580a7031); assert_eq!( tree.try_sample(&mut rng).unwrap_err(), - WeightError::InsufficientNonZero + Error::InsufficientNonZero ); } @@ -336,13 +333,13 @@ mod test { fn test_invalid_weight_error() { assert_eq!( WeightedTreeIndex::::new([1, -1]).unwrap_err(), - WeightError::InvalidWeight + Error::InvalidWeight ); #[allow(clippy::needless_borrows_for_generic_args)] let mut tree = WeightedTreeIndex::::new(&[]).unwrap(); - assert_eq!(tree.push(-1).unwrap_err(), WeightError::InvalidWeight); + assert_eq!(tree.push(-1).unwrap_err(), Error::InvalidWeight); tree.push(1).unwrap(); - assert_eq!(tree.update(0, -1).unwrap_err(), WeightError::InvalidWeight); + assert_eq!(tree.update(0, -1).unwrap_err(), Error::InvalidWeight); } #[test] diff --git a/src/distr/distribution.rs b/src/distr/distribution.rs index f9385ec617..6f4e202647 100644 --- a/src/distr/distribution.rs +++ b/src/distr/distribution.rs @@ -69,40 +69,37 @@ pub trait Distribution { /// println!("Not a 6; rolling again!"); /// } /// ``` - fn sample_iter(self, rng: R) -> DistIter + fn sample_iter(self, rng: R) -> Iter where R: Rng, Self: Sized, { - DistIter { + Iter { distr: self, rng, phantom: core::marker::PhantomData, } } - /// Create a distribution of values of 'S' by mapping the output of `Self` - /// through the closure `F` + /// Map sampled values to type `S` /// /// # Example /// /// ``` /// use rand::distr::{Distribution, Uniform}; /// - /// let mut rng = rand::rng(); - /// /// let die = Uniform::new_inclusive(1, 6).unwrap(); /// let even_number = die.map(|num| num % 2 == 0); - /// while !even_number.sample(&mut rng) { + /// while !even_number.sample(&mut rand::rng()) { /// println!("Still odd; rolling again!"); /// } /// ``` - fn map(self, func: F) -> DistMap + fn map(self, func: F) -> Map where F: Fn(T) -> S, Self: Sized, { - DistMap { + Map { distr: self, func, phantom: core::marker::PhantomData, @@ -116,21 +113,22 @@ impl + ?Sized> Distribution for &D { } } -/// An iterator that generates random values of `T` with distribution `D`, -/// using `R` as the source of randomness. +/// An iterator over a [`Distribution`] /// -/// This `struct` is created by the [`sample_iter`] method on [`Distribution`]. -/// See its documentation for more. +/// This iterator yields random values of type `T` with distribution `D` +/// from a random generator of type `R`. /// -/// [`sample_iter`]: Distribution::sample_iter +/// Construct this `struct` using [`Distribution::sample_iter`] or +/// [`Rng::sample_iter`]. It is also used by [`Rng::random_iter`] and +/// [`crate::random_iter`]. #[derive(Debug)] -pub struct DistIter { +pub struct Iter { distr: D, rng: R, phantom: core::marker::PhantomData, } -impl Iterator for DistIter +impl Iterator for Iter where D: Distribution, R: Rng, @@ -150,26 +148,25 @@ where } } -impl iter::FusedIterator for DistIter +impl iter::FusedIterator for Iter where D: Distribution, R: Rng, { } -/// A distribution of values of type `S` derived from the distribution `D` -/// by mapping its output of type `T` through the closure `F`. +/// A [`Distribution`] which maps sampled values to type `S` /// /// This `struct` is created by the [`Distribution::map`] method. /// See its documentation for more. #[derive(Debug)] -pub struct DistMap { +pub struct Map { distr: D, func: F, phantom: core::marker::PhantomData S>, } -impl Distribution for DistMap +impl Distribution for Map where D: Distribution, F: Fn(T) -> S, @@ -179,16 +176,23 @@ where } } -/// `String` sampler +/// Sample or extend a [`String`] /// -/// Sampling a `String` of random characters is not quite the same as collecting -/// a sequence of chars. This trait contains some helpers. +/// Helper methods to extend a [`String`] or sample a new [`String`]. #[cfg(feature = "alloc")] -pub trait DistString { +pub trait SampleString { /// Append `len` random chars to `string` + /// + /// Note: implementations may leave `string` with excess capacity. If this + /// is undesirable, consider calling [`String::shrink_to_fit`] after this + /// method. fn append_string(&self, rng: &mut R, string: &mut String, len: usize); - /// Generate a `String` of `len` random chars + /// Generate a [`String`] of `len` random chars + /// + /// Note: implementations may leave the string with excess capacity. If this + /// is undesirable, consider calling [`String::shrink_to_fit`] after this + /// method. #[inline] fn sample_string(&self, rng: &mut R, len: usize) -> String { let mut s = String::new(); @@ -246,7 +250,7 @@ mod tests { #[test] #[cfg(feature = "alloc")] fn test_dist_string() { - use crate::distr::{Alphanumeric, DistString, StandardUniform}; + use crate::distr::{Alphanumeric, SampleString, StandardUniform}; use core::str; let mut rng = crate::test::rng(213); diff --git a/src/distr/mod.rs b/src/distr/mod.rs index 84bf4925a2..10016119ba 100644 --- a/src/distr/mod.rs +++ b/src/distr/mod.rs @@ -69,8 +69,7 @@ //! Sampling a simple true/false outcome with a given probability has a name: //! the [`Bernoulli`] distribution (this is used by [`Rng::random_bool`]). //! -//! For weighted sampling from a sequence of discrete values, use the -//! [`WeightedIndex`] distribution. +//! For weighted sampling of discrete values see the [`weighted`] module. //! //! This crate no longer includes other non-uniform distributions; instead //! it is recommended that you use either [`rand_distr`] or [`statrs`]. @@ -89,28 +88,25 @@ mod distribution; mod float; mod integer; mod other; -mod slice; mod utils; -#[cfg(feature = "alloc")] -mod weighted_index; #[doc(hidden)] pub mod hidden_export { pub use super::float::IntoFloat; // used by rand_distr } +pub mod slice; pub mod uniform; +#[cfg(feature = "alloc")] +pub mod weighted; pub use self::bernoulli::{Bernoulli, BernoulliError}; #[cfg(feature = "alloc")] -pub use self::distribution::DistString; -pub use self::distribution::{DistIter, DistMap, Distribution}; +pub use self::distribution::SampleString; +pub use self::distribution::{Distribution, Iter, Map}; pub use self::float::{Open01, OpenClosed01}; pub use self::other::Alphanumeric; -pub use self::slice::Slice; #[doc(inline)] pub use self::uniform::Uniform; -#[cfg(feature = "alloc")] -pub use self::weighted_index::{Weight, WeightError, WeightedIndex}; #[allow(unused)] use crate::Rng; diff --git a/src/distr/other.rs b/src/distr/other.rs index 8e957f0744..9890bdafe6 100644 --- a/src/distr/other.rs +++ b/src/distr/other.rs @@ -14,7 +14,7 @@ use core::char; use core::num::Wrapping; #[cfg(feature = "alloc")] -use crate::distr::DistString; +use crate::distr::SampleString; use crate::distr::{Distribution, StandardUniform, Uniform}; use crate::Rng; @@ -42,10 +42,10 @@ use serde::{Deserialize, Serialize}; /// println!("Random chars: {}", chars); /// ``` /// -/// The [`DistString`] trait provides an easier method of generating -/// a random `String`, and offers more efficient allocation: +/// The [`SampleString`] trait provides an easier method of generating +/// a random [`String`], and offers more efficient allocation: /// ``` -/// use rand::distr::{Alphanumeric, DistString}; +/// use rand::distr::{Alphanumeric, SampleString}; /// let string = Alphanumeric.sample_string(&mut rand::rng(), 16); /// println!("Random string: {}", string); /// ``` @@ -93,10 +93,8 @@ impl Distribution for StandardUniform { } } -/// Note: the `String` is potentially left with excess capacity; optionally the -/// user may call `string.shrink_to_fit()` afterwards. #[cfg(feature = "alloc")] -impl DistString for StandardUniform { +impl SampleString for StandardUniform { fn append_string(&self, rng: &mut R, s: &mut String, len: usize) { // A char is encoded with at most four bytes, thus this reservation is // guaranteed to be sufficient. We do not shrink_to_fit afterwards so @@ -126,7 +124,7 @@ impl Distribution for Alphanumeric { } #[cfg(feature = "alloc")] -impl DistString for Alphanumeric { +impl SampleString for Alphanumeric { fn append_string(&self, rng: &mut R, string: &mut String, len: usize) { unsafe { let v = string.as_mut_vec(); diff --git a/src/distr/slice.rs b/src/distr/slice.rs index 3eee65a92c..07e243fec5 100644 --- a/src/distr/slice.rs +++ b/src/distr/slice.rs @@ -6,6 +6,8 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +//! Distributions over slices + use core::num::NonZeroUsize; use crate::distr::uniform::{UniformSampler, UniformUsize}; @@ -13,36 +15,26 @@ use crate::distr::Distribution; #[cfg(feature = "alloc")] use alloc::string::String; -/// A distribution to sample items uniformly from a slice. -/// -/// [`Slice::new`] constructs a distribution referencing a slice and uniformly -/// samples references from the items in the slice. It may do extra work up -/// front to make sampling of multiple values faster; if only one sample from -/// the slice is required, [`IndexedRandom::choose`] can be more efficient. -/// -/// Steps are taken to avoid bias which might be present in naive -/// implementations; for example `slice[rng.gen() % slice.len()]` samples from -/// the slice, but may be more likely to select numbers in the low range than -/// other values. +/// A distribution to uniformly sample elements of a slice /// -/// This distribution samples with replacement; each sample is independent. -/// Sampling without replacement requires state to be retained, and therefore -/// cannot be handled by a distribution; you should instead consider methods -/// on [`IndexedRandom`], such as [`IndexedRandom::choose_multiple`]. +/// Like [`IndexedRandom::choose`], this uniformly samples elements of a slice +/// without modification of the slice (so called "sampling with replacement"). +/// This distribution object may be a little faster for repeated sampling (but +/// slower for small numbers of samples). /// -/// # Example +/// ## Examples /// +/// Since this is a distribution, [`Rng::sample_iter`] and +/// [`Distribution::sample_iter`] may be used, for example: /// ``` -/// use rand::Rng; -/// use rand::distr::Slice; +/// use rand::distr::{Distribution, slice::Choose}; /// /// let vowels = ['a', 'e', 'i', 'o', 'u']; -/// let vowels_dist = Slice::new(&vowels).unwrap(); -/// let rng = rand::rng(); +/// let vowels_dist = Choose::new(&vowels).unwrap(); /// /// // build a string of 10 vowels -/// let vowel_string: String = rng -/// .sample_iter(&vowels_dist) +/// let vowel_string: String = vowels_dist +/// .sample_iter(&mut rand::rng()) /// .take(10) /// .collect(); /// @@ -51,33 +43,31 @@ use alloc::string::String; /// assert!(vowel_string.chars().all(|c| vowels.contains(&c))); /// ``` /// -/// For a single sample, [`IndexedRandom::choose`][crate::seq::IndexedRandom::choose] -/// may be preferred: -/// +/// For a single sample, [`IndexedRandom::choose`] may be preferred: /// ``` /// use rand::seq::IndexedRandom; /// /// let vowels = ['a', 'e', 'i', 'o', 'u']; /// let mut rng = rand::rng(); /// -/// println!("{}", vowels.choose(&mut rng).unwrap()) +/// println!("{}", vowels.choose(&mut rng).unwrap()); /// ``` /// -/// [`IndexedRandom`]: crate::seq::IndexedRandom /// [`IndexedRandom::choose`]: crate::seq::IndexedRandom::choose -/// [`IndexedRandom::choose_multiple`]: crate::seq::IndexedRandom::choose_multiple +/// [`Rng::sample_iter`]: crate::Rng::sample_iter #[derive(Debug, Clone, Copy)] -pub struct Slice<'a, T> { +pub struct Choose<'a, T> { slice: &'a [T], range: UniformUsize, num_choices: NonZeroUsize, } -impl<'a, T> Slice<'a, T> { - /// Create a new `Slice` instance which samples uniformly from the slice. - /// Returns `Err` if the slice is empty. - pub fn new(slice: &'a [T]) -> Result { - let num_choices = NonZeroUsize::new(slice.len()).ok_or(EmptySlice)?; +impl<'a, T> Choose<'a, T> { + /// Create a new `Choose` instance which samples uniformly from the slice. + /// + /// Returns error [`Empty`] if the slice is empty. + pub fn new(slice: &'a [T]) -> Result { + let num_choices = NonZeroUsize::new(slice.len()).ok_or(Empty)?; Ok(Self { slice, @@ -92,7 +82,7 @@ impl<'a, T> Slice<'a, T> { } } -impl<'a, T> Distribution<&'a T> for Slice<'a, T> { +impl<'a, T> Distribution<&'a T> for Choose<'a, T> { fn sample(&self, rng: &mut R) -> &'a T { let idx = self.range.sample(rng); @@ -110,24 +100,26 @@ impl<'a, T> Distribution<&'a T> for Slice<'a, T> { } } -/// Error type indicating that a [`Slice`] distribution was improperly -/// constructed with an empty slice. +/// Error: empty slice +/// +/// This error is returned when [`Choose::new`] is given an empty slice. #[derive(Debug, Clone, Copy)] -pub struct EmptySlice; +pub struct Empty; -impl core::fmt::Display for EmptySlice { +impl core::fmt::Display for Empty { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "Tried to create a `distr::Slice` with an empty slice") + write!( + f, + "Tried to create a `rand::distr::slice::Choose` with an empty slice" + ) } } #[cfg(feature = "std")] -impl std::error::Error for EmptySlice {} +impl std::error::Error for Empty {} -/// Note: the `String` is potentially left with excess capacity; optionally the -/// user may call `string.shrink_to_fit()` afterwards. #[cfg(feature = "alloc")] -impl super::DistString for Slice<'_, char> { +impl super::SampleString for Choose<'_, char> { fn append_string(&self, rng: &mut R, string: &mut String, len: usize) { // Get the max char length to minimize extra space. // Limit this check to avoid searching for long slice. @@ -168,7 +160,7 @@ mod test { #[test] fn value_stability() { let rng = crate::test::rng(651); - let slice = Slice::new(b"escaped emus explore extensively").unwrap(); + let slice = Choose::new(b"escaped emus explore extensively").unwrap(); let expected = b"eaxee"; assert!(iter::zip(slice.sample_iter(rng), expected).all(|(a, b)| a == b)); } diff --git a/src/distr/uniform_other.rs b/src/distr/uniform_other.rs index 42a7ff7813..03533debcd 100644 --- a/src/distr/uniform_other.rs +++ b/src/distr/uniform_other.rs @@ -90,11 +90,8 @@ impl UniformSampler for UniformChar { } } -/// Note: the `String` is potentially left with excess capacity if the range -/// includes non ascii chars; optionally the user may call -/// `string.shrink_to_fit()` afterwards. #[cfg(feature = "alloc")] -impl crate::distr::DistString for Uniform { +impl crate::distr::SampleString for Uniform { fn append_string( &self, rng: &mut R, @@ -281,7 +278,7 @@ mod tests { } #[cfg(feature = "alloc")] { - use crate::distr::DistString; + use crate::distr::SampleString; let string1 = d.sample_string(&mut rng, 100); assert_eq!(string1.capacity(), 300); let string2 = Uniform::new( diff --git a/src/distr/weighted/mod.rs b/src/distr/weighted/mod.rs new file mode 100644 index 0000000000..368c5b0703 --- /dev/null +++ b/src/distr/weighted/mod.rs @@ -0,0 +1,115 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Weighted (index) sampling +//! +//! Primarily, this module houses the [`WeightedIndex`] distribution. +//! See also [`rand_distr::weighted`] for alternative implementations supporting +//! potentially-faster sampling or a more easily modifiable tree structure. +//! +//! [`rand_distr::weighted`]: https://docs.rs/rand_distr/latest/rand_distr/weighted/index.html + +use core::fmt; +mod weighted_index; + +pub use weighted_index::WeightedIndex; + +/// Bounds on a weight +/// +/// See usage in [`WeightedIndex`]. +pub trait Weight: Clone { + /// Representation of 0 + const ZERO: Self; + + /// Checked addition + /// + /// - `Result::Ok`: On success, `v` is added to `self` + /// - `Result::Err`: Returns an error when `Self` cannot represent the + /// result of `self + v` (i.e. overflow). The value of `self` should be + /// discarded. + #[allow(clippy::result_unit_err)] + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>; +} + +macro_rules! impl_weight_int { + ($t:ty) => { + impl Weight for $t { + const ZERO: Self = 0; + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { + match self.checked_add(*v) { + Some(sum) => { + *self = sum; + Ok(()) + } + None => Err(()), + } + } + } + }; + ($t:ty, $($tt:ty),*) => { + impl_weight_int!($t); + impl_weight_int!($($tt),*); + } +} +impl_weight_int!(i8, i16, i32, i64, i128, isize); +impl_weight_int!(u8, u16, u32, u64, u128, usize); + +macro_rules! impl_weight_float { + ($t:ty) => { + impl Weight for $t { + const ZERO: Self = 0.0; + + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { + // Floats have an explicit representation for overflow + *self += *v; + Ok(()) + } + } + }; +} +impl_weight_float!(f32); +impl_weight_float!(f64); + +/// Invalid weight errors +/// +/// This type represents errors from [`WeightedIndex::new`], +/// [`WeightedIndex::update_weights`] and other weighted distributions. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +// Marked non_exhaustive to allow a new error code in the solution to #1476. +#[non_exhaustive] +pub enum Error { + /// The input weight sequence is empty, too long, or wrongly ordered + InvalidInput, + + /// A weight is negative, too large for the distribution, or not a valid number + InvalidWeight, + + /// Not enough non-zero weights are available to sample values + /// + /// When attempting to sample a single value this implies that all weights + /// are zero. When attempting to sample `amount` values this implies that + /// less than `amount` weights are greater than zero. + InsufficientNonZero, + + /// Overflow when calculating the sum of weights + Overflow, +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match *self { + Error::InvalidInput => "Weights sequence is empty/too long/unordered", + Error::InvalidWeight => "A weight is negative, too large or not a valid number", + Error::InsufficientNonZero => "Not enough weights > zero", + Error::Overflow => "Overflow when summing weights", + }) + } +} diff --git a/src/distr/weighted_index.rs b/src/distr/weighted/weighted_index.rs similarity index 77% rename from src/distr/weighted_index.rs rename to src/distr/weighted/weighted_index.rs index fef5728e41..4bb9d141fb 100644 --- a/src/distr/weighted_index.rs +++ b/src/distr/weighted/weighted_index.rs @@ -6,16 +6,14 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! Weighted index sampling - +use super::{Error, Weight}; use crate::distr::uniform::{SampleBorrow, SampleUniform, UniformSampler}; use crate::distr::Distribution; use crate::Rng; -use core::fmt; // Note that this whole module is only imported if feature="alloc" is enabled. use alloc::vec::Vec; -use core::fmt::Debug; +use core::fmt::{self, Debug}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -33,12 +31,9 @@ use serde::{Deserialize, Serialize}; /// # Performance /// /// Time complexity of sampling from `WeightedIndex` is `O(log N)` where -/// `N` is the number of weights. There are two alternative implementations with -/// different runtimes characteristics: -/// * [`rand_distr::weighted_alias`] supports `O(1)` sampling, but with much higher -/// initialisation cost. -/// * [`rand_distr::weighted_tree`] keeps the weights in a tree structure where sampling -/// and updating is `O(log N)`. +/// `N` is the number of weights. +/// See also [`rand_distr::weighted`] for alternative implementations supporting +/// potentially-faster sampling or a more easily modifiable tree structure. /// /// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its /// size is the sum of the size of those objects, possibly plus some alignment. @@ -59,7 +54,7 @@ use serde::{Deserialize, Serialize}; /// /// ``` /// use rand::prelude::*; -/// use rand::distr::WeightedIndex; +/// use rand::distr::weighted::WeightedIndex; /// /// let choices = ['a', 'b', 'c']; /// let weights = [2, 1, 1]; @@ -80,8 +75,7 @@ use serde::{Deserialize, Serialize}; /// /// [`Uniform`]: crate::distr::Uniform /// [`RngCore`]: crate::RngCore -/// [`rand_distr::weighted_alias`]: https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html -/// [`rand_distr::weighted_tree`]: https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html +/// [`rand_distr::weighted`]: https://docs.rs/rand_distr/latest/rand_distr/weighted/index.html #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct WeightedIndex { @@ -96,28 +90,24 @@ impl WeightedIndex { /// implementation of [`Uniform`] exists. /// /// Error cases: - /// - [`WeightError::InvalidInput`] when the iterator `weights` is empty. - /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. - /// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero. - /// - [`WeightError::Overflow`] when the sum of all weights overflows. + /// - [`Error::InvalidInput`] when the iterator `weights` is empty. + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::InsufficientNonZero`] when the sum of all weights is zero. + /// - [`Error::Overflow`] when the sum of all weights overflows. /// /// [`Uniform`]: crate::distr::uniform::Uniform - pub fn new(weights: I) -> Result, WeightError> + pub fn new(weights: I) -> Result, Error> where I: IntoIterator, I::Item: SampleBorrow, X: Weight, { let mut iter = weights.into_iter(); - let mut total_weight: X = iter - .next() - .ok_or(WeightError::InvalidInput)? - .borrow() - .clone(); + let mut total_weight: X = iter.next().ok_or(Error::InvalidInput)?.borrow().clone(); let zero = X::ZERO; if !(total_weight >= zero) { - return Err(WeightError::InvalidWeight); + return Err(Error::InvalidWeight); } let mut weights = Vec::::with_capacity(iter.size_hint().0); @@ -125,17 +115,17 @@ impl WeightedIndex { // Note that `!(w >= x)` is not equivalent to `w < x` for partially // ordered types due to NaNs which are equal to nothing. if !(w.borrow() >= &zero) { - return Err(WeightError::InvalidWeight); + return Err(Error::InvalidWeight); } weights.push(total_weight.clone()); if let Err(()) = total_weight.checked_add_assign(w.borrow()) { - return Err(WeightError::Overflow); + return Err(Error::Overflow); } } if total_weight == zero { - return Err(WeightError::InsufficientNonZero); + return Err(Error::InsufficientNonZero); } let distr = X::Sampler::new(zero, total_weight.clone()).unwrap(); @@ -155,10 +145,10 @@ impl WeightedIndex { /// allocation internally. /// /// In case of error, `self` is not modified. Error cases: - /// - [`WeightError::InvalidInput`] when `new_weights` are not ordered by + /// - [`Error::InvalidInput`] when `new_weights` are not ordered by /// index or an index is too large. - /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. - /// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero. + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::InsufficientNonZero`] when the sum of all weights is zero. /// Note that due to floating-point loss of precision, this case is not /// always correctly detected; usage of a fixed-point weight type may be /// preferred. @@ -166,7 +156,7 @@ impl WeightedIndex { /// Updates take `O(N)` time. If you need to frequently update weights, consider /// [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html) /// as an alternative where an update is `O(log N)`. - pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightError> + pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), Error> where X: for<'a> core::ops::AddAssign<&'a X> + for<'a> core::ops::SubAssign<&'a X> @@ -187,14 +177,14 @@ impl WeightedIndex { for &(i, w) in new_weights { if let Some(old_i) = prev_i { if old_i >= i { - return Err(WeightError::InvalidInput); + return Err(Error::InvalidInput); } } if !(*w >= zero) { - return Err(WeightError::InvalidWeight); + return Err(Error::InvalidWeight); } if i > self.cumulative_weights.len() { - return Err(WeightError::InvalidInput); + return Err(Error::InvalidInput); } let mut old_w = if i < self.cumulative_weights.len() { @@ -211,7 +201,7 @@ impl WeightedIndex { prev_i = Some(i); } if total_weight <= zero { - return Err(WeightError::InsufficientNonZero); + return Err(Error::InsufficientNonZero); } // Update the weights. Because we checked all the preconditions in the @@ -306,7 +296,7 @@ impl WeightedIndex { /// # Example /// /// ``` - /// use rand::distr::WeightedIndex; + /// use rand::distr::weighted::WeightedIndex; /// /// let weights = [0, 1, 2]; /// let dist = WeightedIndex::new(&weights).unwrap(); @@ -341,7 +331,7 @@ impl WeightedIndex { /// # Example /// /// ``` - /// use rand::distr::WeightedIndex; + /// use rand::distr::weighted::WeightedIndex; /// /// let weights = [1, 2, 3]; /// let mut dist = WeightedIndex::new(&weights).unwrap(); @@ -377,62 +367,6 @@ where } } -/// Bounds on a weight -/// -/// See usage in [`WeightedIndex`]. -pub trait Weight: Clone { - /// Representation of 0 - const ZERO: Self; - - /// Checked addition - /// - /// - `Result::Ok`: On success, `v` is added to `self` - /// - `Result::Err`: Returns an error when `Self` cannot represent the - /// result of `self + v` (i.e. overflow). The value of `self` should be - /// discarded. - #[allow(clippy::result_unit_err)] - fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>; -} - -macro_rules! impl_weight_int { - ($t:ty) => { - impl Weight for $t { - const ZERO: Self = 0; - fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { - match self.checked_add(*v) { - Some(sum) => { - *self = sum; - Ok(()) - } - None => Err(()), - } - } - } - }; - ($t:ty, $($tt:ty),*) => { - impl_weight_int!($t); - impl_weight_int!($($tt),*); - } -} -impl_weight_int!(i8, i16, i32, i64, i128, isize); -impl_weight_int!(u8, u16, u32, u64, u128, usize); - -macro_rules! impl_weight_float { - ($t:ty) => { - impl Weight for $t { - const ZERO: Self = 0.0; - - fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { - // Floats have an explicit representation for overflow - *self += *v; - Ok(()) - } - } - }; -} -impl_weight_float!(f32); -impl_weight_float!(f64); - #[cfg(test)] mod test { use super::*; @@ -457,15 +391,15 @@ mod test { fn test_accepting_nan() { assert_eq!( WeightedIndex::new([f32::NAN, 0.5]).unwrap_err(), - WeightError::InvalidWeight, + Error::InvalidWeight, ); assert_eq!( WeightedIndex::new([f32::NAN]).unwrap_err(), - WeightError::InvalidWeight, + Error::InvalidWeight, ); assert_eq!( WeightedIndex::new([0.5, f32::NAN]).unwrap_err(), - WeightError::InvalidWeight, + Error::InvalidWeight, ); assert_eq!( @@ -473,7 +407,7 @@ mod test { .unwrap() .update_weights(&[(0, &f32::NAN)]) .unwrap_err(), - WeightError::InvalidWeight, + Error::InvalidWeight, ) } @@ -533,24 +467,21 @@ mod test { assert_eq!( WeightedIndex::new(&[10][0..0]).unwrap_err(), - WeightError::InvalidInput + Error::InvalidInput ); assert_eq!( WeightedIndex::new([0]).unwrap_err(), - WeightError::InsufficientNonZero + Error::InsufficientNonZero ); assert_eq!( WeightedIndex::new([10, 20, -1, 30]).unwrap_err(), - WeightError::InvalidWeight + Error::InvalidWeight ); assert_eq!( WeightedIndex::new([-10, 20, 1, 30]).unwrap_err(), - WeightError::InvalidWeight - ); - assert_eq!( - WeightedIndex::new([-10]).unwrap_err(), - WeightError::InvalidWeight + Error::InvalidWeight ); + assert_eq!(WeightedIndex::new([-10]).unwrap_err(), Error::InvalidWeight); } #[test] @@ -588,22 +519,22 @@ mod test { ( &[1i32, 0, 0][..], &[(0, &0)][..], - WeightError::InsufficientNonZero, + Error::InsufficientNonZero, ), ( &[10, 10, 10, 10][..], &[(1, &-11)][..], - WeightError::InvalidWeight, // A weight is negative + Error::InvalidWeight, // A weight is negative ), ( &[1, 2, 3, 4, 5][..], &[(1, &5), (0, &5)][..], // Wrong order - WeightError::InvalidInput, + Error::InvalidInput, ), ( &[1][..], &[(1, &1)][..], // Index too large - WeightError::InvalidInput, + Error::InvalidInput, ), ]; @@ -695,45 +626,6 @@ mod test { #[test] fn overflow() { - assert_eq!( - WeightedIndex::new([2, usize::MAX]), - Err(WeightError::Overflow) - ); - } -} - -/// Errors returned by [`WeightedIndex::new`], [`WeightedIndex::update_weights`] and other weighted distributions -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -// Marked non_exhaustive to allow a new error code in the solution to #1476. -#[non_exhaustive] -pub enum WeightError { - /// The input weight sequence is empty, too long, or wrongly ordered - InvalidInput, - - /// A weight is negative, too large for the distribution, or not a valid number - InvalidWeight, - - /// Not enough non-zero weights are available to sample values - /// - /// When attempting to sample a single value this implies that all weights - /// are zero. When attempting to sample `amount` values this implies that - /// less than `amount` weights are greater than zero. - InsufficientNonZero, - - /// Overflow when calculating the sum of weights - Overflow, -} - -#[cfg(feature = "std")] -impl std::error::Error for WeightError {} - -impl fmt::Display for WeightError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(match *self { - WeightError::InvalidInput => "Weights sequence is empty/too long/unordered", - WeightError::InvalidWeight => "A weight is negative, too large or not a valid number", - WeightError::InsufficientNonZero => "Not enough weights > zero", - WeightError::Overflow => "Overflow when summing weights", - }) + assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(Error::Overflow)); } } diff --git a/src/lib.rs b/src/lib.rs index b5bb4fcb2f..54ae884025 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -181,7 +181,7 @@ where /// ``` #[cfg(feature = "thread_rng")] #[inline] -pub fn random_iter() -> distr::DistIter +pub fn random_iter() -> distr::Iter where StandardUniform: Distribution, { diff --git a/src/rng.rs b/src/rng.rs index 04b71f74b7..258c87de27 100644 --- a/src/rng.rs +++ b/src/rng.rs @@ -117,7 +117,7 @@ pub trait Rng: RngCore { /// assert_eq!(&v, &[1, 2, 3, 4, 5]); /// ``` #[inline] - fn random_iter(self) -> distr::DistIter + fn random_iter(self) -> distr::Iter where Self: Sized, StandardUniform: Distribution, @@ -283,7 +283,7 @@ pub trait Rng: RngCore { /// println!("Not a 6; rolling again!"); /// } /// ``` - fn sample_iter(self, distr: D) -> distr::DistIter + fn sample_iter(self, distr: D) -> distr::Iter where D: Distribution, Self: Sized, diff --git a/src/seq/mod.rs b/src/seq/mod.rs index 82601304da..91d634d865 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -19,7 +19,7 @@ //! //! Also see: //! -//! * [`crate::distr::WeightedIndex`] distribution which provides +//! * [`crate::distr::weighted::WeightedIndex`] distribution which provides //! weighted index sampling. //! //! In order to make results reproducible across 32-64 bit architectures, all @@ -37,7 +37,7 @@ mod index_; #[cfg(feature = "alloc")] #[doc(no_inline)] -pub use crate::distr::WeightError; +pub use crate::distr::weighted::Error as WeightError; pub use iterator::IteratorRandom; #[cfg(feature = "alloc")] pub use slice::SliceChooseIter; diff --git a/src/seq/slice.rs b/src/seq/slice.rs index 1fc10c0985..d48d9d2e9f 100644 --- a/src/seq/slice.rs +++ b/src/seq/slice.rs @@ -13,7 +13,7 @@ use super::index; #[cfg(feature = "alloc")] use crate::distr::uniform::{SampleBorrow, SampleUniform}; #[cfg(feature = "alloc")] -use crate::distr::{Weight, WeightError}; +use crate::distr::weighted::{Error as WeightError, Weight}; use crate::Rng; use core::ops::{Index, IndexMut}; @@ -136,7 +136,7 @@ pub trait IndexedRandom: Index { /// /// For slices of length `n`, complexity is `O(n)`. /// For more information about the underlying algorithm, - /// see [`distr::WeightedIndex`]. + /// see the [`WeightedIndex`] distribution. /// /// See also [`choose_weighted_mut`]. /// @@ -153,7 +153,7 @@ pub trait IndexedRandom: Index { /// ``` /// [`choose`]: IndexedRandom::choose /// [`choose_weighted_mut`]: IndexedMutRandom::choose_weighted_mut - /// [`distr::WeightedIndex`]: crate::distr::WeightedIndex + /// [`WeightedIndex`]: crate::distr::weighted::WeightedIndex #[cfg(feature = "alloc")] fn choose_weighted( &self, @@ -166,7 +166,7 @@ pub trait IndexedRandom: Index { B: SampleBorrow, X: SampleUniform + Weight + PartialOrd, { - use crate::distr::{Distribution, WeightedIndex}; + use crate::distr::{weighted::WeightedIndex, Distribution}; let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?; Ok(&self[distr.sample(rng)]) } @@ -273,13 +273,13 @@ pub trait IndexedMutRandom: IndexedRandom + IndexMut { /// /// For slices of length `n`, complexity is `O(n)`. /// For more information about the underlying algorithm, - /// see [`distr::WeightedIndex`]. + /// see the [`WeightedIndex`] distribution. /// /// See also [`choose_weighted`]. /// /// [`choose_mut`]: IndexedMutRandom::choose_mut /// [`choose_weighted`]: IndexedRandom::choose_weighted - /// [`distr::WeightedIndex`]: crate::distr::WeightedIndex + /// [`WeightedIndex`]: crate::distr::weighted::WeightedIndex #[cfg(feature = "alloc")] fn choose_weighted_mut( &mut self, @@ -292,7 +292,7 @@ pub trait IndexedMutRandom: IndexedRandom + IndexMut { B: SampleBorrow, X: SampleUniform + Weight + PartialOrd, { - use crate::distr::{Distribution, WeightedIndex}; + use crate::distr::{weighted::WeightedIndex, Distribution}; let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?; let index = distr.sample(rng); Ok(&mut self[index])