From 285f315acd05700176d1310689600db48cf68ae6 Mon Sep 17 00:00:00 2001 From: Damon Brown Date: Thu, 2 Sep 2021 15:23:51 +0000 Subject: [PATCH 1/2] Support Spark-3.1 version. --- README.md | 6 +- pom.xml | 49 ++++- .../spark_3_1/OnOffsetsFetchCallback.java | 93 +++++++++ .../compat/spark_3_1/UcxShuffleClient.java | 136 +++++++++++++ .../spark_3_1/UcxLocalDiskShuffleDataIO.scala | 20 ++ ...cxLocalDiskShuffleExecutorComponents.scala | 47 +++++ .../spark_3_1/UcxShuffleBlockResolver.scala | 52 +++++ .../compat/spark_3_1/UcxShuffleManager.scala | 75 +++++++ .../compat/spark_3_1/UcxShuffleReader.scala | 191 ++++++++++++++++++ 9 files changed, 662 insertions(+), 7 deletions(-) create mode 100755 src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_1/OnOffsetsFetchCallback.java create mode 100755 src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_1/UcxShuffleClient.java create mode 100755 src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxLocalDiskShuffleDataIO.scala create mode 100755 src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxLocalDiskShuffleExecutorComponents.scala create mode 100755 src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleBlockResolver.scala create mode 100755 src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleManager.scala create mode 100755 src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleReader.scala diff --git a/README.md b/README.md index b9cb6da3..5e61196f 100755 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ that are supported by [UCX](https://github.com/openucx/ucx#supported-transports) This open-source project is developed, maintained and supported by the [UCF consortium](http://www.ucfconsortium.org/). ## Runtime requirements -* Apache Spark 2.3/2.4/3.0 +* Apache Spark 2.3/2.4/3.0/3.1 * Java 8+ * Installed UCX of version 1.10+, and [UCX supported transport hardware](https://github.com/openucx/ucx#supported-transports). @@ -34,9 +34,9 @@ to spark (e.g. in $SPARK_HOME/conf/spark-defaults.conf): ``` spark.shuffle.manager org.apache.spark.shuffle.UcxShuffleManager ``` -For spark-3.0 version add SparkUCX ShuffleIO plugin: +For spark-3.0 or spark-3.1 versions add SparkUCX ShuffleIO plugin: ``` -spark.shuffle.sort.io.plugin.class org.apache.spark.shuffle.compat.spark_3_0.UcxLocalDiskShuffleDataIO +spark.shuffle.sort.io.plugin.class org.apache.spark.shuffle.compat.spark_(3_0|3_1).UcxLocalDiskShuffleDataIO ``` ### Build diff --git a/pom.xml b/pom.xml index b23b54ac..0bb569ff 100755 --- a/pom.xml +++ b/pom.xml @@ -43,6 +43,7 @@ See file LICENSE for terms. maven-compiler-plugin + **/spark_3_1/** **/spark_3_0/** **/spark_2_4/** @@ -53,6 +54,7 @@ See file LICENSE for terms. scala-maven-plugin + **/spark_3_1/** **/spark_3_0/** **/spark_2_4/** @@ -62,7 +64,7 @@ See file LICENSE for terms. 2.1.0 - **/spark_3_0/**, **/spark_2_4/** + **/spark_3_1/**, **/spark_3_0/**, **/spark_2_4/** 2.11.12 2.11 @@ -76,6 +78,7 @@ See file LICENSE for terms. maven-compiler-plugin + **/spark_3_1/** **/spark_3_0/** **/spark_2_1/** @@ -86,6 +89,7 @@ See file LICENSE for terms. scala-maven-plugin + **/spark_3_1/** **/spark_2_1/** **/spark_3_0/** @@ -95,13 +99,48 @@ See file LICENSE for terms. 2.4.0 - **/spark_3_0/**, **/spark_2_1/** + **/spark_3_1/**, **/spark_3_0/**, **/spark_2_1/** 2.11.12 2.11 spark-3.0 + + + + org.apache.maven.plugins + maven-compiler-plugin + + + **/spark_3_1/** + **/spark_2_1/** + **/spark_2_4/** + + + + + net.alchim31.maven + scala-maven-plugin + + + **/spark_3_1/** + **/spark_2_1/** + **/spark_2_4/** + + + + + + + 3.0.1 + 2.12.10 + 2.12 + **/spark_3_1/**, **/spark_2_1/**, **/spark_2_4/** + + + + spark-3.1 true @@ -112,6 +151,7 @@ See file LICENSE for terms. maven-compiler-plugin + **/spark_3_0/** **/spark_2_1/** **/spark_2_4/** @@ -122,6 +162,7 @@ See file LICENSE for terms. scala-maven-plugin + **/spark_3_0/** **/spark_2_1/** **/spark_2_4/** @@ -130,10 +171,10 @@ See file LICENSE for terms. - 3.0.1 + 3.1.2 2.12.10 2.12 - **/spark_2_1/**, **/spark_2_4/** + **/spark_3_0/**, **/spark_2_1/**, **/spark_2_4/** diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_1/OnOffsetsFetchCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_1/OnOffsetsFetchCallback.java new file mode 100755 index 00000000..14612111 --- /dev/null +++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_1/OnOffsetsFetchCallback.java @@ -0,0 +1,93 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ +package org.apache.spark.shuffle.ucx.reducer.compat.spark_3_1; + +import org.apache.spark.network.shuffle.BlockFetchingListener; +import org.apache.spark.shuffle.UcxWorkerWrapper; +import org.apache.spark.shuffle.ucx.UnsafeUtils; +import org.apache.spark.shuffle.ucx.memory.RegisteredMemory; +import org.apache.spark.shuffle.ucx.reducer.ReducerCallback; +import org.apache.spark.shuffle.ucx.reducer.OnBlocksFetchCallback; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.ShuffleBlockBatchId; +import org.apache.spark.storage.ShuffleBlockId; +import org.openucx.jucx.UcxUtils; +import org.openucx.jucx.ucp.UcpEndpoint; +import org.openucx.jucx.ucp.UcpRemoteKey; +import org.openucx.jucx.ucp.UcpRequest; + +import java.nio.ByteBuffer; +import java.util.Map; + +/** + * Callback, called when got all offsets for blocks + */ +public class OnOffsetsFetchCallback extends ReducerCallback { + private final RegisteredMemory offsetMemory; + private final long[] dataAddresses; + private Map dataRkeysCache; + private final Map mapId2PartitionId; + + public OnOffsetsFetchCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener, + RegisteredMemory offsetMemory, long[] dataAddresses, + Map dataRkeysCache, + Map mapId2PartitionId) { + super(blockIds, endpoint, listener); + this.offsetMemory = offsetMemory; + this.dataAddresses = dataAddresses; + this.dataRkeysCache = dataRkeysCache; + this.mapId2PartitionId = mapId2PartitionId; + } + + @Override + public void onSuccess(UcpRequest request) { + ByteBuffer resultOffset = offsetMemory.getBuffer(); + long totalSize = 0; + int[] sizes = new int[blockIds.length]; + int offset = 0; + long blockOffset; + long blockLength; + int offsetSize = UnsafeUtils.LONG_SIZE; + for (int i = 0; i < blockIds.length; i++) { + // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd | + if (blockIds[i] instanceof ShuffleBlockBatchId) { + ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId) blockIds[i]; + int blocksInBatch = blockBatchId.endReduceId() - blockBatchId.startReduceId(); + blockOffset = resultOffset.getLong(offset * 2 * offsetSize); + blockLength = resultOffset.getLong(offset * 2 * offsetSize + offsetSize * blocksInBatch) + - blockOffset; + offset += blocksInBatch; + } else { + blockOffset = resultOffset.getLong(offset * 16); + blockLength = resultOffset.getLong(offset * 16 + 8) - blockOffset; + offset++; + } + + assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE); + sizes[i] = (int) blockLength; + totalSize += blockLength; + dataAddresses[i] += blockOffset; + } + + assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE); + mempool.put(offsetMemory); + RegisteredMemory blocksMemory = mempool.get((int) totalSize); + + offset = 0; + // Submits N fetch blocks requests + for (int i = 0; i < blockIds.length; i++) { + int mapPartitionId = (blockIds[i] instanceof ShuffleBlockId) ? + mapId2PartitionId.get(((ShuffleBlockId)blockIds[i]).mapId()) : + mapId2PartitionId.get(((ShuffleBlockBatchId)blockIds[i]).mapId()); + endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(mapPartitionId), + UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]); + offset += sizes[i]; + } + + // Process blocks when all fetched. + // Flush guarantees that callback would invoke when all fetch requests will completed. + endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes)); + } +} diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_1/UcxShuffleClient.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_1/UcxShuffleClient.java new file mode 100755 index 00000000..c83cc3e1 --- /dev/null +++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_1/UcxShuffleClient.java @@ -0,0 +1,136 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ +package org.apache.spark.shuffle.ucx.reducer.compat.spark_3_1; + +import org.apache.spark.SparkEnv; +import org.apache.spark.executor.TempShuffleReadMetrics; +import org.apache.spark.network.shuffle.BlockFetchingListener; +import org.apache.spark.network.shuffle.BlockStoreClient; +import org.apache.spark.network.shuffle.DownloadFileManager; +import org.apache.spark.shuffle.DriverMetadata; +import org.apache.spark.shuffle.UcxShuffleManager; +import org.apache.spark.shuffle.UcxWorkerWrapper; +import org.apache.spark.shuffle.ucx.UnsafeUtils; +import org.apache.spark.shuffle.ucx.memory.RegisteredMemory; +import org.apache.spark.storage.*; +import org.openucx.jucx.UcxUtils; +import org.openucx.jucx.ucp.UcpEndpoint; +import org.openucx.jucx.ucp.UcpRemoteKey; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Option; + + +import java.util.HashMap; +import java.util.Map; + +public class UcxShuffleClient extends BlockStoreClient { + private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class); + private final UcxWorkerWrapper workerWrapper; + private final Map mapId2PartitionId; + private final TempShuffleReadMetrics shuffleReadMetrics; + private final int shuffleId; + final HashMap offsetRkeysCache = new HashMap<>(); + final HashMap dataRkeysCache = new HashMap<>(); + + + public UcxShuffleClient(int shuffleId, UcxWorkerWrapper workerWrapper, + Map mapId2PartitionId, TempShuffleReadMetrics shuffleReadMetrics) { + this.workerWrapper = workerWrapper; + this.shuffleId = shuffleId; + this.mapId2PartitionId = mapId2PartitionId; + this.shuffleReadMetrics = shuffleReadMetrics; + } + + /** + * Submits n non blocking fetch offsets to get needed offsets for n blocks. + */ + private void submitFetchOffsets(UcpEndpoint endpoint, BlockId[] blockIds, + RegisteredMemory offsetMemory, + long[] dataAddresses) { + DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(shuffleId); + long offset = 0; + int startReduceId; + long size; + + for (int i = 0; i < blockIds.length; i++) { + BlockId blockId = blockIds[i]; + int mapIdpartition; + + if (blockId instanceof ShuffleBlockId) { + ShuffleBlockId shuffleBlockId = (ShuffleBlockId) blockId; + mapIdpartition = mapId2PartitionId.get(shuffleBlockId.mapId()); + size = 2L * UnsafeUtils.LONG_SIZE; + startReduceId = shuffleBlockId.reduceId(); + } else { + ShuffleBlockBatchId shuffleBlockBatchId = (ShuffleBlockBatchId) blockId; + mapIdpartition = mapId2PartitionId.get(shuffleBlockBatchId.mapId()); + size = (shuffleBlockBatchId.endReduceId() - shuffleBlockBatchId.startReduceId()) + * 2L * UnsafeUtils.LONG_SIZE; + startReduceId = shuffleBlockBatchId.startReduceId(); + } + + long offsetAddress = driverMetadata.offsetAddress(mapIdpartition); + dataAddresses[i] = driverMetadata.dataAddress(mapIdpartition); + + offsetRkeysCache.computeIfAbsent(mapIdpartition, mapId -> + endpoint.unpackRemoteKey(driverMetadata.offsetRkey(mapIdpartition))); + + dataRkeysCache.computeIfAbsent(mapIdpartition, mapId -> + endpoint.unpackRemoteKey(driverMetadata.dataRkey(mapIdpartition))); + + endpoint.getNonBlockingImplicit( + offsetAddress + startReduceId * UnsafeUtils.LONG_SIZE, + offsetRkeysCache.get(mapIdpartition), + UcxUtils.getAddress(offsetMemory.getBuffer()) + offset, + size); + + offset += size; + } + } + + @Override + public void fetchBlocks(String host, int port, String execId, String[] blockIds, BlockFetchingListener listener, + DownloadFileManager downloadFileManager) { + long startTime = System.currentTimeMillis(); + BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty()); + UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId); + long[] dataAddresses = new long[blockIds.length]; + int totalBlocks = 0; + + BlockId[] blocks = new BlockId[blockIds.length]; + + for (int i = 0; i < blockIds.length; i++) { + blocks[i] = BlockId.apply(blockIds[i]); + if (blocks[i] instanceof ShuffleBlockId) { + totalBlocks += 1; + } else { + ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId)blocks[i]; + totalBlocks += (blockBatchId.endReduceId() - blockBatchId.startReduceId()); + } + } + + RegisteredMemory offsetMemory = ((UcxShuffleManager)SparkEnv.get().shuffleManager()) + .ucxNode().getMemoryPool().get(totalBlocks * 2 * UnsafeUtils.LONG_SIZE); + // Submits N implicit get requests without callback + submitFetchOffsets(endpoint, blocks, offsetMemory, dataAddresses); + + // flush guarantees that all that requests completes when callback is called. + // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush. + workerWrapper.worker().flushNonBlocking( + new OnOffsetsFetchCallback(blocks, endpoint, listener, offsetMemory, + dataAddresses, dataRkeysCache, mapId2PartitionId)); + + shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime); + } + + @Override + public void close() { + offsetRkeysCache.values().forEach(UcpRemoteKey::close); + dataRkeysCache.values().forEach(UcpRemoteKey::close); + logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime()); + } + +} diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxLocalDiskShuffleDataIO.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxLocalDiskShuffleDataIO.scala new file mode 100755 index 00000000..47c6e448 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxLocalDiskShuffleDataIO.scala @@ -0,0 +1,20 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.compat.spark_3_1 + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO + +/** + * Ucx local disk IO plugin to handle logic of writing to local disk and shuffle memory registration. + */ +case class UcxLocalDiskShuffleDataIO(sparkConf: SparkConf) extends LocalDiskShuffleDataIO(sparkConf) with Logging { + + override def executor(): ShuffleExecutorComponents = { + new UcxLocalDiskShuffleExecutorComponents(sparkConf) + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxLocalDiskShuffleExecutorComponents.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxLocalDiskShuffleExecutorComponents.scala new file mode 100755 index 00000000..088377de --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxLocalDiskShuffleExecutorComponents.scala @@ -0,0 +1,47 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.compat.spark_3_1 + +import java.util +import java.util.Optional + +import org.apache.spark.internal.Logging +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.shuffle.sort.io.{LocalDiskShuffleExecutorComponents, LocalDiskShuffleMapOutputWriter, LocalDiskSingleSpillMapOutputWriter} +import org.apache.spark.shuffle.UcxShuffleManager +import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter} + +/** + * Entry point to UCX executor. + */ +class UcxLocalDiskShuffleExecutorComponents(sparkConf: SparkConf) + extends LocalDiskShuffleExecutorComponents(sparkConf) with Logging{ + + private var blockResolver: UcxShuffleBlockResolver = _ + + override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = { + val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] + ucxShuffleManager.startUcxNodeIfMissing() + blockResolver = ucxShuffleManager.shuffleBlockResolver + } + + override def createMapOutputWriter(shuffleId: Int, mapTaskId: Long, numPartitions: Int): ShuffleMapOutputWriter = { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers.") + } + new LocalDiskShuffleMapOutputWriter( + shuffleId, mapTaskId, numPartitions, blockResolver, sparkConf) + } + + override def createSingleFileMapOutputWriter(shuffleId: Int, mapId: Long): Optional[SingleSpillShuffleMapOutputWriter] = { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers.") + } + Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver)) + } + +} diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleBlockResolver.scala new file mode 100755 index 00000000..1fa8c912 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleBlockResolver.scala @@ -0,0 +1,52 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.compat.spark_3_1 + +import java.io.{File, RandomAccessFile} + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.network.shuffle.ExecutorDiskUtils +import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID +import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager} +import org.apache.spark.storage.ShuffleIndexBlockId + +/** + * Mapper entry point for UcxShuffle plugin. Performs memory registration + * of data and index files and publish addresses to driver metadata buffer. + */ +class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) + extends CommonUcxShuffleBlockResolver(ucxShuffleManager) { + + override def getIndexFile( + shuffleId: Int, + mapId: Long, + dirs: Option[Array[String]] = None): File = { + val blockId = ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID) + val blockManager = SparkEnv.get.blockManager + dirs + .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name)) + .getOrElse(blockManager.diskBlockManager.getFile(blockId)) + } + + override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long, + lengths: Array[Long], dataTmp: File): Unit = { + super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) + // In Spark-3.0 MapId is long and unique among all jobs in spark. We need to use partitionId as offset + // in metadata buffer + val partitionId = TaskContext.getPartitionId() + val dataFile = getDataFile(shuffleId, mapId) + val dataBackFile = new RandomAccessFile(dataFile, "rw") + + if (dataBackFile.length() == 0) { + dataBackFile.close() + return + } + + val indexFile = getIndexFile(shuffleId, mapId) + val indexBackFile = new RandomAccessFile(indexFile, "rw") + + writeIndexFileAndCommitCommon(shuffleId, partitionId, lengths, dataTmp, indexBackFile, dataBackFile) + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleManager.scala new file mode 100755 index 00000000..64a55726 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleManager.scala @@ -0,0 +1,75 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle + +import scala.collection.JavaConverters._ + +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.compat.spark_3_1.{UcxShuffleBlockResolver, UcxShuffleReader} +import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, SortShuffleWriter, UnsafeShuffleWriter} +import org.apache.spark.util.ShutdownHookManager +import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} + +/** + * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin + * and injects needed logic in override methods. + */ +class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) { + ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop) + private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + + override val shuffleBlockResolver = new UcxShuffleBlockResolver(this) + + override def registerShuffle[K, V, C](shuffleId: ShuffleId, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + assume(isDriver) + val numMaps = dependency.partitioner.numPartitions + val baseHandle = super.registerShuffle(shuffleId, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]] + registerShuffleCommon(baseHandle, shuffleId, numMaps) + } + + override def getWriter[K, V](handle: ShuffleHandle, mapId: Long, context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K, V, _]]) + val env = SparkEnv.get + handle.asInstanceOf[UcxShuffleHandle[K, V, _]].baseHandle match { + case unsafeShuffleHandle: SerializedShuffleHandle[K@unchecked, V@unchecked] => + new UnsafeShuffleWriter( + env.blockManager, + context.taskMemoryManager(), + unsafeShuffleHandle, + mapId, + context, + env.conf, + metrics, + shuffleExecutorComponents) + case other: BaseShuffleHandle[K@unchecked, V@unchecked, _] => + new SortShuffleWriter( + shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents) + } + } + + override def getReader[K, C](handle: ShuffleHandle, startMapIndex: Int, endMapIndex: Int, + startPartition: MapId, endPartition: MapId, context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + + startUcxNodeIfMissing() + shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K, _, C]]) + new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startMapIndex, endMapIndex, startPartition, endPartition, + context, readMetrics = metrics, shouldBatchFetch = true) + } + + + private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { + val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() + val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX) + .toMap + executorComponents.initializeExecutor( + conf.getAppId, + SparkEnv.get.executorId, + extraConfigs.asJava) + executorComponents + } + +} diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleReader.scala new file mode 100755 index 00000000..6ca31966 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_1/UcxShuffleReader.scala @@ -0,0 +1,191 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.compat.spark_3_1 + +import java.io.InputStream +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.{Logging, config} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle.ucx.reducer.compat.spark_3_1.UcxShuffleClient +import org.apache.spark.shuffle.{ShuffleReadMetricsReporter, ShuffleReader, UcxShuffleHandle, UcxShuffleManager} +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockBatchId, ShuffleBlockFetcherIterator, ShuffleBlockId} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.{InterruptibleIterator, SparkEnv, SparkException, TaskContext} + + +/** + * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient, + * and lazy progress only when result queue is empty. + */ +class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C], + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + readMetrics: ShuffleReadMetricsReporter, + shouldBatchFetch: Boolean = false) extends ShuffleReader[K, C] with Logging { + + private val dep = handle.baseHandle.dependency + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val (blocksByAddressIterator1, blocksByAddressIterator2) = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition).duplicate + val mapIdToBlockIndex = blocksByAddressIterator2.flatMap{ + case (_, blocks) => blocks.map { + case (blockId, _, mapIdx) => blockId match { + case x: ShuffleBlockId => (x.mapId.asInstanceOf[java.lang.Long], mapIdx.asInstanceOf[java.lang.Integer]) + case x: ShuffleBlockBatchId => (x.mapId.asInstanceOf[java.lang.Long], mapIdx.asInstanceOf[java.lang.Integer]) + case _ => throw new SparkException("Unknown block") + } + } + }.toMap + + val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] + .ucxNode.getThreadLocalWorker + val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() + val shuffleClient = new UcxShuffleClient(handle.shuffleId, workerWrapper, mapIdToBlockIndex.asJava, shuffleMetrics) + val shuffleIterator = new ShuffleBlockFetcherIterator( + context, + shuffleClient, + blockManager, + blocksByAddressIterator1, + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, + SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + readMetrics, + fetchContinuousBlocksInBatch) + + val wrappedStreams = shuffleIterator.toCompletionIterator + + // Ucx shuffle logic + // Java reflection to get access to private results queue + val queueField = shuffleIterator.getClass.getDeclaredField( + "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") + queueField.setAccessible(true) + val resultQueue = queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]] + + // Do progress if queue is empty before calling next on ShuffleIterator + val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { + override def next(): (BlockId, InputStream) = { + val startTime = System.currentTimeMillis() + workerWrapper.fillQueueWithBlocks(resultQueue) + readMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) + wrappedStreams.next() + } + + override def hasNext: Boolean = { + val result = wrappedStreams.hasNext + if (!result) { + shuffleClient.close() + } + result + } + } + // End of ucx shuffle logic + + val serializerInstance = dep.serializer.newInstance() + + // Create a key/value iterator for each stream + val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map { record => + readMetrics.incRecordsRead(1) + record + }, + context.taskMetrics().mergeShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + if (dep.mapSideCombine) { + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + } else { + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) + } + } else { + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] + } + + // Sort the output if there is a sort ordering defined. + val resultIter = dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. + val sorter = + new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) + sorter.insertAll(aggregatedIter) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener[Unit](_ => { + sorter.stop() + }) + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) + case None => + aggregatedIter + } + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } + } + + private def fetchContinuousBlocksInBatch: Boolean = { + val conf = SparkEnv.get.conf + val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects + val compressed = conf.get(config.SHUFFLE_COMPRESS) + val codecConcatenation = if (compressed) { + CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf)) + } else { + true + } + val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL) + + val doBatchFetch = shouldBatchFetch && serializerRelocatable && + (!compressed || codecConcatenation) && !useOldFetchProtocol + if (shouldBatchFetch && !doBatchFetch) { + logWarning("The feature tag of continuous shuffle block fetching is set to true, but " + + "we can not enable the feature because other conditions are not satisfied. " + + s"Shuffle compress: $compressed, serializer ${dep.serializer.getClass.getName} " + + s"relocatable: $serializerRelocatable, " + + s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " + + s"$useOldFetchProtocol.") + } + doBatchFetch + } + +} From f0eaee37f36844154038312ba7392baecf1e6812 Mon Sep 17 00:00:00 2001 From: Damon Brown Date: Tue, 26 Oct 2021 17:48:48 +0000 Subject: [PATCH 2/2] feat: Spark 3.1 github build ci --- .github/workflows/sparkucx-ci.yml | 2 +- .github/workflows/sparkucx-release.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/sparkucx-ci.yml b/.github/workflows/sparkucx-ci.yml index 3f627287..4bf27cae 100755 --- a/.github/workflows/sparkucx-ci.yml +++ b/.github/workflows/sparkucx-ci.yml @@ -9,7 +9,7 @@ jobs: build-sparkucx: strategy: matrix: - spark_version: ["2.1", "2.4", "3.0"] + spark_version: ["2.1", "2.4", "3.0", "3.1"] runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 diff --git a/.github/workflows/sparkucx-release.yml b/.github/workflows/sparkucx-release.yml index cfa93c58..842a3832 100644 --- a/.github/workflows/sparkucx-release.yml +++ b/.github/workflows/sparkucx-release.yml @@ -13,7 +13,7 @@ jobs: release: strategy: matrix: - spark_version: ["2.1", "2.4", "3.0"] + spark_version: ["2.1", "2.4", "3.0", "3.1"] runs-on: ubuntu-latest steps: - name: Checkout code