Skip to content

Commit

Permalink
Add allocator template parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
amukkara committed Sep 10, 2023
1 parent 7f3e3ac commit cdbabef
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 33 deletions.
40 changes: 21 additions & 19 deletions include/cuco/detail/trie/trie.inl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
namespace cuco {
namespace experimental {

template <typename LabelType>
constexpr trie<LabelType>::trie()
: num_keys_{0},
template <typename LabelType, class Allocator>
constexpr trie<LabelType, Allocator>::trie(Allocator const& allocator)
: allocator_{allocator},
num_keys_{0},
num_nodes_{1},
last_key_{},
num_levels_{2},
Expand All @@ -37,15 +38,15 @@ constexpr trie<LabelType>::trie()
levels_[0].labels_.push_back(root_label_);
}

template <typename LabelType>
trie<LabelType>::~trie() noexcept(false)
template <typename LabelType, class Allocator>
trie<LabelType, Allocator>::~trie() noexcept(false)
{
if (d_levels_ptr_) { CUCO_CUDA_TRY(cudaFree(d_levels_ptr_)); }
if (device_ptr_) { CUCO_CUDA_TRY(cudaFree(device_ptr_)); }
}

template <typename LabelType>
void trie<LabelType>::insert(const std::vector<LabelType>& key) noexcept
template <typename LabelType, class Allocator>
void trie<LabelType, Allocator>::insert(const std::vector<LabelType>& key) noexcept
{
if (key == last_key_) { return; } // Ignore duplicate keys
assert(num_keys_ == 0 || key > last_key_); // Keys are expected to be inserted in sorted order
Expand Down Expand Up @@ -95,8 +96,8 @@ void trie<LabelType>::insert(const std::vector<LabelType>& key) noexcept
last_key_ = key;
}

template <typename LabelType>
void trie<LabelType>::build() noexcept(false)
template <typename LabelType, class Allocator>
void trie<LabelType, Allocator>::build() noexcept(false)
{
// Perform build level-by-level for all levels, followed by a deep-copy from host to device
size_type offset = 0;
Expand Down Expand Up @@ -125,13 +126,13 @@ void trie<LabelType>::build() noexcept(false)
CUCO_CUDA_TRY(cudaMemcpy(device_ptr_, this, sizeof(trie<LabelType>), cudaMemcpyHostToDevice));
}

template <typename LabelType>
template <typename LabelType, class Allocator>
template <typename KeyIt, typename OffsetIt, typename OutputIt>
void trie<LabelType>::lookup(KeyIt keys_begin,
OffsetIt offsets_begin,
OffsetIt offsets_end,
OutputIt outputs_begin,
cuda_stream_ref stream) const noexcept
void trie<LabelType, Allocator>::lookup(KeyIt keys_begin,
OffsetIt offsets_begin,
OffsetIt offsets_end,
OutputIt outputs_begin,
cuda_stream_ref stream) const noexcept
{
auto num_keys = cuco::detail::distance(offsets_begin, offsets_end) - 1;
if (num_keys == 0) { return; }
Expand Down Expand Up @@ -159,16 +160,17 @@ __global__ void trie_lookup_kernel(
}
}

template <typename LabelType>
template <typename LabelType, class Allocator>
template <typename... Operators>
auto trie<LabelType>::ref(Operators...) const noexcept
auto trie<LabelType, Allocator>::ref(Operators...) const noexcept
{
static_assert(sizeof...(Operators), "No operators specified");
return ref_type<Operators...>{device_ptr_};
}

template <typename LabelType>
trie<LabelType>::level::level() : louds_{}, outs_{}, labels_{}, labels_ptr_{nullptr}, offset_{0}
template <typename LabelType, class Allocator>
trie<LabelType, Allocator>::level::level()
: louds_{}, outs_{}, labels_{}, labels_ptr_{nullptr}, offset_{0}
{
}

Expand Down
12 changes: 6 additions & 6 deletions include/cuco/detail/trie/trie_ref.inl
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
namespace cuco {
namespace experimental {

template <typename LabelType, typename... Operators>
__host__ __device__ constexpr trie_ref<LabelType, Operators...>::trie_ref(
const trie<LabelType>* trie) noexcept
template <typename LabelType, class Allocator, typename... Operators>
__host__ __device__ constexpr trie_ref<LabelType, Allocator, Operators...>::trie_ref(
const trie<LabelType, Allocator>* trie) noexcept
: trie_{trie}
{
}

namespace detail {

template <typename LabelType, typename... Operators>
class operator_impl<op::trie_lookup_tag, trie_ref<LabelType, Operators...>> {
using ref_type = trie_ref<LabelType, Operators...>;
template <typename LabelType, class Allocator, typename... Operators>
class operator_impl<op::trie_lookup_tag, trie_ref<LabelType, Allocator, Operators...>> {
using ref_type = trie_ref<LabelType, Allocator, Operators...>;
using size_type = size_t;

public:
Expand Down
14 changes: 11 additions & 3 deletions include/cuco/trie.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,17 @@ namespace experimental {
* @brief Trie class
*
* @tparam label_type type of individual characters of vector keys (eg. char or int)
* @tparam Allocator Type of allocator used for device storage
*/
template <typename LabelType>
template <typename LabelType, class Allocator = thrust::device_malloc_allocator<std::byte>>
class trie {
public:
constexpr trie();
/**
* @brief Constructs an empty trie
*
* @param allocator Allocator used for allocating device storage
*/
constexpr trie(Allocator const& allocator = Allocator{});
~trie() noexcept(false);

/**
Expand Down Expand Up @@ -88,6 +94,7 @@ class trie {
[[nodiscard]] auto ref(Operators... ops) const noexcept;

private:
Allocator allocator_; ///< Allocator
size_type num_keys_; ///< Number of keys inserted into trie
size_type num_nodes_; ///< Number of internal nodes
std::vector<LabelType> last_key_; ///< Last key inserted into trie
Expand All @@ -110,7 +117,8 @@ class trie {

template <typename... Operators>
using ref_type =
cuco::experimental::trie_ref<LabelType, Operators...>; ///< Non-owning container ref type
cuco::experimental::trie_ref<LabelType, Allocator, Operators...>; ///< Non-owning container ref
///< type

// Mixins need to be friends with this class in order to access private members
template <typename Op, typename Ref>
Expand Down
11 changes: 6 additions & 5 deletions include/cuco/trie_ref.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace cuco {
namespace experimental {

template <typename LabelType>
template <typename LabelType, class Allocator>
class trie;

/**
Expand All @@ -15,18 +15,19 @@ class trie;
* @tparam LabelType Trie label type
* @tparam Operators Device operator options defined in `include/cuco/operator.hpp`
*/
template <typename LabelType, typename... Operators>
class trie_ref : public detail::operator_impl<Operators, trie_ref<LabelType, Operators...>>... {
template <typename LabelType, class Allocator, typename... Operators>
class trie_ref
: public detail::operator_impl<Operators, trie_ref<LabelType, Allocator, Operators...>>... {
public:
/**
* @brief Constructs trie_ref.
*
* @param trie Non-owning ref of trie
*/
__host__ __device__ explicit constexpr trie_ref(const trie<LabelType>* trie) noexcept;
__host__ __device__ explicit constexpr trie_ref(const trie<LabelType, Allocator>* trie) noexcept;

private:
const trie<LabelType>* trie_;
const trie<LabelType, Allocator>* trie_;

// Mixins need to be friends with this class in order to access private members
template <typename Op, typename Ref>
Expand Down

0 comments on commit cdbabef

Please sign in to comment.