Skip to content

Commit

Permalink
Batch reply. (#16)
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Rudenko <[email protected]>
  • Loading branch information
petro-rudenko authored Mar 24, 2022
1 parent c6c2aaf commit 6bd8710
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@ class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Ma
override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String],
listener: BlockFetchingListener,
downloadFileManager: DownloadFileManager): Unit = {
if (blockIds.length > transport.ucxShuffleConf.maxBlocksPerRequest) {
val (b1, b2) = blockIds.splitAt(blockIds.length / 2)
fetchBlocks(host, port, execId, b1, listener, downloadFileManager)
fetchBlocks(host, port, execId, b2, listener, downloadFileManager)
return
}

val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length)
val callbacks = Array.ofDim[OperationCallback](blockIds.length)
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[SparkShuffleBlockId]
ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, mapId2PartitionId(blockId.mapId), blockId.reduceId)
callbacks(i) = (result: OperationResult) => {
logInfo(s"Received ${ucxBlockIds(i)} " +
s"in ${result.getStats.get.getElapsedTimeNs} ns")
val memBlock = result.getData
val buffer = UnsafeUtils.getByteBufferView(memBlock.address, memBlock.size.toInt)
listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) {
Expand Down
13 changes: 10 additions & 3 deletions src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ class UcxShuffleConf(sparkConf: SparkConf) extends SparkConf {

lazy val useWakeup: Boolean = sparkConf.getBoolean(WAKEUP_FEATURE.key, WAKEUP_FEATURE.defaultValue.get)

private lazy val NUM_PROGRESS_THREADS= ConfigBuilder(getUcxConf("numProgressThreads"))
.doc("Number of threads in progress thread pool")
private lazy val NUM_IO_THREADS= ConfigBuilder(getUcxConf("numIoThreads"))
.doc("Number of threads in io thread pool")
.intConf
.createWithDefault(3)

lazy val numProgressThreads: Int = sparkConf.getInt(NUM_PROGRESS_THREADS.key, NUM_PROGRESS_THREADS.defaultValue.get)
lazy val numIoThreads: Int = sparkConf.getInt(NUM_IO_THREADS.key, NUM_IO_THREADS.defaultValue.get)

private lazy val NUM_WORKERS = ConfigBuilder(getUcxConf("numWorkers"))
.doc("Number of client workers")
Expand All @@ -94,4 +94,11 @@ class UcxShuffleConf(sparkConf: SparkConf) extends SparkConf {

lazy val numWorkers: Int = sparkConf.getInt(NUM_WORKERS.key, sparkConf.getInt("spark.executor.cores",
NUM_WORKERS.defaultValue.get))

private lazy val MAX_BLOCKS_IN_FLIGHT = ConfigBuilder(getUcxConf("maxBlocksPerRequest"))
.doc("Maximum number blocks per request")
.intConf
.createWithDefault(50)

lazy val maxBlocksPerRequest = sparkConf.getInt(MAX_BLOCKS_IN_FLIGHT.key, MAX_BLOCKS_IN_FLIGHT.defaultValue.get)
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,22 @@ package org.apache.spark.shuffle.ucx

import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.concurrent.TrieMap
import scala.collection.mutable
import scala.collection.parallel.ForkJoinTaskSupport

import org.openucx.jucx.ucp._
import org.openucx.jucx.ucs.UcsConstants
import org.openucx.jucx.ucs.UcsConstants.MEMORY_TYPE
import org.openucx.jucx.{UcxCallback, UcxUtils}
import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils}
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.ucx.memory.UcxHostBounceBuffersPool
import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread
import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils}
import org.apache.spark.shuffle.utils.UnsafeUtils
import org.apache.spark.util.ThreadUtils

class UcxRequest(private var request: UcpRequest, stats: OperationStats)
extends Request {
Expand Down Expand Up @@ -96,6 +98,9 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
private val registeredBlocks = new TrieMap[BlockId, Block]
private var progressThread: Thread = _
var hostBounceBufferMemoryPool: UcxHostBounceBuffersPool = _
private val threadPool = ThreadUtils.newForkJoinPool("IO threads",
ucxShuffleConf.numIoThreads)
private val taskSupport = new ForkJoinTaskSupport(threadPool)

private val errorHandler = new UcpEndpointErrorHandler {
override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = {
Expand Down Expand Up @@ -191,6 +196,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
ucxContext.close()
ucxContext = null
}
threadPool.shutdown()
}
}

Expand All @@ -204,7 +210,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
}

def addExecutors(executorIdsToAddress: Map[ExecutorId, SerializableDirectBuffer]): Unit = {
executorIdsToAddress.foreach{
executorIdsToAddress.foreach {
case (executorId, address) => executorAddresses.put(executorId, address.value)
}
}
Expand Down Expand Up @@ -252,7 +258,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
}

def unregisterShuffle(shuffleId: Int): Unit = {
registeredBlocks.keysIterator.foreach{
registeredBlocks.keysIterator.foreach {
case bid@UcxShuffleBockId(sid, _, _) if sid == shuffleId => registeredBlocks.remove(bid)
}
}
Expand All @@ -271,42 +277,54 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
.fetchBlocksByBlockIds(executorId, blockIds, resultBufferAllocator, callbacks)
}

def handleFetchBlockRequest(startTag: Int, buffer: ByteBuffer, replyEp: UcpEndpoint): Unit = try {
def handleFetchBlockRequest(replyTag: Int, buffer: ByteBuffer, replyEp: UcpEndpoint): Unit = try {
val blockIds = mutable.ArrayBuffer.empty[BlockId]

// 1. Deserialize blockIds from header
while (buffer.remaining() > 0) {
blockIds += UcxShuffleBockId.deserialize(buffer)
val blockId = UcxShuffleBockId.deserialize(buffer)
if (!registeredBlocks.contains(blockId)) {
throw new UcxException(s"$blockId is not registered")
}
blockIds += blockId
}

val blocks = blockIds.map(bid => registeredBlocks(bid))
val resultMemory = hostBounceBufferMemoryPool.get(4 * blockIds.length
+ blocks.map(_.getSize).sum)
val tagAndSizes = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE * blockIds.length
val resultMemory = hostBounceBufferMemoryPool.get(tagAndSizes + blocks.map(_.getSize).sum)
val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address,
resultMemory.size)
val outstandingRequests = new AtomicInteger(blockIds.length)

for ((block, i) <- blocks.zipWithIndex) {
val headerAddress = resultMemory.address + resultBuffer.position()
resultBuffer.putInt(startTag + i)
val localBuffer = resultBuffer.slice()
localBuffer.limit(block.getSize.toInt)
block.getBlock(localBuffer)
resultBuffer.position(resultBuffer.position() + block.getSize.toInt)

val startTime = System.nanoTime()
replyEp.sendAmNonBlocking(1, headerAddress, 4L,
headerAddress + 4, block.getSize, 0, new UcxCallback {
override def onSuccess(request: UcpRequest): Unit = {
if (outstandingRequests.decrementAndGet() == 0) {
hostBounceBufferMemoryPool.put(resultMemory)
}
logTrace(s"Sent ${blockIds(i)} to tag ${startTag + i} " +
s"in ${System.nanoTime() - startTime} ns.")
}
}, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)
resultBuffer.putInt(replyTag)

var offset = 0
val localBuffers = blocks.zipWithIndex.map {
case (block, i) =>
resultBuffer.putInt(UnsafeUtils.INT_SIZE + i * UnsafeUtils.INT_SIZE, block.getSize.toInt)
resultBuffer.position(tagAndSizes + offset)
val localBuffer = resultBuffer.slice()
offset += block.getSize.toInt
localBuffer.limit(block.getSize.toInt)
localBuffer
}
// Do parallel read of blocks
val blocksCollection = blocks.indices.par
blocksCollection.tasksupport = taskSupport
for (i <- blocksCollection) {
blocks(i).getBlock(localBuffers(i))
}

val startTime = System.nanoTime()
replyEp.sendAmNonBlocking(1, resultMemory.address, tagAndSizes,
resultMemory.address + tagAndSizes, resultMemory.size - tagAndSizes, 0, new UcxCallback {
override def onSuccess(request: UcpRequest): Unit = {
logTrace(s"Sent ${blockIds.length} blocks of size: ${resultMemory.size} " +
s"to tag $replyTag in ${System.nanoTime() - startTime} ns.")
hostBounceBufferMemoryPool.put(resultMemory)
}
}, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)

} catch {
case ex: Exception => logError(ex.getLocalizedMessage)
case ex: Throwable => logError(s"Failed to read and send data: $ex")
}

/**
Expand Down
113 changes: 74 additions & 39 deletions src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,25 @@ class UcxFailureOperationResult(errorMsg: String) extends OperationResult {
override def getData: MemoryBlock = null
}

class UcxAmDataMemoryBlock(ucpAmData: UcpAmData)
extends MemoryBlock(ucpAmData.getDataAddress, ucpAmData.getLength, true) with Logging {
class UcxAmDataMemoryBlock(ucpAmData: UcpAmData, offset: Long, size: Long,
refCount: AtomicInteger)
extends MemoryBlock(ucpAmData.getDataAddress + offset, size, true) with Logging {

override def close(): Unit = {
ucpAmData.close()
if (refCount.decrementAndGet() == 0) {
ucpAmData.close()
}
}
}

class UcxRefCountMemoryBlock(baseBlock: MemoryBlock, offset: Long, size: Long,
refCount: AtomicInteger)
extends MemoryBlock(baseBlock.address + offset, size, true) with Logging {

override def close(): Unit = {
if (refCount.decrementAndGet() == 0) {
baseBlock.close()
}
}
}

Expand All @@ -45,42 +59,51 @@ class UcxWorkerWrapper(val worker: UcpWorker, val transport: UcxShuffleTransport
extends Closeable with Logging {

private final val connections = new TrieMap[transport.ExecutorId, UcpEndpoint]
private val requestData = new TrieMap[Int, (OperationCallback, UcxRequest, transport.BufferAllocator)]
private val requestData = new TrieMap[Int, (Seq[OperationCallback], UcxRequest, transport.BufferAllocator)]
private val tag = new AtomicInteger(Random.nextInt())

// Receive block data handler
worker.setAmRecvHandler(1,
(headerAddress: Long, headerSize: Long, ucpAmData: UcpAmData, _: UcpEndpoint) => {
val headerBuffer = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt)
val i = headerBuffer.getInt(0)
val i = headerBuffer.getInt
val data = requestData.remove(i)

if (data.isEmpty) {
throw new UcxException(s"No data for tag $i.")
}

val (callback, request, allocator) = data.get
val (callbacks, request, allocator) = data.get
val stats = request.getStats.get.asInstanceOf[UcxStats]
stats.receiveSize = ucpAmData.getLength

// Header contains tag followed by sizes of blocks
val numBlocks = (headerSize.toInt - UnsafeUtils.INT_SIZE) / UnsafeUtils.INT_SIZE

var offset = 0
val refCounts = new AtomicInteger(numBlocks)
if (ucpAmData.isDataValid) {
if (callback != null) {
request.completed = true
stats.endTime = System.nanoTime()
logDebug(s"Received amData: $ucpAmData for tag $i " +
s"in ${stats.getElapsedTimeNs} ns")
callback.onComplete(new OperationResult {
override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS

override def getError: TransportError = null

override def getStats: Option[OperationStats] = Some(stats)

override def getData: MemoryBlock = new UcxAmDataMemoryBlock(ucpAmData)
})
UcsConstants.STATUS.UCS_INPROGRESS
} else {
UcsConstants.STATUS.UCS_OK
request.completed = true
stats.endTime = System.nanoTime()
logDebug(s"Received amData: $ucpAmData for tag $i " +
s"in ${stats.getElapsedTimeNs} ns")

for (b <- 0 until numBlocks) {
val blockSize = headerBuffer.getInt
if (callbacks(b) != null) {
callbacks(b).onComplete(new OperationResult {
override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS

override def getError: TransportError = null

override def getStats: Option[OperationStats] = Some(stats)

override def getData: MemoryBlock = new UcxAmDataMemoryBlock(ucpAmData, offset, blockSize, refCounts)
})
offset += blockSize
}
}
if (callbacks.isEmpty) UcsConstants.STATUS.UCS_OK else UcsConstants.STATUS.UCS_INPROGRESS
} else {
val mem = allocator(ucpAmData.getLength)
stats.amHandleTime = System.nanoTime()
Expand All @@ -92,15 +115,20 @@ class UcxWorkerWrapper(val worker: UcpWorker, val transport: UcxShuffleTransport
logDebug(s"Received rndv data of size: ${mem.size} for tag $i in " +
s"${stats.getElapsedTimeNs} ns " +
s"time from amHandle: ${System.nanoTime() - stats.amHandleTime} ns")
callback.onComplete(new OperationResult {
override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS
for (b <- 0 until numBlocks) {
val blockSize = headerBuffer.getInt
callbacks(b).onComplete(new OperationResult {
override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS

override def getError: TransportError = null
override def getError: TransportError = null

override def getStats: Option[OperationStats] = Some(stats)
override def getStats: Option[OperationStats] = Some(stats)

override def getData: MemoryBlock = new UcxRefCountMemoryBlock(mem, offset, blockSize, refCounts)
})
offset += blockSize
}

override def getData: MemoryBlock = mem
})
}
}, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST))
UcsConstants.STATUS.UCS_OK
Expand Down Expand Up @@ -177,21 +205,28 @@ class UcxWorkerWrapper(val worker: UcpWorker, val transport: UcxShuffleTransport
callbacks: Seq[OperationCallback]): Seq[Request] = {
val startTime = System.nanoTime()
val ep = getConnection(executorId)
val t = tag.getAndAdd(blockIds.length)

val buffer = Platform.allocateDirectBuffer(4 + blockIds.map(_.serializedSize).sum)

if (worker.getMaxAmHeaderSize <=
UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE * blockIds.length) {
val (b1, b2) = blockIds.splitAt(blockIds.length / 2)
val (c1, c2) = callbacks.splitAt(callbacks.length / 2)
val r1 = fetchBlocksByBlockIds(executorId, b1, resultBufferAllocator, c1)
val r2 = fetchBlocksByBlockIds(executorId, b2, resultBufferAllocator, c2)
return r1 ++ r2
}

val t = tag.incrementAndGet()

val buffer = Platform.allocateDirectBuffer(UnsafeUtils.INT_SIZE + blockIds.map(_.serializedSize).sum)
buffer.putInt(t)
blockIds.foreach(b => b.serialize(buffer))

val requests = new Array[UcxRequest](blockIds.size)
for (i <- blockIds.indices) {
val stats = new UcxStats()
requests(i) = new UcxRequest(null, stats)
requestData.put(t + i, (callbacks(i), requests(i), resultBufferAllocator))
}
val request = new UcxRequest(null, new UcxStats())
requestData.put(t, (callbacks, request, resultBufferAllocator))

val address = UnsafeUtils.getAdress(buffer)
val dataAddress = address + 4
val dataAddress = address + UnsafeUtils.INT_SIZE

ep.sendAmNonBlocking(0, address, 4, dataAddress, buffer.capacity() - 4,
UcpConstants.UCP_AM_SEND_FLAG_REPLY | UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() {
Expand All @@ -202,7 +237,7 @@ class UcxWorkerWrapper(val worker: UcpWorker, val transport: UcxShuffleTransport
}
}, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)

requests
Seq(request)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,14 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransp
setDaemon(true)
setName("Global worker progress thread")

private val threadPool = ThreadUtils.newDaemonFixedThreadPool(transport.ucxShuffleConf.numProgressThreads,
"Progress threads")

globalWorker.setAmRecvHandler(0, (headerAddress: Long, headerSize: Long, amData: UcpAmData,
replyEp: UcpEndpoint) => {
threadPool.submit(new Runnable() {
private val startTag = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt).getInt
private val data = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt)
private val ep = replyEp

override def run(): Unit = {
transport.handleFetchBlockRequest(startTag, data, ep)
amData.close()
}
})
UcsConstants.STATUS.UCS_INPROGRESS
}, UcpConstants.UCP_AM_FLAG_PERSISTENT_DATA | UcpConstants.UCP_AM_FLAG_WHOLE_MSG)
val replyTag = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt).getInt
val data = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt)
transport.handleFetchBlockRequest(replyTag, data, replyEp)
UcsConstants.STATUS.UCS_OK
})

override def run(): Unit = {
if (transport.ucxShuffleConf.useWakeup) {
Expand Down
Loading

0 comments on commit 6bd8710

Please sign in to comment.