Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Jan 16, 2025
1 parent e79787b commit e414899
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 26 deletions.
13 changes: 8 additions & 5 deletions include/cuco/detail/equal_wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace detail {
* @brief Enum of equality comparison results
*/
// ENUM VALUE MATTERS, DO NOT CHANGE
enum class equal_result : int32_t { UNEQUAL = 0, EQUAL = 1, EMPTY = 2, AVAILABLE = 3 };
enum class equal_result : int32_t { UNEQUAL = 0, EQUAL = 1, EMPTY = 2, ERASED = 3 };

enum class is_insert : bool { YES, NO };

Expand Down Expand Up @@ -97,10 +97,13 @@ struct equal_wrapper {
__device__ constexpr equal_result operator()(LHS const& lhs, RHS const& rhs) const noexcept
{
if constexpr (IsInsert == is_insert::YES) {
return (cuco::detail::bitwise_compare(rhs, empty_sentinel_) or
cuco::detail::bitwise_compare(rhs, erased_sentinel_))
? equal_result::AVAILABLE
: this->equal_to(lhs, rhs);
if (cuco::detail::bitwise_compare(rhs, empty_sentinel_)) {
return equal_result::EMPTY;
} else if (cuco::detail::bitwise_compare(rhs, erased_sentinel_)) {
return equal_result::ERASED;
} else {
return this->equal_to(lhs, rhs);
}
} else {
return cuco::detail::bitwise_compare(rhs, empty_sentinel_) ? equal_result::EMPTY
: this->equal_to(lhs, rhs);
Expand Down
56 changes: 39 additions & 17 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,12 @@ class open_addressing_ref_impl {
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

[[maybe_unused]] auto probing_iter_copy = probing_iter;
[[maybe_unused]] bool erased = false;
[[maybe_unused]] bool empty_after_erased = false;

while (true) {
[[maybe_unused]] continue_after_erased:
auto const bucket_slots = storage_ref_[*probing_iter];

for (auto& slot_content : bucket_slots) {
Expand All @@ -393,21 +398,34 @@ class open_addressing_ref_impl {
if constexpr (not allows_duplicates) {
// If the key is already in the container, return false
if (eq_res == detail::equal_result::EQUAL) { return false; }
if (eq_res == detail::equal_result::ERASED and not erased and not empty_after_erased) {
erased = true;
probing_iter_copy = probing_iter;
}
if (eq_res == detail::equal_result::EMPTY and erased and not empty_after_erased) {
empty_after_erased = true;
probing_iter = probing_iter_copy;
goto continue_after_erased;
}
}
if (eq_res == detail::equal_result::AVAILABLE) {
auto const intra_bucket_index = thrust::distance(bucket_slots.begin(), &slot_content);
switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_bucket_index,
slot_content,
val)) {
case insert_result::DUPLICATE: {
if constexpr (allows_duplicates) {
[[fallthrough]];
} else {
return false;

if (not erased or empty_after_erased) {
if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) {
auto const intra_bucket_index = thrust::distance(bucket_slots.begin(), &slot_content);
switch (
attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_bucket_index,
slot_content,
val)) {
case insert_result::DUPLICATE: {
if constexpr (allows_duplicates) {
[[fallthrough]];
} else {
return false;
}
}
case insert_result::CONTINUE: continue;
case insert_result::SUCCESS: return true;
}
case insert_result::CONTINUE: continue;
case insert_result::SUCCESS: return true;
}
}
}
Expand Down Expand Up @@ -442,8 +460,10 @@ class open_addressing_ref_impl {
for (auto i = 0; i < bucket_size; ++i) {
switch (
this->predicate_.operator()<is_insert::YES>(key, this->extract_key(bucket_slots[i]))) {
case detail::equal_result::AVAILABLE:
return bucket_probing_results{detail::equal_result::AVAILABLE, i};
case detail::equal_result::EMPTY:
return bucket_probing_results{detail::equal_result::EMPTY, i};
case detail::equal_result::ERASED:
return bucket_probing_results{detail::equal_result::ERASED, i};
case detail::equal_result::EQUAL: {
if constexpr (allows_duplicates) {
continue;
Expand All @@ -463,7 +483,8 @@ class open_addressing_ref_impl {
if (group.any(state == detail::equal_result::EQUAL)) { return false; }
}

auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or
(state == detail::equal_result::ERASED));
if (group_contains_available) {
auto const src_lane = __ffs(group_contains_available) - 1;
auto const status =
Expand Down Expand Up @@ -538,7 +559,7 @@ class open_addressing_ref_impl {
}
return {iterator{&bucket_ptr[i]}, false};
}
if (eq_res == detail::equal_result::AVAILABLE) {
if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) {
switch (this->attempt_insert_stable(bucket_ptr + i, bucket_slots[i], val)) {
case insert_result::SUCCESS: {
if constexpr (has_payload) {
Expand Down Expand Up @@ -626,7 +647,8 @@ class open_addressing_ref_impl {
return {iterator{reinterpret_cast<value_type*>(res)}, false};
}

auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or
(state == detail::equal_result::ERASED));
if (group_contains_available) {
auto const src_lane = __ffs(group_contains_available) - 1;
auto const res = group.shfl(reinterpret_cast<intptr_t>(slot_ptr), src_lane);
Expand Down
10 changes: 6 additions & 4 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ class operator_impl<
payload_ref.store(val.second, cuda::memory_order_relaxed);
return;
}
if (eq_res == detail::equal_result::AVAILABLE) {
if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) {
if (attempt_insert_or_assign(slot_ptr, val)) { return; }
}
}
Expand Down Expand Up @@ -571,7 +571,8 @@ class operator_impl<
return;
}

auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or
(state == detail::equal_result::ERASED));
if (group_contains_available) {
auto const src_lane = __ffs(group_contains_available) - 1;
auto const status =
Expand Down Expand Up @@ -883,7 +884,7 @@ class operator_impl<
op(cuda::atomic_ref<T, Scope>{slot_ptr->second}, val.second);
return false;
}
if (eq_res == detail::equal_result::AVAILABLE) {
if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) {
switch (ref_.attempt_insert_or_apply<UseDirectApply>(slot_ptr, slot_content, val, op)) {
case insert_result::SUCCESS: return true;
case insert_result::DUPLICATE: {
Expand Down Expand Up @@ -970,7 +971,8 @@ class operator_impl<
return false;
}

auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or
(state == detail::equal_result::ERASED));
if (group_contains_available) {
auto const src_lane = __ffs(group_contains_available) - 1;
auto const status = [&, target_idx = intra_bucket_index]() {
Expand Down
4 changes: 4 additions & 0 deletions tests/static_map/erase_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ void test_erase(Map& map, size_type num_keys)
REQUIRE(cuco::test::all_of(
d_keys_exist.begin() + num_keys / 2, d_keys_exist.end(), thrust::identity{}));

// tests #606
map.insert(pairs_begin + num_keys / 2, pairs_begin + num_keys);
// TODO insert_and_find, insert_or_assign, insert_or_apply

map.erase(keys_begin + num_keys / 2, keys_begin + num_keys);
REQUIRE(map.size() == 0);
}
Expand Down

0 comments on commit e414899

Please sign in to comment.