Skip to content

Commit

Permalink
bug fix and code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
seunghwak committed Feb 4, 2025
1 parent 8309b05 commit 0ac05c5
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions cpp/src/prims/detail/sample_and_compute_local_nbr_indices.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1274,13 +1274,13 @@ rmm::device_uvector<edge_t> compute_homogeneous_uniform_sampling_index_without_r
thrust::make_counting_iterator((*retry_segment_indices).size()),
[high_partition_oversampling_K,
unique_counts = raft::device_span<edge_t>(unique_counts.data(), unique_counts.size()),
retry_segment_indices = (*retry_segment_indices).data(),
retry_segment_sorted_pair_first = thrust::make_zip_iterator(
thrust::make_tuple((*retry_segment_sorted_nbr_indices).begin(),
(*retry_segment_sorted_sample_indices).begin())),
segment_sorted_pair_first = thrust::make_zip_iterator(thrust::make_tuple(
retry_segment_indices = (*retry_segment_indices).data(),
retry_segment_sorted_pair_first =
thrust::make_zip_iterator((*retry_segment_sorted_nbr_indices).begin(),
(*retry_segment_sorted_sample_indices).begin()),
segment_sorted_pair_first = thrust::make_zip_iterator(
segment_sorted_tmp_nbr_indices.begin(),
segment_sorted_tmp_sample_indices.begin()))] __device__(size_t i) {
segment_sorted_tmp_sample_indices.begin())] __device__(size_t i) {
auto input_pair_first =
retry_segment_sorted_pair_first + high_partition_oversampling_K * i;
auto segment_idx = retry_segment_indices[i];
Expand Down Expand Up @@ -1310,9 +1310,9 @@ rmm::device_uvector<edge_t> compute_homogeneous_uniform_sampling_index_without_r
thrust::make_counting_iterator(num_segments),
[high_partition_oversampling_K,
unique_counts = raft::device_span<edge_t>(unique_counts.data(), unique_counts.size()),
segment_sorted_pair_first = thrust::make_zip_iterator(thrust::make_tuple(
segment_sorted_pair_first = thrust::make_zip_iterator(
segment_sorted_tmp_nbr_indices.begin(),
segment_sorted_tmp_sample_indices.begin()))] __device__(size_t i) {
segment_sorted_tmp_sample_indices.begin())] __device__(size_t i) {
auto pair_first = segment_sorted_pair_first + high_partition_oversampling_K * i;
assert(high_partition_oversampling_K > 0);
thrust::tuple<edge_t, int32_t> prev = *pair_first;
Expand Down Expand Up @@ -1701,9 +1701,9 @@ rmm::device_uvector<edge_t> compute_heterogeneous_uniform_sampling_index_without
retry_segment_sorted_pair_first = thrust::make_zip_iterator(
thrust::make_tuple((*retry_segment_sorted_per_type_nbr_indices).begin(),
(*retry_segment_sorted_sample_indices).begin())),
segment_sorted_pair_first = thrust::make_zip_iterator(thrust::make_tuple(
segment_sorted_pair_first = thrust::make_zip_iterator(
segment_sorted_tmp_per_type_nbr_indices.begin(),
segment_sorted_tmp_sample_indices.begin()))] __device__(size_t i) {
segment_sorted_tmp_sample_indices.begin())] __device__(size_t i) {
auto input_pair_first =
retry_segment_sorted_pair_first + high_partition_oversampling_K * i;
auto segment_idx = retry_segment_indices[i];
Expand Down Expand Up @@ -1733,9 +1733,9 @@ rmm::device_uvector<edge_t> compute_heterogeneous_uniform_sampling_index_without
thrust::make_counting_iterator(num_segments),
[high_partition_oversampling_K,
unique_counts = raft::device_span<edge_t>(unique_counts.data(), unique_counts.size()),
segment_sorted_pair_first = thrust::make_zip_iterator(thrust::make_tuple(
segment_sorted_pair_first = thrust::make_zip_iterator(
segment_sorted_tmp_per_type_nbr_indices.begin(),
segment_sorted_tmp_sample_indices.begin()))] __device__(size_t i) {
segment_sorted_tmp_sample_indices.begin())] __device__(size_t i) {
auto pair_first = segment_sorted_pair_first + high_partition_oversampling_K * i;
assert(high_partition_oversampling_K > 0);
thrust::tuple<edge_t, int32_t> prev = *pair_first;
Expand Down Expand Up @@ -2871,11 +2871,10 @@ shuffle_and_compute_local_nbr_values(
thrust::make_zip_iterator(sample_local_nbr_values.begin(),
thrust::make_transform_iterator(
thrust::make_counting_iterator(size_t{0}), divider_t<size_t>{K}));
auto output_tuple_first =
thrust::make_zip_iterator(thrust::make_tuple(minor_comm_ranks.begin(),
intra_partition_displacements.begin(),
sample_local_nbr_values.begin(),
key_indices.begin()));
auto output_tuple_first = thrust::make_zip_iterator(minor_comm_ranks.begin(),
intra_partition_displacements.begin(),
sample_local_nbr_values.begin(),
key_indices.begin());
thrust::for_each(
handle.get_thrust_policy(),
thrust::make_counting_iterator(size_t{0}),
Expand Down Expand Up @@ -3838,7 +3837,7 @@ homogeneous_biased_sample_without_replacement(

// shuffle local sampling outputs

std::vector<size_t> tx_counts(high_local_frontier_sizes);
std::vector<size_t> tx_counts(high_local_frontier_sizes.size());
std::transform(high_local_frontier_sizes.begin(),
high_local_frontier_sizes.end(),
tx_counts.begin(),
Expand Down Expand Up @@ -4470,7 +4469,7 @@ heterogeneous_biased_sample_without_replacement(
// local sample and update indices

rmm::device_uvector<size_t> aggregate_high_local_frontier_output_offsets(
high_local_frontier_offsets.back(), handle.get_stream());
high_local_frontier_offsets.back() + 1, handle.get_stream());
{
auto K_first = thrust::make_transform_iterator(
std::get<1>(aggregate_high_local_frontier_index_type_pairs).begin(),
Expand Down Expand Up @@ -4537,7 +4536,7 @@ heterogeneous_biased_sample_without_replacement(

// shuffle local sampling outputs

std::vector<size_t> tx_counts(high_local_frontier_sizes);
std::vector<size_t> tx_counts(high_local_frontier_sizes.size());
{
rmm::device_uvector<size_t> d_high_local_frontier_offsets(
high_local_frontier_offsets.size(), handle.get_stream());
Expand Down

0 comments on commit 0ac05c5

Please sign in to comment.