From d1a746289645f87ca78c9cba997d7003ca8cbfe4 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Mon, 5 Jun 2023 11:18:42 +0100 Subject: [PATCH 1/2] SecondaryMap impl for HashSet and HashMap --- src/hierarchy.rs | 4 +- src/lib.rs | 66 ++++++++++++++++++------- src/secondary.rs | 124 ++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 175 insertions(+), 19 deletions(-) diff --git a/src/hierarchy.rs b/src/hierarchy.rs index 095c71e..28e13bf 100644 --- a/src/hierarchy.rs +++ b/src/hierarchy.rs @@ -54,7 +54,7 @@ use std::mem::{replace, take}; use thiserror::Error; use crate::unmanaged::UnmanagedDenseMap; -use crate::NodeIndex; +use crate::{impl_static_default, NodeIndex}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -487,6 +487,8 @@ impl Default for NodeData { } } +impl_static_default!(NodeData, NodeData::new()); + /// Iterator created by [`Hierarchy::children`]. #[derive(Clone)] pub struct Children<'a> { diff --git a/src/lib.rs b/src/lib.rs index 74c2ae4..6f11404 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -162,6 +162,17 @@ impl NodeIndex { pub fn index(self) -> usize { self.into() } + + /// Constant implementation of TryFrom + #[inline] + const fn try_from_usize(index: usize) -> Result { + if index > Self::MAX { + Err(IndexError { index }) + } else { + // SAFETY: The value cannot be zero + Ok(Self(unsafe { NonZeroU32::new_unchecked(1 + index as u32) })) + } + } } impl From for usize { @@ -176,12 +187,7 @@ impl TryFrom for NodeIndex { #[inline] fn try_from(index: usize) -> Result { - if index > Self::MAX { - Err(IndexError { index }) - } else { - // SAFETY: The value cannot be zero - Ok(Self(unsafe { NonZeroU32::new_unchecked(1 + index as u32) })) - } + Self::try_from_usize(index) } } @@ -192,6 +198,15 @@ impl std::fmt::Debug for NodeIndex { } } +impl_static_default!( + NodeIndex, + match NodeIndex::try_from_usize(0) { + Ok(index) => index, + // Zero is always a valid index + Err(_) => unreachable!(), + } +); + /// Index of a port within a `PortGraph`. /// /// Restricted to be at most `2^31 - 1` to allow more efficient encodings of the port graph structure. @@ -221,6 +236,17 @@ impl PortIndex { pub fn index(self) -> usize { self.into() } + + /// Constant implementation of TryFrom + #[inline] + const fn try_from_usize(index: usize) -> Result { + if index > Self::MAX { + Err(IndexError { index }) + } else { + // SAFETY: The value cannot be zero + Ok(Self(unsafe { NonZeroU32::new_unchecked(1 + index as u32) })) + } + } } impl From for usize { @@ -235,12 +261,7 @@ impl TryFrom for PortIndex { #[inline] fn try_from(index: usize) -> Result { - if index > Self::MAX { - Err(IndexError { index }) - } else { - // SAFETY: The value cannot be zero - Ok(Self(unsafe { NonZeroU32::new_unchecked(1 + index as u32) })) - } + Self::try_from_usize(index) } } @@ -257,6 +278,15 @@ impl Default for PortIndex { } } +impl_static_default!( + PortIndex, + match PortIndex::try_from_usize(0) { + Ok(index) => index, + // Zero is always a valid index + Err(_) => unreachable!(), + } +); + /// Error indicating a `NodeIndex`, `PortIndex`, or `Direction` is too large. #[derive(Debug, Clone, Error, PartialEq, Eq)] #[error("the index {index} is too large.")] @@ -281,7 +311,7 @@ pub enum PortOffset { impl PortOffset { /// Creates a new port offset. #[inline(always)] - pub fn new(direction: Direction, offset: usize) -> Self { + pub const fn new(direction: Direction, offset: usize) -> Self { match direction { Direction::Incoming => Self::new_incoming(offset), Direction::Outgoing => Self::new_outgoing(offset), @@ -290,7 +320,7 @@ impl PortOffset { /// Creates a new incoming port offset. #[inline(always)] - pub fn new_incoming(offset: usize) -> Self { + pub const fn new_incoming(offset: usize) -> Self { assert!(offset < u16::MAX as usize); // SAFETY: The value cannot be zero let offset = unsafe { NonZeroU16::new_unchecked(offset.saturating_add(1) as u16) }; @@ -299,14 +329,14 @@ impl PortOffset { /// Creates a new outgoing port offset. #[inline(always)] - pub fn new_outgoing(offset: usize) -> Self { + pub const fn new_outgoing(offset: usize) -> Self { assert!(offset <= u16::MAX as usize); PortOffset::Outgoing(offset as u16) } /// Returns the direction of the port. #[inline(always)] - pub fn direction(self) -> Direction { + pub const fn direction(self) -> Direction { match self { PortOffset::Incoming(_) => Direction::Incoming, PortOffset::Outgoing(_) => Direction::Outgoing, @@ -315,7 +345,7 @@ impl PortOffset { /// Returns the offset of the port. #[inline(always)] - pub fn index(self) -> usize { + pub const fn index(self) -> usize { match self { PortOffset::Incoming(offset) => (offset.get() - 1) as usize, PortOffset::Outgoing(offset) => offset as usize, @@ -337,3 +367,5 @@ impl std::fmt::Debug for PortOffset { } } } + +impl_static_default!(PortOffset, PortOffset::new_outgoing(0)); diff --git a/src/secondary.rs b/src/secondary.rs index 2f2b0de..1ceb0ed 100644 --- a/src/secondary.rs +++ b/src/secondary.rs @@ -1,6 +1,6 @@ //! Trait definition for secondary maps from keys to values with default elements. -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::{hash::Hash, iter::FusedIterator}; use bitvec::{ @@ -333,3 +333,125 @@ where self.iter.size_hint() } } + +impl SecondaryMap for HashMap +where + K: Hash + Eq + Clone, + V: StaticDefault + Eq + Clone, +{ + type Iter<'a> = HashMapIter<'a, K, V> where Self: 'a, K: 'a; + + #[inline] + fn new() -> Self { + HashMap::new() + } + + #[inline] + fn with_capacity(capacity: usize) -> Self { + HashMap::with_capacity(capacity) + } + + #[inline] + fn default_value(&self) -> V { + V::default_ref().clone() + } + + #[inline] + fn ensure_capacity(&mut self, capacity: usize) { + HashMap::reserve(self, capacity.saturating_sub(self.capacity())); + } + + #[inline] + fn resize(&mut self, _new_len: usize) {} + + #[inline] + fn capacity(&self) -> usize { + HashMap::capacity(self) + } + + #[inline] + fn get(&self, key: K) -> &V { + HashMap::get(self, &key).unwrap_or(V::default_ref()) + } + + #[inline] + fn set(&mut self, key: K, val: V) { + match &val == V::default_ref() { + true => HashMap::insert(self, key, val), + false => HashMap::remove(self, &key), + }; + } + + #[inline] + fn take(&mut self, key: K) -> V { + HashMap::remove(self, &key).unwrap_or(self.default_value()) + } + + #[inline] + fn iter<'a>(&'a self) -> Self::Iter<'a> + where + K: 'a, + { + HashMapIter { + iter: HashMap::iter(self), + } + } +} + +/// Iterator over non-default entries of a bit vector secondary map. +#[derive(Debug, Clone)] +pub struct HashMapIter<'a, K, V> { + iter: std::collections::hash_map::Iter<'a, K, V>, +} + +impl<'a, K, V> Iterator for HashMapIter<'a, K, V> +where + K: Clone, +{ + type Item = (K, &'a V); + + #[inline] + fn next(&mut self) -> Option { + self.iter.next().map(|(k, v)| (k.clone(), v)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + self.iter.nth(n).map(|(k, v)| (k.clone(), v)) + } + + #[inline] + fn count(self) -> usize { + self.iter.count() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +/// A trait for secondary map values that can provide a static reference to a default value. +pub trait StaticDefault: 'static { + /// Returns a static reference to the default value + fn default_ref<'a>() -> &'a Self; +} + +/// Implements the `StaticDefault` trait for a type, using a const element. +#[allow(unused_macros)] +#[macro_export] +macro_rules! impl_static_default { + ($name:ident, $default:expr) => { + impl $crate::secondary::StaticDefault for $name { + fn default_ref<'a>() -> &'a Self { + static DEFAULT: $name = $default; + &DEFAULT + } + } + }; +} +#[allow(unused_imports)] +pub use impl_static_default; + +impl_static_default!(bool, false); +impl_static_default!(usize, 0); From 94dc508f83d50a98b6ae5af97d978c872567eb87 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Tue, 6 Jun 2023 12:46:04 +0100 Subject: [PATCH 2/2] Replace bespoke trait with ConstDefault --- Cargo.toml | 1 + src/hierarchy.rs | 24 +++--------------------- src/lib.rs | 23 ++++++++++++----------- src/secondary.rs | 35 ++++++----------------------------- 4 files changed, 22 insertions(+), 61 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d795e0c..a081a6b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ bitvec = "1.0.1" serde = { version = "1.0.152", features = ["derive"], optional = true} proptest = { version = "1.1.0", optional = true } rand = { version = "0.8.5", optional = true } +const-default = { version = "1.0.0", default-features = false, features = ["derive"]} [features] pyo3 = ["dep:pyo3"] diff --git a/src/hierarchy.rs b/src/hierarchy.rs index 28e13bf..c52a604 100644 --- a/src/hierarchy.rs +++ b/src/hierarchy.rs @@ -49,12 +49,13 @@ //! hierarchy.shrink_to(graph.node_count()); //! ``` +use const_default::ConstDefault; use std::iter::FusedIterator; use std::mem::{replace, take}; use thiserror::Error; use crate::unmanaged::UnmanagedDenseMap; -use crate::{impl_static_default, NodeIndex}; +use crate::NodeIndex; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -457,7 +458,7 @@ impl Hierarchy { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Default, ConstDefault)] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] struct NodeData { /// The first and last child of the node, if any. @@ -470,25 +471,6 @@ struct NodeData { siblings: [Option; 2], } -impl NodeData { - pub const fn new() -> Self { - Self { - children: None, - children_count: 0u32, - parent: None, - siblings: [None; 2], - } - } -} - -impl Default for NodeData { - fn default() -> Self { - Self::new() - } -} - -impl_static_default!(NodeData, NodeData::new()); - /// Iterator created by [`Hierarchy::children`]. #[derive(Clone)] pub struct Children<'a> { diff --git a/src/lib.rs b/src/lib.rs index 6f11404..613c2ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -53,6 +53,7 @@ //! graph component structures. //! - `pyo3` enables Python bindings. //! +use const_default::ConstDefault; use std::num::{NonZeroU16, NonZeroU32}; use thiserror::Error; @@ -198,14 +199,13 @@ impl std::fmt::Debug for NodeIndex { } } -impl_static_default!( - NodeIndex, - match NodeIndex::try_from_usize(0) { +impl ConstDefault for NodeIndex { + const DEFAULT: Self = match NodeIndex::try_from_usize(0) { Ok(index) => index, // Zero is always a valid index Err(_) => unreachable!(), - } -); + }; +} /// Index of a port within a `PortGraph`. /// @@ -278,14 +278,13 @@ impl Default for PortIndex { } } -impl_static_default!( - PortIndex, - match PortIndex::try_from_usize(0) { +impl ConstDefault for PortIndex { + const DEFAULT: Self = match PortIndex::try_from_usize(0) { Ok(index) => index, // Zero is always a valid index Err(_) => unreachable!(), - } -); + }; +} /// Error indicating a `NodeIndex`, `PortIndex`, or `Direction` is too large. #[derive(Debug, Clone, Error, PartialEq, Eq)] @@ -368,4 +367,6 @@ impl std::fmt::Debug for PortOffset { } } -impl_static_default!(PortOffset, PortOffset::new_outgoing(0)); +impl ConstDefault for PortOffset { + const DEFAULT: Self = PortOffset::new_outgoing(0); +} diff --git a/src/secondary.rs b/src/secondary.rs index 1ceb0ed..98af365 100644 --- a/src/secondary.rs +++ b/src/secondary.rs @@ -7,6 +7,7 @@ use bitvec::{ slice::{BitSlice, IterOnes}, vec::BitVec, }; +use const_default::ConstDefault; /// A map from keys to values with default elements. /// @@ -337,7 +338,8 @@ where impl SecondaryMap for HashMap where K: Hash + Eq + Clone, - V: StaticDefault + Eq + Clone, + V: Eq + Clone, + for<'a> &'a V: ConstDefault, { type Iter<'a> = HashMapIter<'a, K, V> where Self: 'a, K: 'a; @@ -353,7 +355,7 @@ where #[inline] fn default_value(&self) -> V { - V::default_ref().clone() + <&V>::DEFAULT.clone() } #[inline] @@ -371,12 +373,12 @@ where #[inline] fn get(&self, key: K) -> &V { - HashMap::get(self, &key).unwrap_or(V::default_ref()) + HashMap::get(self, &key).unwrap_or(<&V>::DEFAULT) } #[inline] fn set(&mut self, key: K, val: V) { - match &val == V::default_ref() { + match &val == <&V>::DEFAULT { true => HashMap::insert(self, key, val), false => HashMap::remove(self, &key), }; @@ -430,28 +432,3 @@ where self.iter.size_hint() } } - -/// A trait for secondary map values that can provide a static reference to a default value. -pub trait StaticDefault: 'static { - /// Returns a static reference to the default value - fn default_ref<'a>() -> &'a Self; -} - -/// Implements the `StaticDefault` trait for a type, using a const element. -#[allow(unused_macros)] -#[macro_export] -macro_rules! impl_static_default { - ($name:ident, $default:expr) => { - impl $crate::secondary::StaticDefault for $name { - fn default_ref<'a>() -> &'a Self { - static DEFAULT: $name = $default; - &DEFAULT - } - } - }; -} -#[allow(unused_imports)] -pub use impl_static_default; - -impl_static_default!(bool, false); -impl_static_default!(usize, 0);