Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove drain-on-drop behavior from DrainFilter #374

Merged
merged 2 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 30 additions & 57 deletions src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -972,15 +972,12 @@ impl<K, V, S, A: Allocator + Clone> HashMap<K, V, S, A> {
/// In other words, move all pairs `(k, v)` such that `f(&k, &mut v)` returns `true` out
/// into another iterator.
///
/// Note that `drain_filter` lets you mutate every value in the filter closure, regardless of
/// Note that `extract_if` lets you mutate every value in the filter closure, regardless of
/// whether you choose to keep or remove it.
///
/// When the returned DrainedFilter is dropped, any remaining elements that satisfy
/// the predicate are dropped from the table.
///
/// It is unspecified how many more elements will be subjected to the closure
/// if a panic occurs in the closure, or a panic occurs while dropping an element,
/// or if the `DrainFilter` value is leaked.
/// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating
/// or the iteration short-circuits, then the remaining elements will be retained.
/// Use [`retain()`] with a negated predicate if you do not need the returned iterator.
///
/// Keeps the allocated memory for reuse.
///
Expand All @@ -991,7 +988,7 @@ impl<K, V, S, A: Allocator + Clone> HashMap<K, V, S, A> {
///
/// let mut map: HashMap<i32, i32> = (0..8).map(|x| (x, x)).collect();
///
/// let drained: HashMap<i32, i32> = map.drain_filter(|k, _v| k % 2 == 0).collect();
/// let drained: HashMap<i32, i32> = map.extract_if(|k, _v| k % 2 == 0).collect();
///
/// let mut evens = drained.keys().cloned().collect::<Vec<_>>();
/// let mut odds = map.keys().cloned().collect::<Vec<_>>();
Expand All @@ -1004,21 +1001,20 @@ impl<K, V, S, A: Allocator + Clone> HashMap<K, V, S, A> {
/// let mut map: HashMap<i32, i32> = (0..8).map(|x| (x, x)).collect();
///
/// { // Iterator is dropped without being consumed.
/// let d = map.drain_filter(|k, _v| k % 2 != 0);
/// let d = map.extract_if(|k, _v| k % 2 != 0);
/// }
///
/// // But the map lens have been reduced by half
/// // even if we do not use DrainFilter iterator.
/// assert_eq!(map.len(), 4);
/// // ExtractIf was not exhausted, therefore no elements were drained.
/// assert_eq!(map.len(), 8);
/// ```
#[cfg_attr(feature = "inline-more", inline)]
pub fn drain_filter<F>(&mut self, f: F) -> DrainFilter<'_, K, V, F, A>
pub fn extract_if<F>(&mut self, f: F) -> ExtractIf<'_, K, V, F, A>
where
F: FnMut(&K, &mut V) -> bool,
{
DrainFilter {
ExtractIf {
f,
inner: DrainFilterInner {
inner: ExtractIfInner {
iter: unsafe { self.table.iter() },
table: &mut self.table,
},
Expand Down Expand Up @@ -2732,10 +2728,10 @@ impl<K, V, A: Allocator + Clone> Drain<'_, K, V, A> {
/// A draining iterator over entries of a `HashMap` which don't satisfy the predicate
/// `f(&k, &mut v)` in arbitrary order. The iterator element type is `(K, V)`.
///
/// This `struct` is created by the [`drain_filter`] method on [`HashMap`]. See its
/// This `struct` is created by the [`extract_if`] method on [`HashMap`]. See its
/// documentation for more.
///
/// [`drain_filter`]: struct.HashMap.html#method.drain_filter
/// [`extract_if`]: struct.HashMap.html#method.extract_if
/// [`HashMap`]: struct.HashMap.html
///
/// # Examples
Expand All @@ -2745,54 +2741,31 @@ impl<K, V, A: Allocator + Clone> Drain<'_, K, V, A> {
///
/// let mut map: HashMap<i32, &str> = [(1, "a"), (2, "b"), (3, "c")].into();
///
/// let mut drain_filter = map.drain_filter(|k, _v| k % 2 != 0);
/// let mut vec = vec![drain_filter.next(), drain_filter.next()];
/// let mut extract_if = map.extract_if(|k, _v| k % 2 != 0);
/// let mut vec = vec![extract_if.next(), extract_if.next()];
///
/// // The `DrainFilter` iterator produces items in arbitrary order, so the
/// // The `ExtractIf` iterator produces items in arbitrary order, so the
/// // items must be sorted to test them against a sorted array.
/// vec.sort_unstable();
/// assert_eq!(vec, [Some((1, "a")),Some((3, "c"))]);
///
/// // It is fused iterator
/// assert_eq!(drain_filter.next(), None);
/// assert_eq!(drain_filter.next(), None);
/// drop(drain_filter);
/// assert_eq!(extract_if.next(), None);
/// assert_eq!(extract_if.next(), None);
/// drop(extract_if);
///
/// assert_eq!(map.len(), 1);
/// ```
pub struct DrainFilter<'a, K, V, F, A: Allocator + Clone = Global>
#[must_use = "Iterators are lazy unless consumed"]
pub struct ExtractIf<'a, K, V, F, A: Allocator + Clone = Global>
where
F: FnMut(&K, &mut V) -> bool,
{
f: F,
inner: DrainFilterInner<'a, K, V, A>,
}

impl<'a, K, V, F, A> Drop for DrainFilter<'a, K, V, F, A>
where
F: FnMut(&K, &mut V) -> bool,
A: Allocator + Clone,
{
#[cfg_attr(feature = "inline-more", inline)]
fn drop(&mut self) {
while let Some(item) = self.next() {
let guard = ConsumeAllOnDrop(self);
drop(item);
mem::forget(guard);
}
}
}

pub(super) struct ConsumeAllOnDrop<'a, T: Iterator>(pub &'a mut T);

impl<T: Iterator> Drop for ConsumeAllOnDrop<'_, T> {
#[cfg_attr(feature = "inline-more", inline)]
fn drop(&mut self) {
self.0.for_each(drop);
}
inner: ExtractIfInner<'a, K, V, A>,
}

impl<K, V, F, A> Iterator for DrainFilter<'_, K, V, F, A>
impl<K, V, F, A> Iterator for ExtractIf<'_, K, V, F, A>
where
F: FnMut(&K, &mut V) -> bool,
A: Allocator + Clone,
Expand All @@ -2810,15 +2783,15 @@ where
}
}

impl<K, V, F> FusedIterator for DrainFilter<'_, K, V, F> where F: FnMut(&K, &mut V) -> bool {}
impl<K, V, F> FusedIterator for ExtractIf<'_, K, V, F> where F: FnMut(&K, &mut V) -> bool {}

/// Portions of `DrainFilter` shared with `set::DrainFilter`
pub(super) struct DrainFilterInner<'a, K, V, A: Allocator + Clone> {
/// Portions of `ExtractIf` shared with `set::ExtractIf`
pub(super) struct ExtractIfInner<'a, K, V, A: Allocator + Clone> {
pub iter: RawIter<(K, V)>,
pub table: &'a mut RawTable<(K, V), A>,
}

impl<K, V, A: Allocator + Clone> DrainFilterInner<'_, K, V, A> {
impl<K, V, A: Allocator + Clone> ExtractIfInner<'_, K, V, A> {
#[cfg_attr(feature = "inline-more", inline)]
pub(super) fn next<F>(&mut self, f: &mut F) -> Option<(K, V)>
where
Expand Down Expand Up @@ -8169,18 +8142,18 @@ mod test_map {
}

#[test]
fn test_drain_filter() {
fn test_extract_if() {
{
let mut map: HashMap<i32, i32> = (0..8).map(|x| (x, x * 10)).collect();
let drained = map.drain_filter(|&k, _| k % 2 == 0);
let drained = map.extract_if(|&k, _| k % 2 == 0);
let mut out = drained.collect::<Vec<_>>();
out.sort_unstable();
assert_eq!(vec![(0, 0), (2, 20), (4, 40), (6, 60)], out);
assert_eq!(map.len(), 4);
}
{
let mut map: HashMap<i32, i32> = (0..8).map(|x| (x, x * 10)).collect();
drop(map.drain_filter(|&k, _| k % 2 == 0));
map.extract_if(|&k, _| k % 2 == 0).for_each(drop);
assert_eq!(map.len(), 4);
}
}
Expand Down
52 changes: 18 additions & 34 deletions src/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ use alloc::borrow::ToOwned;
use core::fmt;
use core::hash::{BuildHasher, Hash};
use core::iter::{Chain, FromIterator, FusedIterator};
use core::mem;
use core::ops::{BitAnd, BitOr, BitXor, Sub};

use super::map::{self, ConsumeAllOnDrop, DefaultHashBuilder, DrainFilterInner, HashMap, Keys};
use super::map::{self, DefaultHashBuilder, ExtractIfInner, HashMap, Keys};
use crate::raw::{Allocator, Global};

// Future Optimization (FIXME!)
Expand Down Expand Up @@ -380,16 +379,17 @@ impl<T, S, A: Allocator + Clone> HashSet<T, S, A> {
/// In other words, move all elements `e` such that `f(&e)` returns `true` out
/// into another iterator.
///
/// When the returned DrainedFilter is dropped, any remaining elements that satisfy
/// the predicate are dropped from the set.
/// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating
/// or the iteration short-circuits, then the remaining elements will be retained.
/// Use [`retain()`] with a negated predicate if you do not need the returned iterator.
///
/// # Examples
///
/// ```
/// use hashbrown::HashSet;
///
/// let mut set: HashSet<i32> = (0..8).collect();
/// let drained: HashSet<i32> = set.drain_filter(|v| v % 2 == 0).collect();
/// let drained: HashSet<i32> = set.extract_if(|v| v % 2 == 0).collect();
///
/// let mut evens = drained.into_iter().collect::<Vec<_>>();
/// let mut odds = set.into_iter().collect::<Vec<_>>();
Expand All @@ -400,13 +400,13 @@ impl<T, S, A: Allocator + Clone> HashSet<T, S, A> {
/// assert_eq!(odds, vec![1, 3, 5, 7]);
/// ```
#[cfg_attr(feature = "inline-more", inline)]
pub fn drain_filter<F>(&mut self, f: F) -> DrainFilter<'_, T, F, A>
pub fn extract_if<F>(&mut self, f: F) -> ExtractIf<'_, T, F, A>
where
F: FnMut(&T) -> bool,
{
DrainFilter {
ExtractIf {
f,
inner: DrainFilterInner {
inner: ExtractIfInner {
iter: unsafe { self.map.table.iter() },
table: &mut self.map.table,
},
Expand Down Expand Up @@ -1567,17 +1567,18 @@ pub struct Drain<'a, K, A: Allocator + Clone = Global> {

/// A draining iterator over entries of a `HashSet` which don't satisfy the predicate `f`.
///
/// This `struct` is created by the [`drain_filter`] method on [`HashSet`]. See its
/// This `struct` is created by the [`extract_if`] method on [`HashSet`]. See its
/// documentation for more.
///
/// [`drain_filter`]: struct.HashSet.html#method.drain_filter
/// [`extract_if`]: struct.HashSet.html#method.extract_if
/// [`HashSet`]: struct.HashSet.html
pub struct DrainFilter<'a, K, F, A: Allocator + Clone = Global>
#[must_use = "Iterators are lazy unless consumed"]
pub struct ExtractIf<'a, K, F, A: Allocator + Clone = Global>
where
F: FnMut(&K) -> bool,
{
f: F,
inner: DrainFilterInner<'a, K, (), A>,
inner: ExtractIfInner<'a, K, (), A>,
}

/// A lazy iterator producing elements in the intersection of `HashSet`s.
Expand Down Expand Up @@ -1768,21 +1769,7 @@ impl<K: fmt::Debug, A: Allocator + Clone> fmt::Debug for Drain<'_, K, A> {
}
}

impl<'a, K, F, A: Allocator + Clone> Drop for DrainFilter<'a, K, F, A>
where
F: FnMut(&K) -> bool,
{
#[cfg_attr(feature = "inline-more", inline)]
fn drop(&mut self) {
while let Some(item) = self.next() {
let guard = ConsumeAllOnDrop(self);
drop(item);
mem::forget(guard);
}
}
}

impl<K, F, A: Allocator + Clone> Iterator for DrainFilter<'_, K, F, A>
impl<K, F, A: Allocator + Clone> Iterator for ExtractIf<'_, K, F, A>
where
F: FnMut(&K) -> bool,
{
Expand All @@ -1801,10 +1788,7 @@ where
}
}

impl<K, F, A: Allocator + Clone> FusedIterator for DrainFilter<'_, K, F, A> where
F: FnMut(&K) -> bool
{
}
impl<K, F, A: Allocator + Clone> FusedIterator for ExtractIf<'_, K, F, A> where F: FnMut(&K) -> bool {}

impl<T, S, A: Allocator + Clone> Clone for Intersection<'_, T, S, A> {
#[cfg_attr(feature = "inline-more", inline)]
Expand Down Expand Up @@ -2850,18 +2834,18 @@ mod test_set {
}

#[test]
fn test_drain_filter() {
fn test_extract_if() {
{
let mut set: HashSet<i32> = (0..8).collect();
let drained = set.drain_filter(|&k| k % 2 == 0);
let drained = set.extract_if(|&k| k % 2 == 0);
let mut out = drained.collect::<Vec<_>>();
out.sort_unstable();
assert_eq!(vec![0, 2, 4, 6], out);
assert_eq!(set.len(), 4);
}
{
let mut set: HashSet<i32> = (0..8).collect();
drop(set.drain_filter(|&k| k % 2 == 0));
set.extract_if(|&k| k % 2 == 0).for_each(drop);
assert_eq!(set.len(), 4, "Removes non-matching items on drop");
}
}
Expand Down