Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Improvements in feature sampling #4278

Merged
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
25dbf45
Saving the changes made so far
vinaydes Sep 13, 2021
a7a6707
Added Thrust shuffle based colid generation
vinaydes Sep 24, 2021
b5658fb
Removing host copy of colids
vinaydes Sep 24, 2021
b05689b
Refacotring kernel arguments
vinaydes Sep 24, 2021
5dad925
Removing unused select function
vinaydes Sep 24, 2021
558dd46
Added print to distinguish code version
vinaydes Sep 27, 2021
9f3adbe
Timing measurement calls added
vinaydes Sep 30, 2021
7cc1676
Added count sort based sampling again for better comparison
vinaydes Oct 1, 2021
2a6f4a2
Minor changes
vinaydes Oct 1, 2021
669bb9c
Working adaptive sampling kernel
vinaydes Oct 11, 2021
9d2ffd0
Removing thrust and other unused code
vinaydes Oct 12, 2021
8b3704f
Making kiss99 static for now
vinaydes Oct 12, 2021
1c50ab7
Removing some more unused code
vinaydes Oct 12, 2021
b1c63bb
Formatting changes
vinaydes Oct 12, 2021
684c668
Fixing select kernel call format
vinaydes Oct 12, 2021
3490114
Merge branch 'branch-21.12' into enh-rf-better-feature-sampling
vinaydes Oct 12, 2021
3138301
Undo local build fix
vinaydes Oct 12, 2021
fbec82f
Merge branch 'branch-22.02' into enh-rf-better-feature-sampling
vinaydes Jan 20, 2022
7f82825
Changing the RAFT repo link
vinaydes Jan 20, 2022
ce44301
Correcting get_raft
vinaydes Jan 20, 2022
0409f57
Merge branch 'branch-22.04' into enh-rf-better-feature-sampling
vinaydes Mar 12, 2022
9c8f068
Fixing merge issue
vinaydes Mar 12, 2022
0beff3b
Merge branch 'branch-22.04' into enh-rf-better-feature-sampling
vinaydes Apr 13, 2022
fcbd074
Merge related changes
vinaydes Apr 13, 2022
eb5723d
DOC
raydouglass Mar 17, 2022
aac34a4
Replace cudf.logical_not with ~ (#4669)
canonizer Mar 31, 2022
28a027e
float64 support in multi-sum and child_index() (#4648)
canonizer Mar 31, 2022
a3ea64d
float64 support in FIL functions (#4655)
canonizer Apr 2, 2022
592f860
Add libcuml-tests package (#4635)
jjacobelli Apr 4, 2022
d2f64f5
float64 support in FIL core (#4646)
canonizer Apr 7, 2022
0c5463b
Replace 22.04.x with 22.06.x in yaml files (#4692)
daxiongshu Apr 8, 2022
19bddfc
unpin dask for development
galipremsagar Apr 8, 2022
e5829eb
Change "principals" to "principles" (#4695)
cakiki Apr 11, 2022
a4374d6
Fixing couple of run time issues
vinaydes Apr 14, 2022
097f359
Merge branch 'branch-22.06' into enh-rf-better-feature-sampling
venkywonka Apr 19, 2022
972bb1c
Replacing kiss99 with RNG from RAFT
vinaydes Apr 19, 2022
fb1bf91
Formatting minor change
vinaydes Apr 19, 2022
9ef9a60
Merge remote-tracking branch 'vinay/enh-rf-better-feature-sampling' i…
venkywonka Apr 19, 2022
26509e4
Merge branch 'branch-22.06' into enh-rf-better-feature-sampling
vinaydes Apr 20, 2022
7758b34
Tidying up memory requirement
vinaydes Apr 20, 2022
49c45b1
include select_kernel and adaptive_sample_kernel
venkywonka Apr 26, 2022
ab3bb4c
[debug] menu-driven sampling strategies
venkywonka May 12, 2022
347f616
menu driven debugging code
venkywonka May 19, 2022
32f9ac1
clean debug code
venkywonka May 19, 2022
000ab2f
add modifications and menu-driven feature-sampling for debug
venkywonka May 25, 2022
5ad8f94
predicate kernel launch based on required work-per-thread
venkywonka Jun 2, 2022
32ba8ce
clean some code
venkywonka Jun 2, 2022
d3c16fb
Merge branch 'branch-22.06' of https://github.com/rapidsai/cuml into …
venkywonka Jun 2, 2022
5594cde
clean more code
venkywonka Jun 2, 2022
ec928c9
revert CMakeLists changes
venkywonka Jun 2, 2022
80bfa4f
clang format
venkywonka Jun 2, 2022
ecddd5e
Merge remote-tracking branch 'vinay/enh-rf-better-feature-sampling' i…
venkywonka Jun 2, 2022
6d7badd
adding review changes
venkywonka Jun 8, 2022
1f79239
fix a memory bug + others
venkywonka Jun 13, 2022
ff92c6e
change seed for a corner case
venkywonka Jun 29, 2022
8d76a78
change seed for a corner case
venkywonka Jun 29, 2022
8e5990d
Merge branch 'branch-22.08' into enh-rf-better-feature-sampling
vinaydes Jul 29, 2022
c094f65
Addressing review comments about docstring
vinaydes Jul 29, 2022
aae286c
Update cpp/src/decisiontree/batched-levelalgo/kernels/builder_kernels…
vinaydes Aug 1, 2022
ea0428a
formatting changes
vinaydes Aug 1, 2022
b5751f1
Merge branch 'branch-22.08' into enh-rf-better-feature-sampling
vinaydes Aug 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 83 additions & 3 deletions cpp/src/decisiontree/batched-levelalgo/builder.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ struct Builder {
int n_blks_for_cols = 10;
/** Memory alignment value */
const size_t align_value = 512;
IdxT* colids;
/** rmm device workspace buffer */
rmm::device_uvector<char> d_buff;
/** pinned host buffer to store the trained nodes */
Expand Down Expand Up @@ -281,6 +282,7 @@ struct Builder {
d_wsize += calculateAlignedBytes(sizeof(NodeWorkItem) * max_batch); // d_work_Items
d_wsize += // workload_info
calculateAlignedBytes(sizeof(WorkloadInfo<IdxT>) * max_blocks_dimx);
d_wsize += calculateAlignedBytes(sizeof(IdxT) * max_batch * dataset.n_sampled_cols); // colids

// all nodes in the tree
h_wsize += // h_workload_info
Expand Down Expand Up @@ -320,6 +322,8 @@ struct Builder {
d_wspace += calculateAlignedBytes(sizeof(NodeWorkItem) * max_batch);
workload_info = reinterpret_cast<WorkloadInfo<IdxT>*>(d_wspace);
d_wspace += calculateAlignedBytes(sizeof(WorkloadInfo<IdxT>) * max_blocks_dimx);
colids = reinterpret_cast<IdxT*>(d_wspace);
d_wspace += calculateAlignedBytes(sizeof(IdxT) * max_batch * dataset.n_sampled_cols);

RAFT_CUDA_TRY(
cudaMemsetAsync(done_count, 0, sizeof(int) * max_batch * n_col_blks, builder_stream));
Expand Down Expand Up @@ -378,7 +382,7 @@ struct Builder {

auto doSplit(const std::vector<NodeWorkItem>& work_items)
{
raft::common::nvtx::range fun_scope("Builder::doSplit @bulder_base.cuh [batched-levelalgo]");
raft::common::nvtx::range fun_scope("Builder::doSplit @builder.cuh [batched-levelalgo]");
// start fresh on the number of *new* nodes created in this batch
RAFT_CUDA_TRY(cudaMemsetAsync(n_nodes, 0, sizeof(IdxT), builder_stream));
initSplit<DataT, IdxT, TPB_DEFAULT>(splits, work_items.size(), builder_stream);
Expand All @@ -388,11 +392,86 @@ struct Builder {

auto [n_blocks_dimx, n_large_nodes] = this->updateWorkloadInfo(work_items);

// do feature-sampling
if (dataset.n_sampled_cols != dataset.N) {
raft::common::nvtx::range fun_scope("feature-sampling");
constexpr int block_threads = 128;
constexpr int max_samples_per_thread = 72; // register spillage if more than this limit
// decide if the problem size is suitable for the excess-sampling strategy.
//
// our required shared memory is a function of number of samples we'll need to sample (in
// parallel, with replacement) in excess to get 'k' uniques out of 'n' features. estimated
// static shared memory required by cub's block-wide collectives:
// max_samples_per_thread * block_threads * sizeof(IdxT)
//
// The maximum items to sample ( the constant `max_samples_per_thread` to be set at
// compile-time) is calibrated so that:
// 1. There is no register spills and accesses to global memory
// 2. The required static shared memory (ie, `max_samples_per_thread * block_threads *
// sizeof(IdxT)` does not exceed 46KB.
//
// number of samples we'll need to sample (in parallel, with replacement), to expect 'k'
// unique samples from 'n' is given by the following equation: log(1 - k/n)/log(1 - 1/n) ref:
// https://stats.stackexchange.com/questions/296005/the-expected-number-of-unique-elements-drawn-with-replacement
IdxT n_parallel_samples =
std::ceil(raft::myLog(1 - double(dataset.n_sampled_cols) / double(dataset.N)) /
(raft::myLog(1 - 1.f / double(dataset.N))));
// maximum sampling work possible by all threads in a block :
// `max_samples_per_thread * block_thread`
// dynamically calculated sampling work to be done per block:
// `n_parallel_samples`
// former must be greater or equal to than latter for excess-sampling-based strategy
if (max_samples_per_thread * block_threads >= n_parallel_samples) {
raft::common::nvtx::range fun_scope("excess-sampling-based approach");
dim3 grid;
grid.x = work_items.size();
grid.y = 1;
grid.z = 1;

if (n_parallel_samples <= block_threads)
// each thread randomly samples only 1 sample
excess_sample_with_replacement_kernel<IdxT, 1, block_threads>
<<<grid, block_threads, 0, builder_stream>>>(colids,
d_work_items,
work_items.size(),
treeid,
seed,
dataset.N,
dataset.n_sampled_cols,
n_parallel_samples);
else
// each thread does more work and samples `max_samples_per_thread` samples
excess_sample_with_replacement_kernel<IdxT, max_samples_per_thread, block_threads>
<<<grid, block_threads, 0, builder_stream>>>(colids,
d_work_items,
work_items.size(),
treeid,
seed,
dataset.N,
dataset.n_sampled_cols,
n_parallel_samples);
raft::common::nvtx::pop_range();
} else {
raft::common::nvtx::range fun_scope("reservoir-sampling-based approach");
// using algo-L (reservoir sampling) strategy to sample 'dataset.n_sampled_cols' unique
// features from 'dataset.N' total features
dim3 grid;
grid.x = (work_items.size() + 127) / 128;
grid.y = 1;
grid.z = 1;
algo_L_sample_kernel<<<grid, block_threads, 0, builder_stream>>>(
colids, d_work_items, work_items.size(), treeid, seed, dataset.N, dataset.n_sampled_cols);
raft::common::nvtx::pop_range();
}
RAFT_CUDA_TRY(cudaPeekAtLastError());
raft::common::nvtx::pop_range();
}

// iterate through a batch of columns (to reduce the memory pressure) and
// compute the best split at the end
for (IdxT c = 0; c < dataset.n_sampled_cols; c += n_blks_for_cols) {
computeSplit(c, n_blocks_dimx, n_large_nodes);
RAFT_CUDA_TRY(cudaGetLastError());
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

// create child nodes (or make the current ones leaf)
Expand All @@ -407,7 +486,7 @@ struct Builder {
dataset,
d_work_items,
splits);
RAFT_CUDA_TRY(cudaGetLastError());
RAFT_CUDA_TRY(cudaPeekAtLastError());
raft::common::nvtx::pop_range();
raft::update_host(h_splits, splits, work_items.size(), builder_stream);
handle.sync_stream(builder_stream);
Expand Down Expand Up @@ -462,6 +541,7 @@ struct Builder {
quantiles,
d_work_items,
col,
colids,
done_count,
mutex,
splits,
Expand Down
215 changes: 215 additions & 0 deletions cpp/src/decisiontree/batched-levelalgo/kernels/builder_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
#include "../objectives.cuh"
#include "../quantiles.h"

#include <raft/random/rng.hpp>

#include <cub/cub.cuh>

namespace ML {
namespace DT {

Expand Down Expand Up @@ -60,6 +64,13 @@ HDI bool SplitNotValid(const SplitT& split,
(IdxT(num_rows) - split.nLeft) < min_samples_leaf;
}

/* Returns 'dataset' rounded up to a correctly-aligned pointer of type OutT* */
template <typename OutT, typename InT>
DI OutT* alignPointer(InT dataset)
{
return reinterpret_cast<OutT*>(raft::alignTo(reinterpret_cast<size_t>(dataset), sizeof(OutT)));
}

template <typename DataT, typename LabelT, typename IdxT, int TPB>
__global__ void nodeSplitKernel(const IdxT max_depth,
const IdxT min_samples_leaf,
Expand Down Expand Up @@ -111,6 +122,209 @@ HDI IdxT lower_bound(DataT* array, IdxT len, DataT element)
return start;
}

template <typename IdxT>
struct CustomDifference {
__device__ IdxT operator()(const IdxT& lhs, const IdxT& rhs)
{
if (lhs == rhs)
return 0;
else
return 1;
}
};

/**
* @brief Generates 'k' unique samples of features from 'n' feature sample-space.
* Does this for each work-item (node), feeding a unique seed for each (treeid, nodeid
* (=blockIdx.x), threadIdx.x). Method used is a random, parallel, sampling with replacement of
* excess of 'k' samples (hence the name) and then eliminating the dupicates by ordering them. The
* excess number of samples (=`n_parallel_samples`) is calculated such that after ordering there is
* atleast 'k' uniques.
*/
template <typename IdxT, int MAX_SAMPLES_PER_THREAD, int BLOCK_THREADS = 128>
__global__ void excess_sample_with_replacement_kernel(
IdxT* colids,
const NodeWorkItem* work_items,
size_t work_items_size,
IdxT treeid,
uint64_t seed,
size_t n /* total cols to sample from*/,
size_t k /* number of unique cols to sample */,
int n_parallel_samples /* number of cols to sample with replacement */)
{
if (blockIdx.x >= work_items_size) return;

const uint32_t nodeid = work_items[blockIdx.x].idx;

uint64_t subsequence(fnv1a32_basis);
subsequence = fnv1a32(subsequence, uint32_t(threadIdx.x));
subsequence = fnv1a32(subsequence, uint32_t(treeid));
subsequence = fnv1a32(subsequence, uint32_t(nodeid));

raft::random::PCGenerator gen(seed, subsequence, uint64_t(0));
raft::random::UniformIntDistParams<IdxT, uint64_t> uniform_int_dist_params;

uniform_int_dist_params.start = 0;
uniform_int_dist_params.end = n;
uniform_int_dist_params.diff =
uint64_t(uniform_int_dist_params.end - uniform_int_dist_params.start);

IdxT n_uniques = 0;
IdxT items[MAX_SAMPLES_PER_THREAD];
IdxT col_indices[MAX_SAMPLES_PER_THREAD];
IdxT mask[MAX_SAMPLES_PER_THREAD];
// populate this
for (int i = 0; i < MAX_SAMPLES_PER_THREAD; ++i)
mask[i] = 0;

do {
// blocked arrangement
for (int cta_sample_idx = MAX_SAMPLES_PER_THREAD * threadIdx.x, thread_local_sample_idx = 0;
thread_local_sample_idx < MAX_SAMPLES_PER_THREAD;
++cta_sample_idx, ++thread_local_sample_idx) {
// mask of the previous iteration, if exists, is re-used here
// so previously generated unique random numbers are used.
// newly generated random numbers may or may not duplicate the previously generated ones
// but this ensures some forward progress in order to generate atleast 'k' unique random
// samples.
if (mask[thread_local_sample_idx] == 0 and cta_sample_idx < n_parallel_samples)
raft::random::custom_next(
gen, &items[thread_local_sample_idx], uniform_int_dist_params, IdxT(0), IdxT(0));
else if (mask[thread_local_sample_idx] ==
0) // indices that exceed `n_parallel_samples` will not generate
items[thread_local_sample_idx] = n - 1;
else
continue; // this case is for samples whose mask == 1 (saving previous iteraion's random
// number generated)
}

// Specialize BlockRadixSort type for our thread block
typedef cub::BlockRadixSort<IdxT, BLOCK_THREADS, MAX_SAMPLES_PER_THREAD> BlockRadixSortT;
// BlockAdjacentDifference
typedef cub::BlockAdjacentDifference<IdxT, BLOCK_THREADS> BlockAdjacentDifferenceT;
// BlockScan
typedef cub::BlockScan<IdxT, BLOCK_THREADS> BlockScanT;

// Shared memory
__shared__ union TempStorage {
typename BlockRadixSortT::TempStorage sort;
typename BlockAdjacentDifferenceT::TempStorage diff;
typename BlockScanT::TempStorage scan;
} temp_storage;

// collectively sort items
BlockRadixSortT(temp_storage.sort).Sort(items);

__syncthreads();

// compute the mask
// compute the adjacent differences according to the functor
BlockAdjacentDifferenceT(temp_storage.diff)
.FlagHeads(mask, items, mask, CustomDifference<IdxT>());
vinaydes marked this conversation as resolved.
Show resolved Hide resolved

__syncthreads();

// do a scan on the mask to get the indices for gathering
BlockScanT(temp_storage.scan).ExclusiveSum(mask, col_indices, n_uniques);

__syncthreads();

} while (n_uniques < k);

// write the items[] of only the ones with mask[]=1 to col[offset + col_idx[]]
IdxT col_offset = k * blockIdx.x;
for (int i = 0; i < MAX_SAMPLES_PER_THREAD; ++i) {
if (mask[i] and col_indices[i] < k) { colids[col_offset + col_indices[i]] = items[i]; }
}
}

// algo L of the reservoir sampling algorithm
/**
* @brief Generates 'k' unique samples of features from 'n' feature sample-space using the algo-L
vinaydes marked this conversation as resolved.
Show resolved Hide resolved
* algorithm of reservoir sampling. wiki :
* https://en.wikipedia.org/wiki/Reservoir_sampling#An_optimal_algorithm
*/
template <typename IdxT>
__global__ void algo_L_sample_kernel(int* colids,
vinaydes marked this conversation as resolved.
Show resolved Hide resolved
const NodeWorkItem* work_items,
size_t work_items_size,
IdxT treeid,
uint64_t seed,
size_t n /* total cols to sample from*/,
size_t k /* cols to sample */)
{
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= work_items_size) return;
const uint32_t nodeid = work_items[tid].idx;
uint64_t subsequence = (uint64_t(treeid) << 32) | uint64_t(nodeid);
raft::random::PCGenerator gen(seed, subsequence, uint64_t(0));
raft::random::UniformIntDistParams<IdxT, uint64_t> uniform_int_dist_params;
uniform_int_dist_params.start = 0;
uniform_int_dist_params.end = k;
uniform_int_dist_params.diff =
uint64_t(uniform_int_dist_params.end - uniform_int_dist_params.start);
float fp_uniform_val;
IdxT int_uniform_val;
// fp_uniform_val will have a random value between 0 and 1
gen.next(fp_uniform_val);
double W = raft::myExp(raft::myLog(fp_uniform_val) / k);
tfeher marked this conversation as resolved.
Show resolved Hide resolved

size_t col(0);
// initially fill the reservoir array in increasing order of cols till k
while (1) {
colids[tid * k + col] = col;
if (col == k - 1)
break;
else
++col;
}
// randomly sample from a geometric distribution
while (col < n) {
// fp_uniform_val will have a random value between 0 and 1
gen.next(fp_uniform_val);
col += static_cast<int>(raft::myLog(fp_uniform_val) / raft::myLog(1 - W)) + 1;
if (col < n) {
// int_uniform_val will now have a random value between 0...k
raft::random::custom_next(gen, &int_uniform_val, uniform_int_dist_params, IdxT(0), IdxT(0));
colids[tid * k + int_uniform_val] = col; // the bad memory coalescing here is hidden
// fp_uniform_val will have a random value between 0 and 1
gen.next(fp_uniform_val);
W *= raft::myExp(raft::myLog(fp_uniform_val) / k);
}
}
}

template <typename IdxT>
__global__ void adaptive_sample_kernel(int* colids,
vinaydes marked this conversation as resolved.
Show resolved Hide resolved
const NodeWorkItem* work_items,
size_t work_items_size,
IdxT treeid,
uint64_t seed,
int N,
int M)
{
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= work_items_size) return;
const uint32_t nodeid = work_items[tid].idx;

uint64_t subsequence = (uint64_t(treeid) << 32) | uint64_t(nodeid);
raft::random::PCGenerator gen(seed, subsequence, uint64_t(0));

int selected_count = 0;
for (int i = 0; i < N; i++) {
uint32_t toss = 0;
gen.next(toss);
uint64_t lhs = uint64_t(M - selected_count);
lhs <<= 32;
uint64_t rhs = uint64_t(toss) * (N - i);
if (lhs > rhs) {
colids[tid * M + selected_count] = i;
selected_count++;
if (selected_count == M) break;
}
}
}

template <typename DataT,
typename LabelT,
typename IdxT,
Expand All @@ -126,6 +340,7 @@ __global__ void computeSplitKernel(BinT* histograms,
const Quantiles<DataT, IdxT> quantiles,
const NodeWorkItem* work_items,
IdxT colStart,
const IdxT* colids,
int* done_count,
int* mutex,
volatile Split<DataT, IdxT>* splits,
Expand Down
Loading