Skip to content

Commit

Permalink
Re-enable arbitrary input pair types
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Nov 21, 2023
1 parent f3b82b2 commit f5bf6d8
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 36 deletions.
52 changes: 43 additions & 9 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ class open_addressing_ref_impl {
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

auto const key = this->extract_key(value);
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(key, storage_ref_.window_extent());

while (true) {
Expand All @@ -289,7 +290,7 @@ class open_addressing_ref_impl {
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
slot_content,
value)) {
val)) {
case insert_result::CONTINUE: continue;
case insert_result::SUCCESS: return true;
case insert_result::DUPLICATE: return false;
Expand All @@ -314,7 +315,8 @@ class open_addressing_ref_impl {
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
Value const& value) noexcept
{
auto const key = this->extract_key(value);
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent());

while (true) {
Expand Down Expand Up @@ -352,7 +354,7 @@ class open_addressing_ref_impl {
(group.thread_rank() == src_lane)
? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
window_slots[intra_window_index],
value)
val)
: insert_result::CONTINUE;

switch (group.shfl(status, src_lane)) {
Expand Down Expand Up @@ -392,7 +394,8 @@ class open_addressing_ref_impl {
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
#endif

auto const key = this->extract_key(value);
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(key, storage_ref_.window_extent());

while (true) {
Expand All @@ -413,7 +416,7 @@ class open_addressing_ref_impl {
if (eq_res == detail::equal_result::EMPTY or
cuco::detail::bitwise_compare(this->extract_key(window_slots[i]),
this->erased_key_sentinel())) {
switch (this->attempt_insert_stable(window_ptr + i, window_slots[i], value)) {
switch (this->attempt_insert_stable(window_ptr + i, window_slots[i], val)) {
case insert_result::SUCCESS: {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
Expand Down Expand Up @@ -463,7 +466,8 @@ class open_addressing_ref_impl {
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
#endif

auto const key = this->extract_key(value);
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent());

while (true) {
Expand Down Expand Up @@ -514,7 +518,7 @@ class open_addressing_ref_impl {
auto const res = group.shfl(reinterpret_cast<intptr_t>(slot_ptr), src_lane);
auto const status = [&, target_idx = intra_window_index]() {
if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; }
return this->attempt_insert_stable(slot_ptr, window_slots[target_idx], value);
return this->attempt_insert_stable(slot_ptr, window_slots[target_idx], val);
}();

switch (group.shfl(status, src_lane)) {
Expand Down Expand Up @@ -890,7 +894,6 @@ class open_addressing_ref_impl {
}
}

private:
/**
* @brief Extracts the key from a given value type.
*
Expand Down Expand Up @@ -948,6 +951,37 @@ class open_addressing_ref_impl {
}
}

/**
* @brief Converts the given type to the container's native `value_type` while maintaining the
* heterogeneous key type.
*
* @tparam T Input type which is convertible to 'value_type'
*
* @param value The input value
*
* @return The converted object
*/
template <typename T>
[[nodiscard]] __host__ __device__ constexpr auto heterogeneous_value(
T const& value) const noexcept
{
if constexpr (this->has_payload and not cuda::std::is_same_v<T, value_type>) {
using mapped_type = decltype(this->empty_slot_sentinel_.second);
if constexpr (cuco::detail::is_cuda_std_pair_like<T>::value) {
return cuco::pair{cuda::std::get<0>(value),
static_cast<mapped_type>(cuda::std::get<1>(value))};
} else if constexpr (cuco::detail::is_thrust_pair_like<T>::value) {
return cuco::pair{thrust::get<0>(value), static_cast<mapped_type>(thrust::get<1>(value))};
} else {
// hail mary (convert using .first/.second members)
return cuco::pair{thrust::raw_reference_cast(value.first),
static_cast<mapped_type>(value.second)};
}
} else {
return thrust::raw_reference_cast(value);
}
}

/**
* @brief Gets the sentinel used to represent an erased slot.
*
Expand Down
55 changes: 28 additions & 27 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,14 @@ class operator_impl<
/**
* @brief Inserts an element.
*
* @tparam ProbeKey Input key type which is convertible to 'key_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param value The element to insert
*
* @return True if the given element is successfully inserted
*/
template <typename ProbeKey>
__device__ bool insert(cuco::pair<ProbeKey, mapped_type> const& value) noexcept
template <typename Value>
__device__ bool insert(Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert(value);
Expand All @@ -199,16 +199,16 @@ class operator_impl<
/**
* @brief Inserts an element.
*
* @tparam ProbeKey Input key type which is convertible to 'key_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
*
* @return True if the given element is successfully inserted
*/
template <typename ProbeKey>
template <typename Value>
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
cuco::pair<ProbeKey, mapped_type> const& value) noexcept
Value const& value) noexcept
{
auto& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert(group, value);
Expand Down Expand Up @@ -242,17 +242,19 @@ class operator_impl<
* @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v`
* to the mapped_type corresponding to the key `k`.
*
* @tparam ProbeKey Input key type which is convertible to 'key_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param value The element to insert
*/
template <typename ProbeKey>
__device__ void insert_or_assign(cuco::pair<ProbeKey, mapped_type> const& value) noexcept
template <typename Value>
__device__ void insert_or_assign(Value const& value) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

ref_type& ref_ = static_cast<ref_type&>(*this);
auto const key = value.first;
ref_type& ref_ = static_cast<ref_type&>(*this);

auto const val = ref_.impl_.heterogeneous_value(value);
auto const key = ref_.impl_.extract_key(val);
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(key, storage_ref.window_extent());
Expand All @@ -268,14 +270,14 @@ class operator_impl<
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
ref_.impl_.atomic_store(
&((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second,
value.second);
val.second);
return;
}
if (eq_res == detail::equal_result::EMPTY or
cuco::detail::bitwise_compare(slot_content.first, ref_.impl_.erased_key_sentinel())) {
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
if (attempt_insert_or_assign(
(storage_ref.data() + *probing_iter)->data() + intra_window_index, value)) {
(storage_ref.data() + *probing_iter)->data() + intra_window_index, val)) {
return;
}
}
Expand All @@ -290,18 +292,19 @@ class operator_impl<
* @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v`
* to the mapped_type corresponding to the key `k`.
*
* @tparam ProbeKey Input key type which is convertible to 'key_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
*/
template <typename ProbeKey>
template <typename Value>
__device__ void insert_or_assign(cooperative_groups::thread_block_tile<cg_size> const& group,
cuco::pair<ProbeKey, mapped_type> const& value) noexcept
Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);

auto const key = value.first;
auto const val = ref_.impl_.heterogeneous_value(value);
auto const key = ref_.impl_.extract_key(val);
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(group, key, storage_ref.window_extent());
Expand Down Expand Up @@ -336,7 +339,7 @@ class operator_impl<
if (group.thread_rank() == src_lane) {
ref_.impl_.atomic_store(
&((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second,
value.second);
val.second);
}
group.sync();
return;
Expand All @@ -349,7 +352,7 @@ class operator_impl<
auto const status =
(group.thread_rank() == src_lane)
? attempt_insert_or_assign(
(storage_ref.data() + *probing_iter)->data() + intra_window_index, value)
(storage_ref.data() + *probing_iter)->data() + intra_window_index, val)
: false;

// Exit if inserted or assigned
Expand Down Expand Up @@ -452,16 +455,15 @@ class operator_impl<
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam ProbeKey Input key type which is convertible to 'key_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param value The element to insert
*
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <typename ProbeKey>
__device__ thrust::pair<iterator, bool> insert_and_find(
cuco::pair<ProbeKey, mapped_type> const& value) noexcept
template <typename Value>
__device__ thrust::pair<iterator, bool> insert_and_find(Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert_and_find(value);
Expand All @@ -474,18 +476,17 @@ class operator_impl<
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam ProbeKey Input key type which is convertible to 'key_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert_and_find
* @param value The element to insert
*
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <typename ProbeKey>
template <typename Value>
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group,
cuco::pair<ProbeKey, mapped_type> const& value) noexcept
cooperative_groups::thread_block_tile<cg_size> const& group, Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert_and_find(group, value);
Expand Down

0 comments on commit f5bf6d8

Please sign in to comment.