Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NCCL API Executor Integration #331

Merged
merged 15 commits into from
Jul 25, 2024
Merged
245 changes: 179 additions & 66 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include <iostream>
#include <algorithm>
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/executor.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel_device.hpp>
#include <unordered_map>
#include <vector>
#include <sstream>

#include "allgather.hpp"
#include "allreduce.hpp"
Expand Down Expand Up @@ -54,13 +57,17 @@ struct ncclComm {
std::shared_ptr<mscclpp::Communicator> comm;
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> smSemaphores;
std::shared_ptr<mscclpp::Executor> executor;
std::shared_ptr<mscclpp::ExecutionPlan> allReducePacketIPPlan, allReducePacketOPPlan, allReduceIPPlan,
allReduceOPPlan;

std::unordered_map<channelKey, ChannelInfo> channelInInfos;
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
std::shared_ptr<char> scratchBuff;
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;

size_t smallMessageSizeBoundary, largeMessageSizeBoundary;
uint32_t numScratchBuff;
uint32_t buffFlag;
};
Expand Down Expand Up @@ -97,6 +104,43 @@ static size_t ncclTypeSize(ncclDataType_t type) {
return 0;
}

double parseSize(const char* value) {
std::string valueStr(value);
std::istringstream iss(valueStr);
long long int units;
double size;
char size_lit = 0;

if (iss >> size) {
iss >> std::ws; // eat whitespace
iss >> size_lit;
} else {
return -1.0;
}

if (size_lit != 0 && !std::isspace(size_lit)) {
switch (size_lit) {
case 'G':
case 'g':
units = 1024 * 1024 * 1024;
break;
case 'M':
case 'm':
units = 1024 * 1024;
break;
case 'K':
case 'k':
units = 1024;
break;
default:
return -1.0;
};
} else {
units = 1;
}
return size * units;
}

static mscclpp::Transport getTransport(int, int) {
// if (rank / nRanksPerNode == peerRank / nRanksPerNode) {
// return mscclpp::Transport::CudaIpc;
Expand Down Expand Up @@ -151,6 +195,86 @@ static std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SmChannel>> setupSmChannel
return ptr;
}

static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclRedOp_t, ncclComm_t comm, cudaStream_t stream) {
// Checking if the parameters are valids
if (sendbuff == nullptr || recvbuff == nullptr || count == 0 || ncclTypeSize(datatype) == 0 || comm == nullptr)
return ncclInvalidArgument;

// Declarating variables
size_t sendBytes, recvBytes;
CUdeviceptr sendBasePtr, recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)sendbuff));
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr;
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
uint32_t scratchBuffIdx = (++(comm->buffFlag)) % comm->numScratchBuff;
size_t offsetScratch = (SCRATCH_SIZE / comm->numScratchBuff) * scratchBuffIdx;
int rank = comm->comm->bootstrap()->getRank();
channelKey sendKey{(void*)sendBasePtr, sendBytes};
channelKey recvKey{(void*)recvBasePtr, recvBytes};
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels = nullptr;

// Creating the channels
if (count * ncclTypeSize(datatype) <= (1 << 20)) {
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved
auto sendIt = comm->channelScratchInfos.find(sendKey);
if (sendIt == comm->channelScratchInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
sendIt = comm->channelScratchInfos.emplace(sendKey, channelInfo).first;
}

smChannels = sendIt->second.smChannelDeviceHandles.get();
} else {
std::vector<mscclpp::RegisteredMemory> remoteMemories;

auto sendIt = comm->channelInInfos.find(sendKey);
if (sendIt == comm->channelInInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first;
}

auto recvIt = comm->channelOutInfos.find(recvKey);
if (recvIt == comm->channelOutInfos.end()) {
remoteMemories =
setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> outChannels =
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
ChannelInfo channelInfo{outChannels, setupSmChannelDeviceHandles(outChannels)};
recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
}

smChannels = sendIt->second.smChannelDeviceHandles.get();
smOutChannels = recvIt->second.smChannelDeviceHandles.get();
}

switch (datatype) {
case ncclFloat16:
CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels,
offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclFloat32:
CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels,
smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclInt32:
case ncclUint32:
CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels,
offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
default:
return ncclInvalidArgument;
}
return ncclSuccess;
}

NCCL_API ncclResult_t ncclGetVersion(int* version) {
if (version == nullptr) return ncclInvalidArgument;
*version = MSCCLPP_VERSION;
Expand Down Expand Up @@ -211,6 +335,30 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
commPtr->scratchBuff = mscclpp::allocExtSharedCuda<char>(SCRATCH_SIZE);
commPtr->remoteScratchRegMemories =
setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
commPtr->executor = std::make_shared<mscclpp::Executor>(mscclppComm);

if (getenv("ALLREDUCEPKT_IP_JSON_FILE"))
commPtr->allReducePacketIPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allreduce_packet", getenv("ALLREDUCEPKT_IP_JSON_FILE")));
if (getenv("ALLREDUCEPKT_OP_JSON_FILE"))
commPtr->allReducePacketOPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allreduce_packet", getenv("ALLREDUCEPKT_OP_JSON_FILE")));
if (getenv("ALLREDUCE_IP_JSON_FILE"))
commPtr->allReduceIPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_IP_JSON_FILE")));
if (getenv("ALLREDUCE_OP_JSON_FILE"))
commPtr->allReduceOPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_OP_JSON_FILE")));
if (getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"))
commPtr->smallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"));
else
commPtr->smallMessageSizeBoundary = 16 * (1 << 10);
if (getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"))
commPtr->largeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"));
else
commPtr->largeMessageSizeBoundary = 1 << 20;

if (commPtr->smallMessageSizeBoundary > commPtr->largeMessageSizeBoundary) return ncclInvalidArgument;

*comm = commPtr;
return ncclSuccess;
Expand Down Expand Up @@ -321,82 +469,46 @@ NCCL_API ncclResult_t ncclBroadcast(const void*, void*, size_t, ncclDataType_t,
}

NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclRedOp_t, ncclComm_t comm, cudaStream_t stream) {
ncclRedOp_t reductionOperation, ncclComm_t comm, cudaStream_t stream) {
// Checking if the parameters are valids
if (sendbuff == nullptr || recvbuff == nullptr || count == 0 || ncclTypeSize(datatype) == 0 || comm == nullptr)
return ncclInvalidArgument;

// Declarating variables
size_t sendBytes, recvBytes;
CUdeviceptr sendBasePtr, recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)sendbuff));
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr;
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
uint32_t scratchBuffIdx = (++(comm->buffFlag)) % comm->numScratchBuff;
size_t offsetScratch = (SCRATCH_SIZE / comm->numScratchBuff) * scratchBuffIdx;
size_t bytes = count * ncclTypeSize(datatype);
int rank = comm->comm->bootstrap()->getRank();
channelKey sendKey{(void*)sendBasePtr, sendBytes};
channelKey recvKey{(void*)recvBasePtr, recvBytes};
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels = nullptr;

// Creating the channels
if (count * ncclTypeSize(datatype) <= (1 << 20)) {
auto sendIt = comm->channelScratchInfos.find(sendKey);
if (sendIt == comm->channelScratchInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
sendIt = comm->channelScratchInfos.emplace(sendKey, channelInfo).first;
}

smChannels = sendIt->second.smChannelDeviceHandles.get();
if (bytes < comm->smallMessageSizeBoundary) {
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);
} else {
std::vector<mscclpp::RegisteredMemory> remoteMemories;

auto sendIt = comm->channelInInfos.find(sendKey);
if (sendIt == comm->channelInInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first;
std::shared_ptr<mscclpp::ExecutionPlan> plan;
if (bytes <= comm->largeMessageSizeBoundary)
plan = (sendbuff == recvbuff) ? comm->allReducePacketIPPlan : comm->allReducePacketOPPlan;
else
plan = (sendbuff == recvbuff) ? comm->allReduceIPPlan : comm->allReduceOPPlan;

if (plan == nullptr)
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);

switch (datatype) {
case ncclFloat16:
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, 1024,
*plan, stream, mscclpp::PacketType::LL8);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32,
1024, *plan, stream, mscclpp::PacketType::LL8);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, 1024,
*plan, stream, mscclpp::PacketType::LL8);
break;
default:
return ncclInvalidArgument;
}

auto recvIt = comm->channelOutInfos.find(recvKey);
if (recvIt == comm->channelOutInfos.end()) {
remoteMemories =
setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> outChannels =
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
ChannelInfo channelInfo{outChannels, setupSmChannelDeviceHandles(outChannels)};
recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
}

smChannels = sendIt->second.smChannelDeviceHandles.get();
smOutChannels = recvIt->second.smChannelDeviceHandles.get();
}

switch (datatype) {
case ncclFloat16:
CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels,
offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclFloat32:
CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels,
smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclInt32:
case ncclUint32:
CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels,
offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
default:
return ncclInvalidArgument;
}
return ncclSuccess;
}

Expand Down Expand Up @@ -442,6 +554,7 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
CUDACHECK(allgather<true>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
}

return ncclSuccess;
}

Expand All @@ -468,4 +581,4 @@ NCCL_API ncclResult_t ncclGroupStart() {
NCCL_API ncclResult_t ncclGroupEnd() {
// Do nothing
return ncclSuccess;
}
}
5 changes: 4 additions & 1 deletion python/test/test_mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,10 @@ def test_executor(mpi_group: MpiGroup, filename: str):
cp.random.seed(42)
buffer = cp.random.random(nelems).astype(cp.float16)
sub_arrays = cp.split(buffer, mpi_group.comm.size)
sendbuf = sub_arrays[mpi_group.comm.rank]
nelems_per_rank = int(nelems / mpi_group.comm.size)
sendbuf = cp.empty(nelems_per_rank).astype(cp.float16)
for i in range(nelems_per_rank):
sendbuf[i] = sub_arrays[mpi_group.comm.rank][i]
expected = cp.zeros_like(sendbuf)
for i in range(mpi_group.comm.size):
expected += sub_arrays[i]
Expand Down
Loading
Loading