Skip to content

Commit

Permalink
[CUDA] Add lambdarank objective for cuda_exp (#5453)
Browse files Browse the repository at this point in the history
* add lambdarank for cuda_exp

* support unlimited number of ranks in labels

* fix lint errors

* remove warning for lambdarank with cuda_exp

* Update src/objective/cuda/cuda_rank_objective.hpp

Co-authored-by: Nikita Titov <[email protected]>

* Update src/objective/cuda/cuda_rank_objective.hpp

Co-authored-by: Nikita Titov <[email protected]>

Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
shiyu1994 and StrikerRUS authored Sep 5, 2022
1 parent c9a3b47 commit 1d5f46f
Show file tree
Hide file tree
Showing 7 changed files with 588 additions and 9 deletions.
60 changes: 54 additions & 6 deletions include/LightGBM/cuda/cuda_algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,11 @@

#include <algorithm>

#define NUM_BANKS_DATA_PARTITION (16)
#define LOG_NUM_BANKS_DATA_PARTITION (4)
#define GLOBAL_PREFIX_SUM_BLOCK_SIZE (1024)

#define BITONIC_SORT_NUM_ELEMENTS (1024)
#define BITONIC_SORT_DEPTH (11)
#define BITONIC_SORT_QUERY_ITEM_BLOCK_SIZE (10)

#define CONFLICT_FREE_INDEX(n) \
((n) + ((n) >> LOG_NUM_BANKS_DATA_PARTITION)) \

namespace LightGBM {

template <typename T>
Expand Down Expand Up @@ -223,6 +217,54 @@ __device__ __forceinline__ void BitonicArgSort_1024(const VAL_T* scores, INDEX_T
}
}

template <typename VAL_T, typename INDEX_T, bool ASCENDING>
__device__ __forceinline__ void BitonicArgSort_2048(const VAL_T* scores, INDEX_T* indices) {
for (INDEX_T base = 0; base < 2048; base += 1024) {
for (INDEX_T outer_depth = 10; outer_depth >= 1; --outer_depth) {
const INDEX_T outer_segment_length = 1 << (11 - outer_depth);
const INDEX_T outer_segment_index = threadIdx.x / outer_segment_length;
const bool ascending = ((base == 0) ^ ASCENDING) ? (outer_segment_index % 2 > 0) : (outer_segment_index % 2 == 0);
for (INDEX_T inner_depth = outer_depth; inner_depth < 11; ++inner_depth) {
const INDEX_T segment_length = 1 << (11 - inner_depth);
const INDEX_T half_segment_length = segment_length >> 1;
const INDEX_T half_segment_index = threadIdx.x / half_segment_length;
if (half_segment_index % 2 == 0) {
const INDEX_T index_to_compare = threadIdx.x + half_segment_length + base;
if ((scores[indices[threadIdx.x + base]] > scores[indices[index_to_compare]]) == ascending) {
const INDEX_T index = indices[threadIdx.x + base];
indices[threadIdx.x + base] = indices[index_to_compare];
indices[index_to_compare] = index;
}
}
__syncthreads();
}
}
}
const unsigned int index_to_compare = threadIdx.x + 1024;
if (scores[indices[index_to_compare]] > scores[indices[threadIdx.x]]) {
const INDEX_T temp_index = indices[index_to_compare];
indices[index_to_compare] = indices[threadIdx.x];
indices[threadIdx.x] = temp_index;
}
__syncthreads();
for (INDEX_T base = 0; base < 2048; base += 1024) {
for (INDEX_T inner_depth = 1; inner_depth < 11; ++inner_depth) {
const INDEX_T segment_length = 1 << (11 - inner_depth);
const INDEX_T half_segment_length = segment_length >> 1;
const INDEX_T half_segment_index = threadIdx.x / half_segment_length;
if (half_segment_index % 2 == 0) {
const INDEX_T index_to_compare = threadIdx.x + half_segment_length + base;
if (scores[indices[threadIdx.x + base]] < scores[indices[index_to_compare]]) {
const INDEX_T index = indices[threadIdx.x + base];
indices[threadIdx.x + base] = indices[index_to_compare];
indices[index_to_compare] = index;
}
}
__syncthreads();
}
}
}

template <typename VAL_T, typename INDEX_T, bool ASCENDING, uint32_t BLOCK_DIM, uint32_t MAX_DEPTH>
__device__ void BitonicArgSortDevice(const VAL_T* values, INDEX_T* indices, const int len) {
__shared__ VAL_T shared_values[BLOCK_DIM];
Expand Down Expand Up @@ -387,6 +429,12 @@ __device__ void BitonicArgSortDevice(const VAL_T* values, INDEX_T* indices, cons
}
}

void BitonicArgSortItemsGlobal(
const double* scores,
const int num_queries,
const data_size_t* cuda_query_boundaries,
data_size_t* out_indices);

template <typename VAL_T, typename INDEX_T, bool ASCENDING>
void BitonicArgSortGlobal(const VAL_T* values, INDEX_T* indices, const size_t len);

Expand Down
28 changes: 28 additions & 0 deletions src/cuda/cuda_algorithms.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,34 @@ void ShufflePrefixSumGlobal(uint64_t* values, size_t len, uint64_t* block_prefix
ShufflePrefixSumGlobalInner<uint64_t>(values, len, block_prefix_sum_buffer);
}

__global__ void BitonicArgSortItemsGlobalKernel(const double* scores,
const int num_queries,
const data_size_t* cuda_query_boundaries,
data_size_t* out_indices) {
const int query_index_start = static_cast<int>(blockIdx.x) * BITONIC_SORT_QUERY_ITEM_BLOCK_SIZE;
const int query_index_end = min(query_index_start + BITONIC_SORT_QUERY_ITEM_BLOCK_SIZE, num_queries);
for (int query_index = query_index_start; query_index < query_index_end; ++query_index) {
const data_size_t query_item_start = cuda_query_boundaries[query_index];
const data_size_t query_item_end = cuda_query_boundaries[query_index + 1];
const data_size_t num_items_in_query = query_item_end - query_item_start;
BitonicArgSortDevice<double, data_size_t, false, BITONIC_SORT_NUM_ELEMENTS, 11>(scores + query_item_start,
out_indices + query_item_start,
num_items_in_query);
__syncthreads();
}
}

void BitonicArgSortItemsGlobal(
const double* scores,
const int num_queries,
const data_size_t* cuda_query_boundaries,
data_size_t* out_indices) {
const int num_blocks = (num_queries + BITONIC_SORT_QUERY_ITEM_BLOCK_SIZE - 1) / BITONIC_SORT_QUERY_ITEM_BLOCK_SIZE;
BitonicArgSortItemsGlobalKernel<<<num_blocks, BITONIC_SORT_NUM_ELEMENTS>>>(
scores, num_queries, cuda_query_boundaries, out_indices);
SynchronizeCUDADevice(__FILE__, __LINE__);
}

template <typename T>
__global__ void BlockReduceSum(T* block_buffer, const data_size_t num_blocks) {
__shared__ T shared_buffer[32];
Expand Down
65 changes: 65 additions & 0 deletions src/objective/cuda/cuda_rank_objective.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/

#ifdef USE_CUDA_EXP

#include <string>
#include <vector>

#include "cuda_rank_objective.hpp"

namespace LightGBM {

CUDALambdarankNDCG::CUDALambdarankNDCG(const Config& config):
LambdarankNDCG(config) {}

CUDALambdarankNDCG::CUDALambdarankNDCG(const std::vector<std::string>& strs): LambdarankNDCG(strs) {}

void CUDALambdarankNDCG::Init(const Metadata& metadata, data_size_t num_data) {
const int num_threads = OMP_NUM_THREADS();
LambdarankNDCG::Init(metadata, num_data);

std::vector<uint16_t> thread_max_num_items_in_query(num_threads);
Threading::For<data_size_t>(0, num_queries_, 1,
[this, &thread_max_num_items_in_query] (int thread_index, data_size_t start, data_size_t end) {
for (data_size_t query_index = start; query_index < end; ++query_index) {
const data_size_t query_item_count = query_boundaries_[query_index + 1] - query_boundaries_[query_index];
if (query_item_count > thread_max_num_items_in_query[thread_index]) {
thread_max_num_items_in_query[thread_index] = query_item_count;
}
}
});
data_size_t max_items_in_query = 0;
for (int thread_index = 0; thread_index < num_threads; ++thread_index) {
if (thread_max_num_items_in_query[thread_index] > max_items_in_query) {
max_items_in_query = thread_max_num_items_in_query[thread_index];
}
}
max_items_in_query_aligned_ = 1;
--max_items_in_query;
while (max_items_in_query > 0) {
max_items_in_query >>= 1;
max_items_in_query_aligned_ <<= 1;
}
if (max_items_in_query_aligned_ > 2048) {
cuda_item_indices_buffer_.Resize(static_cast<size_t>(metadata.query_boundaries()[metadata.num_queries()]));
}
cuda_labels_ = metadata.cuda_metadata()->cuda_label();
cuda_query_boundaries_ = metadata.cuda_metadata()->cuda_query_boundaries();
cuda_inverse_max_dcgs_.Resize(inverse_max_dcgs_.size());
CopyFromHostToCUDADevice(cuda_inverse_max_dcgs_.RawData(), inverse_max_dcgs_.data(), inverse_max_dcgs_.size(), __FILE__, __LINE__);
cuda_label_gain_.Resize(label_gain_.size());
CopyFromHostToCUDADevice(cuda_label_gain_.RawData(), label_gain_.data(), label_gain_.size(), __FILE__, __LINE__);
}

void CUDALambdarankNDCG::GetGradients(const double* score, score_t* gradients, score_t* hessians) const {
LaunchGetGradientsKernel(score, gradients, hessians);
}


} // namespace LightGBM

#endif // USE_CUDA_EXP
Loading

0 comments on commit 1d5f46f

Please sign in to comment.