Skip to content

Commit

Permalink
[xla:gpu] Add NCCL communicator splitting
Browse files Browse the repository at this point in the history
Porting #7586 to new NcclClique APIs

PiperOrigin-RevId: 608015848
  • Loading branch information
ezhulenev authored and copybara-github committed Feb 17, 2024
1 parent 840f8e0 commit d6cfa24
Show file tree
Hide file tree
Showing 13 changed files with 278 additions and 90 deletions.
7 changes: 7 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_nccl_termination_timeout_seconds(-1);
opts.set_xla_gpu_enable_shared_constants(true);
opts.set_xla_gpu_enable_nccl_user_buffers(false);
opts.set_xla_gpu_enable_nccl_comm_splitting(false);

// Set 4GB space limit for redzone scratch allocator.
opts.set_xla_gpu_redzone_scratch_max_megabytes(1LL << 12);
Expand Down Expand Up @@ -1169,6 +1170,12 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"Enables NCCL User Buffer Registration. collective_memory_size in the "
"allocator config must also be set to a non-zero value that is large "
"enough to meet peak collective memory usage."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_nccl_comm_splitting",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_nccl_comm_splitting),
debug_options->xla_gpu_enable_nccl_comm_splitting(),
"Enables NCCL communicator splitting which allows sharing NCCL resources "
"between different NCCL cliques."));
flag_list->push_back(tsl::Flag(
"xla_gpu_redzone_scratch_max_megabytes",
int64_setter_for(
Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,7 @@ cc_library(
"//xla/stream_executor",
"//xla/stream_executor/gpu:gpu_activation",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -1066,7 +1067,6 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/hash",
Expand All @@ -1079,6 +1079,7 @@ cc_library(
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:hash",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
],
Expand Down
69 changes: 33 additions & 36 deletions xla/service/gpu/nccl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.

#include <cstddef>
#include <cstdint>
#include <iterator>
#include <optional>
#include <string_view>
#include <utility>
Expand Down Expand Up @@ -302,13 +301,9 @@ class DefaultNcclApi final : public NcclApi {
int32_t nranks, const NcclCliqueId& clique_id,
absl::Span<const DeviceRank> ranks) final;

absl::StatusOr<OwnedNcclComm> CommSplit(NcclCommHandle comm,
std::optional<int32_t> color,
int32_t key) final;

absl::StatusOr<std::vector<OwnedNcclComm>> CommSplit(
absl::Span<const DeviceComm> comms,
absl::Span<const int32_t> ranks) final;
absl::Span<const NcclCommHandle> comms, int32_t color,
absl::Span<const int32_t> keys) final;

absl::Status CommAbort(NcclCommHandle comm) final;
absl::Status CommFinalize(NcclCommHandle comm) final;
Expand Down Expand Up @@ -408,42 +403,44 @@ DefaultNcclApi::CommInitRanks(int32_t nranks, const NcclCliqueId& clique_id,
return comms;
}

absl::StatusOr<NcclApi::OwnedNcclComm> DefaultNcclApi::CommSplit(
NcclCommHandle comm, std::optional<int32_t> color, int32_t key) {
VLOG(1) << "Split NCCL communicator " << comm << " with color "
<< color.value_or(NCCL_SPLIT_NOCOLOR) << " and key " << key;

ncclComm_t split_comm = nullptr;
XLA_NCCL_RETURN_IF_ERROR(ncclCommSplit(Cast(comm),
color.value_or(NCCL_SPLIT_NOCOLOR),
key, &split_comm, /*config=*/nullptr));

return OwnedNcclComm(Cast(split_comm), NcclCommDeleter{this});
}

absl::StatusOr<std::vector<NcclApi::OwnedNcclComm>> DefaultNcclApi::CommSplit(
absl::Span<const DeviceComm> comms, absl::Span<const int32_t> ranks) {
VLOG(1) << "Split " << comms.size() << " NCCL communicators with"
<< " participating split ranks [" << absl::StrJoin(ranks, ",") << "]";
absl::Span<const NcclCommHandle> comms, int32_t color,
absl::Span<const int32_t> keys) {
VLOG(1) << absl::StreamFormat(
"Split %d NCCL communicators using color %d and keys: [%s]", comms.size(),
color, absl::StrJoin(keys, ","));

if (keys.size() != comms.size()) {
return absl::InvalidArgumentError(
absl::StrFormat("Comms and keys must have the same size, but %d != %d",
comms.size(), keys.size()));
}

std::vector<OwnedNcclComm> split_comms;
std::vector<ncclComm_t> split_comms;
split_comms.resize(comms.size(), nullptr);

TF_RETURN_IF_ERROR(GroupStart());
for (int32_t rank = 0; rank < comms.size(); ++rank) {
se::gpu::ScopedActivateExecutorContext activate_context(comms[rank].device);

if (auto it = absl::c_find(ranks, rank); it != ranks.end()) {
TF_ASSIGN_OR_RETURN(split_comms.emplace_back(),
CommSplit(comms[rank].comm, /*color=*/0,
/*key=*/std::distance(ranks.begin(), it)));
} else {
TF_RETURN_IF_ERROR(
CommSplit(comms[rank].comm, std::nullopt, rank).status());
}
for (size_t i = 0; i < comms.size(); ++i) {
VLOG(1) << "Split NCCL communicator " << comms[i] << " with color " << color
<< " and key " << keys[i];
XLA_NCCL_RETURN_IF_ERROR(ncclCommSplit(
Cast(comms[i]), color, keys[i], &split_comms[i], /*config=*/nullptr));
}
TF_RETURN_IF_ERROR(GroupEnd());

return split_comms;
// Check that every split rank got a communicator and convert created
// communicators into owned RAII wrappers.
std::vector<OwnedNcclComm> split_owned_comms;
for (size_t i = 0; i < split_comms.size(); ++i) {
if (split_comms[i] == nullptr) {
return absl::InternalError(absl::StrFormat(
"Failed to create a split communicator with color %d for key %d",
color, keys[i]));
}
split_owned_comms.emplace_back(Cast(split_comms[i]), NcclCommDeleter{this});
}

return split_owned_comms;
}

absl::Status DefaultNcclApi::CommAbort(NcclCommHandle comm) {
Expand Down
27 changes: 11 additions & 16 deletions xla/service/gpu/nccl_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <vector>

#include "absl/status/status.h"
Expand Down Expand Up @@ -117,13 +116,11 @@ class NcclApi {
};

struct DeviceRank {
se::StreamExecutor* device;
int32_t rank;
};
DeviceRank(se::StreamExecutor* device, int32_t rank)
: device(device), rank(rank) {}

struct DeviceComm {
se::StreamExecutor* device;
NcclCommHandle comm;
int32_t rank;
};

// Returns a slice of device memory `buff` containing `count` values of data
Expand All @@ -150,23 +147,21 @@ class NcclApi {
//
// This API doesn't have a corresponding API in NCCL and implemented as
// multiple calls to CommInitRank within a single group.
//
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcomminitrank
virtual absl::StatusOr<std::vector<OwnedNcclComm>> CommInitRanks(
int32_t nranks, const NcclCliqueId& clique_id,
absl::Span<const DeviceRank> ranks) = 0;

// Creates a new communicator by splitting an existing one.
//
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommsplit
virtual absl::StatusOr<OwnedNcclComm> CommSplit(NcclCommHandle comm,
std::optional<int32_t> color,
int32_t key) = 0;

// Creates new communicators by splitting existing ones.
// Creates new communicators by splitting `comms`.
//
// This API doesn't have a corresponding API in NCCL and implemented as
// multiple calls to CommSplit within a single group.
// multiple calls to ncclCommSplit within a single group.
//
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommsplit
virtual absl::StatusOr<std::vector<OwnedNcclComm>> CommSplit(
absl::Span<const DeviceComm> comms, absl::Span<const int32_t> ranks) = 0;
absl::Span<const NcclCommHandle> comms, int32_t color,
absl::Span<const int32_t> keys) = 0;

// Abort any uncompleted operations and destroys the communicator. Frees
// resources that are allocated to a communicator object comm.
Expand Down
9 changes: 2 additions & 7 deletions xla/service/gpu/nccl_api_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,9 @@ class NcclApiStub final : public NcclApi {
return UnimplementedError();
}

absl::StatusOr<OwnedNcclComm> CommSplit(NcclCommHandle comm,
std::optional<int32_t>,
int32_t) final {
return UnimplementedError();
}

absl::StatusOr<std::vector<OwnedNcclComm>> CommSplit(
absl::Span<const DeviceComm>, absl::Span<const int32_t>) final {
absl::Span<const NcclCommHandle>, int32_t,
absl::Span<const int32_t>) final {
return UnimplementedError();
}

Expand Down
Loading

0 comments on commit d6cfa24

Please sign in to comment.