Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NCCL API Executor Integration #331

Merged
merged 15 commits into from
Jul 25, 2024
Merged
194 changes: 128 additions & 66 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <algorithm>
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/executor.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel_device.hpp>
#include <unordered_map>
Expand Down Expand Up @@ -54,6 +55,9 @@ struct ncclComm {
std::shared_ptr<mscclpp::Communicator> comm;
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> smSemaphores;
std::shared_ptr<mscclpp::Executor> executor;
std::shared_ptr<mscclpp::ExecutionPlan> allReducePacketIPPlan, allReducePacketOPPlan, allReduceIPPlan,
allReduceOPPlan;

std::unordered_map<channelKey, ChannelInfo> channelInInfos;
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
Expand Down Expand Up @@ -151,6 +155,86 @@ static std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SmChannel>> setupSmChannel
return ptr;
}

static ncclResult_t ncclAllReduceOld(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved
ncclRedOp_t, ncclComm_t comm, cudaStream_t stream) {
// Checking if the parameters are valids
if (sendbuff == nullptr || recvbuff == nullptr || count == 0 || ncclTypeSize(datatype) == 0 || comm == nullptr)
return ncclInvalidArgument;

// Declarating variables
size_t sendBytes, recvBytes;
CUdeviceptr sendBasePtr, recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)sendbuff));
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr;
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
uint32_t scratchBuffIdx = (++(comm->buffFlag)) % comm->numScratchBuff;
size_t offsetScratch = (SCRATCH_SIZE / comm->numScratchBuff) * scratchBuffIdx;
int rank = comm->comm->bootstrap()->getRank();
channelKey sendKey{(void*)sendBasePtr, sendBytes};
channelKey recvKey{(void*)recvBasePtr, recvBytes};
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels = nullptr;

// Creating the channels
if (count * ncclTypeSize(datatype) <= (1 << 20)) {
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved
auto sendIt = comm->channelScratchInfos.find(sendKey);
if (sendIt == comm->channelScratchInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
sendIt = comm->channelScratchInfos.emplace(sendKey, channelInfo).first;
}

smChannels = sendIt->second.smChannelDeviceHandles.get();
} else {
std::vector<mscclpp::RegisteredMemory> remoteMemories;

auto sendIt = comm->channelInInfos.find(sendKey);
if (sendIt == comm->channelInInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first;
}

auto recvIt = comm->channelOutInfos.find(recvKey);
if (recvIt == comm->channelOutInfos.end()) {
remoteMemories =
setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> outChannels =
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
ChannelInfo channelInfo{outChannels, setupSmChannelDeviceHandles(outChannels)};
recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
}

smChannels = sendIt->second.smChannelDeviceHandles.get();
smOutChannels = recvIt->second.smChannelDeviceHandles.get();
}

switch (datatype) {
case ncclFloat16:
CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels,
offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclFloat32:
CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels,
smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclInt32:
case ncclUint32:
CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels,
offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
default:
return ncclInvalidArgument;
}
return ncclSuccess;
}

NCCL_API ncclResult_t ncclGetVersion(int* version) {
if (version == nullptr) return ncclInvalidArgument;
*version = MSCCLPP_VERSION;
Expand Down Expand Up @@ -211,6 +295,20 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
commPtr->scratchBuff = mscclpp::allocExtSharedCuda<char>(SCRATCH_SIZE);
commPtr->remoteScratchRegMemories =
setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
commPtr->executor = std::make_shared<mscclpp::Executor>(mscclppComm);

if (getenv("ALLREDUCEPKT_IP_JSON_FILE"))
commPtr->allReducePacketIPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allreduce_packet", getenv("ALLREDUCEPKT_IP_JSON_FILE")));
if (getenv("ALLREDUCEPKT_OP_JSON_FILE"))
commPtr->allReducePacketOPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allreduce_packet", getenv("ALLREDUCEPKT_OP_JSON_FILE")));
if (getenv("ALLREDUCE_IP_JSON_FILE"))
commPtr->allReduceIPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_IP_JSON_FILE")));
if (getenv("ALLREDUCE_OP_JSON_FILE"))
commPtr->allReduceOPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_OP_JSON_FILE")));

*comm = commPtr;
return ncclSuccess;
Expand Down Expand Up @@ -321,82 +419,45 @@ NCCL_API ncclResult_t ncclBroadcast(const void*, void*, size_t, ncclDataType_t,
}

NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclRedOp_t, ncclComm_t comm, cudaStream_t stream) {
ncclRedOp_t reductionOperation, ncclComm_t comm, cudaStream_t stream) {
// Checking if the parameters are valids
if (sendbuff == nullptr || recvbuff == nullptr || count == 0 || ncclTypeSize(datatype) == 0 || comm == nullptr)
return ncclInvalidArgument;

// Declarating variables
size_t sendBytes, recvBytes;
CUdeviceptr sendBasePtr, recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)sendbuff));
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr;
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
uint32_t scratchBuffIdx = (++(comm->buffFlag)) % comm->numScratchBuff;
size_t offsetScratch = (SCRATCH_SIZE / comm->numScratchBuff) * scratchBuffIdx;
size_t bytes = count * ncclTypeSize(datatype);
int rank = comm->comm->bootstrap()->getRank();
channelKey sendKey{(void*)sendBasePtr, sendBytes};
channelKey recvKey{(void*)recvBasePtr, recvBytes};
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels = nullptr;

// Creating the channels
if (count * ncclTypeSize(datatype) <= (1 << 20)) {
auto sendIt = comm->channelScratchInfos.find(sendKey);
if (sendIt == comm->channelScratchInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
sendIt = comm->channelScratchInfos.emplace(sendKey, channelInfo).first;
}

smChannels = sendIt->second.smChannelDeviceHandles.get();
if (bytes < 16 * (1 << 10)) {
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved
return ncclAllReduceOld(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);
} else {
std::vector<mscclpp::RegisteredMemory> remoteMemories;

auto sendIt = comm->channelInInfos.find(sendKey);
if (sendIt == comm->channelInInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first;
}

auto recvIt = comm->channelOutInfos.find(recvKey);
if (recvIt == comm->channelOutInfos.end()) {
remoteMemories =
setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> outChannels =
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
ChannelInfo channelInfo{outChannels, setupSmChannelDeviceHandles(outChannels)};
recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
std::shared_ptr<mscclpp::ExecutionPlan> plan;
if (bytes <= (1 << 20))
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved
plan = (sendbuff == recvbuff) ? comm->allReducePacketIPPlan : comm->allReducePacketOPPlan;
else
plan = (sendbuff == recvbuff) ? comm->allReduceIPPlan : comm->allReduceOPPlan;

if (plan == nullptr) return ncclAllReduceOld(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);

switch (datatype) {
case ncclFloat16:
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, 1024,
*plan, stream, mscclpp::PacketType::LL8);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32,
1024, *plan, stream, mscclpp::PacketType::LL8);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, 1024,
*plan, stream, mscclpp::PacketType::LL8);
break;
default:
return ncclInvalidArgument;
}

smChannels = sendIt->second.smChannelDeviceHandles.get();
smOutChannels = recvIt->second.smChannelDeviceHandles.get();
}

switch (datatype) {
case ncclFloat16:
CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels,
offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclFloat32:
CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels,
smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclInt32:
case ncclUint32:
CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels,
offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
default:
return ncclInvalidArgument;
}
return ncclSuccess;
}

Expand Down Expand Up @@ -442,6 +503,7 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
CUDACHECK(allgather<true>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
}

return ncclSuccess;
}

Expand All @@ -468,4 +530,4 @@ NCCL_API ncclResult_t ncclGroupStart() {
NCCL_API ncclResult_t ncclGroupEnd() {
// Do nothing
return ncclSuccess;
}
}
5 changes: 4 additions & 1 deletion python/test/test_mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,10 @@ def test_executor(mpi_group: MpiGroup, filename: str):
cp.random.seed(42)
buffer = cp.random.random(nelems).astype(cp.float16)
sub_arrays = cp.split(buffer, mpi_group.comm.size)
sendbuf = sub_arrays[mpi_group.comm.rank]
nelems_per_rank = int(nelems / mpi_group.comm.size)
sendbuf = cp.empty(nelems_per_rank).astype(cp.float16)
for i in range(nelems_per_rank):
sendbuf[i] = sub_arrays[mpi_group.comm.rank][i]
expected = cp.zeros_like(sendbuf)
for i in range(mpi_group.comm.size):
expected += sub_arrays[i]
Expand Down
44 changes: 37 additions & 7 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ std::vector<Operation> ExecutionPlan::Impl::getOperations(int rank, int threadbl

int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->operations.at(rank).size(); }

void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize) {
void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset) {
std::ifstream file(this->planPath);
json obj = json::parse(file);
if (this->name != obj["name"]) {
Expand All @@ -145,7 +145,31 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize) {
this->setupChannels(gpus);

this->inputSize = inputSize;
this->setupOperations(gpus);
this->setupOperations(gpus, contsSrcOffset, constDstOffset);
}

void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset) {
std::ifstream file(this->planPath);
json obj = json::parse(file);
if (this->name != obj["name"]) {
throw Error("Plan name does not match", ErrorCode::ExecutorError);
}
std::string protocol = obj["protocol"];
if (protocol == "LL") {
this->isUsingPacket = true;
}
const auto& gpus = obj["gpus"];

for (const auto& gpu : gpus) {
int rank = gpu["id"];
this->inputChunks[rank] = gpu["inputChunks"];
this->outputChunks[rank] = gpu["outputChunks"];
this->scratchChunks[rank] = gpu["scratchChunks"];
this->chunkGroups[rank] = gpu["chunkGroups"];
}

this->inputSize = inputSize;
this->setupOperations(gpus, contsSrcOffset, constDstOffset);
}

// Construct the channel info. Step 1. Flatten SM and PROXY channels into separate vectors.
Expand Down Expand Up @@ -201,7 +225,7 @@ void ExecutionPlan::Impl::setupChannels(const json& gpus) {
}
}

void ExecutionPlan::Impl::setupOperations(const json& gpus) {
void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffset, size_t constDstOffset) {
// setup threadblocks and operations
for (const auto& gpu : gpus) {
int rank = gpu["id"];
Expand Down Expand Up @@ -234,7 +258,8 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus) {
// Get the relevant channel index in rank channelInfos
operation.inputChannelIndexes[i] =
channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["i_cids"][i]["id"]];
operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["i_cids"][i]["off"]);
operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["i_cids"][i]["off"]) +
(srcBufferType != BufferType::SCRATCH ? contsSrcOffset : 0);
chunkIndexes.push_back((uint32_t)op["i_cids"][i]["off"]);
}
}
Expand All @@ -243,7 +268,8 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus) {
operation.nInputs = op["srcs"].size();
operation.inputBufferType = convertToBufferType(op["srcs"][0]["buff"]);
for (int i = 0; i < operation.nInputs; i++) {
operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["srcs"][i]["off"]);
operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["srcs"][i]["off"]) +
(operation.inputBufferType != BufferType::SCRATCH ? contsSrcOffset : 0);
chunkIndexes.push_back((uint32_t)op["srcs"][i]["off"]);
}
}
Expand All @@ -254,7 +280,8 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus) {
BufferType dstBufferType = convertToBufferType(op["o_buff"]["dst"]);
operation.outputChannelIndexes[i] =
channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["o_cids"][i]["id"]];
operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["o_cids"][i]["off"]);
operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["o_cids"][i]["off"]) +
(dstBufferType != BufferType::SCRATCH ? constDstOffset : 0);
chunkIndexes.push_back((uint32_t)op["o_cids"][i]["off"]);
}
}
Expand All @@ -263,7 +290,8 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus) {
operation.nOutputs = op["dsts"].size();
operation.outputBufferType = convertToBufferType(op["dsts"][0]["buff"]);
for (int i = 0; i < operation.nOutputs; i++) {
operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["dsts"][i]["off"]);
operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["dsts"][i]["off"]) +
(operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : 0);
chunkIndexes.push_back((uint32_t)op["dsts"][i]["off"]);
}
}
Expand Down Expand Up @@ -340,6 +368,8 @@ void ExecutionPlan::Impl::reset() {
this->chunkGroups.clear();
}

void ExecutionPlan::Impl::operationsReset() { this->operations.clear(); }

ExecutionPlan::ExecutionPlan(const std::string& name, const std::string& planPath)
: impl_(std::make_shared<Impl>(name, planPath)) {}

Expand Down
Loading
Loading