diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 01bc56e0b..93654184d 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -4,8 +4,10 @@ #include #include #include +#include #include #include +#include #include #include @@ -54,6 +56,9 @@ struct ncclComm { std::shared_ptr comm; std::vector> connections; std::vector> smSemaphores; + std::shared_ptr executor; + std::shared_ptr allReducePacketIPPlan, allReducePacketOPPlan, allReduceIPPlan, + allReduceOPPlan; std::unordered_map channelInInfos; std::unordered_map channelOutInfos; @@ -61,6 +66,7 @@ struct ncclComm { std::shared_ptr scratchBuff; std::vector remoteScratchRegMemories; + size_t smallMessageSizeBoundary, largeMessageSizeBoundary; uint32_t numScratchBuff; uint32_t buffFlag; }; @@ -97,6 +103,43 @@ static size_t ncclTypeSize(ncclDataType_t type) { return 0; } +double parseSize(const char* value) { + std::string valueStr(value); + std::istringstream iss(valueStr); + long long int units; + double size; + char size_lit = 0; + + if (iss >> size) { + iss >> std::ws; // eat whitespace + iss >> size_lit; + } else { + return -1.0; + } + + if (size_lit != 0 && !std::isspace(size_lit)) { + switch (size_lit) { + case 'G': + case 'g': + units = 1024 * 1024 * 1024; + break; + case 'M': + case 'm': + units = 1024 * 1024; + break; + case 'K': + case 'k': + units = 1024; + break; + default: + return -1.0; + }; + } else { + units = 1; + } + return size * units; +} + static mscclpp::Transport getTransport(int, int) { // if (rank / nRanksPerNode == peerRank / nRanksPerNode) { // return mscclpp::Transport::CudaIpc; @@ -151,6 +194,86 @@ static std::shared_ptr> setupSmChannel return ptr; } +static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, + ncclRedOp_t, ncclComm_t comm, cudaStream_t stream) { + // Checking if the parameters are valids + if (sendbuff == nullptr || recvbuff == nullptr || count == 0 || ncclTypeSize(datatype) == 0 || comm == nullptr) + return ncclInvalidArgument; + + // Declarating variables + size_t sendBytes, recvBytes; + CUdeviceptr sendBasePtr, recvBasePtr; + MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)sendbuff)); + MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff)); + size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr; + size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr; + uint32_t scratchBuffIdx = (++(comm->buffFlag)) % comm->numScratchBuff; + size_t offsetScratch = (SCRATCH_SIZE / comm->numScratchBuff) * scratchBuffIdx; + int rank = comm->comm->bootstrap()->getRank(); + channelKey sendKey{(void*)sendBasePtr, sendBytes}; + channelKey recvKey{(void*)recvBasePtr, recvBytes}; + mscclpp::DeviceHandle* smChannels = nullptr; + mscclpp::DeviceHandle* smOutChannels = nullptr; + + // Creating the channels + if (count * ncclTypeSize(datatype) <= comm->largeMessageSizeBoundary) { + auto sendIt = comm->channelScratchInfos.find(sendKey); + if (sendIt == comm->channelScratchInfos.end()) { + std::vector channels = + setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)sendBasePtr)); + ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; + sendIt = comm->channelScratchInfos.emplace(sendKey, channelInfo).first; + } + + smChannels = sendIt->second.smChannelDeviceHandles.get(); + } else { + std::vector remoteMemories; + + auto sendIt = comm->channelInInfos.find(sendKey); + if (sendIt == comm->channelInInfos.end()) { + std::vector channels = + setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)sendBasePtr)); + ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; + sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first; + } + + auto recvIt = comm->channelOutInfos.find(recvKey); + if (recvIt == comm->channelOutInfos.end()) { + remoteMemories = + setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc); + std::vector outChannels = + setupSmChannels(comm, remoteMemories, const_cast((void*)recvBasePtr)); + ChannelInfo channelInfo{outChannels, setupSmChannelDeviceHandles(outChannels)}; + recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first; + } + + smChannels = sendIt->second.smChannelDeviceHandles.get(); + smOutChannels = recvIt->second.smChannelDeviceHandles.get(); + } + + switch (datatype) { + case ncclFloat16: + CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels, + offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, + comm->comm->bootstrap()->getNranks(), count, stream)); + break; + case ncclFloat32: + CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels, + smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), + NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); + break; + case ncclInt32: + case ncclUint32: + CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels, + offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE, + comm->comm->bootstrap()->getNranks(), count, stream)); + break; + default: + return ncclInvalidArgument; + } + return ncclSuccess; +} + NCCL_API ncclResult_t ncclGetVersion(int* version) { if (version == nullptr) return ncclInvalidArgument; *version = MSCCLPP_VERSION; @@ -211,6 +334,30 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI commPtr->scratchBuff = mscclpp::allocExtSharedCuda(SCRATCH_SIZE); commPtr->remoteScratchRegMemories = setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc); + commPtr->executor = std::make_shared(mscclppComm); + + if (getenv("ALLREDUCEPKT_IP_JSON_FILE")) + commPtr->allReducePacketIPPlan = std::make_shared( + mscclpp::ExecutionPlan("allreduce_packet", getenv("ALLREDUCEPKT_IP_JSON_FILE"))); + if (getenv("ALLREDUCEPKT_OP_JSON_FILE")) + commPtr->allReducePacketOPPlan = std::make_shared( + mscclpp::ExecutionPlan("allreduce_packet", getenv("ALLREDUCEPKT_OP_JSON_FILE"))); + if (getenv("ALLREDUCE_IP_JSON_FILE")) + commPtr->allReduceIPPlan = + std::make_shared(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_IP_JSON_FILE"))); + if (getenv("ALLREDUCE_OP_JSON_FILE")) + commPtr->allReduceOPPlan = + std::make_shared(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_OP_JSON_FILE"))); + if (getenv("ALLREDUCE_SMALL_MSG_BOUNDARY")) + commPtr->smallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY")); + else + commPtr->smallMessageSizeBoundary = 16 * (1 << 10); + if (getenv("ALLREDUCE_LARGE_MSG_BOUNDARY")) + commPtr->largeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY")); + else + commPtr->largeMessageSizeBoundary = 1 << 20; + + if (commPtr->smallMessageSizeBoundary > commPtr->largeMessageSizeBoundary) return ncclInvalidArgument; *comm = commPtr; return ncclSuccess; @@ -321,82 +468,46 @@ NCCL_API ncclResult_t ncclBroadcast(const void*, void*, size_t, ncclDataType_t, } NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, - ncclRedOp_t, ncclComm_t comm, cudaStream_t stream) { + ncclRedOp_t reductionOperation, ncclComm_t comm, cudaStream_t stream) { // Checking if the parameters are valids if (sendbuff == nullptr || recvbuff == nullptr || count == 0 || ncclTypeSize(datatype) == 0 || comm == nullptr) return ncclInvalidArgument; // Declarating variables - size_t sendBytes, recvBytes; - CUdeviceptr sendBasePtr, recvBasePtr; - MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)sendbuff)); - MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff)); - size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr; - size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr; - uint32_t scratchBuffIdx = (++(comm->buffFlag)) % comm->numScratchBuff; - size_t offsetScratch = (SCRATCH_SIZE / comm->numScratchBuff) * scratchBuffIdx; + size_t bytes = count * ncclTypeSize(datatype); int rank = comm->comm->bootstrap()->getRank(); - channelKey sendKey{(void*)sendBasePtr, sendBytes}; - channelKey recvKey{(void*)recvBasePtr, recvBytes}; - mscclpp::DeviceHandle* smChannels = nullptr; - mscclpp::DeviceHandle* smOutChannels = nullptr; - - // Creating the channels - if (count * ncclTypeSize(datatype) <= (1 << 20)) { - auto sendIt = comm->channelScratchInfos.find(sendKey); - if (sendIt == comm->channelScratchInfos.end()) { - std::vector channels = - setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)sendBasePtr)); - ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; - sendIt = comm->channelScratchInfos.emplace(sendKey, channelInfo).first; - } - smChannels = sendIt->second.smChannelDeviceHandles.get(); + if (bytes < comm->smallMessageSizeBoundary) { + return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream); } else { - std::vector remoteMemories; - - auto sendIt = comm->channelInInfos.find(sendKey); - if (sendIt == comm->channelInInfos.end()) { - std::vector channels = - setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)sendBasePtr)); - ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; - sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first; + std::shared_ptr plan; + if (bytes <= comm->largeMessageSizeBoundary) + plan = (sendbuff == recvbuff) ? comm->allReducePacketIPPlan : comm->allReducePacketOPPlan; + else + plan = (sendbuff == recvbuff) ? comm->allReduceIPPlan : comm->allReduceOPPlan; + + if (plan == nullptr) + return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream); + + switch (datatype) { + case ncclFloat16: + comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, 1024, + *plan, stream, mscclpp::PacketType::LL8); + break; + case ncclFloat32: + comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32, + 1024, *plan, stream, mscclpp::PacketType::LL8); + break; + case ncclInt32: + case ncclUint32: + comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, 1024, + *plan, stream, mscclpp::PacketType::LL8); + break; + default: + return ncclInvalidArgument; } - - auto recvIt = comm->channelOutInfos.find(recvKey); - if (recvIt == comm->channelOutInfos.end()) { - remoteMemories = - setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc); - std::vector outChannels = - setupSmChannels(comm, remoteMemories, const_cast((void*)recvBasePtr)); - ChannelInfo channelInfo{outChannels, setupSmChannelDeviceHandles(outChannels)}; - recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first; - } - - smChannels = sendIt->second.smChannelDeviceHandles.get(); - smOutChannels = recvIt->second.smChannelDeviceHandles.get(); } - switch (datatype) { - case ncclFloat16: - CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels, - offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, - comm->comm->bootstrap()->getNranks(), count, stream)); - break; - case ncclFloat32: - CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels, - smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), - NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); - break; - case ncclInt32: - case ncclUint32: - CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels, - offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE, - comm->comm->bootstrap()->getNranks(), count, stream)); - break; - default: - return ncclInvalidArgument; - } return ncclSuccess; } @@ -442,6 +553,7 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t CUDACHECK(allgather((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank, NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream)); } + return ncclSuccess; } diff --git a/python/test/test_mscclpp.py b/python/test/test_mscclpp.py index 4af3ddb36..9535c869f 100644 --- a/python/test/test_mscclpp.py +++ b/python/test/test_mscclpp.py @@ -613,7 +613,10 @@ def test_executor(mpi_group: MpiGroup, filename: str): cp.random.seed(42) buffer = cp.random.random(nelems).astype(cp.float16) sub_arrays = cp.split(buffer, mpi_group.comm.size) - sendbuf = sub_arrays[mpi_group.comm.rank] + nelems_per_rank = int(nelems / mpi_group.comm.size) + sendbuf = cp.empty(nelems_per_rank).astype(cp.float16) + for i in range(nelems_per_rank): + sendbuf[i] = sub_arrays[mpi_group.comm.rank][i] expected = cp.zeros_like(sendbuf) for i in range(mpi_group.comm.size): expected += sub_arrays[i] diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc index e1b84a16c..6655a72df 100644 --- a/src/executor/execution_plan.cc +++ b/src/executor/execution_plan.cc @@ -123,7 +123,7 @@ std::vector ExecutionPlan::Impl::getOperations(int rank, int threadbl int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->operations.at(rank).size(); } -void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize) { +void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset) { std::ifstream file(this->planPath); json obj = json::parse(file); if (this->name != obj["name"]) { @@ -145,7 +145,31 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize) { this->setupChannels(gpus); this->inputSize = inputSize; - this->setupOperations(gpus); + this->setupOperations(gpus, contsSrcOffset, constDstOffset); +} + +void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset) { + std::ifstream file(this->planPath); + json obj = json::parse(file); + if (this->name != obj["name"]) { + throw Error("Plan name does not match", ErrorCode::ExecutorError); + } + std::string protocol = obj["protocol"]; + if (protocol == "LL") { + this->isUsingPacket = true; + } + const auto& gpus = obj["gpus"]; + + for (const auto& gpu : gpus) { + int rank = gpu["id"]; + this->inputChunks[rank] = gpu["inputChunks"]; + this->outputChunks[rank] = gpu["outputChunks"]; + this->scratchChunks[rank] = gpu["scratchChunks"]; + this->chunkGroups[rank] = gpu["chunkGroups"]; + } + + this->inputSize = inputSize; + this->setupOperations(gpus, contsSrcOffset, constDstOffset); } // Construct the channel info. Step 1. Flatten SM and PROXY channels into separate vectors. @@ -201,7 +225,7 @@ void ExecutionPlan::Impl::setupChannels(const json& gpus) { } } -void ExecutionPlan::Impl::setupOperations(const json& gpus) { +void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffset, size_t constDstOffset) { // setup threadblocks and operations for (const auto& gpu : gpus) { int rank = gpu["id"]; @@ -234,7 +258,8 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus) { // Get the relevant channel index in rank channelInfos operation.inputChannelIndexes[i] = channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["i_cids"][i]["id"]]; - operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["i_cids"][i]["off"]); + operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["i_cids"][i]["off"]) + + (srcBufferType != BufferType::SCRATCH ? contsSrcOffset : 0); chunkIndexes.push_back((uint32_t)op["i_cids"][i]["off"]); } } @@ -243,7 +268,8 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus) { operation.nInputs = op["srcs"].size(); operation.inputBufferType = convertToBufferType(op["srcs"][0]["buff"]); for (int i = 0; i < operation.nInputs; i++) { - operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["srcs"][i]["off"]); + operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["srcs"][i]["off"]) + + (operation.inputBufferType != BufferType::SCRATCH ? contsSrcOffset : 0); chunkIndexes.push_back((uint32_t)op["srcs"][i]["off"]); } } @@ -254,7 +280,8 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus) { BufferType dstBufferType = convertToBufferType(op["o_buff"]["dst"]); operation.outputChannelIndexes[i] = channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["o_cids"][i]["id"]]; - operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["o_cids"][i]["off"]); + operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["o_cids"][i]["off"]) + + (dstBufferType != BufferType::SCRATCH ? constDstOffset : 0); chunkIndexes.push_back((uint32_t)op["o_cids"][i]["off"]); } } @@ -263,7 +290,8 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus) { operation.nOutputs = op["dsts"].size(); operation.outputBufferType = convertToBufferType(op["dsts"][0]["buff"]); for (int i = 0; i < operation.nOutputs; i++) { - operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["dsts"][i]["off"]); + operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["dsts"][i]["off"]) + + (operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : 0); chunkIndexes.push_back((uint32_t)op["dsts"][i]["off"]); } } @@ -340,6 +368,8 @@ void ExecutionPlan::Impl::reset() { this->chunkGroups.clear(); } +void ExecutionPlan::Impl::operationsReset() { this->operations.clear(); } + ExecutionPlan::ExecutionPlan(const std::string& name, const std::string& planPath) : impl_(std::make_shared(name, planPath)) {} diff --git a/src/executor/executor.cc b/src/executor/executor.cc index 62d749d00..8402ad099 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -78,14 +78,24 @@ struct Executor::Impl { } ~Impl() = default; - ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t sendBufferSize, + ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t messageSize, + size_t contsSrcOffset, size_t constDstOffset, size_t sendBufferSize, size_t recvBufferSize, const ExecutionPlan& plan) { ExecutionContextKey key = {sendbuff, recvbuff, sendBufferSize, recvBufferSize, plan.impl_->name}; if (this->contexts.find(key) != this->contexts.end()) { + plan.impl_->operationsReset(); + plan.impl_->lightLoadExecutionPlan(messageSize, contsSrcOffset, constDstOffset); + this->setupDeviceExecutionPlan(this->contexts[key], rank, plan); + this->contexts[key].deviceExecutionPlansBuffer = + allocExtSharedCuda(this->contexts[key].deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan)); + memcpyCuda(this->contexts[key].deviceExecutionPlansBuffer.get(), + (char*)this->contexts[key].deviceExecutionPlans.data(), + this->contexts[key].deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice); return this->contexts[key]; } + plan.impl_->reset(); - plan.impl_->loadExecutionPlan(sendBufferSize); + plan.impl_->loadExecutionPlan(messageSize, contsSrcOffset, constDstOffset); ExecutionContext context; size_t scratchBufferSize = plan.impl_->getScratchBufferSize(rank, sendBufferSize); @@ -172,6 +182,16 @@ struct Executor::Impl { comm->setup(); for (size_t i = 0; i < remoteRegMemoryFutures.size(); i++) { context.registeredMemories[{bufferType, connectedPeers[i]}] = std::move(remoteRegMemoryFutures[i].get()); + CUdeviceptr myRegBaseAdr, peerRegBaseAdr; + size_t temp; + MSCCLPP_CUTHROW(cuMemGetAddressRange(&myRegBaseAdr, &temp, (CUdeviceptr)(char*)memory.data())); + MSCCLPP_CUTHROW(cuMemGetAddressRange( + &peerRegBaseAdr, &temp, + (CUdeviceptr)(char*)context.registeredMemories[{bufferType, connectedPeers[i]}].data())); + size_t myRegOffset = (char*)memory.data() - (char*)myRegBaseAdr; + size_t peerRegOffset = + (char*)context.registeredMemories[{bufferType, connectedPeers[i]}].data() - (char*)peerRegBaseAdr; + if (myRegOffset != peerRegOffset) throw Error("Divergent data offset between peers", ErrorCode::ExecutorError); } } } @@ -295,13 +315,20 @@ struct Executor::Impl { Executor::Executor(std::shared_ptr comm) : impl_(std::make_unique(comm)) {} -void Executor::execute(int rank, void* sendbuff, void* recvBuff, size_t sendBuffSize, size_t recvBuffSize, +void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuffSize, size_t recvBuffSize, DataType dataType, int nthreads, const ExecutionPlan& plan, cudaStream_t stream, PacketType packetType) { - ExecutionContext context = - this->impl_->setupExecutionContext(rank, sendbuff, recvBuff, sendBuffSize, recvBuffSize, plan); + size_t sendBytes, recvBytes; + CUdeviceptr sendBasePtr, recvBasePtr; + MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)sendbuff)); + MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff)); + size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr; + size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr; + + ExecutionContext context = this->impl_->setupExecutionContext( + rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, offsetIn, offsetOut, sendBytes, recvBytes, plan); // TODO(binyli): need to flush proxy channel here this->impl_->proxyService->startProxy(); - this->impl_->launchKernel(context, rank, nthreads, sendbuff, recvBuff, dataType, stream, packetType); + this->impl_->launchKernel(context, rank, nthreads, sendbuff, recvbuff, dataType, stream, packetType); } Executor::~Executor() = default; diff --git a/src/include/execution_kernel.hpp b/src/include/execution_kernel.hpp index e781daa38..184092f5f 100644 --- a/src/include/execution_kernel.hpp +++ b/src/include/execution_kernel.hpp @@ -413,6 +413,7 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu } else if (op.type == OperationType::READ_REDUCE_COPY) { T* dst = getBuffer(input, output, scratch, op.dstBufferType); T* src = getBuffer(input, output, scratch, op.srcBufferType); + handleReadReduceCopySend(dst, op.dstOffset, src, op.srcOffset, smChannels, op.outputChannelIndexes, op.inputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs, op.nInputs, op.size, false); diff --git a/src/include/execution_plan.hpp b/src/include/execution_plan.hpp index 5e008c2ca..86545a0ee 100644 --- a/src/include/execution_plan.hpp +++ b/src/include/execution_plan.hpp @@ -57,11 +57,13 @@ struct ExecutionPlan::Impl { std::vector getOperations(int rank, int threadblock) const; int getThreadblockCount(int rank) const; - void loadExecutionPlan(size_t inputSize); + void loadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset); + void lightLoadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset); void setupChannels(const nlohmann::json& gpus); - void setupOperations(const nlohmann::json& gpus); + void setupOperations(const nlohmann::json& gpus, size_t contsSrcOffset, size_t constDstOffset); void reset(); + void operationsReset(); const std::string name; const std::string planPath;