Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SecondaryMap for HashSet and HashMap #60

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ bitvec = "1.0.1"
serde = { version = "1.0.152", features = ["derive"], optional = true}
proptest = { version = "1.1.0", optional = true }
rand = { version = "0.8.5", optional = true }
const-default = { version = "1.0.0", default-features = false, features = ["derive"]}

[features]
pyo3 = ["dep:pyo3"]
Expand Down
20 changes: 2 additions & 18 deletions src/hierarchy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
//! hierarchy.shrink_to(graph.node_count());
//! ```

use const_default::ConstDefault;
use std::iter::FusedIterator;
use std::mem::{replace, take};
use thiserror::Error;
Expand Down Expand Up @@ -457,7 +458,7 @@ impl Hierarchy {
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, Default, ConstDefault)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
struct NodeData {
/// The first and last child of the node, if any.
Expand All @@ -470,23 +471,6 @@ struct NodeData {
siblings: [Option<NodeIndex>; 2],
}

impl NodeData {
pub const fn new() -> Self {
Self {
children: None,
children_count: 0u32,
parent: None,
siblings: [None; 2],
}
}
}

impl Default for NodeData {
fn default() -> Self {
Self::new()
}
}

/// Iterator created by [`Hierarchy::children`].
#[derive(Clone)]
pub struct Children<'a> {
Expand Down
67 changes: 50 additions & 17 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
//! graph component structures.
//! - `pyo3` enables Python bindings.
//!
use const_default::ConstDefault;
use std::num::{NonZeroU16, NonZeroU32};
use thiserror::Error;

Expand Down Expand Up @@ -162,6 +163,17 @@ impl NodeIndex {
pub fn index(self) -> usize {
self.into()
}

/// Constant implementation of TryFrom<usize>
#[inline]
const fn try_from_usize(index: usize) -> Result<Self, IndexError> {
if index > Self::MAX {
Err(IndexError { index })
} else {
// SAFETY: The value cannot be zero
Ok(Self(unsafe { NonZeroU32::new_unchecked(1 + index as u32) }))
}
}
}

impl From<NodeIndex> for usize {
Expand All @@ -176,12 +188,7 @@ impl TryFrom<usize> for NodeIndex {

#[inline]
fn try_from(index: usize) -> Result<Self, Self::Error> {
if index > Self::MAX {
Err(IndexError { index })
} else {
// SAFETY: The value cannot be zero
Ok(Self(unsafe { NonZeroU32::new_unchecked(1 + index as u32) }))
}
Self::try_from_usize(index)
}
}

Expand All @@ -192,6 +199,14 @@ impl std::fmt::Debug for NodeIndex {
}
}

impl ConstDefault for NodeIndex {
const DEFAULT: Self = match NodeIndex::try_from_usize(0) {
Ok(index) => index,
// Zero is always a valid index
Err(_) => unreachable!(),
};
}

/// Index of a port within a `PortGraph`.
///
/// Restricted to be at most `2^31 - 1` to allow more efficient encodings of the port graph structure.
Expand Down Expand Up @@ -221,6 +236,17 @@ impl PortIndex {
pub fn index(self) -> usize {
self.into()
}

/// Constant implementation of TryFrom<usize>
#[inline]
const fn try_from_usize(index: usize) -> Result<Self, IndexError> {
if index > Self::MAX {
Err(IndexError { index })
} else {
// SAFETY: The value cannot be zero
Ok(Self(unsafe { NonZeroU32::new_unchecked(1 + index as u32) }))
}
}
}

impl From<PortIndex> for usize {
Expand All @@ -235,12 +261,7 @@ impl TryFrom<usize> for PortIndex {

#[inline]
fn try_from(index: usize) -> Result<Self, Self::Error> {
if index > Self::MAX {
Err(IndexError { index })
} else {
// SAFETY: The value cannot be zero
Ok(Self(unsafe { NonZeroU32::new_unchecked(1 + index as u32) }))
}
Self::try_from_usize(index)
}
}

Expand All @@ -257,6 +278,14 @@ impl Default for PortIndex {
}
}

impl ConstDefault for PortIndex {
const DEFAULT: Self = match PortIndex::try_from_usize(0) {
Ok(index) => index,
// Zero is always a valid index
Err(_) => unreachable!(),
};
}

/// Error indicating a `NodeIndex`, `PortIndex`, or `Direction` is too large.
#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[error("the index {index} is too large.")]
Expand All @@ -281,7 +310,7 @@ pub enum PortOffset {
impl PortOffset {
/// Creates a new port offset.
#[inline(always)]
pub fn new(direction: Direction, offset: usize) -> Self {
pub const fn new(direction: Direction, offset: usize) -> Self {
match direction {
Direction::Incoming => Self::new_incoming(offset),
Direction::Outgoing => Self::new_outgoing(offset),
Expand All @@ -290,7 +319,7 @@ impl PortOffset {

/// Creates a new incoming port offset.
#[inline(always)]
pub fn new_incoming(offset: usize) -> Self {
pub const fn new_incoming(offset: usize) -> Self {
assert!(offset < u16::MAX as usize);
// SAFETY: The value cannot be zero
let offset = unsafe { NonZeroU16::new_unchecked(offset.saturating_add(1) as u16) };
Expand All @@ -299,14 +328,14 @@ impl PortOffset {

/// Creates a new outgoing port offset.
#[inline(always)]
pub fn new_outgoing(offset: usize) -> Self {
pub const fn new_outgoing(offset: usize) -> Self {
assert!(offset <= u16::MAX as usize);
PortOffset::Outgoing(offset as u16)
}

/// Returns the direction of the port.
#[inline(always)]
pub fn direction(self) -> Direction {
pub const fn direction(self) -> Direction {
match self {
PortOffset::Incoming(_) => Direction::Incoming,
PortOffset::Outgoing(_) => Direction::Outgoing,
Expand All @@ -315,7 +344,7 @@ impl PortOffset {

/// Returns the offset of the port.
#[inline(always)]
pub fn index(self) -> usize {
pub const fn index(self) -> usize {
match self {
PortOffset::Incoming(offset) => (offset.get() - 1) as usize,
PortOffset::Outgoing(offset) => offset as usize,
Expand All @@ -337,3 +366,7 @@ impl std::fmt::Debug for PortOffset {
}
}
}

impl ConstDefault for PortOffset {
const DEFAULT: Self = PortOffset::new_outgoing(0);
}
101 changes: 100 additions & 1 deletion src/secondary.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
//! Trait definition for secondary maps from keys to values with default elements.

use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::{hash::Hash, iter::FusedIterator};

use bitvec::{
slice::{BitSlice, IterOnes},
vec::BitVec,
};
use const_default::ConstDefault;

/// A map from keys to values with default elements.
///
Expand Down Expand Up @@ -333,3 +334,101 @@ where
self.iter.size_hint()
}
}

impl<K, V> SecondaryMap<K, V> for HashMap<K, V>
where
K: Hash + Eq + Clone,
V: Eq + Clone,
for<'a> &'a V: ConstDefault,
{
type Iter<'a> = HashMapIter<'a, K, V> where Self: 'a, K: 'a;

#[inline]
fn new() -> Self {
HashMap::new()
}

#[inline]
fn with_capacity(capacity: usize) -> Self {
HashMap::with_capacity(capacity)
}

#[inline]
fn default_value(&self) -> V {
<&V>::DEFAULT.clone()
}

#[inline]
fn ensure_capacity(&mut self, capacity: usize) {
HashMap::reserve(self, capacity.saturating_sub(self.capacity()));
}

#[inline]
fn resize(&mut self, _new_len: usize) {}

#[inline]
fn capacity(&self) -> usize {
HashMap::capacity(self)
}

#[inline]
fn get(&self, key: K) -> &V {
HashMap::get(self, &key).unwrap_or(<&V>::DEFAULT)
}

#[inline]
fn set(&mut self, key: K, val: V) {
match &val == <&V>::DEFAULT {
true => HashMap::insert(self, key, val),
false => HashMap::remove(self, &key),
};
}

#[inline]
fn take(&mut self, key: K) -> V {
HashMap::remove(self, &key).unwrap_or(self.default_value())
}

#[inline]
fn iter<'a>(&'a self) -> Self::Iter<'a>
where
K: 'a,
{
HashMapIter {
iter: HashMap::iter(self),
}
}
}

/// Iterator over non-default entries of a bit vector secondary map.
#[derive(Debug, Clone)]
pub struct HashMapIter<'a, K, V> {
iter: std::collections::hash_map::Iter<'a, K, V>,
}

impl<'a, K, V> Iterator for HashMapIter<'a, K, V>
where
K: Clone,
{
type Item = (K, &'a V);

#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|(k, v)| (k.clone(), v))
}

#[inline]
fn nth(&mut self, n: usize) -> Option<Self::Item> {
self.iter.nth(n).map(|(k, v)| (k.clone(), v))
}

#[inline]
fn count(self) -> usize {
self.iter.count()
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}