diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala index cff68d1d..37e9f63a 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala @@ -1,34 +1,43 @@ package org.apache.spark.shuffle.compat.spark_2_4 -import org.openucx.jucx.UcxUtils -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ShuffleClient} -import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBockId, UcxShuffleTransport} -import org.apache.spark.shuffle.utils.UnsafeUtils +import org.apache.spark.shuffle.ucx.{UcxFetchCallBack, UcxDownloadCallBack, UcxShuffleBockId, UcxShuffleTransport} import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient{ override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { - val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length) - val callbacks = Array.ofDim[OperationCallback](blockIds.length) - for (i <- blockIds.indices) { - val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[SparkShuffleBlockId] - ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, blockId.mapId, blockId.reduceId) - callbacks(i) = (result: OperationResult) => { - val memBlock = result.getData - val buffer = UnsafeUtils.getByteBufferView(memBlock.address, memBlock.size.toInt) - listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) { - override def release: ManagedBuffer = { - memBlock.close() - this - } - }) + if (downloadFileManager == null) { + val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length) + val callbacks = Array.ofDim[UcxFetchCallBack](blockIds.length) + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)) + .asInstanceOf[SparkShuffleBlockId] + ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, blockId.mapId, + blockId.reduceId) + callbacks(i) = new UcxFetchCallBack(blockIds(i), listener) + } + val maxBlocksPerRequest= transport.maxBlocksPerRequest + val resultBufferAllocator = transport.hostBounceBufferMemoryPool.get _ + for (i <- 0 until blockIds.length by maxBlocksPerRequest) { + val j = i + maxBlocksPerRequest + transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds.slice(i, j), + resultBufferAllocator, + callbacks.slice(i, j)) + } + } else { + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)) + .asInstanceOf[SparkShuffleBlockId] + val ucxBlockId = UcxShuffleBockId(blockId.shuffleId, blockId.mapId, + blockId.reduceId) + val callback = new UcxDownloadCallBack(blockIds(i), listener, + downloadFileManager, + transport.sparkTransportConf) + transport.fetchBlockByStream(execId.toLong, ucxBlockId, callback) } } - val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) - transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks) } override def close(): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala index 50b6cfd5..7ee0d443 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala @@ -5,10 +5,8 @@ package org.apache.spark.shuffle.compat.spark_3_0 import org.apache.spark.internal.Logging -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient, DownloadFileManager} -import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBockId, UcxShuffleTransport} -import org.apache.spark.shuffle.utils.UnsafeUtils +import org.apache.spark.shuffle.ucx.{UcxFetchCallBack, UcxDownloadCallBack, UcxShuffleBockId, UcxShuffleTransport} import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Map[Long, Int]) extends BlockStoreClient with Logging { @@ -16,32 +14,38 @@ class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Ma override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { - if (blockIds.length > transport.ucxShuffleConf.maxBlocksPerRequest) { - val (b1, b2) = blockIds.splitAt(blockIds.length / 2) - fetchBlocks(host, port, execId, b1, listener, downloadFileManager) - fetchBlocks(host, port, execId, b2, listener, downloadFileManager) - return - } - - val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length) - val callbacks = Array.ofDim[OperationCallback](blockIds.length) - for (i <- blockIds.indices) { - val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[SparkShuffleBlockId] - ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, mapId2PartitionId(blockId.mapId), blockId.reduceId) - callbacks(i) = (result: OperationResult) => { - val memBlock = result.getData - val buffer = UnsafeUtils.getByteBufferView(memBlock.address, memBlock.size.toInt) - listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) { - override def release: ManagedBuffer = { - memBlock.close() - this - } - }) + if (downloadFileManager == null) { + val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length) + val callbacks = Array.ofDim[UcxFetchCallBack](blockIds.length) + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)) + .asInstanceOf[SparkShuffleBlockId] + ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, + mapId2PartitionId(blockId.mapId), + blockId.reduceId) + callbacks(i) = new UcxFetchCallBack(blockIds(i), listener) + } + val maxBlocksPerRequest= transport.maxBlocksPerRequest + val resultBufferAllocator = transport.hostBounceBufferMemoryPool.get _ + for (i <- 0 until blockIds.length by maxBlocksPerRequest) { + val j = i + maxBlocksPerRequest + transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds.slice(i, j), + resultBufferAllocator, + callbacks.slice(i, j)) + } + } else { + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)) + .asInstanceOf[SparkShuffleBlockId] + val ucxBlockId = UcxShuffleBockId(blockId.shuffleId, + mapId2PartitionId(blockId.mapId), + blockId.reduceId) + val callback = new UcxDownloadCallBack(blockIds(i), listener, + downloadFileManager, + transport.sparkTransportConf) + transport.fetchBlockByStream(execId.toLong, ucxBlockId, callback) } } - val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) - transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks) - transport.progress() } override def close(): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/CommonUcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/ucx/CommonUcxShuffleBlockResolver.scala index 41a2337f..7230d0be 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/CommonUcxShuffleBlockResolver.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/CommonUcxShuffleBlockResolver.scala @@ -44,8 +44,8 @@ abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffle val block = new Block { private val fileOffset = offset - override def getBlock(byteBuffer: ByteBuffer): Unit = { - channel.read(byteBuffer, fileOffset) + override def getBlock(byteBuffer: ByteBuffer, offset: Long): Unit = { + channel.read(byteBuffer, fileOffset + offset) } override def getSize: Long = blockLength diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala index 314b88a3..68344409 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala @@ -43,7 +43,7 @@ trait Block extends BlockLock { def getMemoryBlock: MemoryBlock = ??? // Get block from a file into byte buffer backed bunce buffer - def getBlock(byteBuffer: ByteBuffer): Unit + def getBlock(byteBuffer: ByteBuffer, offset: Long): Unit } object OperationStatus extends Enumeration { @@ -90,6 +90,7 @@ trait Request { */ trait OperationCallback { def onComplete(result: OperationResult): Unit + def onData(buf: ByteBuffer): Unit = ??? } /** diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxFetchCallBack.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxFetchCallBack.scala new file mode 100644 index 00000000..ee03001f --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxFetchCallBack.scala @@ -0,0 +1,50 @@ +package org.apache.spark.shuffle.ucx + +import java.nio.ByteBuffer + +import org.apache.spark.network.util.TransportConf +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager} + +import org.apache.spark.shuffle.utils.UnsafeUtils + +class UcxFetchCallBack( + blockId: String, listener: BlockFetchingListener) + extends OperationCallback { + + override def onComplete(result: OperationResult): Unit = { + val memBlock = result.getData + val buffer = UnsafeUtils.getByteBufferView(memBlock.address, + memBlock.size.toInt) + listener.onBlockFetchSuccess(blockId, new NioManagedBuffer(buffer) { + override def release: ManagedBuffer = { + memBlock.close() + this + } + }) + } +} + +class UcxDownloadCallBack( + blockId: String, listener: BlockFetchingListener, + downloadFileManager: DownloadFileManager, + transportConf: TransportConf) + extends OperationCallback { + + private[this] val targetFile = downloadFileManager.createTempFile( + transportConf) + private[this] val channel = targetFile.openForWriting(); + + override def onData(buffer: ByteBuffer): Unit = { + while (buffer.hasRemaining()) { + channel.write(buffer); + } + } + + override def onComplete(result: OperationResult): Unit = { + listener.onBlockFetchSuccess(blockId, channel.closeAndRead()); + if (!downloadFileManager.registerTempFileToClean(targetFile)) { + targetFile.delete(); + } + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala index 86a26e20..a9741599 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala @@ -91,4 +91,11 @@ class UcxShuffleConf(sparkConf: SparkConf) extends SparkConf { .createWithDefault(50) lazy val maxBlocksPerRequest: Int = sparkConf.getInt(MAX_BLOCKS_IN_FLIGHT.key, MAX_BLOCKS_IN_FLIGHT.defaultValue.get) + + private lazy val MAX_REPLY_SIZE = ConfigBuilder(getUcxConf("maxReplySize")) + .doc("Maximum size of fetch reply message") + .bytesConf(ByteUnit.MiB) + .createWithDefault(32) + + lazy val maxReplySize: Long = sparkConf.getSizeAsBytes(MAX_REPLY_SIZE.key, MAX_REPLY_SIZE.defaultValueString) } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index ae8bb119..5366e8fc 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -4,8 +4,10 @@ */ package org.apache.spark.shuffle.ucx +import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging +import org.apache.spark.util.ThreadUtils import org.apache.spark.shuffle.ucx.memory.UcxHostBounceBuffersPool import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils} @@ -88,6 +90,14 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo private var progressThread: Thread = _ var hostBounceBufferMemoryPool: UcxHostBounceBuffersPool = _ + private[spark] lazy val replyThreadPool = + ThreadUtils.newDaemonFixedThreadPool(ucxShuffleConf.numListenerThreads, + "UcxListenerThread") + private[spark] lazy val sparkTransportConf = SparkTransportConf.fromSparkConf( + ucxShuffleConf.getSparkConf, "shuffle", ucxShuffleConf.numWorkers) + private[spark] lazy val maxBlocksPerRequest = maxBlocksInAmHeader.min( + ucxShuffleConf.maxBlocksPerRequest).toInt + private val errorHandler = new UcpEndpointErrorHandler { override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = { if (errorCode == UcsConstants.STATUS.UCS_ERR_CONNECTION_RESET) { @@ -190,6 +200,10 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo } } + def maxBlocksInAmHeader(): Long = { + (globalWorker.getMaxAmHeaderSize - 2) / UnsafeUtils.INT_SIZE + } + /** * Add executor's worker address. For standalone testing purpose and for implementations that makes * connection establishment outside of UcxShuffleManager. @@ -268,8 +282,14 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo override def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], resultBufferAllocator: BufferAllocator, callbacks: Seq[OperationCallback]): Seq[Request] = { - allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt) - .fetchBlocksByBlockIds(executorId, blockIds, resultBufferAllocator, callbacks) + selectClientWorker.fetchBlocksByBlockIds(executorId, blockIds, + resultBufferAllocator, + callbacks) + } + + def fetchBlockByStream(executorId: ExecutorId, blockId: BlockId, + callback: OperationCallback): Unit = { + selectClientWorker.fetchBlockByStream(executorId, blockId, callback) } def connectServerWorkers(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { @@ -277,24 +297,51 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo } def handleFetchBlockRequest(replyTag: Int, amData: UcpAmData, replyExecutor: Long): Unit = { - val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) - val blockIds = mutable.ArrayBuffer.empty[BlockId] - - // 1. Deserialize blockIds from header - while (buffer.remaining() > 0) { - val blockId = UcxShuffleBockId.deserialize(buffer) - if (!registeredBlocks.contains(blockId)) { - throw new UcxException(s"$blockId is not registered") + replyThreadPool.submit(new Runnable { + override def run(): Unit = { + val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, + amData.getLength.toInt) + val blockIds = mutable.ArrayBuffer.empty[BlockId] + + // 1. Deserialize blockIds from header + while (buffer.remaining() > 0) { + val blockId = UcxShuffleBockId.deserialize(buffer) + if (!registeredBlocks.contains(blockId)) { + throw new UcxException(s"$blockId is not registered") + } + blockIds += blockId + } + + val blocks = blockIds.map(bid => registeredBlocks(bid)) + amData.close() + selectServerWorker.handleFetchBlockRequest(blocks, replyTag, + replyExecutor) } - blockIds += blockId - } + }) + } + + def handleFetchBlockStream(replyTag: Int, blockId: BlockId, + replyExecutor: Long): Unit = { + replyThreadPool.submit(new Runnable { + override def run(): Unit = { + val block = registeredBlocks(blockId) + selectServerWorker.handleFetchBlockStream(block, replyTag, + replyExecutor) + } + }) + } - val blocks = blockIds.map(bid => registeredBlocks(bid)) - amData.close() - allocatedServerWorkers((Thread.currentThread().getId % allocatedServerWorkers.length).toInt) - .handleFetchBlockRequest(blocks, replyTag, replyExecutor) + @inline + def selectClientWorker(): UcxWorkerWrapper = { + allocatedClientWorkers( + (Thread.currentThread().getId % allocatedClientWorkers.length).toInt) } + @inline + def selectServerWorker(): UcxWorkerWrapper = { + allocatedServerWorkers( + (Thread.currentThread().getId % allocatedServerWorkers.length).toInt) + } /** * Progress outstanding operations. This routine is blocking (though may poll for event). @@ -304,7 +351,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo * But not guaranteed that at least one [[ fetchBlocksByBlockIds ]] completed! */ override def progress(): Unit = { - allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt).progress() + selectClientWorker.progress() } def progressConnect(): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxStreamState.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxStreamState.scala new file mode 100644 index 00000000..ee350eda --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxStreamState.scala @@ -0,0 +1,11 @@ +package org.apache.spark.shuffle.ucx + +class UcxStreamState(val callback: OperationCallback, + val request: UcxRequest, + var remaining: Int) {} + +class UcxSliceState(val callback: OperationCallback, + val request: UcxRequest, + val mem: MemoryBlock, + var offset: Long, + var remaining: Int) {} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 66cefd76..88f09787 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -5,7 +5,7 @@ package org.apache.spark.shuffle.ucx import java.io.Closeable -import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch} import java.util.concurrent.atomic.AtomicInteger import scala.collection.concurrent.TrieMap import scala.util.Random @@ -24,6 +24,16 @@ import java.nio.ByteBuffer import scala.collection.parallel.ForkJoinTaskSupport +class UcxSucceedOperationResult(mem: MemoryBlock, stats: OperationStats) extends OperationResult { + override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS + + override def getError: TransportError = null + + override def getStats: Option[OperationStats] = Some(stats) + + override def getData: MemoryBlock = mem +} + class UcxFailureOperationResult(errorMsg: String) extends OperationResult { override def getStatus: OperationStatus.Value = OperationStatus.FAILURE @@ -65,6 +75,8 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i private final val connections = new TrieMap[transport.ExecutorId, UcpEndpoint] private val requestData = new TrieMap[Int, (Seq[OperationCallback], UcxRequest, transport.BufferAllocator)] + private[ucx] lazy val streamData = new TrieMap[Int, UcxStreamState] + private[ucx] lazy val sliceData = new TrieMap[Int, UcxSliceState] private val tag = new AtomicInteger(Random.nextInt()) private val flushRequests = new ConcurrentLinkedQueue[UcpRequest]() @@ -72,6 +84,147 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i transport.ucxShuffleConf.numIoThreads) private val ioTaskSupport = new ForkJoinTaskSupport(ioThreadPool) + private[ucx] lazy val maxReplySize = transport.ucxShuffleConf.maxReplySize + private[ucx] lazy val memPool = transport.hostBounceBufferMemoryPool + + private[this] case class UcxSliceReplyHandle() extends UcpAmRecvCallback() { + override def onReceive(headerAddress: Long, headerSize: Long, + ucpAmData: UcpAmData, ep: UcpEndpoint): Int = { + val headerBuffer = UnsafeUtils.getByteBufferView(headerAddress, + headerSize.toInt) + val i = headerBuffer.getInt + val remaining = headerBuffer.getInt + + val sliceState = sliceData.getOrElseUpdate(i, { + requestData.remove(i) match { + case Some(data) => { + val mem = memPool.get(maxReplySize * (remaining + 1)) + new UcxSliceState(data._1(0), data._2, mem, 0L, Int.MaxValue) + } + case None => throw new UcxException(s"Slice tag $i context not found.") + } + }) + + if (remaining >= sliceState.remaining) { + throw new UcxException( + s"Slice tag $i out of order $remaining <= ${sliceState.remaining}.") + } + sliceState.remaining = remaining + + val stats = sliceState.request.getStats.get.asInstanceOf[UcxStats] + stats.receiveSize += ucpAmData.getLength + + val currentAddress = sliceState.mem.address + sliceState.offset + if (ucpAmData.isDataValid) { + stats.endTime = System.nanoTime() + logDebug(s"Slice receive amData ${ucpAmData} tag $i in " + + s"${stats.getElapsedTimeNs} ns") + val curBuf = UnsafeUtils.getByteBufferView( + ucpAmData.getDataAddress, ucpAmData.getLength.toInt) + val buffer = UnsafeUtils.getByteBufferView( + currentAddress, ucpAmData.getLength.toInt) + buffer.put(curBuf) + sliceState.offset += ucpAmData.getLength() + if (remaining == 0) { + val result = new UcxRefCountMemoryBlock(sliceState.mem, 0, + sliceState.offset, + new AtomicInteger(1)) + sliceState.callback.onComplete( + new UcxSucceedOperationResult(result, stats)) + sliceData.remove(i) + } + } else { + stats.amHandleTime = System.nanoTime() + worker.recvAmDataNonBlocking( + ucpAmData.getDataHandle, currentAddress, ucpAmData.getLength, + new UcxCallback() { + override def onSuccess(r: UcpRequest): Unit = { + stats.endTime = System.nanoTime() + logDebug(s"Slice receive rndv data size ${ucpAmData.getLength} " + + s"tag $i in ${stats.getElapsedTimeNs} ns amHandle " + + s"${stats.endTime - stats.amHandleTime} ns") + sliceState.offset += ucpAmData.getLength() + if (remaining == 0) { + val result = new UcxRefCountMemoryBlock(sliceState.mem, 0, + sliceState.offset, + new AtomicInteger(1)) + sliceState.callback.onComplete( + new UcxSucceedOperationResult(result, stats)) + sliceData.remove(i) + } + } + }, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + } + UcsConstants.STATUS.UCS_OK + } + } + + private[this] case class UcxStreamReplyHandle() extends UcpAmRecvCallback() { + override def onReceive(headerAddress: Long, headerSize: Long, + ucpAmData: UcpAmData, ep: UcpEndpoint): Int = { + val headerBuffer = UnsafeUtils.getByteBufferView(headerAddress, + headerSize.toInt) + val i = headerBuffer.getInt + val remaining = headerBuffer.getInt + + val data = streamData.get(i) + if (data.isEmpty) { + throw new UcxException(s"Stream tag $i context not found.") + } + + val streamState = data.get + if (remaining >= streamState.remaining) { + throw new UcxException( + s"Stream tag $i out of order $remaining <= ${streamState.remaining}.") + } + streamState.remaining = remaining + + val stats = streamState.request.getStats.get.asInstanceOf[UcxStats] + stats.receiveSize += ucpAmData.getLength + + if (ucpAmData.isDataValid) { + stats.endTime = System.nanoTime() + logDebug(s"Stream receive amData ${ucpAmData} tag $i in " + + s"${stats.getElapsedTimeNs} ns") + val buffer = UnsafeUtils.getByteBufferView( + ucpAmData.getDataAddress, ucpAmData.getLength.toInt) + streamState.callback.onData(buffer) + if (remaining == 0) { + streamState.callback.onComplete( + new UcxSucceedOperationResult(null, stats)) + streamData.remove(i) + } + } else { + val mem = memPool.get(ucpAmData.getLength) + stats.amHandleTime = System.nanoTime() + worker.recvAmDataNonBlocking( + ucpAmData.getDataHandle, mem.address, ucpAmData.getLength, + new UcxCallback() { + override def onSuccess(r: UcpRequest): Unit = { + stats.endTime = System.nanoTime() + logDebug(s"Stream receive rndv data ${ucpAmData.getLength} " + + s"tag $i in ${stats.getElapsedTimeNs} ns amHandle " + + s"${stats.endTime - stats.amHandleTime} ns") + val buffer = UnsafeUtils.getByteBufferView( + mem.address, ucpAmData.getLength.toInt) + streamState.callback.onData(buffer) + mem.close() + if (remaining == 0) { + streamState.callback.onComplete( + new UcxSucceedOperationResult(null, stats)) + streamData.remove(i) + } + } + }, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + } + UcsConstants.STATUS.UCS_OK + } + } + + // Receive block data handler + worker.setAmRecvHandler(3, UcxSliceReplyHandle(), UcpConstants.UCP_AM_FLAG_WHOLE_MSG) + worker.setAmRecvHandler(2, UcxStreamReplyHandle(), UcpConstants.UCP_AM_FLAG_WHOLE_MSG) + if (isClientWorker) { // Receive block data handler worker.setAmRecvHandler(1, @@ -241,15 +394,6 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.LONG_SIZE val ep = getConnection(executorId) - if (worker.getMaxAmHeaderSize <= - headerSize + UnsafeUtils.INT_SIZE * blockIds.length) { - val (b1, b2) = blockIds.splitAt(blockIds.length / 2) - val (c1, c2) = callbacks.splitAt(callbacks.length / 2) - val r1 = fetchBlocksByBlockIds(executorId, b1, resultBufferAllocator, c1) - val r2 = fetchBlocksByBlockIds(executorId, b2, resultBufferAllocator, c2) - return r1 ++ r2 - } - val t = tag.incrementAndGet() val buffer = Platform.allocateDirectBuffer(headerSize + blockIds.map(_.serializedSize).sum) @@ -279,6 +423,11 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } def handleFetchBlockRequest(blocks: Seq[Block], replyTag: Int, replyExecutor: Long): Unit = try { + if ((blocks.length == 1) && (blocks(0).getSize >= maxReplySize)) { + handleFetchBlockStream(blocks(0), replyTag, replyExecutor) + return + } + val tagAndSizes = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE * blocks.length val resultMemory = transport.hostBounceBufferMemoryPool.get(tagAndSizes + blocks.map(_.getSize).sum) .asInstanceOf[UcxBounceBufferMemoryBlock] @@ -306,7 +455,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } for (i <- blocksCollection) { - blocks(i).getBlock(localBuffers(i)) + blocks(i).getBlock(localBuffers(i), 0) } val startTime = System.nanoTime() @@ -331,4 +480,97 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i case ex: Throwable => logError(s"Failed to read and send data: $ex") } + def fetchBlockByStream(executorId: transport.ExecutorId, blockId: BlockId, + callback: OperationCallback): Unit = { + val startTime = System.nanoTime() + val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.LONG_SIZE + + blockId.serializedSize + + val t = tag.incrementAndGet() + + val buffer = Platform.allocateDirectBuffer(headerSize) + buffer.putInt(t) + buffer.putLong(id) + blockId.serialize(buffer) + + val request = new UcxRequest(null, new UcxStats()) + streamData.put(t, new UcxStreamState(callback, request, Int.MaxValue)) + + val address = UnsafeUtils.getAdress(buffer) + + val ep = getConnection(executorId) + worker.synchronized { + ep.sendAmNonBlocking(2, address, headerSize, address, 0, + UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + buffer.clear() + logDebug(s"Worker $id sent stream to $executorId block $blockId " + + s"tag $t in ${System.nanoTime() - startTime} ns") + } + }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + } + } + + def handleFetchBlockStream(block: Block, replyTag: Int, + replyExecutor: Long, amId: Int = 2): Unit = { + val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE + val maxBodySize = maxReplySize - headerSize.toLong + val blockSize = block.getSize + val blockSlice = (0L until blockSize by maxBodySize) + val firstLatch = new CountDownLatch(1) + + def send(workerWrapper: UcxWorkerWrapper, currentId: Int, + sendLatch: CountDownLatch): Unit = try { + val mem = memPool.get(maxReplySize) + .asInstanceOf[UcxBounceBufferMemoryBlock] + val buffer = UcxUtils.getByteBufferView(mem.address, mem.size) + + val remaining = blockSlice.length - currentId - 1 + val currentOffset = blockSlice(currentId) + val currentSize = (blockSize - currentOffset).min(maxBodySize) + buffer.limit(headerSize + currentSize.toInt) + buffer.putInt(replyTag) + buffer.putInt(remaining) + block.getBlock(buffer, currentOffset) + + val nextLatch = new CountDownLatch(1) + sendLatch.await() + + val startTime = System.nanoTime() + val ep = workerWrapper.connections(replyExecutor) + val req = workerWrapper.worker.synchronized { + ep.sendAmNonBlocking(amId, mem.address, headerSize, + mem.address + headerSize, currentSize, 0, new UcxCallback { + override def onSuccess(request: UcpRequest): Unit = { + logTrace(s"Reply stream block $currentId size $currentSize tag " + + s"$replyTag in ${System.nanoTime() - startTime} ns.") + mem.close() + nextLatch.countDown() + } + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to reply stream $errorMsg") + mem.close() + nextLatch.countDown() + } + }, new UcpRequestParams() + .setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + .setMemoryHandle(mem.memory)) + } + if (remaining > 0) { + transport.replyThreadPool.submit(new Runnable { + override def run = send(transport.selectServerWorker, currentId + 1, + nextLatch) + }) + } + while (!req.isCompleted) { + progress() + } + } catch { + case ex: Throwable => + logError(s"Failed to reply stream tag $replyTag id $currentId $ex.") + } + + firstLatch.countDown() + send(this, 0, firstLatch) + } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala index 16bd821b..433c7f05 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala @@ -195,8 +195,8 @@ object UcxPerfBenchmark extends App with Logging { override def getSize: Long = options.blockSize - override def getBlock(byteBuffer: ByteBuffer): Unit = { - channel.read(byteBuffer, fileOffset) + override def getBlock(byteBuffer: ByteBuffer, offset: Long): Unit = { + channel.read(byteBuffer, fileOffset + offset) } } ucxTransport.register(blockId, block) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala index a9f27b83..40e28376 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala @@ -7,28 +7,20 @@ package org.apache.spark.shuffle.ucx.rpc import org.openucx.jucx.ucp.{UcpAmData, UcpConstants, UcpEndpoint, UcpWorker} import org.openucx.jucx.ucs.UcsConstants import org.apache.spark.internal.Logging -import org.apache.spark.shuffle.ucx.UcxShuffleTransport +import org.apache.spark.shuffle.ucx.{UcxShuffleTransport, UcxShuffleBockId} import org.apache.spark.shuffle.utils.UnsafeUtils -import org.apache.spark.util.ThreadUtils class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransport) extends Thread with Logging { setDaemon(true) setName("Global worker progress thread") - private val replyWorkersThreadPool = ThreadUtils.newDaemonFixedThreadPool(transport.ucxShuffleConf.numListenerThreads, - "UcxListenerThread") - // Main RPC thread. Submit each RPC request to separate thread and send reply back from separate worker. globalWorker.setAmRecvHandler(0, (headerAddress: Long, headerSize: Long, amData: UcpAmData, _: UcpEndpoint) => { val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) val replyTag = header.getInt val replyExecutor = header.getLong - replyWorkersThreadPool.submit(new Runnable { - override def run(): Unit = { - transport.handleFetchBlockRequest(replyTag, amData, replyExecutor) - } - }) + transport.handleFetchBlockRequest(replyTag, amData, replyExecutor) UcsConstants.STATUS.UCS_INPROGRESS }, UcpConstants.UCP_AM_FLAG_PERSISTENT_DATA | UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) @@ -43,6 +35,15 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransp UcsConstants.STATUS.UCS_OK }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG) + globalWorker.setAmRecvHandler(2, (headerAddress: Long, headerSize: Long, amData: UcpAmData, _: UcpEndpoint) => { + val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) + val replyTag = header.getInt + val replyExecutor = header.getLong + val blockId = UcxShuffleBockId.deserialize(header) + transport.handleFetchBlockStream(replyTag, blockId, replyExecutor) + UcsConstants.STATUS.UCS_OK + }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) + override def run(): Unit = { if (transport.ucxShuffleConf.useWakeup) { while (!isInterrupted) {