Skip to content

Commit

Permalink
Use custom operators instead of casrt
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Dec 7, 2023
1 parent 5b965fe commit dd1d0b2
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 12 deletions.
18 changes: 17 additions & 1 deletion include/cuco/detail/extent/extent.inl
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,30 @@ template <typename SizeType, std::size_t N>
struct window_extent {
using value_type = SizeType; ///< Extent value type

__host__ __device__ explicit constexpr operator value_type() const noexcept { return N; }
__host__ __device__ constexpr value_type value() const noexcept { return N; }
__host__ __device__ explicit constexpr operator value_type() const noexcept { return value(); }

private:
__host__ __device__ explicit constexpr window_extent() noexcept {}
__host__ __device__ explicit constexpr window_extent(SizeType) noexcept {}

template <int32_t CGSize_, int32_t WindowSize_, typename SizeType_, std::size_t N_>
friend auto constexpr make_window_extent(extent<SizeType_, N_> ext);

template <typename Rhs>
friend __host__ __device__ constexpr value_type operator/(window_extent const& lhs,
Rhs rhs) noexcept
{
return lhs.value() / rhs;
}

template <typename Lhs>
friend __host__ __device__ constexpr value_type operator%(Lhs lhs,
window_extent const& rhs) noexcept
{
return lhs % rhs.value();
;
}
};

template <typename SizeType>
Expand Down
20 changes: 9 additions & 11 deletions include/cuco/detail/probing_scheme_impl.inl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class probing_iterator {
{
// TODO: step_size_ can be a build time constant (e.g. linear probing)
// Worth passing another extent type?
curr_index_ = (curr_index_ + step_size_) % static_cast<size_type>(upper_bound_);
curr_index_ = (curr_index_ + step_size_) % upper_bound_;
return *this;
}

Expand Down Expand Up @@ -114,8 +114,7 @@ __host__ __device__ constexpr auto linear_probing<CGSize, Hash>::operator()(
{
using size_type = typename Extent::value_type;
return detail::probing_iterator<Extent>{
cuco::detail::sanitize_hash<size_type>(hash_(probe_key) + g.thread_rank()) %
static_cast<size_type>(upper_bound),
cuco::detail::sanitize_hash<size_type>(hash_(probe_key) + g.thread_rank()) % upper_bound,
cg_size,
upper_bound};
}
Expand All @@ -134,10 +133,10 @@ __host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::operato
{
using size_type = typename Extent::value_type;
return detail::probing_iterator<Extent>{
cuco::detail::sanitize_hash<size_type>(hash1_(probe_key)) % static_cast<size_type>(upper_bound),
cuco::detail::sanitize_hash<size_type>(hash1_(probe_key)) % upper_bound,
max(size_type{1},
cuco::detail::sanitize_hash<size_type>(hash2_(probe_key)) %
static_cast<size_type>(upper_bound)), // step size in range [1, prime - 1]
upper_bound), // step size in range [1, prime - 1]
upper_bound};
}

Expand All @@ -150,12 +149,11 @@ __host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::operato
{
using size_type = typename Extent::value_type;
return detail::probing_iterator<Extent>{
cuco::detail::sanitize_hash<size_type>(hash1_(probe_key) + g.thread_rank()) %
static_cast<size_type>(upper_bound),
static_cast<size_type>((cuco::detail::sanitize_hash<size_type>(hash2_(probe_key)) %
(static_cast<size_type>(upper_bound) / cg_size - 1) +
1) *
cg_size),
cuco::detail::sanitize_hash<size_type>(hash1_(probe_key) + g.thread_rank()) % upper_bound,
static_cast<size_type>(
(cuco::detail::sanitize_hash<size_type>(hash2_(probe_key)) % (upper_bound / cg_size - 1) +
1) *
cg_size),
upper_bound}; // TODO use fast_int operator
}
} // namespace experimental
Expand Down
13 changes: 13 additions & 0 deletions include/cuco/utility/fast_int.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ struct fast_int {
evaluate_magic_numbers();
}

/**
* @brief Get the underlying integer value.
*
* @return Underlying value
*/
__host__ __device__ constexpr value_type value() const noexcept { return value_; }

/**
* @brief Explicit conversion operator to the underlying value type.
*
Expand Down Expand Up @@ -143,6 +150,12 @@ struct fast_int {
return rhs.mulhi(rhs.magic_, mul) >> rhs.shift_;
}

template <typename Rhs>
friend __host__ __device__ constexpr value_type operator/(fast_int const& lhs, Rhs rhs) noexcept
{
return lhs.value() / static_cast<value_type>(rhs);
}

template <typename Lhs>
friend __host__ __device__ constexpr value_type operator%(Lhs lhs, fast_int const& rhs) noexcept
{
Expand Down

0 comments on commit dd1d0b2

Please sign in to comment.