diff --git a/pom.xml b/pom.xml
index 57315f45..8b57446a 100755
--- a/pom.xml
+++ b/pom.xml
@@ -45,6 +45,8 @@ See file LICENSE for terms.
**/spark_2_1/**
**/spark_3_0/**
+ org/apache/spark/network/**
+ org/apache/spark/shuffle/ucx/external/server/**
@@ -71,6 +73,8 @@ See file LICENSE for terms.
**/spark_2_1/**
**/spark_2_4/**
+ org/apache/spark/network/**
+ org/apache/spark/shuffle/ucx/external/server/**
@@ -83,6 +87,93 @@ See file LICENSE for terms.
**/spark_2_1/**, **/spark_2_4/**
+
+ sparkess-2.4
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ org/apache/spark/network/**
+ org/apache/spark/shuffle/ucx/external/**
+ org/apache/spark/shuffle/ucx/memory/UcxLimitedMemPool.scala
+ org/apache/spark/shuffle/ucx/ShuffleTransport.scala
+ org/apache/spark/shuffle/utils/**
+
+
+ **/spark_3_0/**
+ org/apache/spark/shuffle/ucx/external/client/**
+
+
+
+
+
+
+ 2.4.0
+ **/spark_3_0/**, **/spark_2_1/**
+ 2.11.12
+ 2.11
+ 2.6.7
+
+
+
+ org.scala-lang
+ scala-library
+ ${scala.version}
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+ ${fasterxml.jackson.version}
+
+
+
+
+ sparkess-3.0
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ org/apache/spark/network/**
+ org/apache/spark/shuffle/ucx/external/**
+ org/apache/spark/shuffle/ucx/memory/UcxLimitedMemPool.scala
+ org/apache/spark/shuffle/ucx/ShuffleTransport.scala
+ org/apache/spark/shuffle/utils/**
+
+
+ **/spark_2_4/**
+ org/apache/spark/shuffle/ucx/external/client/**
+
+
+
+
+
+
+ 3.0.1
+ 2.12.10
+ 2.12
+ **/spark_2_1/**, **/spark_2_4/**
+ 2.10.0
+
+
+
+ org.scala-lang
+ scala-library
+ ${scala.version}
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+ ${fasterxml.jackson.version}
+
+
+
+
@@ -93,10 +184,16 @@ See file LICENSE for terms.
${spark.version}
provided
+
+ org.apache.spark
+ spark-network-yarn_${scala.compat.version}
+ ${spark.version}
+ provided
+
org.openucx
jucx
- 1.13.1
+ 1.16.0
diff --git a/src/main/scala/org/apache/spark/network/shuffle/ExternalUcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/network/shuffle/ExternalUcxShuffleBlockResolver.scala
new file mode 100644
index 00000000..d8d0cb9f
--- /dev/null
+++ b/src/main/scala/org/apache/spark/network/shuffle/ExternalUcxShuffleBlockResolver.scala
@@ -0,0 +1,91 @@
+package org.apache.spark.network.shuffle
+
+import java.io.File
+import java.nio.charset.StandardCharsets
+import java.lang.reflect.{Method, Field}
+
+import scala.collection.mutable
+
+import com.fasterxml.jackson.databind.ObjectMapper
+
+import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem;
+import org.apache.hadoop.metrics2.impl.MetricsSystemImpl;
+
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId
+
+import org.apache.spark.shuffle.utils.UcxLogging
+import org.apache.spark.shuffle.ucx.ExternalUcxServerTransport
+
+class ExternalUcxShuffleBlockResolver(conf: TransportConf, registeredExecutorFile: File)
+ extends ExternalShuffleBlockResolver(conf, registeredExecutorFile) with UcxLogging {
+ private[spark] final val APP_KEY_PREFIX = "AppExecShuffleInfo";
+ private[spark] final val ucxMapper = new ObjectMapper
+ private[spark] var dbAppExecKeyMethod: Method = _
+ private[spark] val knownManagers = mutable.Set(
+ "org.apache.spark.shuffle.sort.SortShuffleManager",
+ "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager",
+ "org.apache.spark.shuffle.ExternalUcxShuffleManager")
+ private[spark] var ucxTransport: ExternalUcxServerTransport = _
+
+ // init()
+
+ private[spark] def dbAppExecKey(appExecId: AppExecId): Array[Byte] = {
+ // we stick a common prefix on all the keys so we can find them in the DB
+ val appExecJson = ucxMapper.writeValueAsString(appExecId);
+ val key = (APP_KEY_PREFIX + ";" + appExecJson);
+ key.getBytes(StandardCharsets.UTF_8);
+ }
+
+ // def init(): Unit = {
+ // val clazz = Class.forName("org.apache.spark.network.shuffle.ExternalShuffleBlockResolver")
+ // try {
+ // dbAppExecKeyMethod = clazz.getDeclaredMethod("dbAppExecKey", classOf[AppExecId])
+ // dbAppExecKeyMethod.setAccessible(true)
+ // } catch {
+ // case e: Exception => {
+ // logError(s"Get dbAppExecKey from ExternalUcxShuffleBlockResolver failed: $e")
+ // }
+ // }
+ // }
+
+ // def dbAppExecKey(fullId: AppExecId): Array[Byte] = {
+ // dbAppExecKeyMethod.invoke(this, fullId).asInstanceOf[Array[Byte]]
+ // }
+
+ def setTransport(transport: ExternalUcxServerTransport): Unit = {
+ ucxTransport = transport
+ }
+ /** Registers a new Executor with all the configuration we need to find its shuffle files. */
+ override def registerExecutor(
+ appId: String,
+ execId: String,
+ executorInfo: ExecutorShuffleInfo): Unit = {
+ val fullId = new AppExecId(appId, execId)
+ logInfo(s"Registered executor ${fullId} with ${executorInfo}")
+ if (!knownManagers.contains(executorInfo.shuffleManager)) {
+ throw new UnsupportedOperationException(
+ "Unsupported shuffle manager of executor: " + executorInfo)
+ }
+ try {
+ if (db != null) {
+ val key = dbAppExecKey(fullId)
+ val value = ucxMapper.writeValueAsString(executorInfo).getBytes(StandardCharsets.UTF_8)
+ db.put(key, value)
+ }
+ executors.put(fullId, executorInfo)
+ } catch {
+ case e: Exception => logError("Error saving registered executors", e)
+ }
+ }
+
+ override def applicationRemoved(appId: String, cleanupLocalDirs: Boolean): Unit = {
+ super.applicationRemoved(appId, cleanupLocalDirs)
+ ucxTransport.applicationRemoved(appId)
+ }
+ override def executorRemoved(executorId: String, appId: String): Unit = {
+ super.executorRemoved(executorId, appId)
+ ucxTransport.executorRemoved(executorId, appId)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/network/shuffle/spark_2_4/ExternalUcxShuffleBlockHandler.scala b/src/main/scala/org/apache/spark/network/shuffle/spark_2_4/ExternalUcxShuffleBlockHandler.scala
new file mode 100644
index 00000000..8a092638
--- /dev/null
+++ b/src/main/scala/org/apache/spark/network/shuffle/spark_2_4/ExternalUcxShuffleBlockHandler.scala
@@ -0,0 +1,20 @@
+package org.apache.spark.network.shuffle
+
+import java.io.File
+
+import org.apache.spark.network.server.OneForOneStreamManager
+import org.apache.spark.network.util.TransportConf
+
+import org.apache.spark.shuffle.utils.UcxLogging
+import org.apache.spark.shuffle.ucx.ExternalUcxServerTransport
+
+class ExternalUcxShuffleBlockHandler(conf: TransportConf, registeredExecutorFile: File)
+ extends ExternalShuffleBlockHandler(new OneForOneStreamManager(),
+ new ExternalUcxShuffleBlockResolver(conf, registeredExecutorFile)) with UcxLogging {
+ def ucxBlockManager(): ExternalUcxShuffleBlockResolver = {
+ blockManager.asInstanceOf[ExternalUcxShuffleBlockResolver]
+ }
+ def setTransport(transport: ExternalUcxServerTransport): Unit = {
+ ucxBlockManager.setTransport(transport)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/network/shuffle/spark_3_0/ExternalUcxShuffleBlockHandler.scala b/src/main/scala/org/apache/spark/network/shuffle/spark_3_0/ExternalUcxShuffleBlockHandler.scala
new file mode 100644
index 00000000..926050d2
--- /dev/null
+++ b/src/main/scala/org/apache/spark/network/shuffle/spark_3_0/ExternalUcxShuffleBlockHandler.scala
@@ -0,0 +1,20 @@
+package org.apache.spark.network.shuffle
+
+import java.io.File
+
+import org.apache.spark.network.server.OneForOneStreamManager
+import org.apache.spark.network.util.TransportConf
+
+import org.apache.spark.shuffle.utils.UcxLogging
+import org.apache.spark.shuffle.ucx.ExternalUcxServerTransport
+
+class ExternalUcxShuffleBlockHandler(conf: TransportConf, registeredExecutorFile: File)
+ extends ExternalBlockHandler(new OneForOneStreamManager(),
+ new ExternalUcxShuffleBlockResolver(conf, registeredExecutorFile)) with UcxLogging {
+ def ucxBlockManager(): ExternalUcxShuffleBlockResolver = {
+ blockManager.asInstanceOf[ExternalUcxShuffleBlockResolver]
+ }
+ def setTransport(transport: ExternalUcxServerTransport): Unit = {
+ ucxBlockManager.setTransport(transport)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/network/yarn/spark_2_4/UcxYarnShuffleService.scala b/src/main/scala/org/apache/spark/network/yarn/spark_2_4/UcxYarnShuffleService.scala
new file mode 100644
index 00000000..74d4fb4c
--- /dev/null
+++ b/src/main/scala/org/apache/spark/network/yarn/spark_2_4/UcxYarnShuffleService.scala
@@ -0,0 +1,328 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License") you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.yarn
+
+import java.io.File
+import java.io.IOException
+import java.nio.charset.StandardCharsets
+import java.nio.ByteBuffer
+
+import com.fasterxml.jackson.databind.ObjectMapper
+import com.google.common.collect.Lists
+import com.google.common.base.Preconditions
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.permission.FsPermission
+import org.apache.hadoop.yarn.server.api._
+import org.apache.spark.network.util.LevelDBProvider
+import org.iq80.leveldb.DB
+
+import org.apache.spark.network.TransportContext
+import org.apache.spark.network.crypto.AuthServerBootstrap
+import org.apache.spark.network.sasl.ShuffleSecretManager
+import org.apache.spark.network.server.TransportServer
+import org.apache.spark.network.server.TransportServerBootstrap
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.network.yarn.util.HadoopConfigProvider
+
+import org.apache.spark.shuffle.utils.UcxLogging
+import org.apache.spark.network.yarn.YarnShuffleService.AppId
+import org.apache.spark.network.shuffle.ExternalUcxShuffleBlockHandler
+import org.apache.spark.shuffle.ucx.ExternalUcxServerConf
+import org.apache.spark.shuffle.ucx.ExternalUcxServerTransport
+
+class UcxYarnShuffleService extends AuxiliaryService("sparkucx_shuffle") with UcxLogging {
+ private[this] var ucxTransport: ExternalUcxServerTransport = _
+ private[this] var secretManager: ShuffleSecretManager = _
+ private[this] var shuffleServer: TransportServer = _
+ private[this] var _conf: Configuration = _
+ private[this] var _recoveryPath: Path = _
+ private[this] var blockHandler: ExternalUcxShuffleBlockHandler = _
+ private[this] var registeredExecutorFile: File = _
+ private[this] var secretsFile: File = _
+ private[this] var db: DB = _
+
+ UcxYarnShuffleService.instance = this
+ /**
+ * Return whether authentication is enabled as specified by the configuration.
+ * If so, fetch requests will fail unless the appropriate authentication secret
+ * for the application is provided.
+ */
+ def isAuthenticationEnabled(): Boolean = secretManager != null
+ /**
+ * Start the shuffle server with the given configuration.
+ */
+ override protected def serviceInit(conf: Configuration) = {
+ _conf = conf
+
+ val stopOnFailure = conf.getBoolean(
+ UcxYarnShuffleService.STOP_ON_FAILURE_KEY,
+ UcxYarnShuffleService.DEFAULT_STOP_ON_FAILURE)
+
+ try {
+ // In case this NM was killed while there were running spark applications, we need to restore
+ // lost state for the existing executors. We look for an existing file in the NM's local dirs.
+ // If we don't find one, then we choose a file to use to save the state next time. Even if
+ // an application was stopped while the NM was down, we expect yarn to call stopApplication()
+ // when it comes back
+ if (_recoveryPath != null) {
+ registeredExecutorFile = initRecoveryDb(UcxYarnShuffleService.RECOVERY_FILE_NAME)
+ }
+
+ val transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf))
+ blockHandler = new ExternalUcxShuffleBlockHandler(transportConf, registeredExecutorFile)
+
+ // If authentication is enabled, set up the shuffle server to use a
+ // special RPC handler that filters out unauthenticated fetch requests
+ val bootstraps = Lists.newArrayList[TransportServerBootstrap]()
+ val authEnabled = conf.getBoolean(
+ UcxYarnShuffleService.SPARK_AUTHENTICATE_KEY,
+ UcxYarnShuffleService.DEFAULT_SPARK_AUTHENTICATE)
+ if (authEnabled) {
+ secretManager = new ShuffleSecretManager()
+ if (_recoveryPath != null) {
+ loadSecretsFromDb()
+ }
+ bootstraps.add(new AuthServerBootstrap(transportConf, secretManager))
+ }
+
+ // User might not like to replace tcp service, so use another port to transfer executors info.
+ val portConf = conf.getInt(
+ ExternalUcxServerConf.SPARK_UCX_SHUFFLE_SERVICE_TCP_PORT_KEY,
+ conf.getInt(
+ UcxYarnShuffleService.SPARK_SHUFFLE_SERVICE_PORT_KEY,
+ UcxYarnShuffleService.DEFAULT_SPARK_SHUFFLE_SERVICE_PORT))
+ val transportContext = new TransportContext(transportConf, blockHandler)
+ shuffleServer = transportContext.createServer(portConf, bootstraps)
+ // the port should normally be fixed, but for tests its useful to find an open port
+ val port = shuffleServer.getPort()
+
+ val authEnabledString = if (authEnabled) "enabled" else "not enabled"
+ logInfo(s"Started YARN shuffle service for Spark on port ${port}. " +
+ s"Authentication is ${authEnabledString}. Registered executor file is ${registeredExecutorFile}")
+
+ // Ucx Transport
+ logInfo("Start launching ExternalUcxServerTransport")
+ val ucxConf = new ExternalUcxServerConf(conf)
+ ucxTransport = new ExternalUcxServerTransport(ucxConf, blockHandler.ucxBlockManager)
+ ucxTransport.init()
+ blockHandler.setTransport(ucxTransport)
+ } catch {
+ case e: Exception => if (stopOnFailure) {
+ throw e
+ } else {
+ // logError(s"Start UcxYarnShuffleService failed: $e")
+ noteFailure(e)
+ }
+ }
+ }
+
+ private def loadSecretsFromDb(): Unit = {
+ secretsFile = initRecoveryDb(UcxYarnShuffleService.SECRETS_RECOVERY_FILE_NAME)
+
+ // Make sure this is protected in case its not in the NM recovery dir
+ val fs = FileSystem.getLocal(_conf)
+ fs.mkdirs(new Path(secretsFile.getPath()), new FsPermission(448.toShort)) // 0700
+
+ db = LevelDBProvider.initLevelDB(secretsFile, UcxYarnShuffleService.CURRENT_VERSION, UcxYarnShuffleService.mapper)
+ logInfo("Recovery location is: " + secretsFile.getPath())
+ if (db != null) {
+ logInfo("Going to reload spark shuffle data")
+ val itr = db.iterator()
+ itr.seek(UcxYarnShuffleService.APP_CREDS_KEY_PREFIX.getBytes(StandardCharsets.UTF_8))
+ while (itr.hasNext()) {
+ val e = itr.next()
+ val key = new String(e.getKey(), StandardCharsets.UTF_8)
+ if (!key.startsWith(UcxYarnShuffleService.APP_CREDS_KEY_PREFIX)) {
+ return
+ }
+ val id = UcxYarnShuffleService.parseDbAppKey(key)
+ val secret = UcxYarnShuffleService.mapper.readValue(e.getValue(), classOf[ByteBuffer])
+ logInfo("Reloading tokens for app: " + id)
+ secretManager.registerApp(id, secret)
+ }
+ }
+ }
+
+ override def initializeApplication(context: ApplicationInitializationContext): Unit = {
+ val appId = context.getApplicationId().toString()
+ try {
+ val shuffleSecret = context.getApplicationDataForService()
+ if (isAuthenticationEnabled()) {
+ val fullId = new AppId(appId)
+ if (db != null) {
+ val key = UcxYarnShuffleService.dbAppKey(fullId)
+ val value = UcxYarnShuffleService.mapper.writeValueAsString(shuffleSecret).getBytes(StandardCharsets.UTF_8)
+ db.put(key, value)
+ }
+ secretManager.registerApp(appId, shuffleSecret)
+ }
+ } catch {
+ case e: Exception => logError(s"Exception when initializing application ${appId}", e)
+ }
+ }
+
+ override def stopApplication(context: ApplicationTerminationContext): Unit = {
+ val appId = context.getApplicationId().toString()
+ try {
+ if (isAuthenticationEnabled()) {
+ val fullId = new AppId(appId)
+ if (db != null) {
+ try {
+ db.delete(UcxYarnShuffleService.dbAppKey(fullId))
+ } catch {
+ case e: IOException => logError(s"Error deleting ${appId} from executor state db", e)
+ }
+ }
+ secretManager.unregisterApp(appId)
+ }
+ blockHandler.applicationRemoved(appId, false /* clean up local dirs */)
+ } catch {
+ case e: Exception => logError(s"Exception when stopping application ${appId}", e)
+ }
+ }
+
+ override def initializeContainer(context: ContainerInitializationContext): Unit = {}
+
+ override def stopContainer(context: ContainerTerminationContext): Unit = {}
+
+ // Not currently used
+ override def getMetaData(): ByteBuffer = {
+ return ByteBuffer.allocate(0)
+ }
+
+ /**
+ * Set the recovery path for shuffle service recovery when NM is restarted. This will be call
+ * by NM if NM recovery is enabled.
+ */
+ override def setRecoveryPath(recoveryPath: Path): Unit = {
+ _recoveryPath = recoveryPath
+ }
+
+ /**
+ * Get the path specific to this auxiliary service to use for recovery.
+ */
+ protected def getRecoveryPath(fileName: String): Path = {
+ return _recoveryPath
+ }
+
+ /**
+ * Figure out the recovery path and handle moving the DB if YARN NM recovery gets enabled
+ * and DB exists in the local dir of NM by old version of shuffle service.
+ */
+ def initRecoveryDb(dbName: String): File = {
+ if (_recoveryPath == null) {
+ throw new NullPointerException("recovery path should not be null if NM recovery is enabled")
+ }
+
+ val recoveryFile = new File(_recoveryPath.toUri().getPath(), dbName)
+ if (recoveryFile.exists()) {
+ return recoveryFile
+ }
+
+ // db doesn't exist in recovery path go check local dirs for it
+ val localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs")
+ for (dir <- localDirs) {
+ val f = new File(new Path(dir).toUri().getPath(), dbName)
+ if (f.exists()) {
+ // If the recovery path is set then either NM recovery is enabled or another recovery
+ // DB has been initialized. If NM recovery is enabled and had set the recovery path
+ // make sure to move all DBs to the recovery path from the old NM local dirs.
+ // If another DB was initialized first just make sure all the DBs are in the same
+ // location.
+ val newLoc = new Path(_recoveryPath, dbName)
+ val copyFrom = new Path(f.toURI())
+ if (!newLoc.equals(copyFrom)) {
+ logInfo("Moving " + copyFrom + " to: " + newLoc)
+ try {
+ // The move here needs to handle moving non-empty directories across NFS mounts
+ val fs = FileSystem.getLocal(_conf)
+ fs.rename(copyFrom, newLoc)
+ } catch {
+ // Fail to move recovery file to new path, just continue on with new DB location
+ case e: Exception => logError(
+ s"Failed to move recovery file ${dbName} to the path ${_recoveryPath.toString()}",
+ e)
+ }
+ }
+ return new File(newLoc.toUri().getPath())
+ }
+ }
+
+ return new File(_recoveryPath.toUri().getPath(), dbName)
+ }
+
+ override protected def serviceStop(): Unit = {
+ try {
+ if (shuffleServer != null) {
+ shuffleServer.close()
+ }
+ if (blockHandler != null) {
+ blockHandler.close()
+ }
+ if (ucxTransport != null) {
+ ucxTransport.close()
+ }
+ if (db != null) {
+ db.close()
+ }
+ } catch {
+ case e: Exception => logError("Exception when stopping service", e)
+ }
+ }
+}
+
+object UcxYarnShuffleService {
+ // Port on which the shuffle server listens for fetch requests
+ val SPARK_SHUFFLE_SERVICE_PORT_KEY = "spark.shuffle.service.port"
+ val DEFAULT_SPARK_SHUFFLE_SERVICE_PORT = 7337
+
+ // Whether the shuffle server should authenticate fetch requests
+ val SPARK_AUTHENTICATE_KEY = "spark.authenticate"
+ val DEFAULT_SPARK_AUTHENTICATE = false
+
+ val RECOVERY_FILE_NAME = "registeredExecutors.ldb"
+ val SECRETS_RECOVERY_FILE_NAME = "sparkShuffleRecovery.ldb"
+
+ // Whether failure during service initialization should stop the NM.
+ val STOP_ON_FAILURE_KEY = "spark.yarn.shuffle.stopOnFailure"
+ val DEFAULT_STOP_ON_FAILURE = false
+
+ val mapper = new ObjectMapper()
+ val APP_CREDS_KEY_PREFIX = "AppCreds"
+ val CURRENT_VERSION = new LevelDBProvider.StoreVersion(1, 0)
+
+ var instance: UcxYarnShuffleService = _
+
+ private def parseDbAppKey(s: String): String = {
+ if (!s.startsWith(APP_CREDS_KEY_PREFIX)) {
+ throw new IllegalArgumentException("expected a string starting with " + APP_CREDS_KEY_PREFIX)
+ }
+ val json = s.substring(APP_CREDS_KEY_PREFIX.length() + 1)
+ val parsed = mapper.readValue(json, classOf[AppId])
+ return parsed.appId
+ }
+
+ private def dbAppKey(appExecId: AppId): Array[Byte] = {
+ // we stick a common prefix on all the keys so we can find them in the DB
+ val appExecJson = mapper.writeValueAsString(appExecId)
+ val key = (APP_CREDS_KEY_PREFIX + ";" + appExecJson)
+ return key.getBytes(StandardCharsets.UTF_8)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/network/yarn/spark_3_0/UcxYarnShuffleService.scala b/src/main/scala/org/apache/spark/network/yarn/spark_3_0/UcxYarnShuffleService.scala
new file mode 100644
index 00000000..35e977b8
--- /dev/null
+++ b/src/main/scala/org/apache/spark/network/yarn/spark_3_0/UcxYarnShuffleService.scala
@@ -0,0 +1,345 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License") you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.yarn
+
+import java.io.File
+import java.io.IOException
+import java.nio.charset.StandardCharsets
+import java.nio.ByteBuffer
+
+import com.fasterxml.jackson.databind.ObjectMapper
+import com.google.common.collect.Lists
+import com.google.common.base.Preconditions
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.permission.FsPermission
+import org.apache.hadoop.metrics2.impl.MetricsSystemImpl
+import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem
+import org.apache.hadoop.yarn.server.api._
+import org.apache.spark.network.util.LevelDBProvider
+import org.iq80.leveldb.DB
+
+import org.apache.spark.network.TransportContext
+import org.apache.spark.network.crypto.AuthServerBootstrap
+import org.apache.spark.network.sasl.ShuffleSecretManager
+import org.apache.spark.network.server.TransportServer
+import org.apache.spark.network.server.TransportServerBootstrap
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.network.yarn.util.HadoopConfigProvider
+
+import org.apache.spark.shuffle.utils.UcxLogging
+import org.apache.spark.network.yarn.YarnShuffleService.AppId
+import org.apache.spark.network.shuffle.ExternalUcxShuffleBlockHandler
+import org.apache.spark.shuffle.ucx.ExternalUcxServerConf
+import org.apache.spark.shuffle.ucx.ExternalUcxServerTransport
+
+class UcxYarnShuffleService extends AuxiliaryService("sparkucx_shuffle") with UcxLogging {
+ private[this] var ucxTransport: ExternalUcxServerTransport = _
+ private[this] var secretManager: ShuffleSecretManager = _
+ private[this] var shuffleServer: TransportServer = _
+ private[this] var _conf: Configuration = _
+ private[this] var _recoveryPath: Path = _
+ private[this] var blockHandler: ExternalUcxShuffleBlockHandler = _
+ private[this] var registeredExecutorFile: File = _
+ private[this] var secretsFile: File = _
+ private[this] var db: DB = _
+
+ UcxYarnShuffleService.instance = this
+ /**
+ * Return whether authentication is enabled as specified by the configuration.
+ * If so, fetch requests will fail unless the appropriate authentication secret
+ * for the application is provided.
+ */
+ def isAuthenticationEnabled(): Boolean = secretManager != null
+ /**
+ * Start the shuffle server with the given configuration.
+ */
+ override protected def serviceInit(conf: Configuration) = {
+ _conf = conf
+
+ val stopOnFailure = conf.getBoolean(
+ UcxYarnShuffleService.STOP_ON_FAILURE_KEY,
+ UcxYarnShuffleService.DEFAULT_STOP_ON_FAILURE)
+
+ try {
+ // In case this NM was killed while there were running spark applications, we need to restore
+ // lost state for the existing executors. We look for an existing file in the NM's local dirs.
+ // If we don't find one, then we choose a file to use to save the state next time. Even if
+ // an application was stopped while the NM was down, we expect yarn to call stopApplication()
+ // when it comes back
+ if (_recoveryPath != null) {
+ registeredExecutorFile = initRecoveryDb(UcxYarnShuffleService.RECOVERY_FILE_NAME)
+ }
+
+ val transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf))
+ blockHandler = new ExternalUcxShuffleBlockHandler(transportConf, registeredExecutorFile)
+
+ // If authentication is enabled, set up the shuffle server to use a
+ // special RPC handler that filters out unauthenticated fetch requests
+ val bootstraps = Lists.newArrayList[TransportServerBootstrap]()
+ val authEnabled = conf.getBoolean(
+ UcxYarnShuffleService.SPARK_AUTHENTICATE_KEY,
+ UcxYarnShuffleService.DEFAULT_SPARK_AUTHENTICATE)
+ if (authEnabled) {
+ secretManager = new ShuffleSecretManager()
+ if (_recoveryPath != null) {
+ loadSecretsFromDb()
+ }
+ bootstraps.add(new AuthServerBootstrap(transportConf, secretManager))
+ }
+
+ // User might not like to replace tcp service, so use another port to transfer executors info.
+ val portConf = conf.getInt(
+ ExternalUcxServerConf.SPARK_UCX_SHUFFLE_SERVICE_TCP_PORT_KEY,
+ conf.getInt(
+ UcxYarnShuffleService.SPARK_SHUFFLE_SERVICE_PORT_KEY,
+ UcxYarnShuffleService.DEFAULT_SPARK_SHUFFLE_SERVICE_PORT))
+ val transportContext = new TransportContext(transportConf, blockHandler)
+ shuffleServer = transportContext.createServer(portConf, bootstraps)
+ // the port should normally be fixed, but for tests its useful to find an open port
+ val port = shuffleServer.getPort()
+
+ val authEnabledString = if (authEnabled) "enabled" else "not enabled"
+ logInfo(s"Started YARN shuffle service for Spark on port ${port}. " +
+ s"Authentication is ${authEnabledString}. Registered executor file is ${registeredExecutorFile}")
+
+ // register metrics on the block handler into the Node Manager's metrics system.
+ blockHandler.getAllMetrics().getMetrics().put("numRegisteredConnections",
+ shuffleServer.getRegisteredConnections());
+ val serviceMetrics =
+ new YarnShuffleServiceMetrics(blockHandler.getAllMetrics());
+
+ val metricsSystem = DefaultMetricsSystem.instance().asInstanceOf[MetricsSystemImpl];
+ metricsSystem.register(
+ "sparkUcxShuffleService", "Metrics on the Spark Shuffle Service", serviceMetrics);
+ logInfo("Registered metrics with Hadoop's DefaultMetricsSystem");
+
+ logInfo(s"Started YARN shuffle service for Spark on port ${port}. " +
+ s"Authentication is ${authEnabledString}. " +
+ s"Registered executor file is ${registeredExecutorFile}");
+
+ // Ucx Transport
+ logInfo("Start launching ExternalUcxServerTransport")
+ val ucxConf = new ExternalUcxServerConf(conf)
+ ucxTransport = new ExternalUcxServerTransport(ucxConf, blockHandler.ucxBlockManager)
+ ucxTransport.init()
+ blockHandler.setTransport(ucxTransport)
+ } catch {
+ case e: Exception => if (stopOnFailure) {
+ throw e
+ } else {
+ // logError(s"Start UcxYarnShuffleService failed: $e")
+ noteFailure(e)
+ }
+ }
+ }
+
+ private def loadSecretsFromDb(): Unit = {
+ secretsFile = initRecoveryDb(UcxYarnShuffleService.SECRETS_RECOVERY_FILE_NAME)
+
+ // Make sure this is protected in case its not in the NM recovery dir
+ val fs = FileSystem.getLocal(_conf)
+ fs.mkdirs(new Path(secretsFile.getPath()), new FsPermission(448.toShort)) // 0700
+
+ db = LevelDBProvider.initLevelDB(secretsFile, UcxYarnShuffleService.CURRENT_VERSION, UcxYarnShuffleService.mapper)
+ logInfo("Recovery location is: " + secretsFile.getPath())
+ if (db != null) {
+ logInfo("Going to reload spark shuffle data")
+ val itr = db.iterator()
+ itr.seek(UcxYarnShuffleService.APP_CREDS_KEY_PREFIX.getBytes(StandardCharsets.UTF_8))
+ while (itr.hasNext()) {
+ val e = itr.next()
+ val key = new String(e.getKey(), StandardCharsets.UTF_8)
+ if (!key.startsWith(UcxYarnShuffleService.APP_CREDS_KEY_PREFIX)) {
+ return
+ }
+ val id = UcxYarnShuffleService.parseDbAppKey(key)
+ val secret = UcxYarnShuffleService.mapper.readValue(e.getValue(), classOf[ByteBuffer])
+ logInfo("Reloading tokens for app: " + id)
+ secretManager.registerApp(id, secret)
+ }
+ }
+ }
+
+ override def initializeApplication(context: ApplicationInitializationContext): Unit = {
+ val appId = context.getApplicationId().toString()
+ try {
+ val shuffleSecret = context.getApplicationDataForService()
+ if (isAuthenticationEnabled()) {
+ val fullId = new AppId(appId)
+ if (db != null) {
+ val key = UcxYarnShuffleService.dbAppKey(fullId)
+ val value = UcxYarnShuffleService.mapper.writeValueAsString(shuffleSecret).getBytes(StandardCharsets.UTF_8)
+ db.put(key, value)
+ }
+ secretManager.registerApp(appId, shuffleSecret)
+ }
+ } catch {
+ case e: Exception => logError(s"Exception when initializing application ${appId}", e)
+ }
+ }
+
+ override def stopApplication(context: ApplicationTerminationContext): Unit = {
+ val appId = context.getApplicationId().toString()
+ try {
+ if (isAuthenticationEnabled()) {
+ val fullId = new AppId(appId)
+ if (db != null) {
+ try {
+ db.delete(UcxYarnShuffleService.dbAppKey(fullId))
+ } catch {
+ case e: IOException => logError(s"Error deleting ${appId} from executor state db", e)
+ }
+ }
+ secretManager.unregisterApp(appId)
+ }
+ blockHandler.applicationRemoved(appId, false /* clean up local dirs */)
+ } catch {
+ case e: Exception => logError(s"Exception when stopping application ${appId}", e)
+ }
+ }
+
+ override def initializeContainer(context: ContainerInitializationContext): Unit = {}
+
+ override def stopContainer(context: ContainerTerminationContext): Unit = {}
+
+ // Not currently used
+ override def getMetaData(): ByteBuffer = {
+ return ByteBuffer.allocate(0)
+ }
+
+ /**
+ * Set the recovery path for shuffle service recovery when NM is restarted. This will be call
+ * by NM if NM recovery is enabled.
+ */
+ override def setRecoveryPath(recoveryPath: Path): Unit = {
+ _recoveryPath = recoveryPath
+ }
+
+ /**
+ * Get the path specific to this auxiliary service to use for recovery.
+ */
+ protected def getRecoveryPath(fileName: String): Path = {
+ return _recoveryPath
+ }
+
+ /**
+ * Figure out the recovery path and handle moving the DB if YARN NM recovery gets enabled
+ * and DB exists in the local dir of NM by old version of shuffle service.
+ */
+ def initRecoveryDb(dbName: String): File = {
+ if (_recoveryPath == null) {
+ throw new NullPointerException("recovery path should not be null if NM recovery is enabled")
+ }
+
+ val recoveryFile = new File(_recoveryPath.toUri().getPath(), dbName)
+ if (recoveryFile.exists()) {
+ return recoveryFile
+ }
+
+ // db doesn't exist in recovery path go check local dirs for it
+ val localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs")
+ for (dir <- localDirs) {
+ val f = new File(new Path(dir).toUri().getPath(), dbName)
+ if (f.exists()) {
+ // If the recovery path is set then either NM recovery is enabled or another recovery
+ // DB has been initialized. If NM recovery is enabled and had set the recovery path
+ // make sure to move all DBs to the recovery path from the old NM local dirs.
+ // If another DB was initialized first just make sure all the DBs are in the same
+ // location.
+ val newLoc = new Path(_recoveryPath, dbName)
+ val copyFrom = new Path(f.toURI())
+ if (!newLoc.equals(copyFrom)) {
+ logInfo("Moving " + copyFrom + " to: " + newLoc)
+ try {
+ // The move here needs to handle moving non-empty directories across NFS mounts
+ val fs = FileSystem.getLocal(_conf)
+ fs.rename(copyFrom, newLoc)
+ } catch {
+ // Fail to move recovery file to new path, just continue on with new DB location
+ case e: Exception => logError(
+ s"Failed to move recovery file ${dbName} to the path ${_recoveryPath.toString()}",
+ e)
+ }
+ }
+ return new File(newLoc.toUri().getPath())
+ }
+ }
+
+ return new File(_recoveryPath.toUri().getPath(), dbName)
+ }
+
+ override protected def serviceStop(): Unit = {
+ try {
+ if (shuffleServer != null) {
+ shuffleServer.close()
+ }
+ if (blockHandler != null) {
+ blockHandler.close()
+ }
+ if (ucxTransport != null) {
+ ucxTransport.close()
+ }
+ if (db != null) {
+ db.close()
+ }
+ } catch {
+ case e: Exception => logError("Exception when stopping service", e)
+ }
+ }
+}
+
+object UcxYarnShuffleService {
+ // Port on which the shuffle server listens for fetch requests
+ val SPARK_SHUFFLE_SERVICE_PORT_KEY = "spark.shuffle.service.port"
+ val DEFAULT_SPARK_SHUFFLE_SERVICE_PORT = 7337
+
+ // Whether the shuffle server should authenticate fetch requests
+ val SPARK_AUTHENTICATE_KEY = "spark.authenticate"
+ val DEFAULT_SPARK_AUTHENTICATE = false
+
+ val RECOVERY_FILE_NAME = "registeredExecutors.ldb"
+ val SECRETS_RECOVERY_FILE_NAME = "sparkShuffleRecovery.ldb"
+
+ // Whether failure during service initialization should stop the NM.
+ val STOP_ON_FAILURE_KEY = "spark.yarn.shuffle.stopOnFailure"
+ val DEFAULT_STOP_ON_FAILURE = false
+
+ val mapper = new ObjectMapper()
+ val APP_CREDS_KEY_PREFIX = "AppCreds"
+ val CURRENT_VERSION = new LevelDBProvider.StoreVersion(1, 0)
+
+ var instance: UcxYarnShuffleService = _
+
+ private def parseDbAppKey(s: String): String = {
+ if (!s.startsWith(APP_CREDS_KEY_PREFIX)) {
+ throw new IllegalArgumentException("expected a string starting with " + APP_CREDS_KEY_PREFIX)
+ }
+ val json = s.substring(APP_CREDS_KEY_PREFIX.length() + 1)
+ val parsed = mapper.readValue(json, classOf[AppId])
+ return parsed.appId
+ }
+
+ private def dbAppKey(appExecId: AppId): Array[Byte] = {
+ // we stick a common prefix on all the keys so we can find them in the DB
+ val appExecJson = mapper.writeValueAsString(appExecId)
+ val key = (APP_CREDS_KEY_PREFIX + ";" + appExecJson)
+ return key.getBytes(StandardCharsets.UTF_8)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleClient.scala
new file mode 100644
index 00000000..763e759d
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleClient.scala
@@ -0,0 +1,49 @@
+package org.apache.spark.shuffle.compat.spark_2_4
+
+import org.openucx.jucx.UcxUtils
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ShuffleClient}
+import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBlockId, ExternalUcxClientTransport, UcxFetchCallBack, UcxDownloadCallBack}
+import org.apache.spark.shuffle.utils.UnsafeUtils
+import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId}
+
+class ExternalUcxShuffleClient(val transport: ExternalUcxClientTransport) extends ShuffleClient{
+
+ override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String],
+ listener: BlockFetchingListener,
+ downloadFileManager: DownloadFileManager): Unit = {
+ if (downloadFileManager == null) {
+ val ucxBlockIds = Array.ofDim[UcxShuffleBlockId](blockIds.length)
+ val callbacks = Array.ofDim[OperationCallback](blockIds.length)
+ for (i <- blockIds.indices) {
+ val blockId = SparkBlockId.apply(blockIds(i))
+ .asInstanceOf[SparkShuffleBlockId]
+ ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId,
+ blockId.reduceId)
+ callbacks(i) = new UcxFetchCallBack(blockIds(i), listener)
+ }
+ val maxBlocksPerRequest = transport.getMaxBlocksPerRequest
+ for (i <- 0 until ucxBlockIds.length by maxBlocksPerRequest) {
+ val j = i + maxBlocksPerRequest
+ transport.fetchBlocksByBlockIds(host, execId.toInt,
+ ucxBlockIds.slice(i, j),
+ callbacks.slice(i, j))
+ }
+ } else {
+ for (i <- blockIds.indices) {
+ val blockId = SparkBlockId.apply(blockIds(i))
+ .asInstanceOf[SparkShuffleBlockId]
+ val ucxBid = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId,
+ blockId.reduceId)
+ val callback = new UcxDownloadCallBack(blockIds(i), listener,
+ downloadFileManager,
+ transport.sparkTransportConf)
+ transport.fetchBlockByStream(host, execId.toInt, ucxBid, callback)
+ }
+ }
+ }
+
+ override def close(): Unit = {
+
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleManager.scala
new file mode 100644
index 00000000..06d9956b
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleManager.scala
@@ -0,0 +1,23 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle
+
+import org.apache.spark.shuffle.compat.spark_2_4.{ExternalUcxShuffleClient, ExternalUcxShuffleReader}
+import org.apache.spark.shuffle.ucx.ExternalBaseUcxShuffleManager
+import org.apache.spark.{SparkConf, TaskContext}
+
+/**
+ * Common part for all spark versions for UcxShuffleManager logic
+ */
+class ExternalUcxShuffleManager(override val conf: SparkConf, isDriver: Boolean)
+ extends ExternalBaseUcxShuffleManager(conf, isDriver) {
+ private[spark] lazy val transport = awaitUcxTransport
+ private[spark] lazy val shuffleClient = new ExternalUcxShuffleClient(transport)
+ override def getReader[K, C](handle: ShuffleHandle, startPartition: Int,
+ endPartition: Int, context: TaskContext): ShuffleReader[K, C] = {
+ new ExternalUcxShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K,_,C]], startPartition,
+ endPartition, context, shuffleClient)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleReader.scala
new file mode 100644
index 00000000..16d19b36
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleReader.scala
@@ -0,0 +1,117 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.compat.spark_2_4
+
+import java.io.InputStream
+import java.util.concurrent.LinkedBlockingQueue
+
+import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark.internal.{Logging, config}
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
+import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator}
+import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.collection.ExternalSorter
+
+/**
+ * Extension of Spark's shuffle reader with a logic of injection UcxShuffleClient,
+ * and lazy progress only when result queue is empty.
+ */
+class ExternalUcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C],
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ shuffleClient: ExternalUcxShuffleClient,
+ serializerManager: SerializerManager = SparkEnv.get.serializerManager,
+ blockManager: BlockManager = SparkEnv.get.blockManager,
+ mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
+ extends ShuffleReader[K, C] with Logging {
+
+ private val dep = handle.dependency
+
+ /** Read the combined key-values for this reduce task */
+ override def read(): Iterator[Product2[K, C]] = {
+ val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
+ val wrappedStreams = new ShuffleBlockFetcherIterator(
+ context,
+ shuffleClient,
+ blockManager,
+ mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId,
+ startPartition, endPartition),
+ serializerManager.wrapStream,
+ // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+ SparkEnv.get.conf.getSizeAsBytes("spark.reducer.maxSizeInFlight", "48m"),
+ SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
+ SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
+ SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
+ SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
+
+ val serializerInstance = dep.serializer.newInstance()
+ val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
+ // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
+ // NextIterator. The NextIterator makes sure that close() is called on the
+ // underlying InputStream when all records have been read.
+ serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+ }
+
+ // Update the context task metrics for each record read.
+ val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
+ val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
+ recordIter.map { record =>
+ readMetrics.incRecordsRead(1)
+ record
+ },
+ context.taskMetrics().mergeShuffleReadMetrics())
+
+ // An interruptible iterator must be used here in order to support task cancellation
+ val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
+
+ val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
+ if (dep.mapSideCombine) {
+ // We are reading values that are already combined
+ val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
+ dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
+ } else {
+ // We don't know the value type, but also don't care -- the dependency *should*
+ // have made sure its compatible w/ this aggregator, which will convert the value
+ // type to the combined type C
+ val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
+ dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
+ }
+ } else {
+ interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
+ }
+
+ // Sort the output if there is a sort ordering defined.
+ val resultIter = dep.keyOrdering match {
+ case Some(keyOrd: Ordering[K]) =>
+ // Create an ExternalSorter to sort the data.
+ val sorter =
+ new ExternalSorter[K, C, C](context,
+ ordering = Some(keyOrd), serializer = dep.serializer)
+ sorter.insertAll(aggregatedIter)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
+ // Use completion callback to stop sorter if task was finished/cancelled.
+ context.addTaskCompletionListener[Unit](_ => {
+ sorter.stop()
+ })
+ CompletionIterator[Product2[K, C],
+ Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
+ case None =>
+ aggregatedIter
+ }
+
+ resultIter match {
+ case _: InterruptibleIterator[Product2[K, C]] => resultIter
+ case _ =>
+ // Use another interruptible iterator here to support task cancellation as aggregator
+ // or(and) sorter may have consumed previous interruptible iterator.
+ new InterruptibleIterator[Product2[K, C]](context, resultIter)
+ }
+ }
+
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleClient.scala
new file mode 100644
index 00000000..2a0c8814
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleClient.scala
@@ -0,0 +1,49 @@
+package org.apache.spark.shuffle.compat.spark_3_0
+
+import org.openucx.jucx.UcxUtils
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, BlockStoreClient}
+import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBlockId, ExternalUcxClientTransport, UcxFetchCallBack, UcxDownloadCallBack}
+import org.apache.spark.shuffle.utils.UnsafeUtils
+import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId}
+
+class ExternalUcxShuffleClient(val transport: ExternalUcxClientTransport) extends BlockStoreClient{
+
+ override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String],
+ listener: BlockFetchingListener,
+ downloadFileManager: DownloadFileManager): Unit = {
+ if (downloadFileManager == null) {
+ val ucxBlockIds = Array.ofDim[UcxShuffleBlockId](blockIds.length)
+ val callbacks = Array.ofDim[OperationCallback](blockIds.length)
+ for (i <- blockIds.indices) {
+ val blockId = SparkBlockId.apply(blockIds(i))
+ .asInstanceOf[SparkShuffleBlockId]
+ ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId,
+ blockId.reduceId)
+ callbacks(i) = new UcxFetchCallBack(blockIds(i), listener)
+ }
+ val maxBlocksPerRequest = transport.getMaxBlocksPerRequest
+ for (i <- 0 until blockIds.length by maxBlocksPerRequest) {
+ val j = i + maxBlocksPerRequest
+ transport.fetchBlocksByBlockIds(host, execId.toInt,
+ ucxBlockIds.slice(i, j),
+ callbacks.slice(i, j))
+ }
+ } else {
+ for (i <- blockIds.indices) {
+ val blockId = SparkBlockId.apply(blockIds(i))
+ .asInstanceOf[SparkShuffleBlockId]
+ val ucxBid = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId,
+ blockId.reduceId)
+ val callback = new UcxDownloadCallBack(blockIds(i), listener,
+ downloadFileManager,
+ transport.sparkTransportConf)
+ transport.fetchBlockByStream(host, execId.toInt, ucxBid, callback)
+ }
+ }
+ }
+
+ override def close(): Unit = {
+
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleManager.scala
new file mode 100644
index 00000000..49d099aa
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleManager.scala
@@ -0,0 +1,26 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle
+
+import org.apache.spark.shuffle.compat.spark_3_0.{ExternalUcxShuffleClient, ExternalUcxShuffleReader}
+import org.apache.spark.shuffle.ucx.ExternalBaseUcxShuffleManager
+import org.apache.spark.{SparkConf, TaskContext}
+
+/**
+ * Common part for all spark versions for UcxShuffleManager logic
+ */
+class ExternalUcxShuffleManager(override val conf: SparkConf, isDriver: Boolean)
+ extends ExternalBaseUcxShuffleManager(conf, isDriver) {
+ private[spark] lazy val transport = awaitUcxTransport
+ private[spark] lazy val shuffleClient = new ExternalUcxShuffleClient(transport)
+ override def getReader[K, C](
+ handle: ShuffleHandle, startPartition: MapId, endPartition: MapId,
+ context: TaskContext, metrics: ShuffleReadMetricsReporter):
+ ShuffleReader[K, C] = {
+ new ExternalUcxShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K,_,C]],
+ startPartition, endPartition, context,
+ shuffleClient, metrics, shouldBatchFetch = false)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleReader.scala
new file mode 100644
index 00000000..5bf53c52
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleReader.scala
@@ -0,0 +1,124 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.compat.spark_3_0
+
+import java.io.InputStream
+import java.util.concurrent.LinkedBlockingQueue
+
+import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark.internal.{Logging, config}
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReadMetricsReporter, ShuffleReader}
+import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator}
+import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.collection.ExternalSorter
+
+/**
+ * Extension of Spark's shuffle reader with a logic of injection UcxShuffleClient,
+ * and lazy progress only when result queue is empty.
+ */
+class ExternalUcxShuffleReader[K, C](
+ handle: BaseShuffleHandle[K, _, C], startPartition: Int,
+ endPartition: Int, context: TaskContext,
+ shuffleClient: ExternalUcxShuffleClient,
+ readMetrics: ShuffleReadMetricsReporter,
+ serializerManager: SerializerManager = SparkEnv.get.serializerManager,
+ blockManager: BlockManager = SparkEnv.get.blockManager,
+ mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker,
+ shouldBatchFetch: Boolean = false)
+ extends ShuffleReader[K, C] with Logging {
+
+ private val dep = handle.dependency
+
+ /** Read the combined key-values for this reduce task */
+ override def read(): Iterator[Product2[K, C]] = {
+ val (blocksByAddressIterator1, _) = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
+ handle.shuffleId, startPartition, endPartition).duplicate
+ val shuffleIterator = new ShuffleBlockFetcherIterator(
+ context,
+ shuffleClient,
+ blockManager,
+ blocksByAddressIterator1,
+ serializerManager.wrapStream,
+ // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+ SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
+ SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
+ SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
+ SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
+ SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
+ SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
+ readMetrics,
+ // TODO: Support batch fetch
+ doBatchFetch = false)
+
+ val wrappedStreams = shuffleIterator.toCompletionIterator
+
+ val serializerInstance = dep.serializer.newInstance()
+
+ // Create a key/value iterator for each stream
+ val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
+ // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
+ // NextIterator. The NextIterator makes sure that close() is called on the
+ // underlying InputStream when all records have been read.
+ serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+ }
+
+ // Update the context task metrics for each record read.
+ val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
+ recordIter.map { record =>
+ readMetrics.incRecordsRead(1)
+ record
+ },
+ context.taskMetrics().mergeShuffleReadMetrics())
+
+ // An interruptible iterator must be used here in order to support task cancellation
+ val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
+
+ val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
+ if (dep.mapSideCombine) {
+ // We are reading values that are already combined
+ val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
+ dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
+ } else {
+ // We don't know the value type, but also don't care -- the dependency *should*
+ // have made sure its compatible w/ this aggregator, which will convert the value
+ // type to the combined type C
+ val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
+ dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
+ }
+ } else {
+ interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
+ }
+
+ // Sort the output if there is a sort ordering defined.
+ val resultIter = dep.keyOrdering match {
+ case Some(keyOrd: Ordering[K]) =>
+ // Create an ExternalSorter to sort the data.
+ val sorter =
+ new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
+ sorter.insertAll(aggregatedIter)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
+ // Use completion callback to stop sorter if task was finished/cancelled.
+ context.addTaskCompletionListener[Unit](_ => {
+ sorter.stop()
+ })
+ CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
+ case None =>
+ aggregatedIter
+ }
+
+ resultIter match {
+ case _: InterruptibleIterator[Product2[K, C]] => resultIter
+ case _ =>
+ // Use another interruptible iterator here to support task cancellation as aggregator
+ // or(and) sorter may have consumed previous interruptible iterator.
+ new InterruptibleIterator[Product2[K, C]](context, resultIter)
+ }
+ }
+
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala
index 314b88a3..64b8df32 100755
--- a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala
@@ -5,7 +5,9 @@
package org.apache.spark.shuffle.ucx
import java.nio.ByteBuffer
+import java.util.{HashSet, HashMap}
import java.util.concurrent.locks.StampedLock
+import org.openucx.jucx.ucp.UcpRequest
/**
* Class that represents some block in memory with it's address, size.
@@ -90,6 +92,8 @@ trait Request {
*/
trait OperationCallback {
def onComplete(result: OperationResult): Unit
+ def onError(result: OperationResult): Unit = ???
+ def onData(buf: ByteBuffer): Unit = ???
}
/**
@@ -167,3 +171,82 @@ trait ShuffleTransport {
def progress(): Unit
}
+
+class UcxRequest(private var request: UcpRequest, stats: OperationStats)
+ extends Request {
+
+ private[ucx] var completed = false
+
+ override def isCompleted: Boolean = completed || ((request != null) && request.isCompleted)
+
+ override def getStats: Option[OperationStats] = Some(stats)
+
+ override def toString(): String = {
+ s"UcxRequest(isCompleted=$isCompleted size=${stats.recvSize} cost=${stats.getElapsedTimeNs}ns)"
+ }
+
+ private[ucx] def setRequest(request: UcpRequest): Unit = {
+ this.request = request
+ }
+}
+
+class UcxStats extends OperationStats {
+ private[ucx] val startTime = System.nanoTime()
+ private[ucx] var amHandleTime = 0L
+ private[ucx] var endTime: Long = 0L
+ private[ucx] var receiveSize: Long = 0L
+
+ /**
+ * Time it took from operation submit to callback call.
+ * This depends on [[ ShuffleTransport.progress() ]] calls,
+ * and does not indicate actual data transfer time.
+ */
+ override def getElapsedTimeNs: Long = endTime - startTime
+
+ /**
+ * Indicates number of valid bytes in receive memory
+ */
+ override def recvSize: Long = receiveSize
+}
+
+class UcxFetchState(val callbacks: Seq[OperationCallback],
+ val request: UcxRequest,
+ val timestamp: Long,
+ val recvSet: HashSet[Int] = new HashSet[Int]) {
+ override def toString(): String = {
+ s"UcxFetchState(chunks=${callbacks.size}, $request, start=$timestamp, received=${recvSet.size})"
+ }
+}
+
+class UcxStreamState(val callback: OperationCallback,
+ val request: UcxRequest,
+ val timestamp: Long,
+ var remaining: Long,
+ val recvMap: HashMap[Long, MemoryBlock] = new HashMap[Long, MemoryBlock]) {
+ override def toString(): String = {
+ s"UcxStreamState($request, start=$timestamp, remaining=$remaining, received=${recvMap.size})"
+ }
+}
+
+class UcxSliceState(val callback: OperationCallback,
+ val request: UcxRequest,
+ val timestamp: Long,
+ val mem: MemoryBlock,
+ var remaining: Long,
+ var offset: Long = 0L,
+ val recvSet: HashSet[Long] = new HashSet[Long]) {
+ override def toString(): String = {
+ s"UcxStreamState($request, start=$timestamp, remaining=$remaining, received=${recvSet.size})"
+ }
+}
+
+class UcxSucceedOperationResult(mem: MemoryBlock, stats: OperationStats)
+ extends OperationResult {
+ override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS
+
+ override def getError: TransportError = null
+
+ override def getStats: Option[OperationStats] = Option(stats)
+
+ override def getData: MemoryBlock = mem
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxProgressThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxProgressThread.scala
new file mode 100644
index 00000000..68c849ce
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxProgressThread.scala
@@ -0,0 +1,25 @@
+package org.apache.spark.shuffle.ucx
+
+import org.openucx.jucx.ucp.UcpWorker
+
+class UcxProgressThread(worker: UcpWorker, useWakeup: Boolean) extends Thread {
+ setDaemon(true)
+ setName(s"UCX-progress-${super.getName}")
+
+ override def run(): Unit = {
+ if (useWakeup) {
+ while (!isInterrupted) {
+ worker.synchronized {
+ while (worker.progress != 0) {}
+ }
+ worker.waitForEvents()
+ }
+ } else {
+ while (!isInterrupted) {
+ worker.synchronized {
+ while (worker.progress != 0) {}
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala
index ae8bb119..2e46c362 100755
--- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala
@@ -19,39 +19,6 @@ import java.nio.ByteBuffer
import scala.collection.concurrent.TrieMap
import scala.collection.mutable
-class UcxRequest(private var request: UcpRequest, stats: OperationStats)
- extends Request {
-
- private[ucx] var completed = false
-
- override def isCompleted: Boolean = completed || ((request != null) && request.isCompleted)
-
- override def getStats: Option[OperationStats] = Some(stats)
-
- private[ucx] def setRequest(request: UcpRequest): Unit = {
- this.request = request
- }
-}
-
-class UcxStats extends OperationStats {
- private[ucx] val startTime = System.nanoTime()
- private[ucx] var amHandleTime = 0L
- private[ucx] var endTime: Long = 0L
- private[ucx] var receiveSize: Long = 0L
-
- /**
- * Time it took from operation submit to callback call.
- * This depends on [[ ShuffleTransport.progress() ]] calls,
- * and does not indicate actual data transfer time.
- */
- override def getElapsedTimeNs: Long = endTime - startTime
-
- /**
- * Indicates number of valid bytes in receive memory
- */
- override def recvSize: Long = receiveSize
-}
-
case class UcxShuffleBockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
override def serializedSize: Int = 12
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/ExternalUcxConf.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/ExternalUcxConf.scala
new file mode 100644
index 00000000..88d6d7d4
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/ExternalUcxConf.scala
@@ -0,0 +1,88 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx
+
+/**
+ * Plugin configuration properties.
+ */
+trait ExternalUcxConf {
+ lazy val preallocateBuffersMap: Map[Long, Int] =
+ ExternalUcxConf.preAllocateConfToMap(ExternalUcxConf.PREALLOCATE_BUFFERS_DEFAULT)
+ lazy val memoryLimit: Boolean = ExternalUcxConf.MEMORY_LIMIT_DEFAULT
+ lazy val memoryGroupSize: Int = ExternalUcxConf.MEMORY_GROUP_SIZE_DEFAULT
+ lazy val minBufferSize: Long = ExternalUcxConf.MIN_BUFFER_SIZE_DEFAULT
+ lazy val maxBufferSize: Long = ExternalUcxConf.MAX_BUFFER_SIZE_DEFAULT
+ lazy val minRegistrationSize: Long = ExternalUcxConf.MIN_REGISTRATION_SIZE_DEFAULT
+ lazy val maxRegistrationSize: Long = ExternalUcxConf.MAX_REGISTRATION_SIZE_DEFAULT
+ lazy val numPools: Int = ExternalUcxConf.NUM_POOLS_DEFAULT
+ lazy val listenerAddress: String = ExternalUcxConf.SOCKADDR_DEFAULT
+ lazy val useWakeup: Boolean = ExternalUcxConf.WAKEUP_FEATURE_DEFAULT
+ lazy val numIoThreads: Int = ExternalUcxConf.NUM_IO_THREADS_DEFAULT
+ lazy val numThreads: Int = ExternalUcxConf.NUM_THREADS_DEFAULT
+ lazy val numWorkers: Int = ExternalUcxConf.NUM_WORKERS_DEFAULT
+ lazy val maxBlocksPerRequest: Int = ExternalUcxConf.MAX_BLOCKS_IN_FLIGHT_DEFAULT
+ lazy val ucxServerPort: Int = ExternalUcxConf.SPARK_UCX_SHUFFLE_SERVICE_PORT_DEFAULT
+ lazy val maxReplySize: Long = ExternalUcxConf.MAX_REPLY_SIZE_DEFAULT
+}
+
+object ExternalUcxConf {
+ private[ucx] def getUcxConf(name: String) = s"spark.shuffle.ucx.$name"
+
+ lazy val PREALLOCATE_BUFFERS_KEY = getUcxConf("memory.preAllocateBuffers")
+ lazy val PREALLOCATE_BUFFERS_DEFAULT = ""
+
+ lazy val MEMORY_LIMIT_KEY = getUcxConf("memory.limit")
+ lazy val MEMORY_LIMIT_DEFAULT = false
+
+ lazy val MEMORY_GROUP_SIZE_KEY = getUcxConf("memory.groupSize")
+ lazy val MEMORY_GROUP_SIZE_DEFAULT = 3
+
+ lazy val MIN_BUFFER_SIZE_KEY = getUcxConf("memory.minBufferSize")
+ lazy val MIN_BUFFER_SIZE_DEFAULT = 4096L
+
+ lazy val MAX_BUFFER_SIZE_KEY = getUcxConf("memory.maxBufferSize")
+ lazy val MAX_BUFFER_SIZE_DEFAULT = Int.MaxValue.toLong
+
+ lazy val MIN_REGISTRATION_SIZE_KEY = getUcxConf("memory.minAllocationSize")
+ lazy val MIN_REGISTRATION_SIZE_DEFAULT = 1L * 1024 * 1024
+
+ lazy val MAX_REGISTRATION_SIZE_KEY = getUcxConf("memory.maxAllocationSize")
+ lazy val MAX_REGISTRATION_SIZE_DEFAULT = 16L * 1024 * 1024 * 1024
+
+ lazy val NUM_POOLS_KEY = getUcxConf("memory.numPools")
+ lazy val NUM_POOLS_DEFAULT = 1
+
+ lazy val SOCKADDR_KEY = getUcxConf("listener.sockaddr")
+ lazy val SOCKADDR_DEFAULT = "0.0.0.0:0"
+
+ lazy val WAKEUP_FEATURE_KEY = getUcxConf("useWakeup")
+ lazy val WAKEUP_FEATURE_DEFAULT = true
+
+ lazy val NUM_IO_THREADS_KEY = getUcxConf("numIoThreads")
+ lazy val NUM_IO_THREADS_DEFAULT = 1
+
+ lazy val NUM_THREADS_KEY = getUcxConf("numThreads")
+ lazy val NUM_THREADS_COMPAT_KEY = getUcxConf("numListenerThreads")
+ lazy val NUM_THREADS_DEFAULT = 4
+
+ lazy val NUM_WORKERS_KEY = getUcxConf("numWorkers")
+ lazy val NUM_WORKERS_COMPAT_KEY = getUcxConf("numClientWorkers")
+ lazy val NUM_WORKERS_DEFAULT = 1
+
+ lazy val MAX_BLOCKS_IN_FLIGHT_KEY = getUcxConf("maxBlocksPerRequest")
+ lazy val MAX_BLOCKS_IN_FLIGHT_DEFAULT = 50
+
+ lazy val MAX_REPLY_SIZE_KEY = getUcxConf("maxReplySize")
+ lazy val MAX_REPLY_SIZE_DEFAULT = 32L * 1024 * 1024
+
+ lazy val SPARK_UCX_SHUFFLE_SERVICE_PORT_KEY = getUcxConf("service.port")
+ lazy val SPARK_UCX_SHUFFLE_SERVICE_PORT_DEFAULT = 3338
+
+ def preAllocateConfToMap(conf: String): Map[Long, Int] =
+ conf.split(",").withFilter(s => s.nonEmpty).map(entry =>
+ entry.split(":") match {
+ case Array(bufferSize, bufferCount) => (bufferSize.toLong, bufferCount.toInt)
+ }).toMap
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/ExternalUcxTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/ExternalUcxTransport.scala
new file mode 100644
index 00000000..b38d4e45
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/ExternalUcxTransport.scala
@@ -0,0 +1,94 @@
+/*
+ * Copyright (C) 2022, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
+package org.apache.spark.shuffle.ucx
+
+import org.apache.spark.shuffle.ucx.memory.UcxLimitedMemPool
+import org.apache.spark.shuffle.utils.UcxLogging
+import org.openucx.jucx.ucp._
+
+import java.nio.ByteBuffer
+import java.util.concurrent.ExecutorService
+
+object ExternalAmId {
+ // client -> server
+ final val ADDRESS = 0
+ final val CONNECT = 1
+ final val FETCH_BLOCK = 2
+ final val FETCH_STREAM = 3
+ // server -> client
+ final val REPLY_ADDRESS = 0
+ final val REPLY_SLICE = 1
+ final val REPLY_BLOCK = 2
+ final val REPLY_STREAM = 3
+}
+
+class ExternalUcxTransport(val ucxShuffleConf: ExternalUcxConf) extends UcxLogging {
+ @volatile protected var initialized: Boolean = false
+ @volatile protected var running: Boolean = true
+ private[ucx] var ucxContext: UcpContext = _
+ private[ucx] var memPools: Array[UcxLimitedMemPool] = _
+ private[ucx] val ucpWorkerParams = new UcpWorkerParams().requestThreadSafety()
+ private[ucx] var taskExecutors: ExecutorService = _
+
+ def hostBounceBufferMemoryPool(i: Int = 0): UcxLimitedMemPool = {
+ memPools(i % memPools.length)
+ }
+
+ def estimateNumEps(): Int = 1
+
+ def initContext(): Unit = {
+ val numEndpoints = estimateNumEps()
+ logInfo(s"Creating UCX context with an estimated number of endpoints: $numEndpoints")
+
+ val params = new UcpParams().requestAmFeature().setMtWorkersShared(true)
+ .setEstimatedNumEps(numEndpoints).requestAmFeature()
+ .setConfig("USE_MT_MUTEX", "yes")
+
+ if (ucxShuffleConf.useWakeup) {
+ params.requestWakeupFeature()
+ }
+
+ ucxContext = new UcpContext(params)
+ }
+
+ def initMemoryPool(): Unit = {
+ val numPools = ucxShuffleConf.numPools.max(1).min(ucxShuffleConf.numWorkers)
+ memPools = new Array[UcxLimitedMemPool](numPools)
+ for (i <- 0 until numPools) {
+ memPools(i) = new UcxLimitedMemPool(ucxContext)
+ memPools(i).init(ucxShuffleConf.minBufferSize,
+ ucxShuffleConf.maxBufferSize,
+ ucxShuffleConf.minRegistrationSize,
+ ucxShuffleConf.maxRegistrationSize / numPools,
+ ucxShuffleConf.preallocateBuffersMap,
+ ucxShuffleConf.memoryLimit,
+ ucxShuffleConf.memoryGroupSize)
+ }
+ }
+
+ def initTaskPool(threadNum: Int): Unit = {
+ taskExecutors = UcxThreadUtils.newFixedDaemonPool("UCX", threadNum)
+ }
+
+ def init(): ByteBuffer = ???
+
+ @`inline`
+ def submit(task: Runnable): Unit = {
+ taskExecutors.submit(task)
+ }
+
+ def close(): Unit = {
+ if (initialized) {
+ memPools.filter(_ != null).foreach(_.close())
+ if (ucxContext != null) {
+ ucxContext.close()
+ }
+ if (taskExecutors != null) {
+ taskExecutors.shutdown()
+ }
+ initialized = false
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/UcxWorkerThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/UcxWorkerThread.scala
new file mode 100644
index 00000000..6c81e765
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/UcxWorkerThread.scala
@@ -0,0 +1,47 @@
+package org.apache.spark.shuffle.ucx
+
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.ConcurrentLinkedQueue
+import org.openucx.jucx.ucp.UcpWorker
+
+class UcxWorkerThread(worker: UcpWorker, useWakeup: Boolean) extends Thread {
+ private val taskQueue = new ConcurrentLinkedQueue[Runnable]
+ private val running = new AtomicBoolean(true)
+
+ setDaemon(true)
+ setName(s"UCX-worker-${super.getName}")
+
+ @`inline`
+ def post(task: Runnable): Unit = {
+ taskQueue.add(task)
+ worker.signal()
+ }
+
+ @`inline`
+ def await() = {
+ if (taskQueue.isEmpty) {
+ worker.waitForEvents()
+ }
+ }
+
+ @`inline`
+ def close(cleanTask: Runnable): Unit = {
+ if (running.compareAndSet(true, false)) {
+ worker.signal()
+ }
+ cleanTask.run()
+ worker.close()
+ }
+
+ override def run(): Unit = {
+ val doAwait = if (useWakeup) await _ else () => {}
+ while (running.get()) {
+ val task = taskQueue.poll()
+ if (task != null) {
+ task.run()
+ }
+ while (worker.progress != 0) {}
+ doAwait()
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalBaseUcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalBaseUcxShuffleManager.scala
new file mode 100644
index 00000000..a4a2e12a
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalBaseUcxShuffleManager.scala
@@ -0,0 +1,128 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx
+
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.util.Success
+
+import org.apache.spark.SparkException
+import org.apache.spark.rpc.RpcEnv
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.shuffle.ucx.rpc.{ExternalUcxExecutorRpcEndpoint, ExternalUcxDriverRpcEndpoint}
+import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{PushServiceAddress, PushAllServiceAddress}
+import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer
+import org.apache.spark.util.{RpcUtils, ThreadUtils}
+import org.apache.spark.{SecurityManager, SparkConf, SparkEnv}
+import org.openucx.jucx.{NativeLibs, UcxException}
+
+/**
+ * Common part for all spark versions for UcxShuffleManager logic
+ */
+abstract class ExternalBaseUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) extends SortShuffleManager(conf) {
+ type ShuffleId = Int
+ type MapId = Int
+ type ReduceId = Long
+
+ /* Load UCX/JUCX libraries as soon as possible to avoid collision with JVM when register malloc/mmap hook. */
+ if (!isDriver) {
+ NativeLibs.load();
+ }
+
+ val ucxShuffleConf = new ExternalUcxClientConf(conf)
+
+ @volatile var ucxTransport: ExternalUcxClientTransport = _
+
+ private var executorEndpoint: ExternalUcxExecutorRpcEndpoint = _
+ private var driverEndpoint: ExternalUcxDriverRpcEndpoint = _
+
+ protected val driverRpcName = "SparkUCX_driver"
+
+ private val setupThread = ThreadUtils.newDaemonSingleThreadExecutor("UcxTransportSetupThread")
+
+ private[this] val latch = setupThread.submit(new Runnable {
+ override def run(): Unit = {
+ while (SparkEnv.get == null) {
+ Thread.sleep(10)
+ }
+ if (isDriver) {
+ val rpcEnv = SparkEnv.get.rpcEnv
+ logInfo(s"Setting up driver RPC")
+ driverEndpoint = new ExternalUcxDriverRpcEndpoint(rpcEnv)
+ rpcEnv.setupEndpoint(driverRpcName, driverEndpoint)
+ } else {
+ while (SparkEnv.get.blockManager.shuffleServerId == null) {
+ Thread.sleep(5)
+ }
+ startUcxTransport()
+ }
+ }
+ })
+
+ def awaitUcxTransport(): ExternalUcxClientTransport = {
+ if (ucxTransport == null) {
+ latch.get(10, TimeUnit.SECONDS)
+ if (ucxTransport == null) {
+ throw new UcxException("ExternalUcxClientTransport init timeout")
+ }
+ }
+ ucxTransport
+ }
+
+ /**
+ * Atomically starts UcxNode singleton - one for all shuffle threads.
+ */
+ def startUcxTransport(): Unit = if (ucxTransport == null) {
+ val blockManager = SparkEnv.get.blockManager.shuffleServerId
+ val transport = new ExternalUcxClientTransport(ucxShuffleConf, blockManager)
+ transport.init()
+ ucxTransport = transport
+ val rpcEnv = SparkEnv.get.rpcEnv
+ executorEndpoint = new ExternalUcxExecutorRpcEndpoint(rpcEnv, ucxTransport, setupThread)
+ val endpoint = rpcEnv.setupEndpoint(
+ s"ucx-shuffle-executor-${blockManager.executorId}",
+ executorEndpoint)
+ var driverCost = 0
+ var driverEndpoint = RpcUtils.makeDriverRef(driverRpcName, conf, rpcEnv)
+ while (driverEndpoint == null) {
+ Thread.sleep(10)
+ driverCost += 10
+ driverEndpoint = RpcUtils.makeDriverRef(driverRpcName, conf, rpcEnv)
+ }
+ driverEndpoint.ask[PushAllServiceAddress](
+ PushServiceAddress(blockManager.host, transport.localServerPorts, endpoint))
+ .andThen {
+ case Success(msg) =>
+ logInfo(s"Driver take $driverCost ms.")
+ executorEndpoint.receive(msg)
+ }
+ }
+
+
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ // shuffleBlockResolver.asInstanceOf[CommonUcxShuffleBlockResolver].removeShuffle(shuffleId)
+ super.unregisterShuffle(shuffleId)
+ }
+
+ /**
+ * Called on both driver and executors to finally cleanup resources.
+ */
+ override def stop(): Unit = synchronized {
+ super.stop()
+ if (ucxTransport != null) {
+ ucxTransport.close()
+ ucxTransport = null
+ }
+ if (executorEndpoint != null) {
+ executorEndpoint.stop()
+ }
+ if (driverEndpoint != null) {
+ driverEndpoint.stop()
+ }
+ setupThread.shutdown()
+ }
+
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientConf.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientConf.scala
new file mode 100644
index 00000000..bffcac4a
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientConf.scala
@@ -0,0 +1,137 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.config.ConfigBuilder
+import org.apache.spark.network.util.ByteUnit
+
+/**
+ * Plugin configuration properties.
+ */
+class ExternalUcxClientConf(val sparkConf: SparkConf) extends SparkConf with ExternalUcxConf {
+
+ def getSparkConf: SparkConf = sparkConf
+
+ // Memory Pool
+ private lazy val PREALLOCATE_BUFFERS =
+ ConfigBuilder(ExternalUcxConf.PREALLOCATE_BUFFERS_KEY)
+ .doc("Comma separated list of buffer size : buffer count pairs to preallocate in memory pool. E.g. 4k:1000,16k:500")
+ .stringConf.createWithDefault("")
+
+ override lazy val preallocateBuffersMap: Map[Long, Int] =
+ ExternalUcxConf.preAllocateConfToMap(
+ sparkConf.get(ExternalUcxConf.PREALLOCATE_BUFFERS_KEY,
+ ExternalUcxConf.PREALLOCATE_BUFFERS_DEFAULT))
+
+ private lazy val MEMORY_LIMIT = ConfigBuilder(ExternalUcxConf.MEMORY_LIMIT_KEY)
+ .doc("Enable memory pool size limit.")
+ .booleanConf
+ .createWithDefault(ExternalUcxConf.MEMORY_LIMIT_DEFAULT)
+
+ override lazy val memoryLimit: Boolean = sparkConf.getBoolean(MEMORY_LIMIT.key,
+ MEMORY_LIMIT.defaultValue.get)
+
+ private lazy val MEMORY_GROUP_SIZE = ConfigBuilder(ExternalUcxConf.MEMORY_GROUP_SIZE_KEY)
+ .doc("Memory group size.")
+ .intConf
+ .createWithDefault(ExternalUcxConf.MEMORY_GROUP_SIZE_DEFAULT)
+
+ override lazy val memoryGroupSize: Int = sparkConf.getInt(MEMORY_GROUP_SIZE.key,
+ MEMORY_GROUP_SIZE.defaultValue.get)
+
+ private lazy val MIN_BUFFER_SIZE = ConfigBuilder(ExternalUcxConf.MIN_BUFFER_SIZE_KEY)
+ .doc("Minimal buffer size in memory pool.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(ExternalUcxConf.MIN_BUFFER_SIZE_DEFAULT)
+
+ override lazy val minBufferSize: Long = sparkConf.getSizeAsBytes(MIN_BUFFER_SIZE.key,
+ MIN_BUFFER_SIZE.defaultValue.get)
+
+ private lazy val MAX_BUFFER_SIZE = ConfigBuilder(ExternalUcxConf.MAX_BUFFER_SIZE_KEY)
+ .doc("Maximal buffer size in memory pool.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(ExternalUcxConf.MAX_BUFFER_SIZE_DEFAULT)
+
+ override lazy val maxBufferSize: Long = sparkConf.getSizeAsBytes(MAX_BUFFER_SIZE.key,
+ MAX_BUFFER_SIZE.defaultValue.get)
+
+ private lazy val MIN_REGISTRATION_SIZE =
+ ConfigBuilder(ExternalUcxConf.MIN_REGISTRATION_SIZE_KEY)
+ .doc("Minimal memory registration size in memory pool.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(ExternalUcxConf.MIN_REGISTRATION_SIZE_DEFAULT)
+
+ override lazy val minRegistrationSize: Long = sparkConf.getSizeAsBytes(MIN_REGISTRATION_SIZE.key,
+ MIN_REGISTRATION_SIZE.defaultValue.get).toInt
+
+ private lazy val MAX_REGISTRATION_SIZE =
+ ConfigBuilder(ExternalUcxConf.MAX_REGISTRATION_SIZE_KEY)
+ .doc("Maximal memory registration size in memory pool.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(ExternalUcxConf.MAX_REGISTRATION_SIZE_DEFAULT)
+
+ override lazy val maxRegistrationSize: Long = sparkConf.getSizeAsBytes(MAX_REGISTRATION_SIZE.key,
+ MAX_REGISTRATION_SIZE.defaultValue.get).toLong
+
+ private lazy val NUM_POOLS =
+ ConfigBuilder(ExternalUcxConf.NUM_POOLS_KEY)
+ .doc("Number of memory pool.")
+ .intConf
+ .createWithDefault(ExternalUcxConf.NUM_POOLS_DEFAULT)
+
+ override lazy val numPools: Int = sparkConf.getInt(NUM_POOLS.key,
+ NUM_POOLS.defaultValue.get)
+
+ private lazy val SOCKADDR =
+ ConfigBuilder(ExternalUcxConf.SOCKADDR_KEY)
+ .doc("Whether to use socket address to connect executors.")
+ .stringConf
+ .createWithDefault(ExternalUcxConf.SOCKADDR_DEFAULT)
+
+ override lazy val listenerAddress: String = sparkConf.get(SOCKADDR.key, SOCKADDR.defaultValueString)
+
+ private lazy val WAKEUP_FEATURE =
+ ConfigBuilder(ExternalUcxConf.WAKEUP_FEATURE_KEY)
+ .doc("Whether to use busy polling for workers")
+ .booleanConf
+ .createWithDefault(ExternalUcxConf.WAKEUP_FEATURE_DEFAULT)
+
+ override lazy val useWakeup: Boolean = sparkConf.getBoolean(WAKEUP_FEATURE.key, WAKEUP_FEATURE.defaultValue.get)
+
+ private lazy val NUM_WORKERS = ConfigBuilder(ExternalUcxConf.NUM_WORKERS_KEY)
+ .doc("Number of client workers")
+ .intConf
+ .createWithDefault(ExternalUcxConf.NUM_WORKERS_DEFAULT)
+
+ override lazy val numWorkers: Int = sparkConf.getInt(NUM_WORKERS.key, sparkConf.getInt("spark.executor.cores",
+ NUM_WORKERS.defaultValue.get))
+
+ private lazy val NUM_THREADS= ConfigBuilder(ExternalUcxConf.NUM_THREADS_KEY)
+ .doc("Number of threads in thread pool")
+ .intConf
+ .createWithDefault(ExternalUcxConf.NUM_THREADS_DEFAULT)
+
+ override lazy val numThreads: Int = sparkConf.getInt(NUM_THREADS.key,
+ sparkConf.getInt(ExternalUcxConf.NUM_THREADS_COMPAT_KEY, NUM_THREADS.defaultValue.get))
+
+ private lazy val MAX_BLOCKS_IN_FLIGHT = ConfigBuilder(ExternalUcxConf.MAX_BLOCKS_IN_FLIGHT_KEY)
+ .doc("Maximum number blocks per request")
+ .intConf
+ .createWithDefault(ExternalUcxConf.MAX_BLOCKS_IN_FLIGHT_DEFAULT)
+
+ override lazy val maxBlocksPerRequest: Int = sparkConf.getInt(MAX_BLOCKS_IN_FLIGHT.key, MAX_BLOCKS_IN_FLIGHT.defaultValue.get)
+
+ private lazy val MAX_REPLY_SIZE = ConfigBuilder(ExternalUcxConf.MAX_REPLY_SIZE_KEY)
+ .doc("Maximum number blocks per request")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(ExternalUcxConf.MAX_REPLY_SIZE_DEFAULT)
+
+ override lazy val maxReplySize: Long = sparkConf.getSizeAsBytes(MAX_REPLY_SIZE.key, MAX_REPLY_SIZE.defaultValue.get)
+
+ override lazy val ucxServerPort: Int = sparkConf.getInt(
+ ExternalUcxConf.SPARK_UCX_SHUFFLE_SERVICE_PORT_KEY,
+ ExternalUcxConf.SPARK_UCX_SHUFFLE_SERVICE_PORT_DEFAULT)
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientTransport.scala
new file mode 100644
index 00000000..564bbcd1
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientTransport.scala
@@ -0,0 +1,228 @@
+package org.apache.spark.shuffle.ucx
+
+// import org.apache.spark.SparkEnv
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager}
+import org.apache.spark.shuffle.utils.{UcxLogging, UnsafeUtils}
+import org.apache.spark.shuffle.ucx.memory.UcxLimitedMemPool
+import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils}
+import org.apache.spark.storage.BlockManagerId
+import org.openucx.jucx.ucp._
+
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, TimeUnit}
+
+class ExternalUcxClientTransport(clientConf: ExternalUcxClientConf, blockManagerId: BlockManagerId)
+extends ExternalUcxTransport(clientConf) with UcxLogging {
+ private[spark] val executorId = blockManagerId.executorId.toLong
+ private[spark] val tcpServerPort = blockManagerId.port
+ private[spark] val ucxServerPort = clientConf.ucxServerPort
+ private[spark] val numWorkers = clientConf.numWorkers
+ private[spark] val timeoutMs = clientConf.getSparkConf.getTimeAsSeconds(
+ "spark.network.timeout", "120s") * 1000
+ private[spark] val sparkTransportConf = SparkTransportConf.fromSparkConf(
+ clientConf.getSparkConf, "ucx-shuffle", numWorkers)
+
+ private[ucx] val ucxServers = new ConcurrentHashMap[String, InetSocketAddress]
+ private[ucx] val localServerPortsDone = new CountDownLatch(1)
+ private[ucx] var localServerPortsBuffer: ByteBuffer = _
+ private[ucx] var localServerPorts: Seq[Int] = _
+
+ private[ucx] lazy val currentWorkerId = new AtomicInteger()
+ private[ucx] lazy val workerLocal = new ThreadLocal[ExternalUcxClientWorker]
+ private[ucx] var allocatedWorker: Array[ExternalUcxClientWorker] = _
+
+ private[ucx] val scheduledLatch = new CountDownLatch(1)
+
+ private var maxBlocksPerRequest = 0
+
+ override def estimateNumEps(): Int = numWorkers *
+ clientConf.sparkConf.getInt("spark.executor.instances", 1)
+
+ override def init(): ByteBuffer = {
+ initContext()
+ initMemoryPool()
+
+ if (clientConf.useWakeup) {
+ ucpWorkerParams.requestWakeupRX().requestWakeupTX().requestWakeupEdge()
+ }
+
+ logInfo(s"Allocating ${numWorkers} client workers")
+ val appId = clientConf.sparkConf.getAppId
+ allocatedWorker = new Array[ExternalUcxClientWorker](numWorkers)
+ for (i <- 0 until numWorkers) {
+ ucpWorkerParams.setClientId((executorId << 32) | i.toLong)
+ val worker = ucxContext.newWorker(ucpWorkerParams)
+ val workerId = new UcxWorkerId(appId, executorId.toInt, i)
+ allocatedWorker(i) = new ExternalUcxClientWorker(worker, this, workerId)
+ }
+
+ logInfo(s"Launching ${numWorkers} client workers")
+ allocatedWorker.foreach(_.start)
+
+ val maxAmHeaderSize = allocatedWorker(0).worker.getMaxAmHeaderSize.toInt
+ maxBlocksPerRequest = clientConf.maxBlocksPerRequest.min(
+ (maxAmHeaderSize - UnsafeUtils.INT_SIZE) / UnsafeUtils.INT_SIZE)
+
+ logInfo(s"Launching time-scheduled threads, period: $timeoutMs ms")
+ initTaskPool(1)
+ submit(() => progressTimeOut())
+
+ logInfo(s"Connecting server.")
+
+ val shuffleServer = new InetSocketAddress(blockManagerId.host, ucxServerPort)
+ allocatedWorker(0).requestAddress(shuffleServer)
+ localServerPortsDone.await()
+
+ logInfo(s"Connected server $shuffleServer")
+
+ initialized = true
+ localServerPortsBuffer
+ }
+
+ override def close(): Unit = {
+ if (initialized) {
+ running = false
+
+ scheduledLatch.countDown()
+
+ if (allocatedWorker != null) {
+ allocatedWorker.map(_.closing).foreach(_.get(5, TimeUnit.MILLISECONDS))
+ }
+
+ super.close()
+ }
+ }
+
+ def getMaxBlocksPerRequest: Int = maxBlocksPerRequest
+
+ // @inline
+ // def selectWorker(): ExternalUcxClientWorker = {
+ // allocatedWorker(
+ // (currentWorkerId.incrementAndGet() % allocatedWorker.length).abs)
+ // }
+
+ @inline
+ def selectWorker(): ExternalUcxClientWorker = {
+ Option(workerLocal.get) match {
+ case Some(worker) => worker
+ case None => {
+ val worker = allocatedWorker(
+ (currentWorkerId.incrementAndGet() % allocatedWorker.length).abs)
+ workerLocal.set(worker)
+ worker
+ }
+ }
+ }
+
+ @`inline`
+ def getServer(host: String): InetSocketAddress = {
+ ucxServers.computeIfAbsent(host, _ => {
+ logInfo(s"connecting $host with controller port")
+ new InetSocketAddress(host, ucxServerPort)
+ })
+ }
+
+ def connect(host: String, ports: Seq[Int]): Unit = {
+ val server = ucxServers.computeIfAbsent(host, _ => {
+ val id = executorId.toInt.abs % ports.length
+ new InetSocketAddress(host, ports(id))
+ })
+ allocatedWorker.foreach(_.connect(server))
+ logDebug(s"connect $host $server")
+ }
+
+ def connectAll(shuffleServerMap: Map[String, Seq[Int]]): Unit = {
+ val addressSet = shuffleServerMap.map(hostPorts => {
+ ucxServers.computeIfAbsent(hostPorts._1, _ => {
+ val id = executorId.toInt.abs % hostPorts._2.length
+ new InetSocketAddress(hostPorts._1, hostPorts._2(id))
+ })
+ }).toSeq
+ allocatedWorker.foreach(_.connectAll(addressSet))
+ logDebug(s"connectAll $addressSet")
+ }
+
+ def handleReplyAddress(msg: ByteBuffer): Unit = {
+ val num = msg.remaining() / UnsafeUtils.INT_SIZE
+ localServerPorts = (0 until num).map(_ => msg.getInt())
+ localServerPortsBuffer = msg
+
+ localServerPortsDone.countDown()
+
+ connect(blockManagerId.host, localServerPorts)
+ }
+
+ /**
+ * Batch version of [[ fetchBlocksByBlockIds ]].
+ */
+ def fetchBlocksByBlockIds(host: String, exeId: Int,
+ blockIds: Seq[BlockId],
+ callbacks: Seq[OperationCallback]): Unit = {
+ selectWorker.fetchBlocksByBlockIds(host, exeId, blockIds, callbacks)
+ }
+
+ def fetchBlockByStream(host: String, exeId: Int, blockId: BlockId,
+ callback: OperationCallback): Unit = {
+ selectWorker.fetchBlockByStream(host, exeId, blockId, callback)
+ }
+
+ def progressTimeOut(): Unit = {
+ while (!scheduledLatch.await(timeoutMs, TimeUnit.MILLISECONDS)) {
+ allocatedWorker.foreach(_.progressTimeOut)
+ }
+ }
+}
+
+private[shuffle] class UcxFetchCallBack(
+ blockId: String, listener: BlockFetchingListener)
+ extends OperationCallback {
+
+ override def onComplete(result: OperationResult): Unit = {
+ val memBlock = result.getData
+ val buffer = UnsafeUtils.getByteBufferView(memBlock.address,
+ memBlock.size.toInt)
+ listener.onBlockFetchSuccess(blockId, new NioManagedBuffer(buffer) {
+ override def release: ManagedBuffer = {
+ memBlock.close()
+ this
+ }
+ })
+ }
+
+ override def onError(result: OperationResult): Unit = {
+ listener.onBlockFetchFailure(blockId, result.getError)
+ }
+}
+
+private[shuffle] class UcxDownloadCallBack(
+ blockId: String, listener: BlockFetchingListener,
+ downloadFileManager: DownloadFileManager,
+ transportConf: TransportConf)
+ extends OperationCallback {
+
+ private[this] val targetFile = downloadFileManager.createTempFile(
+ transportConf)
+ private[this] val channel = targetFile.openForWriting();
+
+ override def onData(buffer: ByteBuffer): Unit = {
+ while (buffer.hasRemaining()) {
+ channel.write(buffer);
+ }
+ }
+
+ override def onComplete(result: OperationResult): Unit = {
+ listener.onBlockFetchSuccess(blockId, channel.closeAndRead());
+ if (!downloadFileManager.registerTempFileToClean(targetFile)) {
+ targetFile.delete();
+ }
+ }
+
+ override def onError(result: OperationResult): Unit = {
+ listener.onBlockFetchFailure(blockId, result.getError)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientWorker.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientWorker.scala
new file mode 100644
index 00000000..37f80fd7
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientWorker.scala
@@ -0,0 +1,543 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx
+
+import java.io.Closeable
+import java.nio.channels.ReadableByteChannel
+import java.util.concurrent.{ConcurrentLinkedQueue, ConcurrentHashMap, Future, FutureTask}
+import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.concurrent.TrieMap
+import scala.collection.JavaConverters._
+import scala.util.Random
+import org.openucx.jucx.ucp._
+import org.openucx.jucx.ucs.UcsConstants
+import org.openucx.jucx.ucs.UcsConstants.MEMORY_TYPE
+import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils}
+import org.apache.spark.shuffle.ucx.memory.UcxSharedMemoryBlock
+import org.apache.spark.shuffle.ucx.utils.SerializationUtils
+import org.apache.spark.shuffle.utils.{UnsafeUtils, UcxLogging}
+import org.apache.spark.unsafe.Platform
+
+import java.nio.ByteBuffer
+import java.net.InetSocketAddress
+
+/**
+ * Worker per thread wrapper, that maintains connection and progress logic.
+ */
+case class ExternalUcxClientWorker(val worker: UcpWorker,
+ transport: ExternalUcxClientTransport,
+ workerId: UcxWorkerId)
+ extends Closeable with UcxLogging {
+ private[this] val tag = new AtomicInteger(Random.nextInt())
+ private[this] val memPool = transport.hostBounceBufferMemoryPool(workerId.workerId)
+ private[this] val connectQueue = new ConcurrentLinkedQueue[InetSocketAddress]
+ private[this] val connectingServers = new ConcurrentHashMap[InetSocketAddress, (UcpEndpoint, UcpRequest)]
+ private[this] val shuffleServers = new ConcurrentHashMap[String, UcpEndpoint]
+ private[this] val executor = new UcxWorkerThread(
+ worker, transport.ucxShuffleConf.useWakeup)
+ private[this] lazy val requestData = new TrieMap[Int, UcxFetchState]
+ private[this] lazy val streamData = new TrieMap[Int, UcxStreamState]
+ private[this] lazy val sliceData = new TrieMap[Int, UcxSliceState]
+ private[this] var prevTag = 0
+
+ // Receive block data handler
+ worker.setAmRecvHandler(ExternalAmId.REPLY_SLICE,
+ (headerAddress: Long, headerSize: Long, ucpAmData: UcpAmData,
+ ep: UcpEndpoint) => {
+ val headerBuffer = UnsafeUtils.getByteBufferView(headerAddress,
+ headerSize.toInt)
+ val i = headerBuffer.getInt
+ val remaining = headerBuffer.getInt
+ val length = headerBuffer.getLong
+ val offset = headerBuffer.getLong
+
+ handleReplySlice(i, offset, length, ucpAmData)
+ UcsConstants.STATUS.UCS_OK
+ }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG)
+
+ worker.setAmRecvHandler(ExternalAmId.REPLY_STREAM,
+ (headerAddress: Long, headerSize: Long, ucpAmData: UcpAmData,
+ _: UcpEndpoint) => {
+ val headerBuffer = UnsafeUtils.getByteBufferView(headerAddress,
+ headerSize.toInt)
+ val i = headerBuffer.getInt
+ val remaining = headerBuffer.getInt
+ val length = headerBuffer.getLong
+ val offset = headerBuffer.getLong
+
+ handleReplyStream(i, offset, length, ucpAmData)
+ UcsConstants.STATUS.UCS_OK
+ }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG)
+
+ worker.setAmRecvHandler(ExternalAmId.REPLY_BLOCK,
+ (headerAddress: Long, headerSize: Long, ucpAmData: UcpAmData, _: UcpEndpoint) => {
+ val headerBuffer = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt)
+ val i = headerBuffer.getInt
+ // Header contains tag followed by sizes of blocks
+ val numBlocks = headerBuffer.remaining() / UnsafeUtils.INT_SIZE
+ val blockSizes = (0 until numBlocks).map(_ => headerBuffer.getInt())
+
+ handleReplyBlock(i, blockSizes, ucpAmData)
+ if (ucpAmData.isDataValid) {
+ UcsConstants.STATUS.UCS_INPROGRESS
+ } else {
+ UcsConstants.STATUS.UCS_OK
+ }
+ }, UcpConstants.UCP_AM_FLAG_PERSISTENT_DATA | UcpConstants.UCP_AM_FLAG_WHOLE_MSG)
+
+ worker.setAmRecvHandler(ExternalAmId.REPLY_ADDRESS,
+ (headerAddress: Long, headerSize: Long, _: UcpAmData, _: UcpEndpoint) => {
+ val headerBuffer = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt)
+ val copiedBuffer = ByteBuffer.allocateDirect(headerBuffer.remaining())
+
+ copiedBuffer.put(headerBuffer)
+ copiedBuffer.rewind()
+
+ transport.handleReplyAddress(copiedBuffer)
+ UcsConstants.STATUS.UCS_OK
+ }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG)
+
+ private def handleReplySlice(i: Int, offset: Long, length: Long,
+ ucpAmData: UcpAmData): Unit = {
+ assert(!ucpAmData.isDataValid)
+
+ val sliceState = sliceData.getOrElseUpdate(i, {
+ requestData.get(i) match {
+ case Some(data) => {
+ val mem = memPool.get(length)
+ val memRef = new UcxRefCountMemoryBlock(mem, 0, length,
+ new AtomicInteger(1))
+ new UcxSliceState(data.callbacks(0), data.request, data.timestamp,
+ memRef, length)
+ }
+ case None => throw new UcxException(s"Slice tag $i context not found.")
+ }
+ })
+
+ val stats = sliceState.request.getStats.get.asInstanceOf[UcxStats]
+ stats.receiveSize += ucpAmData.getLength
+ stats.amHandleTime = System.nanoTime()
+
+ val currentAddress = sliceState.mem.address + offset
+ worker.recvAmDataNonBlocking(
+ ucpAmData.getDataHandle, currentAddress, ucpAmData.getLength,
+ new UcxCallback() {
+ override def onSuccess(r: UcpRequest): Unit = {
+ stats.endTime = System.nanoTime()
+ logDebug(s"Slice receive rndv data size ${ucpAmData.getLength} " +
+ s"tag $i ($offset, $length) in ${stats.getElapsedTimeNs} ns " +
+ s"amHandle ${stats.endTime - stats.amHandleTime} ns")
+ receivedSlice(i, offset, length, ucpAmData.getLength, sliceState)
+ }
+ }, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)
+ }
+
+ private def handleReplyStream(i: Int, offset: Long, length: Long,
+ ucpAmData: UcpAmData): Unit = {
+ assert(!ucpAmData.isDataValid)
+
+ val mem = memPool.get(ucpAmData.getLength)
+ val memRef = new UcxRefCountMemoryBlock(mem, 0, ucpAmData.getLength,
+ new AtomicInteger(1))
+
+ val data = streamData.get(i)
+ if (data.isEmpty) {
+ throw new UcxException(s"Stream tag $i context not found.")
+ }
+
+ val streamState = data.get
+ if (streamState.remaining == Long.MaxValue) {
+ streamState.remaining = length
+ }
+
+ val stats = streamState.request.getStats.get.asInstanceOf[UcxStats]
+ stats.receiveSize += ucpAmData.getLength
+ stats.amHandleTime = System.nanoTime()
+ worker.recvAmDataNonBlocking(
+ ucpAmData.getDataHandle, memRef.address, ucpAmData.getLength,
+ new UcxCallback() {
+ override def onSuccess(r: UcpRequest): Unit = {
+ stats.endTime = System.nanoTime()
+ logDebug(s"Stream receive rndv data ${ucpAmData.getLength} " +
+ s"tag $i ($offset, $length) in ${stats.getElapsedTimeNs} ns " +
+ s"amHandle ${stats.endTime - stats.amHandleTime} ns")
+ receivedStream(i, offset, length, memRef, streamState)
+ }
+ }, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)
+ }
+
+ private def handleReplyBlock(
+ i: Int, blockSizes: Seq[Int], ucpAmData: UcpAmData): Unit = {
+ val data = requestData.get(i)
+
+ if (data.isEmpty) {
+ throw new UcxException(s"No data for tag $i.")
+ }
+
+ val fetchState = data.get
+ val callbacks = fetchState.callbacks
+ val request = fetchState.request
+ val stats = request.getStats.get.asInstanceOf[UcxStats]
+ stats.receiveSize = ucpAmData.getLength
+
+ val numBlocks = blockSizes.length
+
+ var offset = 0
+ val refCounts = new AtomicInteger(numBlocks)
+ if (ucpAmData.isDataValid) {
+ request.completed = true
+ stats.endTime = System.nanoTime()
+ logDebug(s"Received amData: $ucpAmData for tag $i " +
+ s"in ${stats.getElapsedTimeNs} ns")
+
+ val closeCb = () => executor.post(() => ucpAmData.close())
+ val address = ucpAmData.getDataAddress
+ for (b <- 0 until numBlocks) {
+ val blockSize = blockSizes(b)
+ if (callbacks(b) != null) {
+ val mem = new UcxSharedMemoryBlock(closeCb, refCounts, address + offset,
+ blockSize)
+ receivedChunk(i, b, mem, fetchState)
+ offset += blockSize
+ }
+ }
+ } else {
+ val mem = memPool.get(ucpAmData.getLength)
+ stats.amHandleTime = System.nanoTime()
+ request.setRequest(worker.recvAmDataNonBlocking(ucpAmData.getDataHandle, mem.address, ucpAmData.getLength,
+ new UcxCallback() {
+ override def onSuccess(r: UcpRequest): Unit = {
+ request.completed = true
+ stats.endTime = System.nanoTime()
+ logDebug(s"Received rndv data of size: ${ucpAmData.getLength} " +
+ s"for tag $i in ${stats.getElapsedTimeNs} ns " +
+ s"time from amHandle: ${System.nanoTime() - stats.amHandleTime} ns")
+ for (b <- 0 until numBlocks) {
+ val blockSize = blockSizes(b)
+ val memRef = new UcxRefCountMemoryBlock(mem, offset, blockSize, refCounts)
+ receivedChunk(i, b, memRef, fetchState)
+ offset += blockSize
+ }
+
+ }
+ }, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST))
+ }
+ }
+
+ private def receivedSlice(tag: Int, offset: Long, length: Long, msgLen: Long,
+ sliceState: UcxSliceState): Unit = {
+ if (!sliceState.recvSet.add(offset)) {
+ logWarning(s"Received duplicate slice: tag $tag offset $offset")
+ return
+ }
+
+ logTrace(s"tag $tag $sliceState")
+ sliceState.remaining -= msgLen
+ if (sliceState.remaining != 0) {
+ return
+ }
+
+ val stats = sliceState.request.getStats.get
+ val result = new UcxSucceedOperationResult(sliceState.mem, stats)
+ sliceState.callback.onComplete(result)
+ sliceData.remove(tag)
+
+ val data = requestData.get(tag)
+ if (data.isEmpty) {
+ logWarning(s"No fetch found: tag $tag")
+ return
+ }
+ // block must be the last chunk
+ val fetchState = data.get
+ val chunkId = fetchState.callbacks.size - 1
+ receivedChunk(tag, chunkId, null, fetchState)
+ }
+
+ private def receivedStream(tag: Int, offset: Long, length: Long, mem: MemoryBlock,
+ streamState: UcxStreamState): Unit = {
+ if (streamState.recvMap.put(offset, mem) != null) {
+ logWarning(s"Received duplicate stream: tag $tag offset $offset")
+ return
+ }
+
+ logTrace(s"tag $tag $streamState")
+ var received = length - streamState.remaining
+ var memNow = streamState.recvMap.get(received)
+ while (memNow != null) {
+ val buf = UnsafeUtils.getByteBufferView(memNow.address, memNow.size.toInt)
+ streamState.callback.onData(buf)
+ received += memNow.size
+ memNow.close()
+ memNow = streamState.recvMap.get(received)
+ }
+
+ streamState.remaining = length - received
+ if (streamState.remaining != 0) {
+ return
+ }
+
+ val stats = streamState.request.getStats.get
+ val result = new UcxSucceedOperationResult(null, stats)
+ streamState.callback.onComplete(result)
+ streamData.remove(tag)
+ }
+
+ private def receivedChunk(tag: Int, chunkId: Int, mem: MemoryBlock,
+ fetchState: UcxFetchState): Unit = {
+ if (!fetchState.recvSet.add(chunkId)) {
+ logWarning(s"Received duplicate chunk: tag $tag chunk $chunkId")
+ return
+ }
+
+ if (mem != null) {
+ val stats = fetchState.request.getStats.get
+ val result = new UcxSucceedOperationResult(mem, stats)
+ fetchState.callbacks(chunkId).onComplete(result)
+ }
+
+ logTrace(s"tag $tag $fetchState")
+ if (fetchState.recvSet.size != fetchState.callbacks.size) {
+ return
+ }
+
+ requestData.remove(tag)
+ }
+
+ def start(): Unit = {
+ executor.start()
+ }
+
+ override def close(): Unit = {
+ val closeConnecting = connectingServers.values.asScala.filterNot {
+ case (_, req) => req.isCompleted
+ }.map {
+ case (endpoint, _) => endpoint.closeNonBlockingForce()
+ }
+ val closeRequests = shuffleServers.asScala.map {
+ case (_, endpoint) => endpoint.closeNonBlockingForce()
+ }
+ while (!closeConnecting.forall(_.isCompleted)) {
+ progress()
+ }
+ while (!closeRequests.forall(_.isCompleted)) {
+ progress()
+ }
+ }
+
+ def closing(): Future[Unit.type] = {
+ val cleanTask = new FutureTask(new Runnable {
+ override def run() = close()
+ }, Unit)
+ executor.close(cleanTask)
+ cleanTask
+ }
+
+ /**
+ * The only place for worker progress
+ */
+ def progress(): Int = worker.synchronized {
+ worker.progress()
+ }
+
+ @`inline`
+ def requestAddress(localServer: InetSocketAddress): Unit = {
+ executor.post(() => shuffleServers.computeIfAbsent("0.0.0.0", _ => {
+ doConnect(localServer, ExternalAmId.ADDRESS)._1
+ }))
+ }
+
+ @`inline`
+ def connect(shuffleServer: InetSocketAddress): Unit = {
+ connectQueue.add(shuffleServer)
+ }
+
+ @`inline`
+ def connectAll(addressSet: Seq[InetSocketAddress]): Unit = {
+ addressSet.foreach(connectQueue.add(_))
+ }
+
+ @`inline`
+ private def connectNext(): Unit = {
+ if (!connectQueue.isEmpty) {
+ executor.post(() => {
+ val shuffleServer = connectQueue.poll()
+ try {
+ startConnection(shuffleServer)
+ } catch {
+ // We cannot throw error here as current server might be down/rebooting.
+ case e: Throwable =>
+ logWarning(s"Endpoint to $shuffleServer got an error: $e")
+ }
+ })
+ }
+ }
+
+ private def doConnect(shuffleServer: InetSocketAddress,
+ amId: Int): (UcpEndpoint, UcpRequest) = {
+ val endpointParams = new UcpEndpointParams().setPeerErrorHandlingMode()
+ .setSocketAddress(shuffleServer).sendClientId()
+ .setErrorHandler(new UcpEndpointErrorHandler() {
+ override def onError(ep: UcpEndpoint, status: Int, errorMsg: String): Unit = {
+ logError(s"Endpoint to $shuffleServer got an error: $errorMsg")
+ shuffleServers.remove(shuffleServer.getHostName())
+ }
+ }).setName(s"Client to $shuffleServer")
+
+ logDebug(s"$workerId connecting to external service $shuffleServer")
+
+ val header = Platform.allocateDirectBuffer(workerId.serializedSize)
+ workerId.serialize(header)
+ header.rewind()
+ val workerAddress = worker.getAddress
+
+ val ep = worker.newEndpoint(endpointParams)
+ val req = ep.sendAmNonBlocking(
+ amId, UcxUtils.getAddress(header), header.remaining(),
+ UcxUtils.getAddress(workerAddress), workerAddress.remaining(),
+ UcpConstants.UCP_AM_SEND_FLAG_EAGER | UcpConstants.UCP_AM_SEND_FLAG_REPLY,
+ new UcxCallback() {
+ override def onSuccess(request: UcpRequest): Unit = {
+ connectNext()
+ }
+ override def onError(ucsStatus: Int, errorMsg: String): Unit = {
+ logError(s"$workerId Sent connect to $shuffleServer failed: $errorMsg");
+ connectNext()
+ header.clear()
+ workerAddress.clear()
+ }
+ }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)
+ ep -> req
+ }
+
+ private def startConnection(shuffleServer: InetSocketAddress): (UcpEndpoint, UcpRequest) = {
+ connectingServers.computeIfAbsent(shuffleServer, _ =>
+ doConnect(shuffleServer, ExternalAmId.CONNECT))
+ }
+
+ private def getConnection(host: String): UcpEndpoint = {
+ val shuffleServer = transport.getServer(host)
+ shuffleServers.computeIfAbsent(shuffleServer.getAddress().getHostAddress(), _ => {
+ val (ep, req) = startConnection(shuffleServer)
+ if (!req.isCompleted) {
+ val deadline = System.currentTimeMillis() + transport.timeoutMs
+ do {
+ worker.progress()
+ if (System.currentTimeMillis() > deadline) {
+ throw new UcxException(s"connect $shuffleServer timeout")
+ }
+ } while (!req.isCompleted)
+ }
+ ep
+ })
+ }
+
+ def fetchBlocksByBlockIds(host: String, execId: Int, blockIds: Seq[BlockId],
+ callbacks: Seq[OperationCallback]): Unit = {
+ val startTime = System.nanoTime()
+ val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE + workerId.serializedSize
+
+ val t = tag.incrementAndGet()
+
+ val buffer = Platform.allocateDirectBuffer(headerSize + blockIds.map(_.serializedSize).sum)
+ workerId.serialize(buffer)
+ buffer.putInt(t)
+ buffer.putInt(execId)
+ blockIds.foreach(b => b.serialize(buffer))
+
+ val request = new UcxRequest(null, new UcxStats())
+ requestData.put(t, new UcxFetchState(callbacks, request, startTime))
+
+ buffer.rewind()
+ val address = UnsafeUtils.getAdress(buffer)
+ val dataAddress = address + headerSize
+
+ executor.post(() => {
+ val ep = getConnection(host)
+ ep.sendAmNonBlocking(ExternalAmId.FETCH_BLOCK, address,
+ headerSize, dataAddress, buffer.capacity() - headerSize,
+ UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() {
+ override def onSuccess(request: UcpRequest): Unit = {
+ buffer.clear()
+ logDebug(s"Sent fetch to $host tag $t blocks ${blockIds.length} " +
+ s"in ${System.nanoTime() - startTime} ns")
+ }
+ override def onError(ucsStatus: Int, errorMsg: String): Unit = {
+ val err = s"Sent fetch to $host tag $t failed: $errorMsg";
+ logError(err)
+ throw new UcxException(err)
+ }
+ }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)
+ })
+ }
+
+ def fetchBlockByStream(host: String, execId: Int, blockId: BlockId,
+ callback: OperationCallback): Unit = {
+ val startTime = System.nanoTime()
+ val headerSize = workerId.serializedSize + UnsafeUtils.INT_SIZE +
+ UnsafeUtils.INT_SIZE + blockId.serializedSize
+
+ val t = tag.incrementAndGet()
+
+ val buffer = Platform.allocateDirectBuffer(headerSize)
+ workerId.serialize(buffer)
+ buffer.putInt(t)
+ buffer.putInt(execId)
+ blockId.serialize(buffer)
+
+ val request = new UcxRequest(null, new UcxStats())
+ streamData.put(t, new UcxStreamState(callback, request, startTime,
+ Long.MaxValue))
+
+ val address = UnsafeUtils.getAdress(buffer)
+
+ executor.post(() => {
+ val ep = getConnection(host)
+ ep.sendAmNonBlocking(ExternalAmId.FETCH_STREAM, address, headerSize,
+ address, 0, UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() {
+ override def onSuccess(request: UcpRequest): Unit = {
+ buffer.clear()
+ logDebug(s"Sent stream to $host tag $t block $blockId " +
+ s"in ${System.nanoTime() - startTime} ns")
+ }
+ override def onError(ucsStatus: Int, errorMsg: String): Unit = {
+ val err = s"Sent stream to $host tag $t failed: $errorMsg";
+ logError(err)
+ throw new UcxException(err)
+ }
+ }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)
+ })
+ }
+
+ def progressTimeOut(): Unit = {
+ val currTag = tag.get()
+ if (prevTag != currTag) {
+ prevTag = currTag
+ return
+ }
+
+ val validTime = System.nanoTime - transport.timeoutMs * 1000000L
+ if (requestData.nonEmpty) {
+ requestData.filterNot {
+ case (_, request) => request.timestamp >= validTime
+ }.keys.foreach(requestData.remove(_).foreach(request => {
+ request.callbacks.foreach(_.onError(new UcxFailureOperationResult("timeout")))
+ }))
+ }
+ if (streamData.nonEmpty) {
+ streamData.filterNot {
+ case (_, request) => request.timestamp >= validTime
+ }.keys.foreach(streamData.remove(_).foreach(request => {
+ request.callback.onError(new UcxFailureOperationResult("timeout"))
+ }))
+ }
+ if (sliceData.nonEmpty) {
+ sliceData.filterNot {
+ case (_, request) => request.timestamp >= validTime
+ }.keys.foreach(sliceData.remove(_).foreach(request => {
+ request.callback.onError(new UcxFailureOperationResult("timeout"))
+ }))
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerConf.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerConf.scala
new file mode 100644
index 00000000..398236d5
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerConf.scala
@@ -0,0 +1,87 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.spark.internal.config.ConfigBuilder
+import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.SparkConf
+
+/**
+ * Plugin configuration properties.
+ */
+class ExternalUcxServerConf(val yarnConf: Configuration) extends ExternalUcxConf {
+ override lazy val preallocateBuffersMap: Map[Long, Int] =
+ ExternalUcxConf.preAllocateConfToMap(
+ yarnConf.get(ExternalUcxConf.PREALLOCATE_BUFFERS_KEY,
+ ExternalUcxConf.PREALLOCATE_BUFFERS_DEFAULT))
+
+ override lazy val memoryLimit: Boolean = yarnConf.getBoolean(
+ ExternalUcxConf.MEMORY_LIMIT_KEY,
+ ExternalUcxConf.MEMORY_LIMIT_DEFAULT)
+
+ override lazy val memoryGroupSize: Int = yarnConf.getInt(
+ ExternalUcxConf.MEMORY_GROUP_SIZE_KEY,
+ ExternalUcxConf.MEMORY_GROUP_SIZE_DEFAULT)
+
+ override lazy val minBufferSize: Long = yarnConf.getLong(
+ ExternalUcxConf.MIN_BUFFER_SIZE_KEY,
+ ExternalUcxConf.MIN_BUFFER_SIZE_DEFAULT)
+
+ override lazy val maxBufferSize: Long = yarnConf.getLong(
+ ExternalUcxConf.MAX_BUFFER_SIZE_KEY,
+ ExternalUcxConf.MAX_BUFFER_SIZE_DEFAULT)
+
+ override lazy val minRegistrationSize: Long = yarnConf.getLong(
+ ExternalUcxConf.MIN_REGISTRATION_SIZE_KEY,
+ ExternalUcxConf.MIN_REGISTRATION_SIZE_DEFAULT)
+
+ override lazy val maxRegistrationSize: Long = yarnConf.getLong(
+ ExternalUcxConf.MAX_REGISTRATION_SIZE_KEY,
+ ExternalUcxConf.MAX_REGISTRATION_SIZE_DEFAULT)
+
+ override lazy val numPools: Int = yarnConf.getInt(
+ ExternalUcxConf.NUM_POOLS_KEY,
+ ExternalUcxConf.NUM_POOLS_DEFAULT)
+
+ override lazy val listenerAddress: String = yarnConf.get(
+ ExternalUcxConf.SOCKADDR_KEY,
+ ExternalUcxConf.SOCKADDR_DEFAULT)
+
+ override lazy val useWakeup: Boolean = yarnConf.getBoolean(
+ ExternalUcxConf.WAKEUP_FEATURE_KEY,
+ ExternalUcxConf.WAKEUP_FEATURE_DEFAULT)
+
+ override lazy val numWorkers: Int = yarnConf.getInt(
+ ExternalUcxConf.NUM_WORKERS_KEY,
+ yarnConf.getInt(
+ ExternalUcxConf.NUM_WORKERS_COMPAT_KEY,
+ ExternalUcxConf.NUM_WORKERS_DEFAULT))
+
+ override lazy val numThreads: Int = yarnConf.getInt(
+ ExternalUcxConf.NUM_THREADS_KEY,
+ yarnConf.getInt(
+ ExternalUcxConf.NUM_THREADS_COMPAT_KEY,
+ ExternalUcxConf.NUM_THREADS_DEFAULT))
+
+ override lazy val ucxServerPort: Int = yarnConf.getInt(
+ ExternalUcxConf.SPARK_UCX_SHUFFLE_SERVICE_PORT_KEY,
+ ExternalUcxConf.SPARK_UCX_SHUFFLE_SERVICE_PORT_DEFAULT)
+
+ override lazy val maxReplySize: Long = yarnConf.getLong(
+ ExternalUcxConf.MAX_REPLY_SIZE_KEY,
+ ExternalUcxConf.MAX_REPLY_SIZE_DEFAULT)
+
+ lazy val ucxEpsNum: Int = yarnConf.getInt(
+ ExternalUcxServerConf.SPARK_UCX_SHUFFLE_EPS_NUM_KEY,
+ ExternalUcxServerConf.SPARK_UCX_SHUFFLE_EPS_NUM_DEFAULT)
+}
+
+object ExternalUcxServerConf {
+ lazy val SPARK_UCX_SHUFFLE_SERVICE_TCP_PORT_KEY = ExternalUcxConf.getUcxConf("service.tcp.port")
+
+ lazy val SPARK_UCX_SHUFFLE_EPS_NUM_KEY = ExternalUcxConf.getUcxConf("eps.num")
+ lazy val SPARK_UCX_SHUFFLE_EPS_NUM_DEFAULT = 16777216
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerTransport.scala
new file mode 100644
index 00000000..ae12cc6f
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerTransport.scala
@@ -0,0 +1,206 @@
+package org.apache.spark.shuffle.ucx
+
+// import org.apache.spark.SparkEnv
+import org.apache.spark.shuffle.utils.{UcxLogging, UnsafeUtils}
+import org.apache.spark.shuffle.ucx.utils.SerializationUtils
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer
+import org.apache.spark.network.shuffle.ExternalUcxShuffleBlockResolver
+import org.openucx.jucx.ucp._
+import org.openucx.jucx.ucs.UcsConstants
+import org.openucx.jucx.ucs.UcsConstants.MEMORY_TYPE
+import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils}
+
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit, CountDownLatch}
+import java.nio.channels.FileChannel
+import java.nio.file.StandardOpenOption
+import scala.collection.concurrent.TrieMap
+
+class ExternalUcxServerTransport(
+ serverConf: ExternalUcxServerConf, blockManager: ExternalUcxShuffleBlockResolver)
+ extends ExternalUcxTransport(serverConf)
+ with UcxLogging {
+ private[ucx] val workerMap = new TrieMap[String, TrieMap[UcxWorkerId, Unit]]
+ private[ucx] val fileMap = new TrieMap[String, ConcurrentHashMap[UcxShuffleMapId, FileChannel]]
+
+ private[ucx] var allocatedWorker: Array[ExternalUcxServerWorker] = _
+ private[ucx] var globalWorker: ExternalUcxServerWorker = _
+ private[ucx] var serverPorts: Seq[Int] = _
+
+ private[ucx] val scheduledLatch = new CountDownLatch(1)
+
+ private[ucx] var maxReplySize: Long = 0
+
+ private var serverPortsBuffer: ByteBuffer = _
+
+ override def estimateNumEps(): Int = serverConf.ucxEpsNum
+
+ override def init(): ByteBuffer = {
+ initContext()
+ initMemoryPool()
+
+ if (serverConf.useWakeup) {
+ ucpWorkerParams.requestWakeupRX().requestWakeupTX().requestWakeupEdge()
+ }
+
+ // additional 1 for mem pool report
+ initTaskPool(serverConf.numThreads + 1)
+ submit(() => scheduledReport())
+
+ logInfo(s"Allocating global worker")
+ val worker = ucxContext.newWorker(ucpWorkerParams)
+ globalWorker = new ExternalUcxServerWorker(
+ worker, this, new UcxWorkerId("Listener", 0, 0), serverConf.ucxServerPort)
+
+ val maxAmHeaderSize = worker.getMaxAmHeaderSize
+ maxReplySize = serverConf.maxReplySize.min(serverConf.maxBufferSize -
+ maxAmHeaderSize)
+
+ logInfo(s"Allocating ${serverConf.numWorkers} server workers")
+
+ allocatedWorker = new Array[ExternalUcxServerWorker](serverConf.numWorkers)
+ for (i <- 0 until serverConf.numWorkers) {
+ val worker = ucxContext.newWorker(ucpWorkerParams)
+ val workerId = new UcxWorkerId("Server", 0, i)
+ allocatedWorker(i) = new ExternalUcxServerWorker(worker, this, workerId, 0)
+ }
+ serverPorts = allocatedWorker.map(_.getPort)
+
+ serverPortsBuffer = ByteBuffer.allocateDirect(
+ serverPorts.length * UnsafeUtils.INT_SIZE)
+ serverPorts.foreach(serverPortsBuffer.putInt(_))
+ serverPortsBuffer.rewind()
+
+ logInfo(s"Launching ${serverConf.numWorkers} server workers")
+ allocatedWorker.foreach(_.start)
+
+ logInfo(s"Launching global worker")
+
+ globalWorker.start
+
+ initialized = true
+ logInfo(s"Started listener on ${globalWorker.getAddress} ${serverPorts} maxReplySize $maxReplySize")
+ SerializationUtils.serializeInetAddress(globalWorker.getAddress)
+ }
+
+ /**
+ * Close all transport resources
+ */
+ override def close(): Unit = {
+ if (initialized) {
+ running = false
+
+ if (globalWorker != null) {
+ globalWorker.closing().get(1, TimeUnit.MILLISECONDS)
+ }
+
+ if (allocatedWorker != null) {
+ allocatedWorker.map(_.closing).foreach(_.get(5, TimeUnit.MILLISECONDS))
+ }
+
+ scheduledLatch.countDown()
+
+ super.close()
+
+ logInfo("UCX transport closed.")
+ }
+ }
+
+ def applicationRemoved(appId: String): Unit = {
+ workerMap.remove(appId).foreach(clients => {
+ val clientIds = clients.keys.toSeq
+ allocatedWorker.foreach(_.disconnect(clientIds))
+ globalWorker.disconnect(clientIds)
+ })
+ fileMap.remove(appId).foreach(files => files.values.forEach(_.close))
+ // allocatedWorker.foreach(_.debugClients())
+ }
+
+ def executorRemoved(executorId: String, appId: String): Unit = {
+ val exeId = executorId.toInt
+ workerMap.get(appId).map(clients => {
+ val clientIds = clients.filterKeys(_.exeId == exeId).keys.toSeq
+ allocatedWorker.foreach(_.disconnect(clientIds))
+ globalWorker.disconnect(clientIds)
+ })
+ }
+
+ def getMaxReplySize(): Long = maxReplySize
+
+ def getServerPortsBuffer(): ByteBuffer = {
+ serverPortsBuffer.duplicate()
+ }
+
+ def handleConnect(handler: ExternalUcxServerWorker,
+ clientWorker: UcxWorkerId): Unit = {
+ workerMap.getOrElseUpdate(clientWorker.appId, {
+ new TrieMap[UcxWorkerId, Unit]
+ }).getOrElseUpdate(clientWorker, Unit)
+ }
+
+ def handleFetchBlockRequest(handler: ExternalUcxServerWorker,
+ clientWorker: UcxWorkerId, exeId: Int,
+ replyTag: Int, blockIds: Seq[UcxShuffleBlockId]):
+ Unit = {
+ submit(new Runnable {
+ override def run(): Unit = {
+ var block: FileSegmentManagedBuffer = null
+ var blockInfos: Seq[(FileChannel, Long, Long)] = null
+ try {
+ blockInfos = blockIds.map(bid => {
+ block = blockManager.getBlockData(clientWorker.appId, exeId.toString,
+ bid.shuffleId, bid.mapId,
+ bid.reduceId).asInstanceOf[
+ FileSegmentManagedBuffer]
+ (openBlock(clientWorker.appId, bid, block), block.getOffset, block.size)
+ })
+ handler.handleFetchBlockRequest(clientWorker, replyTag, blockInfos)
+ } catch {
+ case ex: Throwable =>
+ logError(s"Failed to reply fetch $clientWorker tag $replyTag files $blockInfos block $block $ex.")
+ }
+ }
+ })
+ }
+
+ def handleFetchBlockStream(handler: ExternalUcxServerWorker,
+ clientWorker: UcxWorkerId, exeId: Int,
+ replyTag: Int, bid: UcxShuffleBlockId): Unit = {
+ submit(new Runnable {
+ override def run(): Unit = {
+ var block: FileSegmentManagedBuffer = null
+ var blockInfo: (FileChannel, Long, Long) = null
+ try {
+ block = blockManager.getBlockData(clientWorker.appId, exeId.toString,
+ bid.shuffleId, bid.mapId,
+ bid.reduceId).asInstanceOf[
+ FileSegmentManagedBuffer]
+ blockInfo =
+ (openBlock(clientWorker.appId, bid, block), block.getOffset, block.size)
+ handler.handleFetchBlockStream(clientWorker, replyTag, blockInfo)
+ } catch {
+ case ex: Throwable =>
+ logError(s"Failed to reply stream $clientWorker tag $replyTag file $blockInfo block $block $ex.")
+ }
+ }
+ })
+ }
+
+ def openBlock(appId: String, bid: UcxShuffleBlockId,
+ blockData: FileSegmentManagedBuffer): FileChannel = {
+ fileMap.getOrElseUpdate(appId, {
+ new ConcurrentHashMap[UcxShuffleMapId, FileChannel]
+ }).computeIfAbsent(
+ UcxShuffleMapId(bid.shuffleId, bid.mapId),
+ _ => FileChannel.open(blockData.getFile().toPath(), StandardOpenOption.READ)
+ )
+ }
+
+ def scheduledReport(): Unit = {
+ while (!scheduledLatch.await(30, TimeUnit.SECONDS)) {
+ memPools.foreach(_.report)
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerWorker.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerWorker.scala
new file mode 100644
index 00000000..0b1b96ab
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerWorker.scala
@@ -0,0 +1,377 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx
+
+import java.io.Closeable
+import java.nio.ByteBuffer
+import java.nio.channels.FileChannel
+import java.util.concurrent.{ConcurrentLinkedQueue, ConcurrentHashMap, CountDownLatch, Future, FutureTask}
+import scala.collection.mutable
+import scala.collection.JavaConverters._
+import org.openucx.jucx.ucp._
+import org.openucx.jucx.ucs.UcsConstants
+import org.openucx.jucx.ucs.UcsConstants.MEMORY_TYPE
+import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils}
+import org.apache.spark.shuffle.ucx.memory.UcxLinkedMemBlock
+import org.apache.spark.shuffle.utils.{UnsafeUtils, UcxLogging}
+import java.net.InetSocketAddress
+
+class ExternalUcxEndpoint(val ucpEp: UcpEndpoint, var closed: Boolean) {}
+/**
+ * Worker per thread wrapper, that maintains connection and progress logic.
+ */
+case class ExternalUcxServerWorker(val worker: UcpWorker,
+ transport: ExternalUcxServerTransport,
+ workerId: UcxWorkerId,
+ port: Int)
+ extends Closeable with UcxLogging {
+ private[this] val memPool = transport.hostBounceBufferMemoryPool(workerId.workerId)
+ private[this] val maxReplySize = transport.getMaxReplySize()
+ private[this] val shuffleClients = new ConcurrentHashMap[UcxWorkerId, ExternalUcxEndpoint]
+ private[ucx] val executor = new UcxWorkerThread(
+ worker, transport.ucxShuffleConf.useWakeup)
+
+ private val emptyCallback = () => {}
+ private val endpoints = mutable.HashMap.empty[UcpEndpoint, () => Unit]
+ private val listener = worker.newListener(
+ new UcpListenerParams().setSockAddr(new InetSocketAddress("0.0.0.0", port))
+ .setConnectionHandler((ucpConnectionRequest: UcpConnectionRequest) => {
+ val clientAddress = ucpConnectionRequest.getClientAddress()
+ try {
+ val ep = worker.newEndpoint(
+ new UcpEndpointParams().setConnectionRequest(ucpConnectionRequest)
+ .setPeerErrorHandlingMode().setErrorHandler(errorHandler)
+ .setName(s"Endpoint to $clientAddress"))
+ endpoints.getOrElseUpdate(ep, emptyCallback)
+ } catch {
+ case e: UcxException => logError(s"Accept $clientAddress fail: $e")
+ }
+ }))
+
+ private val listenerAddress = listener.getAddress
+ private val errorHandler = new UcpEndpointErrorHandler {
+ override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = {
+ if (errorCode == UcsConstants.STATUS.UCS_ERR_CONNECTION_RESET) {
+ logInfo(s"Connection closed on ep: $ucpEndpoint")
+ } else {
+ logWarning(s"Ep $ucpEndpoint got an error: $errorString")
+ }
+ endpoints.remove(ucpEndpoint).foreach(_())
+ ucpEndpoint.close()
+ }
+ }
+
+ // Main RPC thread. Submit each RPC request to separate thread and send reply back from separate worker.
+ worker.setAmRecvHandler(ExternalAmId.FETCH_BLOCK,
+ (headerAddress: Long, headerSize: Long, amData: UcpAmData, _: UcpEndpoint) => {
+ val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt)
+ val shuffleClient = UcxWorkerId.deserialize(header)
+ val replyTag = header.getInt
+ val exeId = header.getInt
+ val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress,
+ amData.getLength.toInt)
+ val BlockNum = buffer.remaining() / UcxShuffleBlockId.serializedSize
+ val blockIds = (0 until BlockNum).map(
+ _ => UcxShuffleBlockId.deserialize(buffer))
+ logTrace(s"${workerId.workerId} Recv fetch from $shuffleClient tag $replyTag.")
+ transport.handleFetchBlockRequest(this, shuffleClient, exeId, replyTag, blockIds)
+ UcsConstants.STATUS.UCS_OK
+ }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG )
+
+ // Main RPC thread. Submit each RPC request to separate thread and send stream back from separate worker.
+ worker.setAmRecvHandler(ExternalAmId.FETCH_STREAM,
+ (headerAddress: Long, headerSize: Long, amData: UcpAmData, _: UcpEndpoint) => {
+ val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt)
+ val shuffleClient = UcxWorkerId.deserialize(header)
+ val replyTag = header.getInt
+ val exeId = header.getInt
+ val blockId = UcxShuffleBlockId.deserialize(header)
+ logTrace(s"${workerId.workerId} Recv stream from $shuffleClient tag $replyTag.")
+ transport.handleFetchBlockStream(this, shuffleClient, exeId, replyTag, blockId)
+ UcsConstants.STATUS.UCS_OK
+ }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG )
+
+ // AM to get worker address for client worker and connect server workers to it
+ worker.setAmRecvHandler(ExternalAmId.CONNECT,
+ (headerAddress: Long, headerSize: Long, amData: UcpAmData, ep: UcpEndpoint) => {
+ val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt)
+ val shuffleClient = UcxWorkerId.deserialize(header)
+ val workerAddress = UnsafeUtils.getByteBufferView(amData.getDataAddress,
+ amData.getLength.toInt)
+ val copiedAddress = ByteBuffer.allocateDirect(workerAddress.remaining)
+ copiedAddress.put(workerAddress)
+ connected(shuffleClient, copiedAddress)
+ endpoints.put(ep, () => doDisconnect(shuffleClient))
+ UcsConstants.STATUS.UCS_OK
+ }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG )
+ // Main RPC thread. reply with ucpAddress.
+
+ // AM to get worker address for client worker and connect server workers to it
+ worker.setAmRecvHandler(ExternalAmId.ADDRESS,
+ (headerAddress: Long, headerSize: Long, amData: UcpAmData, ep: UcpEndpoint) => {
+ handleAddress(ep)
+ UcsConstants.STATUS.UCS_OK
+ }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG )
+ // Main RPC thread. reply with ucpAddress.
+
+ def start(): Unit = {
+ executor.start()
+ }
+
+ override def close(): Unit = {
+ if (!shuffleClients.isEmpty) {
+ logInfo(s"$workerId closing ${shuffleClients.size} clients")
+ shuffleClients.values.asScala.map(
+ _.ucpEp.closeNonBlockingFlush()).foreach(req =>
+ while (!req.isCompleted){
+ worker.progress()
+ }
+ )
+ }
+ if (!endpoints.isEmpty) {
+ logInfo(s"$workerId closing ${endpoints.size} eps")
+ endpoints.keys.map(
+ _.closeNonBlockingForce()).foreach(req =>
+ while (!req.isCompleted){
+ worker.progress()
+ }
+ )
+ }
+ listener.close()
+ }
+
+ def closing(): Future[Unit.type] = {
+ val cleanTask = new FutureTask(new Runnable {
+ override def run() = close()
+ }, Unit)
+ executor.close(cleanTask)
+ cleanTask
+ }
+
+ def getPort(): Int = {
+ listenerAddress.getPort()
+ }
+
+ def getAddress(): InetSocketAddress = {
+ listenerAddress
+ }
+
+ @`inline`
+ private def doDisconnect(shuffleClient: UcxWorkerId): Unit = {
+ try {
+ Option(shuffleClients.remove(shuffleClient)).foreach(ep => {
+ ep.ucpEp.closeNonBlockingFlush()
+ ep.closed = true
+ logDebug(s"Disconnect $shuffleClient")
+ })
+ } catch {
+ case e: Throwable => logWarning(s"Disconnect $shuffleClient: $e")
+ }
+ }
+
+ @`inline`
+ def isEpClosed(ep: UcpEndpoint): Boolean = {
+ ep.getNativeId() == null
+ }
+
+ @`inline`
+ def disconnect(workerIds: Seq[UcxWorkerId]): Unit = {
+ executor.post(() => workerIds.foreach(doDisconnect(_)))
+ }
+
+ @`inline`
+ def connected(shuffleClient: UcxWorkerId, workerAddress: ByteBuffer): Unit = {
+ logDebug(s"$workerId connecting back to $shuffleClient by worker address")
+ try {
+ shuffleClients.computeIfAbsent(shuffleClient, _ => {
+ val ucpEp = worker.newEndpoint(new UcpEndpointParams()
+ .setErrorHandler(new UcpEndpointErrorHandler {
+ override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = {
+ logInfo(s"Connection to $shuffleClient closed: $errorString")
+ shuffleClients.remove(shuffleClient)
+ ucpEndpoint.close()
+ }
+ })
+ .setName(s"Server to $shuffleClient")
+ .setUcpAddress(workerAddress))
+ new ExternalUcxEndpoint(ucpEp, false)
+ })
+ } catch {
+ case e: UcxException => logWarning(s"Connection to $shuffleClient failed: $e")
+ }
+ }
+
+ def awaitConnection(shuffleClient: UcxWorkerId): ExternalUcxEndpoint = {
+ shuffleClients.getOrDefault(shuffleClient, {
+ // wait until connected finished
+ val startTime = System.currentTimeMillis()
+ while (!shuffleClients.containsKey(shuffleClient)) {
+ if (System.currentTimeMillis() - startTime > 10000) {
+ throw new UcxException(s"Don't get a worker address for $shuffleClient")
+ }
+ Thread.`yield`
+ }
+ shuffleClients.get(shuffleClient)
+ })
+ }
+
+ def handleAddress(ep: UcpEndpoint) = {
+ val msg = transport.getServerPortsBuffer()
+ val header = UnsafeUtils.getAdress(msg)
+ ep.sendAmNonBlocking(
+ ExternalAmId.REPLY_ADDRESS, header, msg.remaining(), header, 0,
+ UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback {
+ override def onSuccess(request: UcpRequest): Unit = {
+ logTrace(s"$workerId sent to REPLY_ADDRESS to $ep")
+ msg.clear()
+ }
+
+ override def onError(ucsStatus: Int, errorMsg: String): Unit = {
+ logWarning(s"$workerId sent to REPLY_ADDRESS to $ep: $errorMsg")
+ msg.clear()
+ }
+ }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)
+ logTrace(s"$workerId sending to REPLY_ADDRESS $msg mem $header")
+ }
+
+ def handleFetchBlockRequest(
+ clientWorker: UcxWorkerId, replyTag: Int,
+ blockInfos: Seq[(FileChannel, Long, Long)]): Unit = {
+ val blockSize = blockInfos.map(x => x._3).sum
+ if (blockSize <= maxReplySize) {
+ handleFetchBlockChunks(clientWorker, replyTag, blockInfos, blockSize)
+ return
+ }
+ // The size of last block could > maxBytesInFlight / 5 in spark.
+ val lastBlock = blockInfos.last
+ if (blockInfos.size > 1) {
+ val chunks = blockInfos.slice(0, blockInfos.size - 1)
+ val chunksSize = blockSize - lastBlock._3
+ handleFetchBlockChunks(clientWorker, replyTag, chunks, chunksSize)
+ }
+ handleFetchBlockStream(clientWorker, replyTag, lastBlock,
+ ExternalAmId.REPLY_SLICE)
+ }
+
+ def handleFetchBlockChunks(
+ clientWorker: UcxWorkerId, replyTag: Int,
+ blockInfos: Seq[(FileChannel, Long, Long)], blockSize: Long): Unit = {
+ val tagAndSizes = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE * blockInfos.length
+ val msgSize = tagAndSizes + blockSize
+ val resultMemory = memPool.get(msgSize).asInstanceOf[UcxLinkedMemBlock]
+ assert(resultMemory != null)
+ val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, msgSize)
+
+ resultBuffer.putInt(replyTag)
+ val blocksRange = 0 until blockInfos.length
+ for (i <- blocksRange) {
+ resultBuffer.putInt(blockInfos(i)._3.toInt)
+ }
+
+ for (i <- blocksRange) {
+ val (blockCh, blockOffset, blockSize) = blockInfos(i)
+ resultBuffer.limit(resultBuffer.position() + blockSize.toInt)
+ blockCh.read(resultBuffer, blockOffset)
+ }
+
+ val ep = awaitConnection(clientWorker)
+ executor.post(new Runnable {
+ override def run(): Unit = {
+ if (ep.closed) {
+ resultMemory.close()
+ return
+ }
+
+ val startTime = System.nanoTime()
+ val req = ep.ucpEp.sendAmNonBlocking(ExternalAmId.REPLY_BLOCK,
+ resultMemory.address, tagAndSizes, resultMemory.address + tagAndSizes,
+ msgSize - tagAndSizes, 0, new UcxCallback {
+ override def onSuccess(request: UcpRequest): Unit = {
+ resultMemory.close()
+ logTrace(s"${workerId.workerId} Sent to ${clientWorker} ${blockInfos.length} blocks of size: " +
+ s"${msgSize} tag $replyTag in ${System.nanoTime() - startTime} ns.")
+ }
+
+ override def onError(ucsStatus: Int, errorMsg: String): Unit = {
+ resultMemory.close()
+ logError(s"${workerId.workerId} Failed to reply fetch $clientWorker tag $replyTag $errorMsg.")
+ }
+ }, new UcpRequestParams().setMemoryType(
+ UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)
+ .setMemoryHandle(resultMemory.memory))
+ }
+ })
+ logTrace(s"${workerId.workerId} Sending to $clientWorker tag $replyTag mem $resultMemory size $msgSize")
+ }
+
+ def handleFetchBlockStream(clientWorker: UcxWorkerId, replyTag: Int,
+ blockInfo: (FileChannel, Long, Long),
+ amId: Int = ExternalAmId.REPLY_STREAM): Unit = {
+ // tag: Int + unsent replies: Int + total length: Long + offset now: Long
+ val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE +
+ UnsafeUtils.LONG_SIZE + UnsafeUtils.LONG_SIZE
+ val (blockCh, blockOffset, blockSize) = blockInfo
+ val blockSlice = (0L until blockSize by maxReplySize).toArray
+ // make sure the last one is not too small
+ if (blockSlice.size >= 2) {
+ val mid = (blockSlice(blockSlice.size - 2) + blockSize) / 2
+ blockSlice(blockSlice.size - 1) = mid
+ }
+
+ def send(workerWrapper: ExternalUcxServerWorker, currentId: Int): Unit = try {
+ val hashNext = (currentId + 1 != blockSlice.size)
+ val nextOffset = if (hashNext) blockSlice(currentId + 1) else blockSize
+ val currentOffset = blockSlice(currentId)
+ val currentSize = nextOffset - currentOffset
+ val unsent = blockSlice.length - currentId - 1
+ val msgSize = headerSize + currentSize.toInt
+ val mem = memPool.get(msgSize).asInstanceOf[UcxLinkedMemBlock]
+ val buffer = mem.toByteBuffer()
+
+ buffer.limit(msgSize)
+ buffer.putInt(replyTag)
+ buffer.putInt(unsent)
+ buffer.putLong(blockSize)
+ buffer.putLong(currentOffset)
+ blockCh.read(buffer, blockOffset + currentOffset)
+
+ val ep = workerWrapper.awaitConnection(clientWorker)
+ workerWrapper.executor.post(new Runnable {
+ override def run(): Unit = {
+ if (ep.closed) {
+ mem.close()
+ return
+ }
+
+ val startTime = System.nanoTime()
+ val req = ep.ucpEp.sendAmNonBlocking(amId, mem.address, headerSize,
+ mem.address + headerSize, currentSize,
+ UcpConstants.UCP_AM_SEND_FLAG_RNDV, new UcxCallback {
+ override def onSuccess(request: UcpRequest): Unit = {
+ mem.close()
+ logTrace(s"${workerId.workerId} Sent to ${clientWorker} size $currentSize tag $replyTag seg " +
+ s"$currentId in ${System.nanoTime() - startTime} ns.")
+ }
+ override def onError(ucsStatus: Int, errorMsg: String): Unit = {
+ mem.close()
+ logError(s"${workerId.workerId} Failed to reply stream $clientWorker tag $replyTag $currentId $errorMsg.")
+ }
+ }, new UcpRequestParams()
+ .setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)
+ .setMemoryHandle(mem.memory))
+ }
+ })
+ logTrace(s"${workerId.workerId} Sending to $clientWorker tag $replyTag $currentId mem $mem size $msgSize.")
+ if (hashNext) {
+ transport.submit(() => send(this, currentId + 1))
+ }
+ } catch {
+ case ex: Throwable =>
+ logError(s"${workerId.workerId} Failed to reply stream $clientWorker tag $replyTag $currentId $ex.")
+ }
+
+ send(this, 0)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/spark_2_4/CompatUtils.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/spark_2_4/CompatUtils.scala
new file mode 100644
index 00000000..d8d01fdc
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/spark_2_4/CompatUtils.scala
@@ -0,0 +1,90 @@
+package org.apache.spark.shuffle.ucx
+
+import org.apache.spark.shuffle.utils.UcxThreadFactory
+import java.nio.ByteBuffer
+import java.util.concurrent.{Executors, ExecutorService}
+import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
+
+case class UcxShuffleMapId(shuffleId: Int, mapId: Int) {}
+
+case class UcxShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+ override def serializedSize: Int = UcxShuffleBlockId.serializedSize
+
+ override def serialize(byteBuffer: ByteBuffer): Unit = {
+ byteBuffer.putInt(shuffleId)
+ byteBuffer.putInt(reduceId)
+ byteBuffer.putInt(mapId)
+ }
+}
+
+object UcxShuffleBlockId {
+ final val serializedSize = 12
+
+ def deserialize(byteBuffer: ByteBuffer): UcxShuffleBlockId = {
+ val shuffleId = byteBuffer.getInt
+ val reduceId = byteBuffer.getInt
+ val mapId = byteBuffer.getInt
+ UcxShuffleBlockId(shuffleId, mapId, reduceId)
+ }
+}
+
+case class UcxWorkerId(appId: String, exeId: Int, workerId: Int) extends BlockId {
+ override def serializedSize: Int = 12 + appId.size
+
+ override def serialize(byteBuffer: ByteBuffer): Unit = {
+ byteBuffer.putInt(exeId)
+ byteBuffer.putInt(workerId)
+ byteBuffer.putInt(appId.size)
+ byteBuffer.put(appId.getBytes)
+ }
+
+ override def toString(): String = s"UcxWorkerId($appId, $exeId, $workerId)"
+}
+
+object UcxWorkerId {
+ def deserialize(byteBuffer: ByteBuffer): UcxWorkerId = {
+ val exeId = byteBuffer.getInt
+ val workerId = byteBuffer.getInt
+ val appIdSize = byteBuffer.getInt
+ val appIdBytes = new Array[Byte](appIdSize)
+ byteBuffer.get(appIdBytes)
+ UcxWorkerId(new String(appIdBytes), exeId, workerId)
+ }
+
+ @`inline`
+ def makeExeWorkerId(id: UcxWorkerId): Long = {
+ (id.workerId.toLong << 32) | id.exeId
+ }
+
+ @`inline`
+ def extractExeId(exeWorkerId: Long): Int = {
+ exeWorkerId.toInt
+ }
+
+ @`inline`
+ def extractWorkerId(exeWorkerId: Long): Int = {
+ (exeWorkerId >> 32).toInt
+ }
+
+ def apply(appId: String, exeWorkerId: Long): UcxWorkerId = {
+ UcxWorkerId(appId, UcxWorkerId.extractExeId(exeWorkerId),
+ UcxWorkerId.extractWorkerId(exeWorkerId))
+ }
+}
+
+object UcxThreadUtils {
+ def newForkJoinPool(prefix: String, maxThreadNumber: Int): SForkJoinPool = {
+ val factory = new SForkJoinPool.ForkJoinWorkerThreadFactory {
+ override def newThread(pool: SForkJoinPool) =
+ new SForkJoinWorkerThread(pool) {
+ setName(s"${prefix}-${super.getName}")
+ }
+ }
+ new SForkJoinPool(maxThreadNumber, factory, null, false)
+ }
+
+ def newFixedDaemonPool(prefix: String, maxThreadNumber: Int): ExecutorService = {
+ val factory = new UcxThreadFactory().setDaemon(true).setPrefix(prefix)
+ Executors.newFixedThreadPool(maxThreadNumber, factory)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/spark_3_0/CompatUtils.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/spark_3_0/CompatUtils.scala
new file mode 100644
index 00000000..0a1f80f4
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/spark_3_0/CompatUtils.scala
@@ -0,0 +1,89 @@
+package org.apache.spark.shuffle.ucx
+
+import org.apache.spark.shuffle.utils.UcxThreadFactory
+import java.nio.ByteBuffer
+import java.util.concurrent.{Executors, ExecutorService, ForkJoinPool, ForkJoinWorkerThread}
+
+case class UcxShuffleMapId(shuffleId: Int, mapId: Long) {}
+
+case class UcxShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId {
+ override def serializedSize: Int = UcxShuffleBlockId.serializedSize
+
+ override def serialize(byteBuffer: ByteBuffer): Unit = {
+ byteBuffer.putInt(shuffleId)
+ byteBuffer.putInt(reduceId)
+ byteBuffer.putLong(mapId)
+ }
+}
+
+object UcxShuffleBlockId {
+ final val serializedSize = 16
+
+ def deserialize(byteBuffer: ByteBuffer): UcxShuffleBlockId = {
+ val shuffleId = byteBuffer.getInt
+ val reduceId = byteBuffer.getInt
+ val mapId = byteBuffer.getLong
+ UcxShuffleBlockId(shuffleId, mapId, reduceId)
+ }
+}
+
+case class UcxWorkerId(appId: String, exeId: Int, workerId: Int) extends BlockId {
+ override def serializedSize: Int = 12 + appId.size
+
+ override def serialize(byteBuffer: ByteBuffer): Unit = {
+ byteBuffer.putInt(exeId)
+ byteBuffer.putInt(workerId)
+ byteBuffer.putInt(appId.size)
+ byteBuffer.put(appId.getBytes)
+ }
+
+ override def toString(): String = s"UcxWorkerId($appId, $exeId, $workerId)"
+}
+
+object UcxWorkerId {
+ def deserialize(byteBuffer: ByteBuffer): UcxWorkerId = {
+ val exeId = byteBuffer.getInt
+ val workerId = byteBuffer.getInt
+ val appIdSize = byteBuffer.getInt
+ val appIdBytes = new Array[Byte](appIdSize)
+ byteBuffer.get(appIdBytes)
+ UcxWorkerId(new String(appIdBytes), exeId, workerId)
+ }
+
+ @`inline`
+ def makeExeWorkerId(id: UcxWorkerId): Long = {
+ (id.workerId.toLong << 32) | id.exeId
+ }
+
+ @`inline`
+ def extractExeId(exeWorkerId: Long): Int = {
+ exeWorkerId.toInt
+ }
+
+ @`inline`
+ def extractWorkerId(exeWorkerId: Long): Int = {
+ (exeWorkerId >> 32).toInt
+ }
+
+ def apply(appId: String, exeWorkerId: Long): UcxWorkerId = {
+ UcxWorkerId(appId, UcxWorkerId.extractExeId(exeWorkerId),
+ UcxWorkerId.extractWorkerId(exeWorkerId))
+ }
+}
+
+object UcxThreadUtils {
+ def newForkJoinPool(prefix: String, maxThreadNumber: Int): ForkJoinPool = {
+ val factory = new ForkJoinPool.ForkJoinWorkerThreadFactory {
+ override def newThread(pool: ForkJoinPool) =
+ new ForkJoinWorkerThread(pool) {
+ setName(s"${prefix}-${super.getName}")
+ }
+ }
+ new ForkJoinPool(maxThreadNumber, factory, null, false)
+ }
+
+ def newFixedDaemonPool(prefix: String, maxThreadNumber: Int): ExecutorService = {
+ val factory = new UcxThreadFactory().setDaemon(true).setPrefix(prefix)
+ Executors.newFixedThreadPool(maxThreadNumber, factory)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxLimitedMemPool.scala b/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxLimitedMemPool.scala
new file mode 100644
index 00000000..e0415b29
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxLimitedMemPool.scala
@@ -0,0 +1,251 @@
+package org.apache.spark.shuffle.ucx.memory
+
+import java.io.Closeable
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedDeque, ConcurrentSkipListSet}
+import java.util.concurrent.Semaphore
+import java.nio.ByteBuffer
+import scala.collection.JavaConverters._
+
+import org.openucx.jucx.ucp.{UcpContext, UcpMemMapParams, UcpMemory}
+import org.openucx.jucx.ucs.UcsConstants
+import org.apache.spark.shuffle.utils.{SparkucxUtils, UcxLogging, UnsafeUtils}
+import org.apache.spark.shuffle.ucx.MemoryBlock
+
+class UcxSharedMemoryBlock(val closeCb: () => Unit, val refCount: AtomicInteger,
+ override val address: Long, override val size: Long)
+ extends MemoryBlock(address, size, true) {
+
+ override def close(): Unit = {
+ if (refCount.decrementAndGet() == 0) {
+ closeCb()
+ }
+ }
+}
+
+class UcxMemBlock(private[ucx] val memory: UcpMemory,
+ private[ucx] val allocator: UcxMemoryAllocator,
+ override val address: Long, override val size: Long)
+ extends MemoryBlock(address, size,memory.getMemType ==
+ UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) {
+ def toByteBuffer() = {
+ UnsafeUtils.getByteBufferView(address, size.min(Int.MaxValue).toInt)
+ }
+
+ override def close(): Unit = {
+ allocator.deallocate(this)
+ }
+}
+
+class UcxLinkedMemBlock(private[memory] val superMem: UcxLinkedMemBlock,
+ private[memory] var broMem: UcxLinkedMemBlock,
+ override private[ucx] val memory: UcpMemory,
+ override private[ucx] val allocator: UcxMemoryAllocator,
+ override val address: Long, override val size: Long)
+ extends UcxMemBlock(memory, allocator, address, size) with Comparable[UcxLinkedMemBlock] {
+ override def compareTo(o: UcxLinkedMemBlock): Int = {
+ return address.compareTo(o.address)
+ }
+}
+
+trait UcxMemoryAllocator extends Closeable {
+ def allocate(): UcxMemBlock
+ def deallocate(mem: UcxMemBlock): Unit
+ def preallocate(numBuffers: Int): Unit = {
+ (0 until numBuffers).map(x => allocate()).foreach(_.close())
+ }
+ def totalSize(): Long
+}
+
+abstract class UcxBaseMemAllocator extends UcxMemoryAllocator with UcxLogging {
+ private[memory] val stack = new ConcurrentSkipListSet[UcxMemBlock]
+ private[memory] val numAllocs = new AtomicInteger(0)
+ private[memory] val memMapParams = new UcpMemMapParams().allocate()
+ .setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)
+
+ override def close(): Unit = {
+ var numBuffers = 0
+ var length = 0L
+ stack.forEach(block => {
+ if (block.memory.getNativeId != null) {
+ length = block.size
+ block.memory.deregister()
+ }
+ numBuffers += 1
+ })
+ if (numBuffers != 0) {
+ logInfo(s"Closing $numBuffers buffers size $length allocations " +
+ s"${numAllocs.get()}. Total ${SparkucxUtils.bytesToString(length * numBuffers)}")
+ stack.clear()
+ }
+ }
+}
+
+case class UcxLinkedMemAllocator(length: Long, minRegistrationSize: Long,
+ next: UcxLinkedMemAllocator,
+ ucxContext: UcpContext)
+ extends UcxBaseMemAllocator() with Closeable {
+ private[this] val registrationSize = length.max(minRegistrationSize)
+ private[this] val sliceRange = (0L until (registrationSize / length))
+ private[this] var limit: Semaphore = _
+ logInfo(s"Allocator stack size $length")
+ if (next == null) {
+ memMapParams.setLength(registrationSize)
+ }
+
+ override def allocate(): UcxMemBlock = {
+ acquireLimit()
+ var result = stack.pollFirst()
+ if (result != null) {
+ result
+ } else if (next == null) {
+ logDebug(s"Allocating buffer of size $length.")
+ while (result == null) {
+ numAllocs.incrementAndGet()
+ val memory = ucxContext.memoryMap(memMapParams)
+ var address = memory.getAddress
+ for (i <- sliceRange) {
+ stack.add(new UcxLinkedMemBlock(null, null, memory, this, address, length))
+ address += length
+ }
+ result = stack.pollFirst()
+ }
+ result
+ } else {
+ val superMem = next.allocate().asInstanceOf[UcxLinkedMemBlock]
+ val address1 = superMem.address
+ val address2 = address1 + length
+ val block1 = new UcxLinkedMemBlock(superMem, null, superMem.memory, this,
+ address1, length)
+ val block2 = new UcxLinkedMemBlock(superMem, null, superMem.memory, this,
+ address2, length)
+ block1.broMem = block2
+ block2.broMem = block1
+ stack.add(block2)
+ block1
+ }
+ }
+
+ override def deallocate(memBlock: UcxMemBlock): Unit = {
+ val block = memBlock.asInstanceOf[UcxLinkedMemBlock]
+ if (block.superMem == null) {
+ stack.add(block)
+ releaseLimit()
+ return
+ }
+
+ var releaseNext = false
+ block.superMem.synchronized {
+ releaseNext = stack.remove(block.broMem)
+ if (!releaseNext) {
+ stack.add(block)
+ }
+ }
+
+ if (releaseNext) {
+ next.deallocate(block.superMem)
+ }
+ releaseLimit()
+ }
+
+ override def totalSize(): Long = registrationSize * numAllocs.get()
+
+ def acquireLimit() = if (limit != null) {
+ limit.acquire(1)
+ }
+
+ def releaseLimit() = if (limit != null) {
+ limit.release(1)
+ }
+
+ def setLimit(num: Int): Unit = {
+ limit = new Semaphore(num)
+ }
+}
+
+case class UcxLimitedMemPool(ucxContext: UcpContext)
+ extends Closeable with UcxLogging {
+ private[memory] val allocatorMap = new ConcurrentHashMap[Long, UcxMemoryAllocator]()
+ private[memory] var minBufferSize: Long = 4096L
+ private[memory] var maxBufferSize: Long = 2L * 1024 * 1024 * 1024
+ private[memory] var minRegistrationSize: Long = 1024L * 1024
+ private[memory] var maxRegistrationSize: Long = 16L * 1024 * 1024 * 1024
+
+ def get(size: Long): MemoryBlock = {
+ allocatorMap.get(roundUpToTheNextPowerOf2(size)).allocate()
+ }
+
+ def report(): Unit = {
+ val memInfo = allocatorMap.asScala.map(allocator =>
+ allocator._1 -> allocator._2.totalSize).filter(_._2 != 0)
+
+ if (memInfo.nonEmpty) {
+ logInfo(s"Memory pool use: $memInfo")
+ }
+ }
+
+ def init(minBufSize: Long, maxBufSize: Long, minRegSize: Long, maxRegSize: Long,
+ preAllocMap: Map[Long, Int], limit: Boolean, memGroupSize: Int = 3):
+ Unit = {
+ assert(memGroupSize > 2, s"Invalid memGroupSize. Expect > 2. Actual $memGroupSize")
+ val maxMemFactor = 1.0 - 1.0 / (1 << (memGroupSize - 1))
+ minBufferSize = roundUpToTheNextPowerOf2(minBufSize)
+ maxBufferSize = roundUpToTheNextPowerOf2(maxBufSize)
+ minRegistrationSize = roundUpToTheNextPowerOf2(minRegSize)
+ maxRegistrationSize = roundUpToTheNextPowerOf2(maxRegSize)
+
+ val memRange = (1 until 47).map(1L << _).filter(m =>
+ (m >= minBufferSize) && (m <= maxBufferSize)).reverse
+ val minLimit = (maxRegistrationSize / maxBufferSize * maxMemFactor).toLong
+ logInfo(s"limit $limit buf ($minBufferSize, $maxBufferSize) reg " +
+ s"($minRegistrationSize, $maxRegistrationSize)")
+
+ var shift = 0
+ for (i <- 0 until memRange.length by memGroupSize) {
+ var superAllocator: UcxLinkedMemAllocator = null
+ for (j <- 0 until memGroupSize.min(memRange.length - i)) {
+ val memSize = memRange(i + j)
+ val current = new UcxLinkedMemAllocator(memSize, minRegistrationSize,
+ superAllocator, ucxContext)
+ // set limit to top allocator
+ if (limit && (superAllocator == null)) {
+ val memLimit = (maxRegistrationSize / memSize).min(minLimit << shift)
+ .max(1L)
+ .min(Int.MaxValue)
+ logInfo(s"mem $memSize limit $memLimit")
+ current.setLimit(memLimit.toInt)
+ shift += 1
+ }
+ superAllocator = current
+ allocatorMap.put(memSize, current)
+ }
+ }
+ preAllocMap.foreach{
+ case (size, count) => {
+ allocatorMap.get(roundUpToTheNextPowerOf2(size)).preallocate(count)
+ }
+ }
+ }
+
+ protected def roundUpToTheNextPowerOf2(size: Long): Long = {
+ if (size < minBufferSize) {
+ minBufferSize
+ } else {
+ // Round up length to the nearest power of two
+ var length = size
+ length -= 1
+ length |= length >> 1
+ length |= length >> 2
+ length |= length >> 4
+ length |= length >> 8
+ length |= length >> 16
+ length += 1
+ length
+ }
+ }
+
+ override def close(): Unit = {
+ allocatorMap.values.forEach(allocator => allocator.close())
+ allocatorMap.clear()
+ }
+ }
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/ExternalUcxDriverRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/ExternalUcxDriverRpcEndpoint.scala
new file mode 100644
index 00000000..0d02e31b
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/ExternalUcxDriverRpcEndpoint.scala
@@ -0,0 +1,45 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx.rpc
+
+import java.net.InetSocketAddress
+
+import scala.collection.immutable.HashMap
+import scala.collection.mutable
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc._
+import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{PushServiceAddress, PushAllServiceAddress}
+import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils}
+
+class ExternalUcxDriverRpcEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging {
+
+ private val endpoints = mutable.HashSet.empty[RpcEndpointRef]
+ private var shuffleServerMap = mutable.HashMap.empty[String, Seq[Int]]
+
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case message@PushServiceAddress(host: String, ports: Seq[Int], endpoint: RpcEndpointRef) => {
+ // Driver receives a message from executor with it's workerAddress
+ logInfo(s"Received message $ports from ${context.senderAddress}")
+ // 1. Introduce existing members of a cluster
+ if (shuffleServerMap.nonEmpty) {
+ val msg = PushAllServiceAddress(shuffleServerMap.toMap)
+ logDebug(s"Replying $msg to $endpoint")
+ context.reply(msg)
+ }
+ // 2. For each existing member introduce newly joined executor.
+ if (!shuffleServerMap.contains(host)) {
+ shuffleServerMap += host -> ports
+ endpoints.foreach(ep => {
+ logDebug(s"Sending message $ports to $ep")
+ ep.send(message)
+ })
+ }
+ // 3. Add ep to registered eps.
+ endpoints.add(endpoint)
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/ExternalUcxExecutorRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/ExternalUcxExecutorRpcEndpoint.scala
new file mode 100644
index 00000000..065ce23c
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/ExternalUcxExecutorRpcEndpoint.scala
@@ -0,0 +1,25 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx.rpc
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
+import org.apache.spark.shuffle.ucx.ExternalUcxClientTransport
+import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{PushServiceAddress, PushAllServiceAddress}
+import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer
+
+import java.util.concurrent.ExecutorService
+
+class ExternalUcxExecutorRpcEndpoint(override val rpcEnv: RpcEnv, transport: ExternalUcxClientTransport,
+ executorService: ExecutorService)
+ extends RpcEndpoint {
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case PushServiceAddress(host: String, ports: Seq[Int], _: RpcEndpointRef) =>
+ transport.connect(host, ports)
+ case PushAllServiceAddress(shuffleServerMap: Map[String, Seq[Int]]) =>
+ transport.connectAll(shuffleServerMap)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala
index 8c476236..27e84000 100755
--- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala
@@ -19,4 +19,8 @@ object UcxRpcMessages {
* Reply from driver with all executors in the cluster with their worker addresses.
*/
case class IntroduceAllExecutors(executorIdToAddress: Map[Long, SerializableDirectBuffer])
+
+ case class PushServiceAddress(host: String, ports: Seq[Int], endpoint: RpcEndpointRef)
+
+ case class PushAllServiceAddress(shuffleServerMap: Map[String, Seq[Int]])
}
diff --git a/src/main/scala/org/apache/spark/shuffle/utils/SerializableDirectBuffer.scala b/src/main/scala/org/apache/spark/shuffle/utils/SerializableDirectBuffer.scala
index b1f6e970..63074e80 100755
--- a/src/main/scala/org/apache/spark/shuffle/utils/SerializableDirectBuffer.scala
+++ b/src/main/scala/org/apache/spark/shuffle/utils/SerializableDirectBuffer.scala
@@ -10,19 +10,18 @@ import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.nio.charset.StandardCharsets
-import org.apache.spark.internal.Logging
-import org.apache.spark.util.Utils
+import org.apache.spark.shuffle.utils.{SparkucxUtils, UcxLogging}
/**
* A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make
* it easier to pass ByteBuffers in case class messages.
*/
class SerializableDirectBuffer(@transient var buffer: ByteBuffer) extends Serializable
- with Logging {
+ with UcxLogging {
def value: ByteBuffer = buffer
- private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
+ private def readObject(in: ObjectInputStream): Unit = SparkucxUtils.tryOrIOException {
val length = in.readInt()
buffer = ByteBuffer.allocateDirect(length)
var amountRead = 0
@@ -37,7 +36,7 @@ class SerializableDirectBuffer(@transient var buffer: ByteBuffer) extends Serial
buffer.rewind() // Allow us to read it later
}
- private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
+ private def writeObject(out: ObjectOutputStream): Unit = SparkucxUtils.tryOrIOException {
out.writeInt(buffer.limit())
buffer.rewind()
while (buffer.position() < buffer.limit()) {
@@ -48,11 +47,11 @@ class SerializableDirectBuffer(@transient var buffer: ByteBuffer) extends Serial
}
class DeserializableToExternalMemoryBuffer(@transient var buffer: ByteBuffer)() extends Serializable
- with Logging {
+ with UcxLogging {
def value: ByteBuffer = buffer
- private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
+ private def readObject(in: ObjectInputStream): Unit = SparkucxUtils.tryOrIOException {
val length = in.readInt()
var amountRead = 0
val channel = Channels.newChannel(in)
@@ -79,7 +78,7 @@ object SerializationUtils {
}
def serializeInetAddress(address: InetSocketAddress): ByteBuffer = {
- val hostAddress = new InetSocketAddress(Utils.localCanonicalHostName(), address.getPort)
+ val hostAddress = new InetSocketAddress(address.getAddress.getCanonicalHostName, address.getPort)
val hostString = hostAddress.getHostName.getBytes(StandardCharsets.UTF_8)
val result = ByteBuffer.allocateDirect(hostString.length + 4)
result.putInt(hostAddress.getPort)
diff --git a/src/main/scala/org/apache/spark/shuffle/utils/SparkucxUtils.scala b/src/main/scala/org/apache/spark/shuffle/utils/SparkucxUtils.scala
new file mode 100644
index 00000000..f28ccf0d
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/utils/SparkucxUtils.scala
@@ -0,0 +1,85 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.utils
+
+import java.io.IOException
+import java.math.{MathContext, RoundingMode}
+import java.util.Locale
+import java.util.concurrent.ThreadFactory
+import scala.util.control.NonFatal
+
+object SparkucxUtils extends UcxLogging {
+ def tryOrIOException[T](block: => T): T = {
+ try {
+ block
+ } catch {
+ case e: IOException =>
+ logError("Exception encountered", e)
+ throw e
+ case NonFatal(e) =>
+ logError("Exception encountered", e)
+ throw new IOException(e)
+ }
+ }
+
+ def bytesToString(size: Long): String = bytesToString(BigInt(size))
+
+ def bytesToString(size: BigInt): String = {
+ val EB = 1L << 60
+ val PB = 1L << 50
+ val TB = 1L << 40
+ val GB = 1L << 30
+ val MB = 1L << 20
+ val KB = 1L << 10
+
+ if (size >= BigInt(1L << 11) * EB) {
+ // The number is too large, show it in scientific notation.
+ BigDecimal(size, new MathContext(3, RoundingMode.HALF_UP)).toString() + " B"
+ } else {
+ val (value, unit) = {
+ if (size >= 2 * EB) {
+ (BigDecimal(size) / EB, "EB")
+ } else if (size >= 2 * PB) {
+ (BigDecimal(size) / PB, "PB")
+ } else if (size >= 2 * TB) {
+ (BigDecimal(size) / TB, "TB")
+ } else if (size >= 2 * GB) {
+ (BigDecimal(size) / GB, "GB")
+ } else if (size >= 2 * MB) {
+ (BigDecimal(size) / MB, "MB")
+ } else if (size >= 2 * KB) {
+ (BigDecimal(size) / KB, "KB")
+ } else {
+ (BigDecimal(size), "B")
+ }
+ }
+ "%.1f %s".formatLocal(Locale.US, value, unit)
+ }
+ }
+}
+
+class UcxThreadFactory extends ThreadFactory {
+ private var daemon: Boolean = true
+ private var prefix: String = "UCX"
+
+ private class NamedThread(r: Runnable) extends Thread(r) {
+ setDaemon(daemon)
+ setName(s"${prefix}-${super.getName}")
+ }
+
+ def setDaemon(isDaemon: Boolean): this.type = {
+ daemon = isDaemon
+ this
+ }
+
+ def setPrefix(name: String): this.type = {
+ prefix = name
+ this
+ }
+
+ def newThread(r: Runnable): Thread = {
+ new NamedThread(r)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/utils/UcxLogging.scala b/src/main/scala/org/apache/spark/shuffle/utils/UcxLogging.scala
new file mode 100644
index 00000000..527e7b86
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/utils/UcxLogging.scala
@@ -0,0 +1,64 @@
+package org.apache.spark.shuffle.utils
+
+import org.slf4j.Logger
+import org.slf4j.LoggerFactory
+import org.apache.log4j.LogManager
+
+trait UcxLogging {
+ @transient private var log_ : Logger = null
+
+ // Method to get the logger name for this object
+ protected def logName = {
+ // Ignore trailing $'s in the class names for Scala objects
+ this.getClass.getName.stripSuffix("$")
+ }
+
+ protected def log: Logger = {
+ if (log_ == null) {
+ log_ = LoggerFactory.getLogger(logName)
+ }
+ log_
+ }
+
+ // Log methods that take only a String
+ protected def logInfo(msg: => String) {
+ if (log.isInfoEnabled) log.info(msg)
+ }
+
+ protected def logDebug(msg: => String) {
+ if (log.isDebugEnabled) log.debug(msg)
+ }
+
+ protected def logTrace(msg: => String) {
+ if (log.isTraceEnabled) log.trace(msg)
+ }
+
+ protected def logWarning(msg: => String) {
+ if (log.isWarnEnabled) log.warn(msg)
+ }
+
+ protected def logError(msg: => String) {
+ if (log.isErrorEnabled) log.error(msg)
+ }
+
+ // Log methods that take Throwables (Exceptions/Errors) too
+ protected def logInfo(msg: => String, throwable: Throwable) {
+ if (log.isInfoEnabled) log.info(msg, throwable)
+ }
+
+ protected def logDebug(msg: => String, throwable: Throwable) {
+ if (log.isDebugEnabled) log.debug(msg, throwable)
+ }
+
+ protected def logTrace(msg: => String, throwable: Throwable) {
+ if (log.isTraceEnabled) log.trace(msg, throwable)
+ }
+
+ protected def logWarning(msg: => String, throwable: Throwable) {
+ if (log.isWarnEnabled) log.warn(msg, throwable)
+ }
+
+ protected def logError(msg: => String, throwable: Throwable) {
+ if (log.isErrorEnabled) log.error(msg, throwable)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/spark/shuffle/utils/UnsafeUtils.scala b/src/main/scala/org/apache/spark/shuffle/utils/UnsafeUtils.scala
index 329fb01b..e4eab255 100755
--- a/src/main/scala/org/apache/spark/shuffle/utils/UnsafeUtils.scala
+++ b/src/main/scala/org/apache/spark/shuffle/utils/UnsafeUtils.scala
@@ -10,9 +10,8 @@ import java.nio.channels.FileChannel
import org.openucx.jucx.UcxException
import sun.nio.ch.{DirectBuffer, FileChannelImpl}
-import org.apache.spark.internal.Logging
-object UnsafeUtils extends Logging {
+object UnsafeUtils extends UcxLogging {
val INT_SIZE: Int = 4
val LONG_SIZE: Int = 8