diff --git a/Cargo.toml b/Cargo.toml index d795e0c..a081a6b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ bitvec = "1.0.1" serde = { version = "1.0.152", features = ["derive"], optional = true} proptest = { version = "1.1.0", optional = true } rand = { version = "0.8.5", optional = true } +const-default = { version = "1.0.0", default-features = false, features = ["derive"]} [features] pyo3 = ["dep:pyo3"] diff --git a/src/hierarchy.rs b/src/hierarchy.rs index 095c71e..c52a604 100644 --- a/src/hierarchy.rs +++ b/src/hierarchy.rs @@ -49,6 +49,7 @@ //! hierarchy.shrink_to(graph.node_count()); //! ``` +use const_default::ConstDefault; use std::iter::FusedIterator; use std::mem::{replace, take}; use thiserror::Error; @@ -457,7 +458,7 @@ impl Hierarchy { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Default, ConstDefault)] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] struct NodeData { /// The first and last child of the node, if any. @@ -470,23 +471,6 @@ struct NodeData { siblings: [Option; 2], } -impl NodeData { - pub const fn new() -> Self { - Self { - children: None, - children_count: 0u32, - parent: None, - siblings: [None; 2], - } - } -} - -impl Default for NodeData { - fn default() -> Self { - Self::new() - } -} - /// Iterator created by [`Hierarchy::children`]. #[derive(Clone)] pub struct Children<'a> { diff --git a/src/lib.rs b/src/lib.rs index 74c2ae4..613c2ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -53,6 +53,7 @@ //! graph component structures. //! - `pyo3` enables Python bindings. //! +use const_default::ConstDefault; use std::num::{NonZeroU16, NonZeroU32}; use thiserror::Error; @@ -162,6 +163,17 @@ impl NodeIndex { pub fn index(self) -> usize { self.into() } + + /// Constant implementation of TryFrom + #[inline] + const fn try_from_usize(index: usize) -> Result { + if index > Self::MAX { + Err(IndexError { index }) + } else { + // SAFETY: The value cannot be zero + Ok(Self(unsafe { NonZeroU32::new_unchecked(1 + index as u32) })) + } + } } impl From for usize { @@ -176,12 +188,7 @@ impl TryFrom for NodeIndex { #[inline] fn try_from(index: usize) -> Result { - if index > Self::MAX { - Err(IndexError { index }) - } else { - // SAFETY: The value cannot be zero - Ok(Self(unsafe { NonZeroU32::new_unchecked(1 + index as u32) })) - } + Self::try_from_usize(index) } } @@ -192,6 +199,14 @@ impl std::fmt::Debug for NodeIndex { } } +impl ConstDefault for NodeIndex { + const DEFAULT: Self = match NodeIndex::try_from_usize(0) { + Ok(index) => index, + // Zero is always a valid index + Err(_) => unreachable!(), + }; +} + /// Index of a port within a `PortGraph`. /// /// Restricted to be at most `2^31 - 1` to allow more efficient encodings of the port graph structure. @@ -221,6 +236,17 @@ impl PortIndex { pub fn index(self) -> usize { self.into() } + + /// Constant implementation of TryFrom + #[inline] + const fn try_from_usize(index: usize) -> Result { + if index > Self::MAX { + Err(IndexError { index }) + } else { + // SAFETY: The value cannot be zero + Ok(Self(unsafe { NonZeroU32::new_unchecked(1 + index as u32) })) + } + } } impl From for usize { @@ -235,12 +261,7 @@ impl TryFrom for PortIndex { #[inline] fn try_from(index: usize) -> Result { - if index > Self::MAX { - Err(IndexError { index }) - } else { - // SAFETY: The value cannot be zero - Ok(Self(unsafe { NonZeroU32::new_unchecked(1 + index as u32) })) - } + Self::try_from_usize(index) } } @@ -257,6 +278,14 @@ impl Default for PortIndex { } } +impl ConstDefault for PortIndex { + const DEFAULT: Self = match PortIndex::try_from_usize(0) { + Ok(index) => index, + // Zero is always a valid index + Err(_) => unreachable!(), + }; +} + /// Error indicating a `NodeIndex`, `PortIndex`, or `Direction` is too large. #[derive(Debug, Clone, Error, PartialEq, Eq)] #[error("the index {index} is too large.")] @@ -281,7 +310,7 @@ pub enum PortOffset { impl PortOffset { /// Creates a new port offset. #[inline(always)] - pub fn new(direction: Direction, offset: usize) -> Self { + pub const fn new(direction: Direction, offset: usize) -> Self { match direction { Direction::Incoming => Self::new_incoming(offset), Direction::Outgoing => Self::new_outgoing(offset), @@ -290,7 +319,7 @@ impl PortOffset { /// Creates a new incoming port offset. #[inline(always)] - pub fn new_incoming(offset: usize) -> Self { + pub const fn new_incoming(offset: usize) -> Self { assert!(offset < u16::MAX as usize); // SAFETY: The value cannot be zero let offset = unsafe { NonZeroU16::new_unchecked(offset.saturating_add(1) as u16) }; @@ -299,14 +328,14 @@ impl PortOffset { /// Creates a new outgoing port offset. #[inline(always)] - pub fn new_outgoing(offset: usize) -> Self { + pub const fn new_outgoing(offset: usize) -> Self { assert!(offset <= u16::MAX as usize); PortOffset::Outgoing(offset as u16) } /// Returns the direction of the port. #[inline(always)] - pub fn direction(self) -> Direction { + pub const fn direction(self) -> Direction { match self { PortOffset::Incoming(_) => Direction::Incoming, PortOffset::Outgoing(_) => Direction::Outgoing, @@ -315,7 +344,7 @@ impl PortOffset { /// Returns the offset of the port. #[inline(always)] - pub fn index(self) -> usize { + pub const fn index(self) -> usize { match self { PortOffset::Incoming(offset) => (offset.get() - 1) as usize, PortOffset::Outgoing(offset) => offset as usize, @@ -337,3 +366,7 @@ impl std::fmt::Debug for PortOffset { } } } + +impl ConstDefault for PortOffset { + const DEFAULT: Self = PortOffset::new_outgoing(0); +} diff --git a/src/secondary.rs b/src/secondary.rs index 2f2b0de..98af365 100644 --- a/src/secondary.rs +++ b/src/secondary.rs @@ -1,12 +1,13 @@ //! Trait definition for secondary maps from keys to values with default elements. -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::{hash::Hash, iter::FusedIterator}; use bitvec::{ slice::{BitSlice, IterOnes}, vec::BitVec, }; +use const_default::ConstDefault; /// A map from keys to values with default elements. /// @@ -333,3 +334,101 @@ where self.iter.size_hint() } } + +impl SecondaryMap for HashMap +where + K: Hash + Eq + Clone, + V: Eq + Clone, + for<'a> &'a V: ConstDefault, +{ + type Iter<'a> = HashMapIter<'a, K, V> where Self: 'a, K: 'a; + + #[inline] + fn new() -> Self { + HashMap::new() + } + + #[inline] + fn with_capacity(capacity: usize) -> Self { + HashMap::with_capacity(capacity) + } + + #[inline] + fn default_value(&self) -> V { + <&V>::DEFAULT.clone() + } + + #[inline] + fn ensure_capacity(&mut self, capacity: usize) { + HashMap::reserve(self, capacity.saturating_sub(self.capacity())); + } + + #[inline] + fn resize(&mut self, _new_len: usize) {} + + #[inline] + fn capacity(&self) -> usize { + HashMap::capacity(self) + } + + #[inline] + fn get(&self, key: K) -> &V { + HashMap::get(self, &key).unwrap_or(<&V>::DEFAULT) + } + + #[inline] + fn set(&mut self, key: K, val: V) { + match &val == <&V>::DEFAULT { + true => HashMap::insert(self, key, val), + false => HashMap::remove(self, &key), + }; + } + + #[inline] + fn take(&mut self, key: K) -> V { + HashMap::remove(self, &key).unwrap_or(self.default_value()) + } + + #[inline] + fn iter<'a>(&'a self) -> Self::Iter<'a> + where + K: 'a, + { + HashMapIter { + iter: HashMap::iter(self), + } + } +} + +/// Iterator over non-default entries of a bit vector secondary map. +#[derive(Debug, Clone)] +pub struct HashMapIter<'a, K, V> { + iter: std::collections::hash_map::Iter<'a, K, V>, +} + +impl<'a, K, V> Iterator for HashMapIter<'a, K, V> +where + K: Clone, +{ + type Item = (K, &'a V); + + #[inline] + fn next(&mut self) -> Option { + self.iter.next().map(|(k, v)| (k.clone(), v)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + self.iter.nth(n).map(|(k, v)| (k.clone(), v)) + } + + #[inline] + fn count(self) -> usize { + self.iter.count() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +}