Skip to content

Commit

Permalink
Minor improvements for arrow_filter_policy (#654)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhaseeb123 authored Dec 10, 2024
1 parent 95f338d commit 096346b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
2 changes: 1 addition & 1 deletion include/cuco/bloom_filter_policies.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace cuco {
* By default, cuco::xxhash_64 hasher will be used.
*
*/
template <class Key, class XXHash64 = cuco::xxhash_64<Key>>
template <typename Key, template <typename> class XXHash64 = cuco::xxhash_64>
using arrow_filter_policy = detail::arrow_filter_policy<Key, XXHash64>;

/**
Expand Down
23 changes: 9 additions & 14 deletions include/cuco/detail/bloom_filter/arrow_filter_policy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ namespace cuco::detail {
* void bulk_insert_and_eval_arrow_policy_bloom_filter(device_vector<KeyType> const& positive_keys,
* device_vector<KeyType> const& negative_keys)
* {
* using xxhash_64 = cuco::xxhash_64<KeyType>;
* using policy_type = cuco::arrow_filter_policy<KeyType, xxhash_64>;
* using policy_type = cuco::arrow_filter_policy<KeyType, cuco::xxhash_64>;
*
* // Warn or throw if the number of filter blocks is greater than maximum used by Arrow policy.
* static_assert(NUM_FILTER_BLOCKS <= policy_type::max_filter_blocks, "NUM_FILTER_BLOCKS must be
Expand Down Expand Up @@ -81,14 +80,13 @@ namespace cuco::detail {
* @tparam Key The type of the values to generate a fingerprint for.
* @tparam XXHash64 64-bit XXHash hasher implementation for fingerprint generation.
*/
template <class Key, class XXHash64>
template <class Key, template <typename> class XXHash64>
class arrow_filter_policy {
public:
using hasher = XXHash64; ///< 64-bit XXHash hasher for Arrow bloom filter policy
using word_type = std::uint32_t; ///< uint32_t for Arrow bloom filter policy
using hash_argument_type = typename hasher::argument_type; ///< Hash function input type
using hash_result_type = decltype(std::declval<hasher>()(
std::declval<hash_argument_type>())); ///< hash function output type
using hasher = XXHash64<Key>; ///< 64-bit XXHash hasher for Arrow bloom filter policy
using word_type = std::uint32_t; ///< uint32_t for Arrow bloom filter policy
using key_type = Key; ///< Hash function input type
using hash_value_type = std::uint64_t; ///< hash function output type

static constexpr uint32_t bits_set_per_block = 8; ///< hardcoded bits set per Arrow filter block
static constexpr uint32_t words_per_block = 8; ///< hardcoded words per Arrow filter block
Expand Down Expand Up @@ -135,10 +133,7 @@ class arrow_filter_policy {
*
* @return The hash value of the key
*/
__device__ constexpr hash_result_type hash(hash_argument_type const& key) const
{
return hash_(key);
}
__device__ constexpr hash_value_type hash(key_type const& key) const { return hash_(key); }

/**
* @brief Determines the filter block a key is added into.
Expand All @@ -155,7 +150,7 @@ class arrow_filter_policy {
* @return The block index for the given key's hash value
*/
template <class Extent>
__device__ constexpr auto block_index(hash_result_type hash, Extent num_blocks) const
__device__ constexpr auto block_index(hash_value_type hash, Extent num_blocks) const
{
constexpr auto hash_bits = cuda::std::numeric_limits<word_type>::digits;
// TODO: assert if num_blocks > max_filter_blocks
Expand All @@ -173,7 +168,7 @@ class arrow_filter_policy {
*
* @return The bit pattern for the word/segment in the filter block
*/
__device__ constexpr word_type word_pattern(hash_result_type hash, std::uint32_t word_index) const
__device__ constexpr word_type word_pattern(hash_value_type hash, std::uint32_t word_index) const
{
// SALT array to calculate bit indexes for the current word
auto constexpr salt = SALT();
Expand Down

0 comments on commit 096346b

Please sign in to comment.