diff --git a/Cargo.toml b/Cargo.toml index 2c56066c..b9473728 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,6 @@ license = "Apache-2.0" homepage = "https://github.com/CDCgov/ixa" [dependencies] -fxhash = "^0.2.1" rand = "^0.8.5" csv = "^1.3.1" serde = { version = "^1.0.217", features = ["derive"] } @@ -32,6 +31,7 @@ reqwest = { version = "0.12.12", features = ["blocking", "json"] } uuid = "1.12.1" tower-http = { version = "0.6.2", features = ["full"] } mime = "0.3.17" +rustc-hash = "^2.1.1" [dev-dependencies] rand_distr = "^0.4.3" @@ -47,6 +47,7 @@ ixa_example_births_deaths = { path = "examples/births-deaths" } [lints.clippy] pedantic = { level = "warn", priority = -1 } module-name-repetitions = "allow" +implicit_hasher = "allow" [lib] # Prevent Cargo from implicitly linking `libtest` for Criterion.rs compatibility. diff --git a/examples/births-deaths/src/infection_manager.rs b/examples/births-deaths/src/infection_manager.rs index dffbec19..7ef4aacb 100644 --- a/examples/births-deaths/src/infection_manager.rs +++ b/examples/births-deaths/src/infection_manager.rs @@ -9,8 +9,8 @@ use ixa::global_properties::ContextGlobalPropertiesExt; use ixa::people::{ContextPeopleExt, PersonId, PersonPropertyChangeEvent}; use ixa::plan::PlanId; use ixa::random::ContextRandomExt; +use ixa::{HashMap, HashMapExt, HashSet, HashSetExt}; use rand_distr::Exp; -use std::collections::{HashMap, HashSet}; define_rng!(InfectionRng); define_data_plugin!( diff --git a/examples/births-deaths/src/parameters_loader.rs b/examples/births-deaths/src/parameters_loader.rs index 322fb5f9..bd002b02 100644 --- a/examples/births-deaths/src/parameters_loader.rs +++ b/examples/births-deaths/src/parameters_loader.rs @@ -2,8 +2,8 @@ use ixa::context::Context; use ixa::define_global_property; use ixa::error::IxaError; use ixa::global_properties::ContextGlobalPropertiesExt; +use ixa::HashMap; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use std::fmt::Debug; use std::path::Path; diff --git a/examples/load-people/population_loader.rs b/examples/load-people/population_loader.rs index eb3929c3..4a1eed87 100644 --- a/examples/load-people/population_loader.rs +++ b/examples/load-people/population_loader.rs @@ -74,10 +74,12 @@ mod tests { fn test_creation_event_access_properties() { let flag = Rc::new(RefCell::new(false)); - // Define expected computed values for each person + // Define expected computed values for each person. The value for dosage will change for + // any change in the deterministic RNG. let expected_computed = vec![ - (20, RiskCategoryValue::Low, VaccineTypeValue::B, 0.8, 1), - (80, RiskCategoryValue::High, VaccineTypeValue::A, 0.9, 2), + // (age, risk_category, vaccine_type, efficacy, doses) + (20, RiskCategoryValue::Low, VaccineTypeValue::B, 0.8, 3), + (80, RiskCategoryValue::High, VaccineTypeValue::A, 0.9, 1), ]; let mut context = Context::new(); @@ -117,6 +119,7 @@ mod tests { context.get_person_property(person, VaccineEfficacy), efficacy ); + // This assert will break for any change that affects the deterministic hasher. assert_eq!(context.get_person_property(person, VaccineDoses), doses); *counter.borrow_mut() += 1; diff --git a/examples/time-varying-infection/exposure_manager.rs b/examples/time-varying-infection/exposure_manager.rs index 86416952..30572984 100644 --- a/examples/time-varying-infection/exposure_manager.rs +++ b/examples/time-varying-infection/exposure_manager.rs @@ -149,6 +149,8 @@ mod test { let hazard_fcn = func!(move |t| foi_t(t, parameters.foi, parameters.foi_sin_shift)); let survival_fcn = func!(move |t| f64::exp(-integrate(&hazard_fcn, 0.0, t))); let theoretical_mean = integrate(&survival_fcn, 0.0, 10000.0); // large enough upper bound - assert!((mean - theoretical_mean).abs() < 0.1); + + // This can break with any change that affects the deterministic RNG. + assert!((mean - theoretical_mean).abs() < 0.2); } } diff --git a/src/context.rs b/src/context.rs index ed074fba..08252f8e 100644 --- a/src/context.rs +++ b/src/context.rs @@ -2,9 +2,10 @@ //! //! Defines a `Context` that is intended to provide the foundational mechanism //! for storing and manipulating the state of a given simulation. +use crate::{HashMap, HashMapExt}; use std::{ any::{Any, TypeId}, - collections::{HashMap, VecDeque}, + collections::VecDeque, rc::Rc, }; diff --git a/src/debugger.rs b/src/debugger.rs index 12de8bac..e513f75f 100644 --- a/src/debugger.rs +++ b/src/debugger.rs @@ -6,8 +6,8 @@ use crate::IxaError; use clap::{ArgMatches, Command, FromArgMatches, Parser, Subcommand}; use rustyline; +use crate::{HashMap, HashMapExt}; use log::trace; -use std::collections::HashMap; use std::io::Write; trait DebuggerCommand { diff --git a/src/global_properties.rs b/src/global_properties.rs index a3dcb6e1..eab4cbbe 100644 --- a/src/global_properties.rs +++ b/src/global_properties.rs @@ -17,11 +17,12 @@ //! Global properties can be read with [`Context::get_global_property_value()`] use crate::context::Context; use crate::error::IxaError; -use log::trace; +use crate::trace; +use crate::{HashMap, HashMapExt}; use serde::de::DeserializeOwned; use std::any::{Any, TypeId}; use std::cell::RefCell; -use std::collections::{hash_map::Entry, HashMap}; +use std::collections::hash_map::Entry; use std::fmt::Debug; use std::fs; use std::io::BufReader; diff --git a/src/hashing.rs b/src/hashing.rs new file mode 100644 index 00000000..589429c9 --- /dev/null +++ b/src/hashing.rs @@ -0,0 +1,50 @@ +//! This module provides a deterministic hasher and `HashMap` and `HashSet` variants that use +//! it. The hashing data structures in the standard library are not deterministic: +//! +//! > By default, HashMap uses a hashing algorithm selected to provide +//! > resistance against HashDoS attacks. The algorithm is randomly seeded, and a +//! > reasonable best-effort is made to generate this seed from a high quality, +//! > secure source of randomness provided by the host without blocking the program. +//! +//! The standard library `HashMap` has a `new` method, but `HashMap` does not have a `new` +//! method by default. Use `HashMap::default()` instead to create a new hashmap with the default +//! hasher. If you really need to keep the API the same across implementations, we provide the +//! `HashMapExt` trait extension. Similarly, for `HashSet` and `HashSetExt`.The traits need only be +//! in scope. +//! +//! The `hash_usize` free function is a convenience function used in `crate::random::get_rng`. + +pub use rustc_hash::FxHashMap as HashMap; +pub use rustc_hash::FxHashSet as HashSet; +use std::hash::Hasher; + +/// Provides API parity with `std::collections::HashMap`. +pub trait HashMapExt { + fn new() -> Self; +} + +impl HashMapExt for HashMap { + fn new() -> Self { + HashMap::default() + } +} + +// Note that trait aliases are not yet stabilized in rustc. +// See https://github.com/rust-lang/rust/issues/41517 +/// Provides API parity with `std::collections::HashSet`. +pub trait HashSetExt { + fn new() -> Self; +} + +impl HashSetExt for HashSet { + fn new() -> Self { + HashSet::default() + } +} + +/// A convenience method to compute the hash of a `&str`. +pub fn hash_str(data: &str) -> u64 { + let mut hasher = rustc_hash::FxHasher::default(); + hasher.write(data.as_bytes()); + hasher.finish() +} diff --git a/src/lib.rs b/src/lib.rs index a2e10650..d71bdfc8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,6 +65,7 @@ pub use log::{ }; pub mod external_api; +mod hashing; pub mod web_api; // Re-export for macros @@ -72,6 +73,9 @@ pub use ctor; pub use paste; pub use rand; +// Deterministic hashing data structures +pub use crate::hashing::{HashMap, HashMapExt, HashSet, HashSetExt}; + #[cfg(test)] mod tests { use assert_cmd::cargo::CargoError; diff --git a/src/log.rs b/src/log.rs index 8634db45..e91007a1 100644 --- a/src/log.rs +++ b/src/log.rs @@ -42,6 +42,7 @@ pub use log::{debug, error, info, trace, warn, LevelFilter}; +use crate::HashMap; use log4rs; use log4rs::append::console::ConsoleAppender; use log4rs::config::runtime::ConfigBuilder; @@ -49,7 +50,6 @@ use log4rs::config::{Appender, Logger, Root}; use log4rs::encode::pattern::PatternEncoder; use log4rs::{Config, Handle}; use std::collections::hash_map::Entry; -use std::collections::HashMap; use std::sync::LazyLock; use std::sync::{Mutex, MutexGuard}; diff --git a/src/network.rs b/src/network.rs index ccc88231..c7579ece 100644 --- a/src/network.rs +++ b/src/network.rs @@ -5,15 +5,13 @@ //! arbitrary number of outgoing edges of a given type, with each edge //! having a weight. Edge types can also specify their own per-type //! data which will be stored along with the edge. +use crate::HashMap; use crate::{ context::Context, define_data_plugin, error::IxaError, people::PersonId, random::ContextRandomExt, random::RngId, }; use rand::Rng; -use std::{ - any::{Any, TypeId}, - collections::HashMap, -}; +use std::any::{Any, TypeId}; #[derive(Copy, Clone, Debug, PartialEq)] /// An edge in network graph. Edges are directed, so the diff --git a/src/people/context_extension.rs b/src/people/context_extension.rs index d9aea92f..eb34d5a6 100644 --- a/src/people/context_extension.rs +++ b/src/people/context_extension.rs @@ -5,10 +5,10 @@ use crate::{ Context, ContextRandomExt, IxaError, PersonCreatedEvent, PersonId, PersonProperty, PersonPropertyChangeEvent, RngId, Tabulator, }; +use crate::{HashMap, HashMapExt, HashSet, HashSetExt}; use rand::Rng; use std::any::TypeId; use std::cell::Ref; -use std::collections::{HashMap, HashSet}; use crate::people::methods::Methods; diff --git a/src/people/data.rs b/src/people/data.rs index 76a1a5e6..2bc1e349 100644 --- a/src/people/data.rs +++ b/src/people/data.rs @@ -3,9 +3,9 @@ use crate::people::index::Index; use crate::people::methods::Methods; use crate::people::InitializationList; use crate::{Context, IxaError, PersonId, PersonProperty, PersonPropertyChangeEvent}; +use crate::{HashMap, HashSet, HashSetExt}; use std::any::{Any, TypeId}; use std::cell::{Ref, RefCell, RefMut}; -use std::collections::{HashMap, HashSet}; type ContextCallback = dyn FnOnce(&mut Context); diff --git a/src/people/external_api.rs b/src/people/external_api.rs index f45dcb3f..3d8192c2 100644 --- a/src/people/external_api.rs +++ b/src/people/external_api.rs @@ -1,11 +1,10 @@ -use std::any::TypeId; -use std::collections::HashMap; - use crate::people::ContextPeopleExt; use crate::people::PeoplePlugin; use crate::Context; use crate::IxaError; use crate::PersonId; +use crate::{HashMap, HashMapExt}; +use std::any::TypeId; pub(crate) trait ContextPeopleExtCrate { fn get_person_property_by_name( diff --git a/src/people/index.rs b/src/people/index.rs index 615506e5..6f96791d 100644 --- a/src/people/index.rs +++ b/src/people/index.rs @@ -1,8 +1,8 @@ use super::methods::Methods; use crate::{Context, ContextPeopleExt, PersonId, PersonProperty}; +use crate::{HashMap, HashSet, HashSetExt}; use bincode::serialize; use serde::Serialize; -use std::collections::{HashMap, HashSet}; #[derive(Clone, PartialEq, Eq, Hash, Debug)] // The lookup key for entries in the index. This is a serialized version of the value. diff --git a/src/people/mod.rs b/src/people/mod.rs index 45c4c7c1..bea6ab9c 100644 --- a/src/people/mod.rs +++ b/src/people/mod.rs @@ -88,15 +88,12 @@ pub use property::{ PersonProperty, }; +use crate::{HashMap, HashMapExt, HashSet, HashSetExt}; use seq_macro::seq; use serde::{Deserialize, Serialize}; use std::cell::RefCell; use std::fmt::{Debug, Display, Formatter}; -use std::{ - any::TypeId, - collections::{HashMap, HashSet}, - hash::Hash, -}; +use std::{any::TypeId, hash::Hash}; define_data_plugin!( PeoplePlugin, diff --git a/src/plan.rs b/src/plan.rs index 8c5f2fbf..fafe26a4 100644 --- a/src/plan.rs +++ b/src/plan.rs @@ -9,11 +9,9 @@ //! This queue is used by `Context` to store future events where some callback //! closure `FnOnce(&mut Context)` will be executed at a given point in time. -use log::trace; -use std::{ - cmp::Ordering, - collections::{BinaryHeap, HashMap}, -}; +use crate::trace; +use crate::{HashMap, HashMapExt}; +use std::{cmp::Ordering, collections::BinaryHeap}; /// A priority queue that stores arbitrary data sorted by time /// diff --git a/src/random.rs b/src/random.rs index 2f0c6c8c..4c5dd98c 100644 --- a/src/random.rs +++ b/src/random.rs @@ -1,4 +1,6 @@ use crate::context::Context; +use crate::hashing::hash_str; +use crate::{HashMap, HashMapExt}; use log::trace; use rand::distributions::uniform::{SampleRange, SampleUniform}; use rand::distributions::WeightedIndex; @@ -6,7 +8,6 @@ use rand::prelude::Distribution; use rand::{Rng, SeedableRng}; use std::any::{Any, TypeId}; use std::cell::{RefCell, RefMut}; -use std::collections::HashMap; /// Use this to define a unique type which will be used as a key to retrieve /// an independent rng instance when calling `.get_rng`. @@ -86,9 +87,11 @@ fn get_rng(context: &Context) -> RefMut { TypeId::of::() ); let base_seed = data_container.base_seed; - let seed_offset = fxhash::hash64(R::get_name()); + let seed_offset = hash_str(R::get_name()); RngHolder { - rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)), + rng: Box::new(R::RngType::seed_from_u64( + base_seed.wrapping_add(seed_offset), + )), } }) .rng @@ -279,7 +282,9 @@ mod test { fn sampler_function_closure_capture() { let mut context = Context::new(); context.init_random(42); - // Initialize weighted sampler + + // Initialize weighted sampler. Zero is selected with probability 1/3, one with a + // probability of 2/3. *context.get_data_container_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap(); let parameters = context.get_data_container(SamplerData).unwrap(); @@ -291,7 +296,8 @@ mod test { zero_counter += 1; } } - assert!((zero_counter - 1000_i32).abs() < 30); + // The expected value of `zero_counter` is 1000. + assert!((zero_counter - 1000_i32).abs() < 100); } #[test] @@ -299,7 +305,8 @@ mod test { let mut context = Context::new(); context.init_random(42); - // Initialize weighted sampler + // Initialize weighted sampler. Zero is selected with probability 1/3, one with a + // probability of 2/3. *context.get_data_container_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap(); let parameters = context.get_data_container(SamplerData).unwrap(); @@ -311,7 +318,8 @@ mod test { zero_counter += 1; } } - assert!((zero_counter - 1000_i32).abs() < 30); + // The expected value of `zero_counter` is 1000. + assert!((zero_counter - 1000_i32).abs() < 100); } #[test] diff --git a/src/report.rs b/src/report.rs index 41a2518d..5ac9bf55 100644 --- a/src/report.rs +++ b/src/report.rs @@ -3,11 +3,11 @@ use crate::error::IxaError; use crate::people::ContextPeopleExt; use crate::Tabulator; use crate::{error, trace}; +use crate::{HashMap, HashMapExt}; use csv::Writer; use serde::Serializer; use std::any::TypeId; use std::cell::{RefCell, RefMut}; -use std::collections::HashMap; use std::env; use std::fs::File; use std::path::PathBuf; @@ -273,7 +273,7 @@ impl ContextReportExt for Context { #[cfg(test)] mod test { - use crate::define_person_property_with_default; + use crate::{define_person_property_with_default, info}; use super::*; use core::convert::TryInto; @@ -525,6 +525,7 @@ mod test { let mut context2 = Context::new(); let config = context2.report_options(); config.file_prefix("prefix1_".to_string()).directory(path); + info!("The next 'file already exists' error is intended for a passing test."); let result = context2.add_report::("sample_report"); assert!(result.is_err()); let error = result.err().unwrap(); diff --git a/src/web_api.rs b/src/web_api.rs index 2c7d6fc0..cbd1bf56 100644 --- a/src/web_api.rs +++ b/src/web_api.rs @@ -4,13 +4,13 @@ use crate::error::IxaError; use crate::external_api::{ global_properties, next, people, population, run_ext_api, time, EmptyArgs, }; +use crate::{HashMap, HashMapExt}; use axum::extract::{Json, Path, State}; use axum::response::Redirect; use axum::routing::get; use axum::{http::StatusCode, routing::post, Router}; use rand::RngCore; use serde_json::json; -use std::collections::HashMap; use std::thread; use tokio::sync::mpsc; use tokio::sync::oneshot;