Skip to content

Commit

Permalink
Allowing custom MPI_Comm for MPI (#559)
Browse files Browse the repository at this point in the history
* removing SyncCommunicator

* using custom mpi communicator

* changing cylon version

* adding test

* setting MPI_ERRORS_RETURN

* minor bug fix

* adding test

* fixing gcylon build

* attempting to fix MPI::DataType::Free error

* attempting to fix MPI::DataType::Free error

* attempting to fix MPI::DataType::Free error

* attempting to fix MPI::DataType::Free error
  • Loading branch information
nirandaperera authored Jan 14, 2022
1 parent fa52dd4 commit 49b343d
Show file tree
Hide file tree
Showing 38 changed files with 602 additions and 283 deletions.
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ include(Build)
# if building gcylon, no need to build cylon
option(GCYLON_BUILD "Build GCylon" OFF)
if (GCYLON_BUILD)
message("GCylon build enabled")
add_subdirectory(src/gcylon)
add_subdirectory(src/examples/gcylon)
if (CYLON_WITH_TEST)
Expand All @@ -192,6 +193,7 @@ endif ()
# ucx
option(CYLON_UCX "Build Cylon with UCX" OFF)
if (CYLON_UCX)
message("Cylon UCX enabled")
# Definition used for checking
add_definitions(-DBUILD_CYLON_UCX)

Expand Down
8 changes: 6 additions & 2 deletions cpp/src/cylon/arrow/arrow_partition_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,12 @@ class RangePartitionKernel : public PartitionKernel {
std::vector<uint64_t> global_counts, *global_counts_ptr;
if (ctx->GetWorldSize() > 1) { // if distributed, all-reduce all local bin counts
global_counts.resize(num_bins + 2, 0);
RETURN_CYLON_STATUS_IF_FAILED(cylon::mpi::AllReduce(local_counts.data(), global_counts.data(), num_bins + 2,
cylon::UInt64(), cylon::net::SUM));
RETURN_CYLON_STATUS_IF_FAILED(cylon::mpi::AllReduce(ctx,
local_counts.data(),
global_counts.data(),
num_bins + 2,
cylon::UInt64(),
cylon::net::SUM));
global_counts_ptr = &global_counts;
local_counts.clear();
} else { // else, just use local bin counts
Expand Down
48 changes: 24 additions & 24 deletions cpp/src/cylon/compute/aggregate_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace compute {
* @return
*/
template<typename NUM_ARROW_T, typename = arrow::enable_if_has_c_type<NUM_ARROW_T>>
cylon::Status AllReduce(cylon::net::CommType comm_type,
cylon::Status AllReduce(const std::shared_ptr<CylonContext> &ctx,
const arrow::Datum &send,
std::shared_ptr<Result> &output,
const std::shared_ptr<DataType> &data_type,
Expand All @@ -45,13 +45,13 @@ cylon::Status AllReduce(cylon::net::CommType comm_type,
auto recv_scalar = std::make_shared<ScalarT>(*send_scalar);
std::memset(&recv_scalar->value, 0, sizeof(CType));

switch (comm_type) {
switch (ctx->GetCommType()) {
case net::LOCAL: {
output = std::make_shared<Result>(send);
return cylon::Status::OK();
}
case cylon::net::CommType::MPI: {
cylon::Status status = cylon::mpi::AllReduce(&(send_scalar->value),
cylon::Status status = cylon::mpi::AllReduce(ctx, &(send_scalar->value),
&(recv_scalar->value),
send.length(),
data_type,
Expand All @@ -69,7 +69,7 @@ cylon::Status AllReduce(cylon::net::CommType comm_type,
}

template<typename NUM_ARROW_T, typename = arrow::enable_if_has_c_type<NUM_ARROW_T>>
cylon::Status AllReduce(cylon::net::CommType comm_type,
cylon::Status AllReduce(const std::shared_ptr<CylonContext> &ctx,
const arrow::Datum &send,
std::shared_ptr<Result> &output,
const std::shared_ptr<DataType> &data_type,
Expand All @@ -87,7 +87,7 @@ cylon::Status AllReduce(cylon::net::CommType comm_type,
arrow::ScalarVector rcv_scalar_vector;
rcv_scalar_vector.reserve(reduce_ops.size());

switch (comm_type) {
switch (ctx->GetCommType()) {
case net::LOCAL: {
output = std::make_shared<Result>(send);
return cylon::Status::OK();
Expand All @@ -99,12 +99,13 @@ cylon::Status AllReduce(cylon::net::CommType comm_type,
auto rcv_scalar = std::make_shared<ScalarT>(*send_scalar);
std::memset(&rcv_scalar->value, 0, sizeof(CType));

RETURN_CYLON_STATUS_IF_FAILED(cylon::mpi::AllReduce(send_scalar->data(),
RETURN_CYLON_STATUS_IF_FAILED(cylon::mpi::AllReduce(ctx, send_scalar->data(),
rcv_scalar->mutable_data(),
1, data_type, reduce_ops[i]));
rcv_scalar_vector.push_back(rcv_scalar);
}
auto rcv_struct_scalar = std::make_shared<arrow::StructScalar>(rcv_scalar_vector, send_struct_scalar->type);
auto rcv_struct_scalar = std::make_shared<arrow::StructScalar>(rcv_scalar_vector,
send_struct_scalar->type);
// build the output datum
arrow::Datum global_result(rcv_struct_scalar);
output = std::make_shared<Result>(global_result);
Expand All @@ -121,24 +122,23 @@ cylon::Status DoAllReduce(const std::shared_ptr<CylonContext> &ctx,
std::shared_ptr<Result> &rcv,
const std::shared_ptr<DataType> &dtype,
const RED_OPS &red_op) {
auto comm_type = ctx->GetCommType();
switch (dtype->getType()) {
case Type::BOOL:return AllReduce<arrow::BooleanType>(comm_type, snd, rcv, dtype, red_op);
case Type::UINT8:return AllReduce<arrow::UInt8Type>(comm_type, snd, rcv, dtype, red_op);
case Type::INT8:return AllReduce<arrow::Int8Type>(comm_type, snd, rcv, dtype, red_op);
case Type::UINT16:return AllReduce<arrow::UInt16Type>(comm_type, snd, rcv, dtype, red_op);
case Type::INT16:return AllReduce<arrow::Int16Type>(comm_type, snd, rcv, dtype, red_op);
case Type::UINT32:return AllReduce<arrow::UInt32Type>(comm_type, snd, rcv, dtype, red_op);
case Type::INT32:return AllReduce<arrow::Int32Type>(comm_type, snd, rcv, dtype, red_op);
case Type::UINT64:return AllReduce<arrow::UInt64Type>(comm_type, snd, rcv, dtype, red_op);
case Type::INT64:return AllReduce<arrow::Int64Type>(comm_type, snd, rcv, dtype, red_op);
case Type::FLOAT:return AllReduce<arrow::FloatType>(comm_type, snd, rcv, dtype, red_op);
case Type::DOUBLE:return AllReduce<arrow::DoubleType>(comm_type, snd, rcv, dtype, red_op);
case Type::DATE32:return AllReduce<arrow::Date32Type>(comm_type, snd, rcv, dtype, red_op);
case Type::DATE64:return AllReduce<arrow::Date64Type>(comm_type, snd, rcv, dtype, red_op);
case Type::TIMESTAMP:return AllReduce<arrow::TimestampType>(comm_type, snd, rcv, dtype, red_op);
case Type::TIME32:return AllReduce<arrow::Time32Type>(comm_type, snd, rcv, dtype, red_op);
case Type::TIME64:return AllReduce<arrow::Time64Type>(comm_type, snd, rcv, dtype, red_op);
case Type::BOOL:return AllReduce<arrow::BooleanType>(ctx, snd, rcv, dtype, red_op);
case Type::UINT8:return AllReduce<arrow::UInt8Type>(ctx, snd, rcv, dtype, red_op);
case Type::INT8:return AllReduce<arrow::Int8Type>(ctx, snd, rcv, dtype, red_op);
case Type::UINT16:return AllReduce<arrow::UInt16Type>(ctx, snd, rcv, dtype, red_op);
case Type::INT16:return AllReduce<arrow::Int16Type>(ctx, snd, rcv, dtype, red_op);
case Type::UINT32:return AllReduce<arrow::UInt32Type>(ctx, snd, rcv, dtype, red_op);
case Type::INT32:return AllReduce<arrow::Int32Type>(ctx, snd, rcv, dtype, red_op);
case Type::UINT64:return AllReduce<arrow::UInt64Type>(ctx, snd, rcv, dtype, red_op);
case Type::INT64:return AllReduce<arrow::Int64Type>(ctx, snd, rcv, dtype, red_op);
case Type::FLOAT:return AllReduce<arrow::FloatType>(ctx, snd, rcv, dtype, red_op);
case Type::DOUBLE:return AllReduce<arrow::DoubleType>(ctx, snd, rcv, dtype, red_op);
case Type::DATE32:return AllReduce<arrow::Date32Type>(ctx, snd, rcv, dtype, red_op);
case Type::DATE64:return AllReduce<arrow::Date64Type>(ctx, snd, rcv, dtype, red_op);
case Type::TIMESTAMP:return AllReduce<arrow::TimestampType>(ctx, snd, rcv, dtype, red_op);
case Type::TIME32:return AllReduce<arrow::Time32Type>(ctx, snd, rcv, dtype, red_op);
case Type::TIME64:return AllReduce<arrow::Time64Type>(ctx, snd, rcv, dtype, red_op);
default: return {cylon::Code::Invalid, "data type not supported for all reduce!"};
}
}
Expand Down
59 changes: 44 additions & 15 deletions cpp/src/cylon/ctx/cylon_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
#include <vector>
#include <arrow/memory_pool.h>

#include <cylon/ctx/cylon_context.hpp>
#include <cylon/net/mpi/mpi_communicator.hpp>
#include "cylon/ctx/cylon_context.hpp"
#include "cylon/util/macros.hpp"

#include "cylon/net/mpi/mpi_communicator.hpp"
#ifdef BUILD_CYLON_UCX
#include <cylon/net/ucx/ucx_communicator.hpp>
#include "cylon/net/ucx/ucx_communicator.hpp"
#endif

namespace cylon {
Expand All @@ -38,32 +39,64 @@ std::shared_ptr<CylonContext> CylonContext::InitDistributed(const std::shared_pt
auto ctx = std::make_shared<CylonContext>(true);
ctx->communicator = std::make_shared<net::MPICommunicator>();
ctx->communicator->Init(config);

ctx->sync_communicator_ = std::make_shared<net::MPISyncCommunicator>();
ctx->is_distributed = true;
return ctx;
}

#ifdef BUILD_CYLON_UCX
#ifdef BUILD_CYLON_UCX
else if (config->Type() == net::CommType::UCX) {
auto ctx = std::make_shared<CylonContext>(true);
ctx->communicator = std::make_shared<net::UCXCommunicator>();
ctx->communicator->Init(config);

ctx->sync_communicator_ = std::make_shared<net::UCXSyncCommunicator>();
ctx->is_distributed = true;
return ctx;
}
#endif
#endif
else {
throw "Unsupported communication type";
}
return nullptr;
}
std::shared_ptr<net::Communicator> CylonContext::GetCommunicator() const {

Status CylonContext::InitDistributed(const std::shared_ptr<cylon::net::CommConfig> &config,
std::shared_ptr<CylonContext> *ctx) {
switch (config->Type()) {
case net::LOCAL: return {Code::Invalid, "InitDistributed called on Local communication"};

case net::MPI: {
*ctx = std::make_shared<CylonContext>(true);
(*ctx)->communicator = std::make_shared<net::MPICommunicator>();
const auto &status = (*ctx)->communicator->Init(config);
if (!status.is_ok()) {
ctx->reset();
return status;
}
return Status::OK();
}

case net::UCX: {
#ifdef BUILD_CYLON_UCX
*ctx = std::make_shared<CylonContext>(true);
(*ctx)->communicator = std::make_shared<net::UCXCommunicator>();
const auto &status = (*ctx)->communicator->Init(config);
if (!status.is_ok()) {
ctx->reset();
return status;
}
return Status::OK();
#else
return {Code::NotImplemented, "UCX communication not implemented"};
#endif
}

case net::TCP:return {Code::NotImplemented, "TCP communication not implemented"};
}
return Status::OK();
}

const std::shared_ptr<net::Communicator> &CylonContext::GetCommunicator() const {
if (!is_distributed) {
LOG(FATAL) << "No communicator available for local mode!";
return nullptr;
}
return this->communicator;
}
Expand Down Expand Up @@ -132,8 +165,4 @@ bool CylonContext::IsDistributed() const {
cylon::net::CommType CylonContext::GetCommType() {
return is_distributed ? this->communicator->GetCommType() : net::CommType::LOCAL;
}

const std::shared_ptr<net::SyncCommunicator> &CylonContext::sync_communicator() const {
return sync_communicator_;
}
} // namespace cylon
9 changes: 4 additions & 5 deletions cpp/src/cylon/ctx/cylon_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class CylonContext {
std::unordered_map<std::string, std::string> config{};
bool is_distributed;
std::shared_ptr<cylon::net::Communicator> communicator{};
std::shared_ptr<cylon::net::SyncCommunicator> sync_communicator_{};
cylon::MemoryPool *memory_pool{};
int32_t sequence_no = 0;

Expand All @@ -55,6 +54,8 @@ class CylonContext {
* @return <cylon::CylonContext*>
*/
static std::shared_ptr<CylonContext> InitDistributed(const std::shared_ptr<cylon::net::CommConfig> &config);
static Status InitDistributed(const std::shared_ptr<cylon::net::CommConfig> &config,
std::shared_ptr<CylonContext> *ctx);

/**
* Completes and closes all operations under the context
Expand All @@ -80,9 +81,7 @@ class CylonContext {
* Returns the Communicator instance
* @return <cylon::net::Communicator>
*/
std::shared_ptr<net::Communicator> GetCommunicator() const;

const std::shared_ptr<net::SyncCommunicator>& sync_communicator() const;
const std::shared_ptr<net::Communicator> &GetCommunicator() const;

/**
* Sets a Communicator
Expand Down Expand Up @@ -143,7 +142,7 @@ class CylonContext {
/**
* Performs a barrier operation
*/
void Barrier() {
void Barrier() const {
if (this->IsDistributed()) this->GetCommunicator()->Barrier();
}
};
Expand Down
5 changes: 0 additions & 5 deletions cpp/src/cylon/net/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ class Communicator {
virtual CommType GetCommType() const = 0;

virtual ~Communicator() = default;
};

class SyncCommunicator {
public:
virtual ~SyncCommunicator() = default;

virtual Status AllGather(const std::shared_ptr<Table> &table,
std::vector<std::shared_ptr<Table>> *out) const = 0;
Expand Down
14 changes: 7 additions & 7 deletions cpp/src/cylon/net/mpi/mpi_channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void MPIChannel::init(int ed, const std::vector<int> &receives, const std::vecto
buf->receiveId = source;
pendingReceives.insert(std::pair<int, PendingReceive *>(source, buf));
MPI_Irecv(buf->headerBuf, CYLON_CHANNEL_HEADER_SIZE, MPI_INT,
source, edge, MPI_COMM_WORLD, &buf->request);
source, edge, comm_, &buf->request);
// set the flag to true so we can identify later which buffers are posted
buf->status = RECEIVE_LENGTH_POSTED;
}
Expand All @@ -49,7 +49,7 @@ void MPIChannel::init(int ed, const std::vector<int> &receives, const std::vecto
sends[target] = new PendingSend();
}
// get the rank
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_rank(comm_, &rank);
}

int MPIChannel::send(std::shared_ptr<TxRequest> request) {
Expand Down Expand Up @@ -97,7 +97,7 @@ void MPIChannel::progressReceives() {
}
x.second->length = length;
MPI_Irecv(x.second->data->GetByteBuffer(), length, MPI_BYTE, x.second->receiveId, edge,
MPI_COMM_WORLD, &(x.second->request));
comm_, &(x.second->request));
x.second->status = RECEIVE_POSTED;
// copy the count - 2 to the buffer
int *header = nullptr;
Expand Down Expand Up @@ -131,7 +131,7 @@ void MPIChannel::progressReceives() {
// clear the array
std::fill_n(x.second->headerBuf, CYLON_CHANNEL_HEADER_SIZE, 0);
MPI_Irecv(x.second->headerBuf, CYLON_CHANNEL_HEADER_SIZE, MPI_INT,
x.second->receiveId, edge, MPI_COMM_WORLD, &(x.second->request));
x.second->receiveId, edge, comm_, &(x.second->request));
x.second->status = RECEIVE_LENGTH_POSTED;
// call the back end
rcv_fn->receivedData(x.first, x.second->data, x.second->length);
Expand All @@ -155,7 +155,7 @@ void MPIChannel::progressSends() {
// now post the actual send
std::shared_ptr<TxRequest> r = x.second->pendingData.front();
MPI_Isend(r->buffer, r->length, MPI_BYTE,
r->target, edge, MPI_COMM_WORLD, &(x.second->request));
r->target, edge, comm_, &(x.second->request));
x.second->status = SEND_POSTED;
x.second->pendingData.pop();
// we set to the current send and pop it
Expand Down Expand Up @@ -220,7 +220,7 @@ void MPIChannel::sendHeader(const std::pair<const int, PendingSend *> &x) const
}
// we have to add 2 to the header length
MPI_Isend(&(x.second->headerBuf[0]), 2 + r->headerLength, MPI_INT,
x.first, edge, MPI_COMM_WORLD, &(x.second->request));
x.first, edge, comm_, &(x.second->request));
x.second->status = SEND_LENGTH_POSTED;
}

Expand All @@ -229,7 +229,7 @@ void MPIChannel::sendFinishHeader(const std::pair<const int, PendingSend *> &x)
x.second->headerBuf[0] = 0;
x.second->headerBuf[1] = CYLON_MSG_FIN;
MPI_Isend(&(x.second->headerBuf[0]), 2, MPI_INT,
x.first, edge, MPI_COMM_WORLD, &(x.second->request));
x.first, edge, comm_, &(x.second->request));
x.second->status = SEND_FINISH;
}

Expand Down
5 changes: 4 additions & 1 deletion cpp/src/cylon/net/mpi/mpi_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ struct PendingReceive {
*/
class MPIChannel : public Channel {
public:
explicit MPIChannel(MPI_Comm comm) : comm_(comm) {}

/**
* Initialize the channel
*
* @param receives receive from these ranks
*/
void init(int edge, const std::vector<int> &receives, const std::vector<int> &sendIds,
ChannelReceiveCallback *rcv, ChannelSendCallback *send, Allocator *alloc) override;
ChannelReceiveCallback *rcv, ChannelSendCallback *send, Allocator *alloc) override;

/**
* Send the message to the target.
Expand Down Expand Up @@ -126,6 +128,7 @@ class MPIChannel : public Channel {
Allocator *allocator;
// mpi rank
int rank;
MPI_Comm comm_;

/**
* Send finish request
Expand Down
Loading

0 comments on commit 49b343d

Please sign in to comment.