Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[coll] Add global functions. #10203

Merged
merged 2 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \
Expand Down
2 changes: 1 addition & 1 deletion src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ template <typename T>
T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) {
std::array<T, 2> results{dividend, divisor};
auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size()));
collective::SafeColl(rc);
SafeColl(rc);
std::tie(dividend, divisor) = std::tuple_cat(results);
if (divisor <= 0) {
return std::numeric_limits<T>::quiet_NaN();
Expand Down
47 changes: 46 additions & 1 deletion src/collective/allgather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,18 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
bool is_last_segment = send_rank == (world - 1);
auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size;
auto send_seg = data.subspan(send_off, send_nbytes);
CHECK_NE(send_seg.size(), 0);
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
} << [&] {
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
auto recv_off = recv_rank * segment_size;
bool is_last_segment = recv_rank == (world - 1);
auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size;
auto recv_seg = data.subspan(recv_off, recv_nbytes);
CHECK_NE(recv_seg.size(), 0);
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
} << [&] {
return prev_ch->Block();
return comm.Block();
};
if (!rc.OK()) {
return rc;
Expand Down Expand Up @@ -106,4 +108,47 @@ namespace detail {
return comm.Block();
}
} // namespace detail

[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input) {
auto n_inputs = input.size();
std::vector<std::int64_t> sizes(n_inputs);
std::transform(input.cbegin(), input.cend(), sizes.begin(),
[](auto const& vec) { return vec.size(); });

std::vector<std::int64_t> recv_segments(comm.World() + 1, 0);

HostDeviceVector<std::int8_t> recv;
auto rc =
AllgatherV(ctx, comm, linalg::MakeVec(sizes.data(), sizes.size()), &recv_segments, &recv);
SafeColl(rc);

auto global_sizes = common::RestoreType<std::int64_t const>(recv.ConstHostSpan());
std::vector<std::int64_t> offset(global_sizes.size() + 1);
offset[0] = 0;
for (std::size_t i = 1; i < offset.size(); i++) {
offset[i] = offset[i - 1] + global_sizes[i - 1];
}

std::vector<char> collected;
for (auto const& vec : input) {
collected.insert(collected.end(), vec.cbegin(), vec.cend());
}
rc = AllgatherV(ctx, comm, linalg::MakeVec(collected.data(), collected.size()), &recv_segments,
&recv);
SafeColl(rc);
auto out = common::RestoreType<char const>(recv.ConstHostSpan());

std::vector<std::vector<char>> result;
for (std::size_t i = 1; i < offset.size(); ++i) {
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
result.emplace_back(std::move(local));
}
return result;
}

[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, std::vector<std::vector<char>> const& input) {
return VectorAllgatherV(ctx, *GlobalCommGroup(), input);
}
} // namespace xgboost::collective
111 changes: 111 additions & 0 deletions src/collective/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,115 @@ template <typename T>

return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
}

template <typename T>
[[nodiscard]] Result Allgather(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());

auto const& cctx = comm.Ctx(ctx, data.Device());
auto backend = comm.Backend(data.Device());
return backend->Allgather(cctx, erased);
}

/**
* @brief Gather all data from all workers.
*
* @param data The input and output buffer, needs to be pre-allocated by the caller.
*/
template <typename T>
[[nodiscard]] Result Allgather(Context const* ctx, linalg::VectorView<T> data) {
auto const& cg = *GlobalCommGroup();
if (data.Size() % cg.World() != 0) {
return Fail("The total number of elements should be multiple of the number of workers.");
}
return Allgather(ctx, cg, data);
}

template <typename T>
[[nodiscard]] Result AllgatherV(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data,
std::vector<std::int64_t>* recv_segments,
HostDeviceVector<std::int8_t>* recv) {
if (!comm.IsDistributed()) {
return Success();
}
std::vector<std::int64_t> sizes(comm.World(), 0);
sizes[comm.Rank()] = data.Values().size_bytes();
auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()});
auto rc = comm.Backend(DeviceOrd::CPU())
->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes);
if (!rc.OK()) {
return rc;
}

recv_segments->resize(sizes.size() + 1);
detail::AllgatherVOffset(sizes, common::Span{recv_segments->data(), recv_segments->size()});
auto total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0LL);
recv->SetDevice(data.Device());
recv->Resize(total_bytes);

auto s_segments = common::Span{recv_segments->data(), recv_segments->size()};

auto backend = comm.Backend(data.Device());
auto erased = common::EraseType(data.Values());

return backend->AllgatherV(
comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments,
data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(), AllgatherVAlgo::kBcast);
}

/**
* @brief Allgather with variable length data.
*
* @param data The input data.
* @param recv_segments segment size for each worker. [0, 2, 5] means [0, 2) elements are
* from the first worker, [2, 5) elements are from the second one.
* @param recv The buffer storing the result.
*/
template <typename T>
[[nodiscard]] Result AllgatherV(Context const* ctx, linalg::VectorView<T> data,
std::vector<std::int64_t>* recv_segments,
HostDeviceVector<std::int8_t>* recv) {
return AllgatherV(ctx, *GlobalCommGroup(), data, recv_segments, recv);
}

[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input);

/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
*
* @param inputs All the inputs from the local worker. The number of inputs can vary
* across different workers. Along with which, the size of each vector in
* the input can also vary.
*
* @return The AllgatherV result, containing vectors from all workers.
*/
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, std::vector<std::vector<char>> const& input);

/**
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
* @param input Variable-length list of variable-length strings.
*/
[[nodiscard]] inline Result AllgatherStrings(std::vector<std::string> const& input,
std::vector<std::string>* p_result) {
std::vector<std::vector<char>> inputs(input.size());
for (std::size_t i = 0; i < input.size(); ++i) {
inputs[i] = {input[i].cbegin(), input[i].cend()};
}
Context ctx;
auto out = VectorAllgatherV(&ctx, *GlobalCommGroup(), inputs);
auto& result = *p_result;
result.resize(out.size());
for (std::size_t i = 0; i < out.size(); ++i) {
result[i] = {out[i].cbegin(), out[i].cend()};
}
return Success();
}
} // namespace xgboost::collective
42 changes: 19 additions & 23 deletions src/collective/allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,39 +68,35 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto s_buf = common::Span{buffer.data(), buffer.size()};

for (std::int32_t r = 0; r < world - 1; ++r) {
// send to ring next
auto send_rank = (rank + world - r) % world;
auto send_off = send_rank * n_bytes_in_seg;
common::Span<std::int8_t> seg, recv_seg;
auto rc = Success() << [&] {
// send to ring next
auto send_rank = (rank + world - r) % world;
auto send_off = send_rank * n_bytes_in_seg;

bool is_last_segment = send_rank == (world - 1);
bool is_last_segment = send_rank == (world - 1);

auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg;
auto send_seg = data.subspan(send_off, seg_nbytes);
auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg;
CHECK_EQ(seg_nbytes % sizeof(T), 0);

auto rc = next_ch->SendAll(send_seg);
if (!rc.OK()) {
return rc;
}

// receive from ring prev
auto recv_rank = (rank + world - r - 1) % world;
auto recv_off = recv_rank * n_bytes_in_seg;
auto send_seg = data.subspan(send_off, seg_nbytes);
return next_ch->SendAll(send_seg);
} << [&] {
// receive from ring prev
auto recv_rank = (rank + world - r - 1) % world;
auto recv_off = recv_rank * n_bytes_in_seg;

is_last_segment = recv_rank == (world - 1);
bool is_last_segment = recv_rank == (world - 1);

seg_nbytes = is_last_segment ? data.size_bytes() - recv_off : n_bytes_in_seg;
CHECK_EQ(seg_nbytes % sizeof(T), 0);
auto recv_seg = data.subspan(recv_off, seg_nbytes);
auto seg = s_buf.subspan(0, recv_seg.size());
auto seg_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : n_bytes_in_seg;
CHECK_EQ(seg_nbytes % sizeof(T), 0);

rc = std::move(rc) << [&] {
recv_seg = data.subspan(recv_off, seg_nbytes);
seg = s_buf.subspan(0, recv_seg.size());
return prev_ch->RecvAll(seg);
} << [&] {
return comm.Block();
};
if (!rc.OK()) {
return rc;
}

// accumulate to recv_seg
CHECK_EQ(seg.size(), recv_seg.size());
Expand Down
46 changes: 42 additions & 4 deletions src/collective/allreduce.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int8_t
#include <functional> // for function
#include <type_traits> // for is_invocable_v, enable_if_t
#include <vector> // for vector

#include "../common/type.h" // for EraseType, RestoreType
#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "../data/array_interface.h" // for ToDType, ArrayInterfaceHandler
#include "comm.h" // for Comm, RestoreType
#include "comm_group.h" // for GlobalCommGroup
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span

namespace xgboost::collective {
Expand All @@ -27,8 +30,7 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>
auto erased = common::EraseType(data);
auto type = ToDType<T>::kType;

auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
common::Span<std::int8_t> out) {
auto erased_fn = [redop](common::Span<std::int8_t const> lhs, common::Span<std::int8_t> out) {
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
auto lhs_t = common::RestoreType<T const>(lhs);
auto rhs_t = common::RestoreType<T>(out);
Expand All @@ -37,4 +39,40 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>

return cpu_impl::RingAllreduce(comm, erased, erased_fn, type);
}

template <typename T, std::int32_t kDim>
[[nodiscard]] Result Allreduce(Context const* ctx, CommGroup const& comm,
linalg::TensorView<T, kDim> data, Op op) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());
auto type = ToDType<T>::kType;

auto backend = comm.Backend(data.Device());
return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op);
}

template <typename T, std::int32_t kDim>
[[nodiscard]] Result Allreduce(Context const* ctx, linalg::TensorView<T, kDim> data, Op op) {
return Allreduce(ctx, *GlobalCommGroup(), data, op);
}

/**
* @brief Specialization for std::vector.
*/
template <typename T, typename Alloc>
[[nodiscard]] Result Allreduce(Context const* ctx, std::vector<T, Alloc>* data, Op op) {
return Allreduce(ctx, linalg::MakeVec(data->data(), data->size()), op);
}

/**
* @brief Specialization for scalar value.
*/
template <typename T>
[[nodiscard]] std::enable_if_t<std::is_standard_layout_v<T> && std::is_trivial_v<T>, Result>
Allreduce(Context const* ctx, T* data, Op op) {
return Allreduce(ctx, linalg::MakeVec(data, 1), op);
}
} // namespace xgboost::collective
27 changes: 24 additions & 3 deletions src/collective/broadcast.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int32_t, int8_t

#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for
#include "../common/type.h"
#include "comm.h" // for Comm, EraseType
#include "comm_group.h" // for CommGroup
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" // for VectorView
#include "xgboost/span.h" // for Span

namespace xgboost::collective {
Expand All @@ -23,4 +27,21 @@ template <typename T>
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
return cpu_impl::Broadcast(comm, erased, root);
}

template <typename T>
[[nodiscard]] Result Broadcast(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data, std::int32_t root) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());
auto backend = comm.Backend(data.Device());
return backend->Broadcast(comm.Ctx(ctx, data.Device()), erased, root);
}

template <typename T>
[[nodiscard]] Result Broadcast(Context const* ctx, linalg::VectorView<T> data, std::int32_t root) {
return Broadcast(ctx, *GlobalCommGroup(), data, root);
}
} // namespace xgboost::collective
Loading
Loading