diff --git a/pom.xml b/pom.xml index 57315f45..8b57446a 100755 --- a/pom.xml +++ b/pom.xml @@ -45,6 +45,8 @@ See file LICENSE for terms. **/spark_2_1/** **/spark_3_0/** + org/apache/spark/network/** + org/apache/spark/shuffle/ucx/external/server/** @@ -71,6 +73,8 @@ See file LICENSE for terms. **/spark_2_1/** **/spark_2_4/** + org/apache/spark/network/** + org/apache/spark/shuffle/ucx/external/server/** @@ -83,6 +87,93 @@ See file LICENSE for terms. **/spark_2_1/**, **/spark_2_4/** + + sparkess-2.4 + + + + net.alchim31.maven + scala-maven-plugin + + + org/apache/spark/network/** + org/apache/spark/shuffle/ucx/external/** + org/apache/spark/shuffle/ucx/memory/UcxLimitedMemPool.scala + org/apache/spark/shuffle/ucx/ShuffleTransport.scala + org/apache/spark/shuffle/utils/** + + + **/spark_3_0/** + org/apache/spark/shuffle/ucx/external/client/** + + + + + + + 2.4.0 + **/spark_3_0/**, **/spark_2_1/** + 2.11.12 + 2.11 + 2.6.7 + + + + org.scala-lang + scala-library + ${scala.version} + + + com.fasterxml.jackson.core + jackson-databind + ${fasterxml.jackson.version} + + + + + sparkess-3.0 + + + + net.alchim31.maven + scala-maven-plugin + + + org/apache/spark/network/** + org/apache/spark/shuffle/ucx/external/** + org/apache/spark/shuffle/ucx/memory/UcxLimitedMemPool.scala + org/apache/spark/shuffle/ucx/ShuffleTransport.scala + org/apache/spark/shuffle/utils/** + + + **/spark_2_4/** + org/apache/spark/shuffle/ucx/external/client/** + + + + + + + 3.0.1 + 2.12.10 + 2.12 + **/spark_2_1/**, **/spark_2_4/** + 2.10.0 + + + + org.scala-lang + scala-library + ${scala.version} + + + com.fasterxml.jackson.core + jackson-databind + ${fasterxml.jackson.version} + + + + @@ -93,10 +184,16 @@ See file LICENSE for terms. ${spark.version} provided + + org.apache.spark + spark-network-yarn_${scala.compat.version} + ${spark.version} + provided + org.openucx jucx - 1.13.1 + 1.16.0 diff --git a/src/main/scala/org/apache/spark/network/shuffle/ExternalUcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/network/shuffle/ExternalUcxShuffleBlockResolver.scala new file mode 100644 index 00000000..d8d0cb9f --- /dev/null +++ b/src/main/scala/org/apache/spark/network/shuffle/ExternalUcxShuffleBlockResolver.scala @@ -0,0 +1,91 @@ +package org.apache.spark.network.shuffle + +import java.io.File +import java.nio.charset.StandardCharsets +import java.lang.reflect.{Method, Field} + +import scala.collection.mutable + +import com.fasterxml.jackson.databind.ObjectMapper + +import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem; +import org.apache.hadoop.metrics2.impl.MetricsSystemImpl; + +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo +import org.apache.spark.network.util.TransportConf +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId + +import org.apache.spark.shuffle.utils.UcxLogging +import org.apache.spark.shuffle.ucx.ExternalUcxServerTransport + +class ExternalUcxShuffleBlockResolver(conf: TransportConf, registeredExecutorFile: File) + extends ExternalShuffleBlockResolver(conf, registeredExecutorFile) with UcxLogging { + private[spark] final val APP_KEY_PREFIX = "AppExecShuffleInfo"; + private[spark] final val ucxMapper = new ObjectMapper + private[spark] var dbAppExecKeyMethod: Method = _ + private[spark] val knownManagers = mutable.Set( + "org.apache.spark.shuffle.sort.SortShuffleManager", + "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager", + "org.apache.spark.shuffle.ExternalUcxShuffleManager") + private[spark] var ucxTransport: ExternalUcxServerTransport = _ + + // init() + + private[spark] def dbAppExecKey(appExecId: AppExecId): Array[Byte] = { + // we stick a common prefix on all the keys so we can find them in the DB + val appExecJson = ucxMapper.writeValueAsString(appExecId); + val key = (APP_KEY_PREFIX + ";" + appExecJson); + key.getBytes(StandardCharsets.UTF_8); + } + + // def init(): Unit = { + // val clazz = Class.forName("org.apache.spark.network.shuffle.ExternalShuffleBlockResolver") + // try { + // dbAppExecKeyMethod = clazz.getDeclaredMethod("dbAppExecKey", classOf[AppExecId]) + // dbAppExecKeyMethod.setAccessible(true) + // } catch { + // case e: Exception => { + // logError(s"Get dbAppExecKey from ExternalUcxShuffleBlockResolver failed: $e") + // } + // } + // } + + // def dbAppExecKey(fullId: AppExecId): Array[Byte] = { + // dbAppExecKeyMethod.invoke(this, fullId).asInstanceOf[Array[Byte]] + // } + + def setTransport(transport: ExternalUcxServerTransport): Unit = { + ucxTransport = transport + } + /** Registers a new Executor with all the configuration we need to find its shuffle files. */ + override def registerExecutor( + appId: String, + execId: String, + executorInfo: ExecutorShuffleInfo): Unit = { + val fullId = new AppExecId(appId, execId) + logInfo(s"Registered executor ${fullId} with ${executorInfo}") + if (!knownManagers.contains(executorInfo.shuffleManager)) { + throw new UnsupportedOperationException( + "Unsupported shuffle manager of executor: " + executorInfo) + } + try { + if (db != null) { + val key = dbAppExecKey(fullId) + val value = ucxMapper.writeValueAsString(executorInfo).getBytes(StandardCharsets.UTF_8) + db.put(key, value) + } + executors.put(fullId, executorInfo) + } catch { + case e: Exception => logError("Error saving registered executors", e) + } + } + + override def applicationRemoved(appId: String, cleanupLocalDirs: Boolean): Unit = { + super.applicationRemoved(appId, cleanupLocalDirs) + ucxTransport.applicationRemoved(appId) + } + override def executorRemoved(executorId: String, appId: String): Unit = { + super.executorRemoved(executorId, appId) + ucxTransport.executorRemoved(executorId, appId) + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/network/shuffle/spark_2_4/ExternalUcxShuffleBlockHandler.scala b/src/main/scala/org/apache/spark/network/shuffle/spark_2_4/ExternalUcxShuffleBlockHandler.scala new file mode 100644 index 00000000..8a092638 --- /dev/null +++ b/src/main/scala/org/apache/spark/network/shuffle/spark_2_4/ExternalUcxShuffleBlockHandler.scala @@ -0,0 +1,20 @@ +package org.apache.spark.network.shuffle + +import java.io.File + +import org.apache.spark.network.server.OneForOneStreamManager +import org.apache.spark.network.util.TransportConf + +import org.apache.spark.shuffle.utils.UcxLogging +import org.apache.spark.shuffle.ucx.ExternalUcxServerTransport + +class ExternalUcxShuffleBlockHandler(conf: TransportConf, registeredExecutorFile: File) + extends ExternalShuffleBlockHandler(new OneForOneStreamManager(), + new ExternalUcxShuffleBlockResolver(conf, registeredExecutorFile)) with UcxLogging { + def ucxBlockManager(): ExternalUcxShuffleBlockResolver = { + blockManager.asInstanceOf[ExternalUcxShuffleBlockResolver] + } + def setTransport(transport: ExternalUcxServerTransport): Unit = { + ucxBlockManager.setTransport(transport) + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/network/shuffle/spark_3_0/ExternalUcxShuffleBlockHandler.scala b/src/main/scala/org/apache/spark/network/shuffle/spark_3_0/ExternalUcxShuffleBlockHandler.scala new file mode 100644 index 00000000..926050d2 --- /dev/null +++ b/src/main/scala/org/apache/spark/network/shuffle/spark_3_0/ExternalUcxShuffleBlockHandler.scala @@ -0,0 +1,20 @@ +package org.apache.spark.network.shuffle + +import java.io.File + +import org.apache.spark.network.server.OneForOneStreamManager +import org.apache.spark.network.util.TransportConf + +import org.apache.spark.shuffle.utils.UcxLogging +import org.apache.spark.shuffle.ucx.ExternalUcxServerTransport + +class ExternalUcxShuffleBlockHandler(conf: TransportConf, registeredExecutorFile: File) + extends ExternalBlockHandler(new OneForOneStreamManager(), + new ExternalUcxShuffleBlockResolver(conf, registeredExecutorFile)) with UcxLogging { + def ucxBlockManager(): ExternalUcxShuffleBlockResolver = { + blockManager.asInstanceOf[ExternalUcxShuffleBlockResolver] + } + def setTransport(transport: ExternalUcxServerTransport): Unit = { + ucxBlockManager.setTransport(transport) + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/network/yarn/spark_2_4/UcxYarnShuffleService.scala b/src/main/scala/org/apache/spark/network/yarn/spark_2_4/UcxYarnShuffleService.scala new file mode 100644 index 00000000..74d4fb4c --- /dev/null +++ b/src/main/scala/org/apache/spark/network/yarn/spark_2_4/UcxYarnShuffleService.scala @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License") you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn + +import java.io.File +import java.io.IOException +import java.nio.charset.StandardCharsets +import java.nio.ByteBuffer + +import com.fasterxml.jackson.databind.ObjectMapper +import com.google.common.collect.Lists +import com.google.common.base.Preconditions +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.permission.FsPermission +import org.apache.hadoop.yarn.server.api._ +import org.apache.spark.network.util.LevelDBProvider +import org.iq80.leveldb.DB + +import org.apache.spark.network.TransportContext +import org.apache.spark.network.crypto.AuthServerBootstrap +import org.apache.spark.network.sasl.ShuffleSecretManager +import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.server.TransportServerBootstrap +import org.apache.spark.network.util.TransportConf +import org.apache.spark.network.yarn.util.HadoopConfigProvider + +import org.apache.spark.shuffle.utils.UcxLogging +import org.apache.spark.network.yarn.YarnShuffleService.AppId +import org.apache.spark.network.shuffle.ExternalUcxShuffleBlockHandler +import org.apache.spark.shuffle.ucx.ExternalUcxServerConf +import org.apache.spark.shuffle.ucx.ExternalUcxServerTransport + +class UcxYarnShuffleService extends AuxiliaryService("sparkucx_shuffle") with UcxLogging { + private[this] var ucxTransport: ExternalUcxServerTransport = _ + private[this] var secretManager: ShuffleSecretManager = _ + private[this] var shuffleServer: TransportServer = _ + private[this] var _conf: Configuration = _ + private[this] var _recoveryPath: Path = _ + private[this] var blockHandler: ExternalUcxShuffleBlockHandler = _ + private[this] var registeredExecutorFile: File = _ + private[this] var secretsFile: File = _ + private[this] var db: DB = _ + + UcxYarnShuffleService.instance = this + /** + * Return whether authentication is enabled as specified by the configuration. + * If so, fetch requests will fail unless the appropriate authentication secret + * for the application is provided. + */ + def isAuthenticationEnabled(): Boolean = secretManager != null + /** + * Start the shuffle server with the given configuration. + */ + override protected def serviceInit(conf: Configuration) = { + _conf = conf + + val stopOnFailure = conf.getBoolean( + UcxYarnShuffleService.STOP_ON_FAILURE_KEY, + UcxYarnShuffleService.DEFAULT_STOP_ON_FAILURE) + + try { + // In case this NM was killed while there were running spark applications, we need to restore + // lost state for the existing executors. We look for an existing file in the NM's local dirs. + // If we don't find one, then we choose a file to use to save the state next time. Even if + // an application was stopped while the NM was down, we expect yarn to call stopApplication() + // when it comes back + if (_recoveryPath != null) { + registeredExecutorFile = initRecoveryDb(UcxYarnShuffleService.RECOVERY_FILE_NAME) + } + + val transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)) + blockHandler = new ExternalUcxShuffleBlockHandler(transportConf, registeredExecutorFile) + + // If authentication is enabled, set up the shuffle server to use a + // special RPC handler that filters out unauthenticated fetch requests + val bootstraps = Lists.newArrayList[TransportServerBootstrap]() + val authEnabled = conf.getBoolean( + UcxYarnShuffleService.SPARK_AUTHENTICATE_KEY, + UcxYarnShuffleService.DEFAULT_SPARK_AUTHENTICATE) + if (authEnabled) { + secretManager = new ShuffleSecretManager() + if (_recoveryPath != null) { + loadSecretsFromDb() + } + bootstraps.add(new AuthServerBootstrap(transportConf, secretManager)) + } + + // User might not like to replace tcp service, so use another port to transfer executors info. + val portConf = conf.getInt( + ExternalUcxServerConf.SPARK_UCX_SHUFFLE_SERVICE_TCP_PORT_KEY, + conf.getInt( + UcxYarnShuffleService.SPARK_SHUFFLE_SERVICE_PORT_KEY, + UcxYarnShuffleService.DEFAULT_SPARK_SHUFFLE_SERVICE_PORT)) + val transportContext = new TransportContext(transportConf, blockHandler) + shuffleServer = transportContext.createServer(portConf, bootstraps) + // the port should normally be fixed, but for tests its useful to find an open port + val port = shuffleServer.getPort() + + val authEnabledString = if (authEnabled) "enabled" else "not enabled" + logInfo(s"Started YARN shuffle service for Spark on port ${port}. " + + s"Authentication is ${authEnabledString}. Registered executor file is ${registeredExecutorFile}") + + // Ucx Transport + logInfo("Start launching ExternalUcxServerTransport") + val ucxConf = new ExternalUcxServerConf(conf) + ucxTransport = new ExternalUcxServerTransport(ucxConf, blockHandler.ucxBlockManager) + ucxTransport.init() + blockHandler.setTransport(ucxTransport) + } catch { + case e: Exception => if (stopOnFailure) { + throw e + } else { + // logError(s"Start UcxYarnShuffleService failed: $e") + noteFailure(e) + } + } + } + + private def loadSecretsFromDb(): Unit = { + secretsFile = initRecoveryDb(UcxYarnShuffleService.SECRETS_RECOVERY_FILE_NAME) + + // Make sure this is protected in case its not in the NM recovery dir + val fs = FileSystem.getLocal(_conf) + fs.mkdirs(new Path(secretsFile.getPath()), new FsPermission(448.toShort)) // 0700 + + db = LevelDBProvider.initLevelDB(secretsFile, UcxYarnShuffleService.CURRENT_VERSION, UcxYarnShuffleService.mapper) + logInfo("Recovery location is: " + secretsFile.getPath()) + if (db != null) { + logInfo("Going to reload spark shuffle data") + val itr = db.iterator() + itr.seek(UcxYarnShuffleService.APP_CREDS_KEY_PREFIX.getBytes(StandardCharsets.UTF_8)) + while (itr.hasNext()) { + val e = itr.next() + val key = new String(e.getKey(), StandardCharsets.UTF_8) + if (!key.startsWith(UcxYarnShuffleService.APP_CREDS_KEY_PREFIX)) { + return + } + val id = UcxYarnShuffleService.parseDbAppKey(key) + val secret = UcxYarnShuffleService.mapper.readValue(e.getValue(), classOf[ByteBuffer]) + logInfo("Reloading tokens for app: " + id) + secretManager.registerApp(id, secret) + } + } + } + + override def initializeApplication(context: ApplicationInitializationContext): Unit = { + val appId = context.getApplicationId().toString() + try { + val shuffleSecret = context.getApplicationDataForService() + if (isAuthenticationEnabled()) { + val fullId = new AppId(appId) + if (db != null) { + val key = UcxYarnShuffleService.dbAppKey(fullId) + val value = UcxYarnShuffleService.mapper.writeValueAsString(shuffleSecret).getBytes(StandardCharsets.UTF_8) + db.put(key, value) + } + secretManager.registerApp(appId, shuffleSecret) + } + } catch { + case e: Exception => logError(s"Exception when initializing application ${appId}", e) + } + } + + override def stopApplication(context: ApplicationTerminationContext): Unit = { + val appId = context.getApplicationId().toString() + try { + if (isAuthenticationEnabled()) { + val fullId = new AppId(appId) + if (db != null) { + try { + db.delete(UcxYarnShuffleService.dbAppKey(fullId)) + } catch { + case e: IOException => logError(s"Error deleting ${appId} from executor state db", e) + } + } + secretManager.unregisterApp(appId) + } + blockHandler.applicationRemoved(appId, false /* clean up local dirs */) + } catch { + case e: Exception => logError(s"Exception when stopping application ${appId}", e) + } + } + + override def initializeContainer(context: ContainerInitializationContext): Unit = {} + + override def stopContainer(context: ContainerTerminationContext): Unit = {} + + // Not currently used + override def getMetaData(): ByteBuffer = { + return ByteBuffer.allocate(0) + } + + /** + * Set the recovery path for shuffle service recovery when NM is restarted. This will be call + * by NM if NM recovery is enabled. + */ + override def setRecoveryPath(recoveryPath: Path): Unit = { + _recoveryPath = recoveryPath + } + + /** + * Get the path specific to this auxiliary service to use for recovery. + */ + protected def getRecoveryPath(fileName: String): Path = { + return _recoveryPath + } + + /** + * Figure out the recovery path and handle moving the DB if YARN NM recovery gets enabled + * and DB exists in the local dir of NM by old version of shuffle service. + */ + def initRecoveryDb(dbName: String): File = { + if (_recoveryPath == null) { + throw new NullPointerException("recovery path should not be null if NM recovery is enabled") + } + + val recoveryFile = new File(_recoveryPath.toUri().getPath(), dbName) + if (recoveryFile.exists()) { + return recoveryFile + } + + // db doesn't exist in recovery path go check local dirs for it + val localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs") + for (dir <- localDirs) { + val f = new File(new Path(dir).toUri().getPath(), dbName) + if (f.exists()) { + // If the recovery path is set then either NM recovery is enabled or another recovery + // DB has been initialized. If NM recovery is enabled and had set the recovery path + // make sure to move all DBs to the recovery path from the old NM local dirs. + // If another DB was initialized first just make sure all the DBs are in the same + // location. + val newLoc = new Path(_recoveryPath, dbName) + val copyFrom = new Path(f.toURI()) + if (!newLoc.equals(copyFrom)) { + logInfo("Moving " + copyFrom + " to: " + newLoc) + try { + // The move here needs to handle moving non-empty directories across NFS mounts + val fs = FileSystem.getLocal(_conf) + fs.rename(copyFrom, newLoc) + } catch { + // Fail to move recovery file to new path, just continue on with new DB location + case e: Exception => logError( + s"Failed to move recovery file ${dbName} to the path ${_recoveryPath.toString()}", + e) + } + } + return new File(newLoc.toUri().getPath()) + } + } + + return new File(_recoveryPath.toUri().getPath(), dbName) + } + + override protected def serviceStop(): Unit = { + try { + if (shuffleServer != null) { + shuffleServer.close() + } + if (blockHandler != null) { + blockHandler.close() + } + if (ucxTransport != null) { + ucxTransport.close() + } + if (db != null) { + db.close() + } + } catch { + case e: Exception => logError("Exception when stopping service", e) + } + } +} + +object UcxYarnShuffleService { + // Port on which the shuffle server listens for fetch requests + val SPARK_SHUFFLE_SERVICE_PORT_KEY = "spark.shuffle.service.port" + val DEFAULT_SPARK_SHUFFLE_SERVICE_PORT = 7337 + + // Whether the shuffle server should authenticate fetch requests + val SPARK_AUTHENTICATE_KEY = "spark.authenticate" + val DEFAULT_SPARK_AUTHENTICATE = false + + val RECOVERY_FILE_NAME = "registeredExecutors.ldb" + val SECRETS_RECOVERY_FILE_NAME = "sparkShuffleRecovery.ldb" + + // Whether failure during service initialization should stop the NM. + val STOP_ON_FAILURE_KEY = "spark.yarn.shuffle.stopOnFailure" + val DEFAULT_STOP_ON_FAILURE = false + + val mapper = new ObjectMapper() + val APP_CREDS_KEY_PREFIX = "AppCreds" + val CURRENT_VERSION = new LevelDBProvider.StoreVersion(1, 0) + + var instance: UcxYarnShuffleService = _ + + private def parseDbAppKey(s: String): String = { + if (!s.startsWith(APP_CREDS_KEY_PREFIX)) { + throw new IllegalArgumentException("expected a string starting with " + APP_CREDS_KEY_PREFIX) + } + val json = s.substring(APP_CREDS_KEY_PREFIX.length() + 1) + val parsed = mapper.readValue(json, classOf[AppId]) + return parsed.appId + } + + private def dbAppKey(appExecId: AppId): Array[Byte] = { + // we stick a common prefix on all the keys so we can find them in the DB + val appExecJson = mapper.writeValueAsString(appExecId) + val key = (APP_CREDS_KEY_PREFIX + ";" + appExecJson) + return key.getBytes(StandardCharsets.UTF_8) + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/network/yarn/spark_3_0/UcxYarnShuffleService.scala b/src/main/scala/org/apache/spark/network/yarn/spark_3_0/UcxYarnShuffleService.scala new file mode 100644 index 00000000..35e977b8 --- /dev/null +++ b/src/main/scala/org/apache/spark/network/yarn/spark_3_0/UcxYarnShuffleService.scala @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License") you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn + +import java.io.File +import java.io.IOException +import java.nio.charset.StandardCharsets +import java.nio.ByteBuffer + +import com.fasterxml.jackson.databind.ObjectMapper +import com.google.common.collect.Lists +import com.google.common.base.Preconditions +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.permission.FsPermission +import org.apache.hadoop.metrics2.impl.MetricsSystemImpl +import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem +import org.apache.hadoop.yarn.server.api._ +import org.apache.spark.network.util.LevelDBProvider +import org.iq80.leveldb.DB + +import org.apache.spark.network.TransportContext +import org.apache.spark.network.crypto.AuthServerBootstrap +import org.apache.spark.network.sasl.ShuffleSecretManager +import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.server.TransportServerBootstrap +import org.apache.spark.network.util.TransportConf +import org.apache.spark.network.yarn.util.HadoopConfigProvider + +import org.apache.spark.shuffle.utils.UcxLogging +import org.apache.spark.network.yarn.YarnShuffleService.AppId +import org.apache.spark.network.shuffle.ExternalUcxShuffleBlockHandler +import org.apache.spark.shuffle.ucx.ExternalUcxServerConf +import org.apache.spark.shuffle.ucx.ExternalUcxServerTransport + +class UcxYarnShuffleService extends AuxiliaryService("sparkucx_shuffle") with UcxLogging { + private[this] var ucxTransport: ExternalUcxServerTransport = _ + private[this] var secretManager: ShuffleSecretManager = _ + private[this] var shuffleServer: TransportServer = _ + private[this] var _conf: Configuration = _ + private[this] var _recoveryPath: Path = _ + private[this] var blockHandler: ExternalUcxShuffleBlockHandler = _ + private[this] var registeredExecutorFile: File = _ + private[this] var secretsFile: File = _ + private[this] var db: DB = _ + + UcxYarnShuffleService.instance = this + /** + * Return whether authentication is enabled as specified by the configuration. + * If so, fetch requests will fail unless the appropriate authentication secret + * for the application is provided. + */ + def isAuthenticationEnabled(): Boolean = secretManager != null + /** + * Start the shuffle server with the given configuration. + */ + override protected def serviceInit(conf: Configuration) = { + _conf = conf + + val stopOnFailure = conf.getBoolean( + UcxYarnShuffleService.STOP_ON_FAILURE_KEY, + UcxYarnShuffleService.DEFAULT_STOP_ON_FAILURE) + + try { + // In case this NM was killed while there were running spark applications, we need to restore + // lost state for the existing executors. We look for an existing file in the NM's local dirs. + // If we don't find one, then we choose a file to use to save the state next time. Even if + // an application was stopped while the NM was down, we expect yarn to call stopApplication() + // when it comes back + if (_recoveryPath != null) { + registeredExecutorFile = initRecoveryDb(UcxYarnShuffleService.RECOVERY_FILE_NAME) + } + + val transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)) + blockHandler = new ExternalUcxShuffleBlockHandler(transportConf, registeredExecutorFile) + + // If authentication is enabled, set up the shuffle server to use a + // special RPC handler that filters out unauthenticated fetch requests + val bootstraps = Lists.newArrayList[TransportServerBootstrap]() + val authEnabled = conf.getBoolean( + UcxYarnShuffleService.SPARK_AUTHENTICATE_KEY, + UcxYarnShuffleService.DEFAULT_SPARK_AUTHENTICATE) + if (authEnabled) { + secretManager = new ShuffleSecretManager() + if (_recoveryPath != null) { + loadSecretsFromDb() + } + bootstraps.add(new AuthServerBootstrap(transportConf, secretManager)) + } + + // User might not like to replace tcp service, so use another port to transfer executors info. + val portConf = conf.getInt( + ExternalUcxServerConf.SPARK_UCX_SHUFFLE_SERVICE_TCP_PORT_KEY, + conf.getInt( + UcxYarnShuffleService.SPARK_SHUFFLE_SERVICE_PORT_KEY, + UcxYarnShuffleService.DEFAULT_SPARK_SHUFFLE_SERVICE_PORT)) + val transportContext = new TransportContext(transportConf, blockHandler) + shuffleServer = transportContext.createServer(portConf, bootstraps) + // the port should normally be fixed, but for tests its useful to find an open port + val port = shuffleServer.getPort() + + val authEnabledString = if (authEnabled) "enabled" else "not enabled" + logInfo(s"Started YARN shuffle service for Spark on port ${port}. " + + s"Authentication is ${authEnabledString}. Registered executor file is ${registeredExecutorFile}") + + // register metrics on the block handler into the Node Manager's metrics system. + blockHandler.getAllMetrics().getMetrics().put("numRegisteredConnections", + shuffleServer.getRegisteredConnections()); + val serviceMetrics = + new YarnShuffleServiceMetrics(blockHandler.getAllMetrics()); + + val metricsSystem = DefaultMetricsSystem.instance().asInstanceOf[MetricsSystemImpl]; + metricsSystem.register( + "sparkUcxShuffleService", "Metrics on the Spark Shuffle Service", serviceMetrics); + logInfo("Registered metrics with Hadoop's DefaultMetricsSystem"); + + logInfo(s"Started YARN shuffle service for Spark on port ${port}. " + + s"Authentication is ${authEnabledString}. " + + s"Registered executor file is ${registeredExecutorFile}"); + + // Ucx Transport + logInfo("Start launching ExternalUcxServerTransport") + val ucxConf = new ExternalUcxServerConf(conf) + ucxTransport = new ExternalUcxServerTransport(ucxConf, blockHandler.ucxBlockManager) + ucxTransport.init() + blockHandler.setTransport(ucxTransport) + } catch { + case e: Exception => if (stopOnFailure) { + throw e + } else { + // logError(s"Start UcxYarnShuffleService failed: $e") + noteFailure(e) + } + } + } + + private def loadSecretsFromDb(): Unit = { + secretsFile = initRecoveryDb(UcxYarnShuffleService.SECRETS_RECOVERY_FILE_NAME) + + // Make sure this is protected in case its not in the NM recovery dir + val fs = FileSystem.getLocal(_conf) + fs.mkdirs(new Path(secretsFile.getPath()), new FsPermission(448.toShort)) // 0700 + + db = LevelDBProvider.initLevelDB(secretsFile, UcxYarnShuffleService.CURRENT_VERSION, UcxYarnShuffleService.mapper) + logInfo("Recovery location is: " + secretsFile.getPath()) + if (db != null) { + logInfo("Going to reload spark shuffle data") + val itr = db.iterator() + itr.seek(UcxYarnShuffleService.APP_CREDS_KEY_PREFIX.getBytes(StandardCharsets.UTF_8)) + while (itr.hasNext()) { + val e = itr.next() + val key = new String(e.getKey(), StandardCharsets.UTF_8) + if (!key.startsWith(UcxYarnShuffleService.APP_CREDS_KEY_PREFIX)) { + return + } + val id = UcxYarnShuffleService.parseDbAppKey(key) + val secret = UcxYarnShuffleService.mapper.readValue(e.getValue(), classOf[ByteBuffer]) + logInfo("Reloading tokens for app: " + id) + secretManager.registerApp(id, secret) + } + } + } + + override def initializeApplication(context: ApplicationInitializationContext): Unit = { + val appId = context.getApplicationId().toString() + try { + val shuffleSecret = context.getApplicationDataForService() + if (isAuthenticationEnabled()) { + val fullId = new AppId(appId) + if (db != null) { + val key = UcxYarnShuffleService.dbAppKey(fullId) + val value = UcxYarnShuffleService.mapper.writeValueAsString(shuffleSecret).getBytes(StandardCharsets.UTF_8) + db.put(key, value) + } + secretManager.registerApp(appId, shuffleSecret) + } + } catch { + case e: Exception => logError(s"Exception when initializing application ${appId}", e) + } + } + + override def stopApplication(context: ApplicationTerminationContext): Unit = { + val appId = context.getApplicationId().toString() + try { + if (isAuthenticationEnabled()) { + val fullId = new AppId(appId) + if (db != null) { + try { + db.delete(UcxYarnShuffleService.dbAppKey(fullId)) + } catch { + case e: IOException => logError(s"Error deleting ${appId} from executor state db", e) + } + } + secretManager.unregisterApp(appId) + } + blockHandler.applicationRemoved(appId, false /* clean up local dirs */) + } catch { + case e: Exception => logError(s"Exception when stopping application ${appId}", e) + } + } + + override def initializeContainer(context: ContainerInitializationContext): Unit = {} + + override def stopContainer(context: ContainerTerminationContext): Unit = {} + + // Not currently used + override def getMetaData(): ByteBuffer = { + return ByteBuffer.allocate(0) + } + + /** + * Set the recovery path for shuffle service recovery when NM is restarted. This will be call + * by NM if NM recovery is enabled. + */ + override def setRecoveryPath(recoveryPath: Path): Unit = { + _recoveryPath = recoveryPath + } + + /** + * Get the path specific to this auxiliary service to use for recovery. + */ + protected def getRecoveryPath(fileName: String): Path = { + return _recoveryPath + } + + /** + * Figure out the recovery path and handle moving the DB if YARN NM recovery gets enabled + * and DB exists in the local dir of NM by old version of shuffle service. + */ + def initRecoveryDb(dbName: String): File = { + if (_recoveryPath == null) { + throw new NullPointerException("recovery path should not be null if NM recovery is enabled") + } + + val recoveryFile = new File(_recoveryPath.toUri().getPath(), dbName) + if (recoveryFile.exists()) { + return recoveryFile + } + + // db doesn't exist in recovery path go check local dirs for it + val localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs") + for (dir <- localDirs) { + val f = new File(new Path(dir).toUri().getPath(), dbName) + if (f.exists()) { + // If the recovery path is set then either NM recovery is enabled or another recovery + // DB has been initialized. If NM recovery is enabled and had set the recovery path + // make sure to move all DBs to the recovery path from the old NM local dirs. + // If another DB was initialized first just make sure all the DBs are in the same + // location. + val newLoc = new Path(_recoveryPath, dbName) + val copyFrom = new Path(f.toURI()) + if (!newLoc.equals(copyFrom)) { + logInfo("Moving " + copyFrom + " to: " + newLoc) + try { + // The move here needs to handle moving non-empty directories across NFS mounts + val fs = FileSystem.getLocal(_conf) + fs.rename(copyFrom, newLoc) + } catch { + // Fail to move recovery file to new path, just continue on with new DB location + case e: Exception => logError( + s"Failed to move recovery file ${dbName} to the path ${_recoveryPath.toString()}", + e) + } + } + return new File(newLoc.toUri().getPath()) + } + } + + return new File(_recoveryPath.toUri().getPath(), dbName) + } + + override protected def serviceStop(): Unit = { + try { + if (shuffleServer != null) { + shuffleServer.close() + } + if (blockHandler != null) { + blockHandler.close() + } + if (ucxTransport != null) { + ucxTransport.close() + } + if (db != null) { + db.close() + } + } catch { + case e: Exception => logError("Exception when stopping service", e) + } + } +} + +object UcxYarnShuffleService { + // Port on which the shuffle server listens for fetch requests + val SPARK_SHUFFLE_SERVICE_PORT_KEY = "spark.shuffle.service.port" + val DEFAULT_SPARK_SHUFFLE_SERVICE_PORT = 7337 + + // Whether the shuffle server should authenticate fetch requests + val SPARK_AUTHENTICATE_KEY = "spark.authenticate" + val DEFAULT_SPARK_AUTHENTICATE = false + + val RECOVERY_FILE_NAME = "registeredExecutors.ldb" + val SECRETS_RECOVERY_FILE_NAME = "sparkShuffleRecovery.ldb" + + // Whether failure during service initialization should stop the NM. + val STOP_ON_FAILURE_KEY = "spark.yarn.shuffle.stopOnFailure" + val DEFAULT_STOP_ON_FAILURE = false + + val mapper = new ObjectMapper() + val APP_CREDS_KEY_PREFIX = "AppCreds" + val CURRENT_VERSION = new LevelDBProvider.StoreVersion(1, 0) + + var instance: UcxYarnShuffleService = _ + + private def parseDbAppKey(s: String): String = { + if (!s.startsWith(APP_CREDS_KEY_PREFIX)) { + throw new IllegalArgumentException("expected a string starting with " + APP_CREDS_KEY_PREFIX) + } + val json = s.substring(APP_CREDS_KEY_PREFIX.length() + 1) + val parsed = mapper.readValue(json, classOf[AppId]) + return parsed.appId + } + + private def dbAppKey(appExecId: AppId): Array[Byte] = { + // we stick a common prefix on all the keys so we can find them in the DB + val appExecJson = mapper.writeValueAsString(appExecId) + val key = (APP_CREDS_KEY_PREFIX + ";" + appExecJson) + return key.getBytes(StandardCharsets.UTF_8) + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleClient.scala new file mode 100644 index 00000000..763e759d --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleClient.scala @@ -0,0 +1,49 @@ +package org.apache.spark.shuffle.compat.spark_2_4 + +import org.openucx.jucx.UcxUtils +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ShuffleClient} +import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBlockId, ExternalUcxClientTransport, UcxFetchCallBack, UcxDownloadCallBack} +import org.apache.spark.shuffle.utils.UnsafeUtils +import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} + +class ExternalUcxShuffleClient(val transport: ExternalUcxClientTransport) extends ShuffleClient{ + + override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], + listener: BlockFetchingListener, + downloadFileManager: DownloadFileManager): Unit = { + if (downloadFileManager == null) { + val ucxBlockIds = Array.ofDim[UcxShuffleBlockId](blockIds.length) + val callbacks = Array.ofDim[OperationCallback](blockIds.length) + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)) + .asInstanceOf[SparkShuffleBlockId] + ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId, + blockId.reduceId) + callbacks(i) = new UcxFetchCallBack(blockIds(i), listener) + } + val maxBlocksPerRequest = transport.getMaxBlocksPerRequest + for (i <- 0 until ucxBlockIds.length by maxBlocksPerRequest) { + val j = i + maxBlocksPerRequest + transport.fetchBlocksByBlockIds(host, execId.toInt, + ucxBlockIds.slice(i, j), + callbacks.slice(i, j)) + } + } else { + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)) + .asInstanceOf[SparkShuffleBlockId] + val ucxBid = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId, + blockId.reduceId) + val callback = new UcxDownloadCallBack(blockIds(i), listener, + downloadFileManager, + transport.sparkTransportConf) + transport.fetchBlockByStream(host, execId.toInt, ucxBid, callback) + } + } + } + + override def close(): Unit = { + + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleManager.scala new file mode 100644 index 00000000..06d9956b --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleManager.scala @@ -0,0 +1,23 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle + +import org.apache.spark.shuffle.compat.spark_2_4.{ExternalUcxShuffleClient, ExternalUcxShuffleReader} +import org.apache.spark.shuffle.ucx.ExternalBaseUcxShuffleManager +import org.apache.spark.{SparkConf, TaskContext} + +/** + * Common part for all spark versions for UcxShuffleManager logic + */ +class ExternalUcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) + extends ExternalBaseUcxShuffleManager(conf, isDriver) { + private[spark] lazy val transport = awaitUcxTransport + private[spark] lazy val shuffleClient = new ExternalUcxShuffleClient(transport) + override def getReader[K, C](handle: ShuffleHandle, startPartition: Int, + endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { + new ExternalUcxShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K,_,C]], startPartition, + endPartition, context, shuffleClient) + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleReader.scala new file mode 100644 index 00000000..16d19b36 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/ExternalUcxShuffleReader.scala @@ -0,0 +1,117 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.compat.spark_2_4 + +import java.io.InputStream +import java.util.concurrent.LinkedBlockingQueue + +import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.internal.{Logging, config} +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter + +/** + * Extension of Spark's shuffle reader with a logic of injection UcxShuffleClient, + * and lazy progress only when result queue is empty. + */ +class ExternalUcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext, + shuffleClient: ExternalUcxShuffleClient, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) + extends ShuffleReader[K, C] with Logging { + + private val dep = handle.dependency + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() + val wrappedStreams = new ShuffleBlockFetcherIterator( + context, + shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, + startPartition, endPartition), + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsBytes("spark.reducer.maxSizeInFlight", "48m"), + SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) + + val serializerInstance = dep.serializer.newInstance() + val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map { record => + readMetrics.incRecordsRead(1) + record + }, + context.taskMetrics().mergeShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + if (dep.mapSideCombine) { + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + } else { + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) + } + } else { + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] + } + + // Sort the output if there is a sort ordering defined. + val resultIter = dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. + val sorter = + new ExternalSorter[K, C, C](context, + ordering = Some(keyOrd), serializer = dep.serializer) + sorter.insertAll(aggregatedIter) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener[Unit](_ => { + sorter.stop() + }) + CompletionIterator[Product2[K, C], + Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) + case None => + aggregatedIter + } + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } + } + +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleClient.scala new file mode 100644 index 00000000..2a0c8814 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleClient.scala @@ -0,0 +1,49 @@ +package org.apache.spark.shuffle.compat.spark_3_0 + +import org.openucx.jucx.UcxUtils +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, BlockStoreClient} +import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBlockId, ExternalUcxClientTransport, UcxFetchCallBack, UcxDownloadCallBack} +import org.apache.spark.shuffle.utils.UnsafeUtils +import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} + +class ExternalUcxShuffleClient(val transport: ExternalUcxClientTransport) extends BlockStoreClient{ + + override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], + listener: BlockFetchingListener, + downloadFileManager: DownloadFileManager): Unit = { + if (downloadFileManager == null) { + val ucxBlockIds = Array.ofDim[UcxShuffleBlockId](blockIds.length) + val callbacks = Array.ofDim[OperationCallback](blockIds.length) + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)) + .asInstanceOf[SparkShuffleBlockId] + ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId, + blockId.reduceId) + callbacks(i) = new UcxFetchCallBack(blockIds(i), listener) + } + val maxBlocksPerRequest = transport.getMaxBlocksPerRequest + for (i <- 0 until blockIds.length by maxBlocksPerRequest) { + val j = i + maxBlocksPerRequest + transport.fetchBlocksByBlockIds(host, execId.toInt, + ucxBlockIds.slice(i, j), + callbacks.slice(i, j)) + } + } else { + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)) + .asInstanceOf[SparkShuffleBlockId] + val ucxBid = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId, + blockId.reduceId) + val callback = new UcxDownloadCallBack(blockIds(i), listener, + downloadFileManager, + transport.sparkTransportConf) + transport.fetchBlockByStream(host, execId.toInt, ucxBid, callback) + } + } + } + + override def close(): Unit = { + + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleManager.scala new file mode 100644 index 00000000..49d099aa --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleManager.scala @@ -0,0 +1,26 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle + +import org.apache.spark.shuffle.compat.spark_3_0.{ExternalUcxShuffleClient, ExternalUcxShuffleReader} +import org.apache.spark.shuffle.ucx.ExternalBaseUcxShuffleManager +import org.apache.spark.{SparkConf, TaskContext} + +/** + * Common part for all spark versions for UcxShuffleManager logic + */ +class ExternalUcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) + extends ExternalBaseUcxShuffleManager(conf, isDriver) { + private[spark] lazy val transport = awaitUcxTransport + private[spark] lazy val shuffleClient = new ExternalUcxShuffleClient(transport) + override def getReader[K, C]( + handle: ShuffleHandle, startPartition: MapId, endPartition: MapId, + context: TaskContext, metrics: ShuffleReadMetricsReporter): + ShuffleReader[K, C] = { + new ExternalUcxShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K,_,C]], + startPartition, endPartition, context, + shuffleClient, metrics, shouldBatchFetch = false) + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleReader.scala new file mode 100644 index 00000000..5bf53c52 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/ExternalUcxShuffleReader.scala @@ -0,0 +1,124 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.compat.spark_3_0 + +import java.io.InputStream +import java.util.concurrent.LinkedBlockingQueue + +import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.internal.{Logging, config} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReadMetricsReporter, ShuffleReader} +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter + +/** + * Extension of Spark's shuffle reader with a logic of injection UcxShuffleClient, + * and lazy progress only when result queue is empty. + */ +class ExternalUcxShuffleReader[K, C]( + handle: BaseShuffleHandle[K, _, C], startPartition: Int, + endPartition: Int, context: TaskContext, + shuffleClient: ExternalUcxShuffleClient, + readMetrics: ShuffleReadMetricsReporter, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + shouldBatchFetch: Boolean = false) + extends ShuffleReader[K, C] with Logging { + + private val dep = handle.dependency + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val (blocksByAddressIterator1, _) = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, startPartition, endPartition).duplicate + val shuffleIterator = new ShuffleBlockFetcherIterator( + context, + shuffleClient, + blockManager, + blocksByAddressIterator1, + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, + SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + readMetrics, + // TODO: Support batch fetch + doBatchFetch = false) + + val wrappedStreams = shuffleIterator.toCompletionIterator + + val serializerInstance = dep.serializer.newInstance() + + // Create a key/value iterator for each stream + val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map { record => + readMetrics.incRecordsRead(1) + record + }, + context.taskMetrics().mergeShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + if (dep.mapSideCombine) { + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + } else { + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) + } + } else { + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] + } + + // Sort the output if there is a sort ordering defined. + val resultIter = dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. + val sorter = + new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) + sorter.insertAll(aggregatedIter) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener[Unit](_ => { + sorter.stop() + }) + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) + case None => + aggregatedIter + } + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } + } + +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala index 314b88a3..64b8df32 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala @@ -5,7 +5,9 @@ package org.apache.spark.shuffle.ucx import java.nio.ByteBuffer +import java.util.{HashSet, HashMap} import java.util.concurrent.locks.StampedLock +import org.openucx.jucx.ucp.UcpRequest /** * Class that represents some block in memory with it's address, size. @@ -90,6 +92,8 @@ trait Request { */ trait OperationCallback { def onComplete(result: OperationResult): Unit + def onError(result: OperationResult): Unit = ??? + def onData(buf: ByteBuffer): Unit = ??? } /** @@ -167,3 +171,82 @@ trait ShuffleTransport { def progress(): Unit } + +class UcxRequest(private var request: UcpRequest, stats: OperationStats) + extends Request { + + private[ucx] var completed = false + + override def isCompleted: Boolean = completed || ((request != null) && request.isCompleted) + + override def getStats: Option[OperationStats] = Some(stats) + + override def toString(): String = { + s"UcxRequest(isCompleted=$isCompleted size=${stats.recvSize} cost=${stats.getElapsedTimeNs}ns)" + } + + private[ucx] def setRequest(request: UcpRequest): Unit = { + this.request = request + } +} + +class UcxStats extends OperationStats { + private[ucx] val startTime = System.nanoTime() + private[ucx] var amHandleTime = 0L + private[ucx] var endTime: Long = 0L + private[ucx] var receiveSize: Long = 0L + + /** + * Time it took from operation submit to callback call. + * This depends on [[ ShuffleTransport.progress() ]] calls, + * and does not indicate actual data transfer time. + */ + override def getElapsedTimeNs: Long = endTime - startTime + + /** + * Indicates number of valid bytes in receive memory + */ + override def recvSize: Long = receiveSize +} + +class UcxFetchState(val callbacks: Seq[OperationCallback], + val request: UcxRequest, + val timestamp: Long, + val recvSet: HashSet[Int] = new HashSet[Int]) { + override def toString(): String = { + s"UcxFetchState(chunks=${callbacks.size}, $request, start=$timestamp, received=${recvSet.size})" + } +} + +class UcxStreamState(val callback: OperationCallback, + val request: UcxRequest, + val timestamp: Long, + var remaining: Long, + val recvMap: HashMap[Long, MemoryBlock] = new HashMap[Long, MemoryBlock]) { + override def toString(): String = { + s"UcxStreamState($request, start=$timestamp, remaining=$remaining, received=${recvMap.size})" + } +} + +class UcxSliceState(val callback: OperationCallback, + val request: UcxRequest, + val timestamp: Long, + val mem: MemoryBlock, + var remaining: Long, + var offset: Long = 0L, + val recvSet: HashSet[Long] = new HashSet[Long]) { + override def toString(): String = { + s"UcxStreamState($request, start=$timestamp, remaining=$remaining, received=${recvSet.size})" + } +} + +class UcxSucceedOperationResult(mem: MemoryBlock, stats: OperationStats) + extends OperationResult { + override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS + + override def getError: TransportError = null + + override def getStats: Option[OperationStats] = Option(stats) + + override def getData: MemoryBlock = mem +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxProgressThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxProgressThread.scala new file mode 100644 index 00000000..68c849ce --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxProgressThread.scala @@ -0,0 +1,25 @@ +package org.apache.spark.shuffle.ucx + +import org.openucx.jucx.ucp.UcpWorker + +class UcxProgressThread(worker: UcpWorker, useWakeup: Boolean) extends Thread { + setDaemon(true) + setName(s"UCX-progress-${super.getName}") + + override def run(): Unit = { + if (useWakeup) { + while (!isInterrupted) { + worker.synchronized { + while (worker.progress != 0) {} + } + worker.waitForEvents() + } + } else { + while (!isInterrupted) { + worker.synchronized { + while (worker.progress != 0) {} + } + } + } + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index ae8bb119..2e46c362 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -19,39 +19,6 @@ import java.nio.ByteBuffer import scala.collection.concurrent.TrieMap import scala.collection.mutable -class UcxRequest(private var request: UcpRequest, stats: OperationStats) - extends Request { - - private[ucx] var completed = false - - override def isCompleted: Boolean = completed || ((request != null) && request.isCompleted) - - override def getStats: Option[OperationStats] = Some(stats) - - private[ucx] def setRequest(request: UcpRequest): Unit = { - this.request = request - } -} - -class UcxStats extends OperationStats { - private[ucx] val startTime = System.nanoTime() - private[ucx] var amHandleTime = 0L - private[ucx] var endTime: Long = 0L - private[ucx] var receiveSize: Long = 0L - - /** - * Time it took from operation submit to callback call. - * This depends on [[ ShuffleTransport.progress() ]] calls, - * and does not indicate actual data transfer time. - */ - override def getElapsedTimeNs: Long = endTime - startTime - - /** - * Indicates number of valid bytes in receive memory - */ - override def recvSize: Long = receiveSize -} - case class UcxShuffleBockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { override def serializedSize: Int = 12 diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/ExternalUcxConf.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/ExternalUcxConf.scala new file mode 100644 index 00000000..88d6d7d4 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/ExternalUcxConf.scala @@ -0,0 +1,88 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +/** + * Plugin configuration properties. + */ +trait ExternalUcxConf { + lazy val preallocateBuffersMap: Map[Long, Int] = + ExternalUcxConf.preAllocateConfToMap(ExternalUcxConf.PREALLOCATE_BUFFERS_DEFAULT) + lazy val memoryLimit: Boolean = ExternalUcxConf.MEMORY_LIMIT_DEFAULT + lazy val memoryGroupSize: Int = ExternalUcxConf.MEMORY_GROUP_SIZE_DEFAULT + lazy val minBufferSize: Long = ExternalUcxConf.MIN_BUFFER_SIZE_DEFAULT + lazy val maxBufferSize: Long = ExternalUcxConf.MAX_BUFFER_SIZE_DEFAULT + lazy val minRegistrationSize: Long = ExternalUcxConf.MIN_REGISTRATION_SIZE_DEFAULT + lazy val maxRegistrationSize: Long = ExternalUcxConf.MAX_REGISTRATION_SIZE_DEFAULT + lazy val numPools: Int = ExternalUcxConf.NUM_POOLS_DEFAULT + lazy val listenerAddress: String = ExternalUcxConf.SOCKADDR_DEFAULT + lazy val useWakeup: Boolean = ExternalUcxConf.WAKEUP_FEATURE_DEFAULT + lazy val numIoThreads: Int = ExternalUcxConf.NUM_IO_THREADS_DEFAULT + lazy val numThreads: Int = ExternalUcxConf.NUM_THREADS_DEFAULT + lazy val numWorkers: Int = ExternalUcxConf.NUM_WORKERS_DEFAULT + lazy val maxBlocksPerRequest: Int = ExternalUcxConf.MAX_BLOCKS_IN_FLIGHT_DEFAULT + lazy val ucxServerPort: Int = ExternalUcxConf.SPARK_UCX_SHUFFLE_SERVICE_PORT_DEFAULT + lazy val maxReplySize: Long = ExternalUcxConf.MAX_REPLY_SIZE_DEFAULT +} + +object ExternalUcxConf { + private[ucx] def getUcxConf(name: String) = s"spark.shuffle.ucx.$name" + + lazy val PREALLOCATE_BUFFERS_KEY = getUcxConf("memory.preAllocateBuffers") + lazy val PREALLOCATE_BUFFERS_DEFAULT = "" + + lazy val MEMORY_LIMIT_KEY = getUcxConf("memory.limit") + lazy val MEMORY_LIMIT_DEFAULT = false + + lazy val MEMORY_GROUP_SIZE_KEY = getUcxConf("memory.groupSize") + lazy val MEMORY_GROUP_SIZE_DEFAULT = 3 + + lazy val MIN_BUFFER_SIZE_KEY = getUcxConf("memory.minBufferSize") + lazy val MIN_BUFFER_SIZE_DEFAULT = 4096L + + lazy val MAX_BUFFER_SIZE_KEY = getUcxConf("memory.maxBufferSize") + lazy val MAX_BUFFER_SIZE_DEFAULT = Int.MaxValue.toLong + + lazy val MIN_REGISTRATION_SIZE_KEY = getUcxConf("memory.minAllocationSize") + lazy val MIN_REGISTRATION_SIZE_DEFAULT = 1L * 1024 * 1024 + + lazy val MAX_REGISTRATION_SIZE_KEY = getUcxConf("memory.maxAllocationSize") + lazy val MAX_REGISTRATION_SIZE_DEFAULT = 16L * 1024 * 1024 * 1024 + + lazy val NUM_POOLS_KEY = getUcxConf("memory.numPools") + lazy val NUM_POOLS_DEFAULT = 1 + + lazy val SOCKADDR_KEY = getUcxConf("listener.sockaddr") + lazy val SOCKADDR_DEFAULT = "0.0.0.0:0" + + lazy val WAKEUP_FEATURE_KEY = getUcxConf("useWakeup") + lazy val WAKEUP_FEATURE_DEFAULT = true + + lazy val NUM_IO_THREADS_KEY = getUcxConf("numIoThreads") + lazy val NUM_IO_THREADS_DEFAULT = 1 + + lazy val NUM_THREADS_KEY = getUcxConf("numThreads") + lazy val NUM_THREADS_COMPAT_KEY = getUcxConf("numListenerThreads") + lazy val NUM_THREADS_DEFAULT = 4 + + lazy val NUM_WORKERS_KEY = getUcxConf("numWorkers") + lazy val NUM_WORKERS_COMPAT_KEY = getUcxConf("numClientWorkers") + lazy val NUM_WORKERS_DEFAULT = 1 + + lazy val MAX_BLOCKS_IN_FLIGHT_KEY = getUcxConf("maxBlocksPerRequest") + lazy val MAX_BLOCKS_IN_FLIGHT_DEFAULT = 50 + + lazy val MAX_REPLY_SIZE_KEY = getUcxConf("maxReplySize") + lazy val MAX_REPLY_SIZE_DEFAULT = 32L * 1024 * 1024 + + lazy val SPARK_UCX_SHUFFLE_SERVICE_PORT_KEY = getUcxConf("service.port") + lazy val SPARK_UCX_SHUFFLE_SERVICE_PORT_DEFAULT = 3338 + + def preAllocateConfToMap(conf: String): Map[Long, Int] = + conf.split(",").withFilter(s => s.nonEmpty).map(entry => + entry.split(":") match { + case Array(bufferSize, bufferCount) => (bufferSize.toLong, bufferCount.toInt) + }).toMap +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/ExternalUcxTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/ExternalUcxTransport.scala new file mode 100644 index 00000000..b38d4e45 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/ExternalUcxTransport.scala @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2022, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ +package org.apache.spark.shuffle.ucx + +import org.apache.spark.shuffle.ucx.memory.UcxLimitedMemPool +import org.apache.spark.shuffle.utils.UcxLogging +import org.openucx.jucx.ucp._ + +import java.nio.ByteBuffer +import java.util.concurrent.ExecutorService + +object ExternalAmId { + // client -> server + final val ADDRESS = 0 + final val CONNECT = 1 + final val FETCH_BLOCK = 2 + final val FETCH_STREAM = 3 + // server -> client + final val REPLY_ADDRESS = 0 + final val REPLY_SLICE = 1 + final val REPLY_BLOCK = 2 + final val REPLY_STREAM = 3 +} + +class ExternalUcxTransport(val ucxShuffleConf: ExternalUcxConf) extends UcxLogging { + @volatile protected var initialized: Boolean = false + @volatile protected var running: Boolean = true + private[ucx] var ucxContext: UcpContext = _ + private[ucx] var memPools: Array[UcxLimitedMemPool] = _ + private[ucx] val ucpWorkerParams = new UcpWorkerParams().requestThreadSafety() + private[ucx] var taskExecutors: ExecutorService = _ + + def hostBounceBufferMemoryPool(i: Int = 0): UcxLimitedMemPool = { + memPools(i % memPools.length) + } + + def estimateNumEps(): Int = 1 + + def initContext(): Unit = { + val numEndpoints = estimateNumEps() + logInfo(s"Creating UCX context with an estimated number of endpoints: $numEndpoints") + + val params = new UcpParams().requestAmFeature().setMtWorkersShared(true) + .setEstimatedNumEps(numEndpoints).requestAmFeature() + .setConfig("USE_MT_MUTEX", "yes") + + if (ucxShuffleConf.useWakeup) { + params.requestWakeupFeature() + } + + ucxContext = new UcpContext(params) + } + + def initMemoryPool(): Unit = { + val numPools = ucxShuffleConf.numPools.max(1).min(ucxShuffleConf.numWorkers) + memPools = new Array[UcxLimitedMemPool](numPools) + for (i <- 0 until numPools) { + memPools(i) = new UcxLimitedMemPool(ucxContext) + memPools(i).init(ucxShuffleConf.minBufferSize, + ucxShuffleConf.maxBufferSize, + ucxShuffleConf.minRegistrationSize, + ucxShuffleConf.maxRegistrationSize / numPools, + ucxShuffleConf.preallocateBuffersMap, + ucxShuffleConf.memoryLimit, + ucxShuffleConf.memoryGroupSize) + } + } + + def initTaskPool(threadNum: Int): Unit = { + taskExecutors = UcxThreadUtils.newFixedDaemonPool("UCX", threadNum) + } + + def init(): ByteBuffer = ??? + + @`inline` + def submit(task: Runnable): Unit = { + taskExecutors.submit(task) + } + + def close(): Unit = { + if (initialized) { + memPools.filter(_ != null).foreach(_.close()) + if (ucxContext != null) { + ucxContext.close() + } + if (taskExecutors != null) { + taskExecutors.shutdown() + } + initialized = false + } + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/UcxWorkerThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/UcxWorkerThread.scala new file mode 100644 index 00000000..6c81e765 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/UcxWorkerThread.scala @@ -0,0 +1,47 @@ +package org.apache.spark.shuffle.ucx + +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.ConcurrentLinkedQueue +import org.openucx.jucx.ucp.UcpWorker + +class UcxWorkerThread(worker: UcpWorker, useWakeup: Boolean) extends Thread { + private val taskQueue = new ConcurrentLinkedQueue[Runnable] + private val running = new AtomicBoolean(true) + + setDaemon(true) + setName(s"UCX-worker-${super.getName}") + + @`inline` + def post(task: Runnable): Unit = { + taskQueue.add(task) + worker.signal() + } + + @`inline` + def await() = { + if (taskQueue.isEmpty) { + worker.waitForEvents() + } + } + + @`inline` + def close(cleanTask: Runnable): Unit = { + if (running.compareAndSet(true, false)) { + worker.signal() + } + cleanTask.run() + worker.close() + } + + override def run(): Unit = { + val doAwait = if (useWakeup) await _ else () => {} + while (running.get()) { + val task = taskQueue.poll() + if (task != null) { + task.run() + } + while (worker.progress != 0) {} + doAwait() + } + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalBaseUcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalBaseUcxShuffleManager.scala new file mode 100644 index 00000000..a4a2e12a --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalBaseUcxShuffleManager.scala @@ -0,0 +1,128 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import java.util.concurrent.TimeUnit + +import scala.concurrent.ExecutionContext.Implicits.global +import scala.util.Success + +import org.apache.spark.SparkException +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.ucx.rpc.{ExternalUcxExecutorRpcEndpoint, ExternalUcxDriverRpcEndpoint} +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{PushServiceAddress, PushAllServiceAddress} +import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer +import org.apache.spark.util.{RpcUtils, ThreadUtils} +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv} +import org.openucx.jucx.{NativeLibs, UcxException} + +/** + * Common part for all spark versions for UcxShuffleManager logic + */ +abstract class ExternalBaseUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) extends SortShuffleManager(conf) { + type ShuffleId = Int + type MapId = Int + type ReduceId = Long + + /* Load UCX/JUCX libraries as soon as possible to avoid collision with JVM when register malloc/mmap hook. */ + if (!isDriver) { + NativeLibs.load(); + } + + val ucxShuffleConf = new ExternalUcxClientConf(conf) + + @volatile var ucxTransport: ExternalUcxClientTransport = _ + + private var executorEndpoint: ExternalUcxExecutorRpcEndpoint = _ + private var driverEndpoint: ExternalUcxDriverRpcEndpoint = _ + + protected val driverRpcName = "SparkUCX_driver" + + private val setupThread = ThreadUtils.newDaemonSingleThreadExecutor("UcxTransportSetupThread") + + private[this] val latch = setupThread.submit(new Runnable { + override def run(): Unit = { + while (SparkEnv.get == null) { + Thread.sleep(10) + } + if (isDriver) { + val rpcEnv = SparkEnv.get.rpcEnv + logInfo(s"Setting up driver RPC") + driverEndpoint = new ExternalUcxDriverRpcEndpoint(rpcEnv) + rpcEnv.setupEndpoint(driverRpcName, driverEndpoint) + } else { + while (SparkEnv.get.blockManager.shuffleServerId == null) { + Thread.sleep(5) + } + startUcxTransport() + } + } + }) + + def awaitUcxTransport(): ExternalUcxClientTransport = { + if (ucxTransport == null) { + latch.get(10, TimeUnit.SECONDS) + if (ucxTransport == null) { + throw new UcxException("ExternalUcxClientTransport init timeout") + } + } + ucxTransport + } + + /** + * Atomically starts UcxNode singleton - one for all shuffle threads. + */ + def startUcxTransport(): Unit = if (ucxTransport == null) { + val blockManager = SparkEnv.get.blockManager.shuffleServerId + val transport = new ExternalUcxClientTransport(ucxShuffleConf, blockManager) + transport.init() + ucxTransport = transport + val rpcEnv = SparkEnv.get.rpcEnv + executorEndpoint = new ExternalUcxExecutorRpcEndpoint(rpcEnv, ucxTransport, setupThread) + val endpoint = rpcEnv.setupEndpoint( + s"ucx-shuffle-executor-${blockManager.executorId}", + executorEndpoint) + var driverCost = 0 + var driverEndpoint = RpcUtils.makeDriverRef(driverRpcName, conf, rpcEnv) + while (driverEndpoint == null) { + Thread.sleep(10) + driverCost += 10 + driverEndpoint = RpcUtils.makeDriverRef(driverRpcName, conf, rpcEnv) + } + driverEndpoint.ask[PushAllServiceAddress]( + PushServiceAddress(blockManager.host, transport.localServerPorts, endpoint)) + .andThen { + case Success(msg) => + logInfo(s"Driver take $driverCost ms.") + executorEndpoint.receive(msg) + } + } + + + override def unregisterShuffle(shuffleId: Int): Boolean = { + // shuffleBlockResolver.asInstanceOf[CommonUcxShuffleBlockResolver].removeShuffle(shuffleId) + super.unregisterShuffle(shuffleId) + } + + /** + * Called on both driver and executors to finally cleanup resources. + */ + override def stop(): Unit = synchronized { + super.stop() + if (ucxTransport != null) { + ucxTransport.close() + ucxTransport = null + } + if (executorEndpoint != null) { + executorEndpoint.stop() + } + if (driverEndpoint != null) { + driverEndpoint.stop() + } + setupThread.shutdown() + } + +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientConf.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientConf.scala new file mode 100644 index 00000000..bffcac4a --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientConf.scala @@ -0,0 +1,137 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import org.apache.spark.SparkConf +import org.apache.spark.internal.config.ConfigBuilder +import org.apache.spark.network.util.ByteUnit + +/** + * Plugin configuration properties. + */ +class ExternalUcxClientConf(val sparkConf: SparkConf) extends SparkConf with ExternalUcxConf { + + def getSparkConf: SparkConf = sparkConf + + // Memory Pool + private lazy val PREALLOCATE_BUFFERS = + ConfigBuilder(ExternalUcxConf.PREALLOCATE_BUFFERS_KEY) + .doc("Comma separated list of buffer size : buffer count pairs to preallocate in memory pool. E.g. 4k:1000,16k:500") + .stringConf.createWithDefault("") + + override lazy val preallocateBuffersMap: Map[Long, Int] = + ExternalUcxConf.preAllocateConfToMap( + sparkConf.get(ExternalUcxConf.PREALLOCATE_BUFFERS_KEY, + ExternalUcxConf.PREALLOCATE_BUFFERS_DEFAULT)) + + private lazy val MEMORY_LIMIT = ConfigBuilder(ExternalUcxConf.MEMORY_LIMIT_KEY) + .doc("Enable memory pool size limit.") + .booleanConf + .createWithDefault(ExternalUcxConf.MEMORY_LIMIT_DEFAULT) + + override lazy val memoryLimit: Boolean = sparkConf.getBoolean(MEMORY_LIMIT.key, + MEMORY_LIMIT.defaultValue.get) + + private lazy val MEMORY_GROUP_SIZE = ConfigBuilder(ExternalUcxConf.MEMORY_GROUP_SIZE_KEY) + .doc("Memory group size.") + .intConf + .createWithDefault(ExternalUcxConf.MEMORY_GROUP_SIZE_DEFAULT) + + override lazy val memoryGroupSize: Int = sparkConf.getInt(MEMORY_GROUP_SIZE.key, + MEMORY_GROUP_SIZE.defaultValue.get) + + private lazy val MIN_BUFFER_SIZE = ConfigBuilder(ExternalUcxConf.MIN_BUFFER_SIZE_KEY) + .doc("Minimal buffer size in memory pool.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(ExternalUcxConf.MIN_BUFFER_SIZE_DEFAULT) + + override lazy val minBufferSize: Long = sparkConf.getSizeAsBytes(MIN_BUFFER_SIZE.key, + MIN_BUFFER_SIZE.defaultValue.get) + + private lazy val MAX_BUFFER_SIZE = ConfigBuilder(ExternalUcxConf.MAX_BUFFER_SIZE_KEY) + .doc("Maximal buffer size in memory pool.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(ExternalUcxConf.MAX_BUFFER_SIZE_DEFAULT) + + override lazy val maxBufferSize: Long = sparkConf.getSizeAsBytes(MAX_BUFFER_SIZE.key, + MAX_BUFFER_SIZE.defaultValue.get) + + private lazy val MIN_REGISTRATION_SIZE = + ConfigBuilder(ExternalUcxConf.MIN_REGISTRATION_SIZE_KEY) + .doc("Minimal memory registration size in memory pool.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(ExternalUcxConf.MIN_REGISTRATION_SIZE_DEFAULT) + + override lazy val minRegistrationSize: Long = sparkConf.getSizeAsBytes(MIN_REGISTRATION_SIZE.key, + MIN_REGISTRATION_SIZE.defaultValue.get).toInt + + private lazy val MAX_REGISTRATION_SIZE = + ConfigBuilder(ExternalUcxConf.MAX_REGISTRATION_SIZE_KEY) + .doc("Maximal memory registration size in memory pool.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(ExternalUcxConf.MAX_REGISTRATION_SIZE_DEFAULT) + + override lazy val maxRegistrationSize: Long = sparkConf.getSizeAsBytes(MAX_REGISTRATION_SIZE.key, + MAX_REGISTRATION_SIZE.defaultValue.get).toLong + + private lazy val NUM_POOLS = + ConfigBuilder(ExternalUcxConf.NUM_POOLS_KEY) + .doc("Number of memory pool.") + .intConf + .createWithDefault(ExternalUcxConf.NUM_POOLS_DEFAULT) + + override lazy val numPools: Int = sparkConf.getInt(NUM_POOLS.key, + NUM_POOLS.defaultValue.get) + + private lazy val SOCKADDR = + ConfigBuilder(ExternalUcxConf.SOCKADDR_KEY) + .doc("Whether to use socket address to connect executors.") + .stringConf + .createWithDefault(ExternalUcxConf.SOCKADDR_DEFAULT) + + override lazy val listenerAddress: String = sparkConf.get(SOCKADDR.key, SOCKADDR.defaultValueString) + + private lazy val WAKEUP_FEATURE = + ConfigBuilder(ExternalUcxConf.WAKEUP_FEATURE_KEY) + .doc("Whether to use busy polling for workers") + .booleanConf + .createWithDefault(ExternalUcxConf.WAKEUP_FEATURE_DEFAULT) + + override lazy val useWakeup: Boolean = sparkConf.getBoolean(WAKEUP_FEATURE.key, WAKEUP_FEATURE.defaultValue.get) + + private lazy val NUM_WORKERS = ConfigBuilder(ExternalUcxConf.NUM_WORKERS_KEY) + .doc("Number of client workers") + .intConf + .createWithDefault(ExternalUcxConf.NUM_WORKERS_DEFAULT) + + override lazy val numWorkers: Int = sparkConf.getInt(NUM_WORKERS.key, sparkConf.getInt("spark.executor.cores", + NUM_WORKERS.defaultValue.get)) + + private lazy val NUM_THREADS= ConfigBuilder(ExternalUcxConf.NUM_THREADS_KEY) + .doc("Number of threads in thread pool") + .intConf + .createWithDefault(ExternalUcxConf.NUM_THREADS_DEFAULT) + + override lazy val numThreads: Int = sparkConf.getInt(NUM_THREADS.key, + sparkConf.getInt(ExternalUcxConf.NUM_THREADS_COMPAT_KEY, NUM_THREADS.defaultValue.get)) + + private lazy val MAX_BLOCKS_IN_FLIGHT = ConfigBuilder(ExternalUcxConf.MAX_BLOCKS_IN_FLIGHT_KEY) + .doc("Maximum number blocks per request") + .intConf + .createWithDefault(ExternalUcxConf.MAX_BLOCKS_IN_FLIGHT_DEFAULT) + + override lazy val maxBlocksPerRequest: Int = sparkConf.getInt(MAX_BLOCKS_IN_FLIGHT.key, MAX_BLOCKS_IN_FLIGHT.defaultValue.get) + + private lazy val MAX_REPLY_SIZE = ConfigBuilder(ExternalUcxConf.MAX_REPLY_SIZE_KEY) + .doc("Maximum number blocks per request") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(ExternalUcxConf.MAX_REPLY_SIZE_DEFAULT) + + override lazy val maxReplySize: Long = sparkConf.getSizeAsBytes(MAX_REPLY_SIZE.key, MAX_REPLY_SIZE.defaultValue.get) + + override lazy val ucxServerPort: Int = sparkConf.getInt( + ExternalUcxConf.SPARK_UCX_SHUFFLE_SERVICE_PORT_KEY, + ExternalUcxConf.SPARK_UCX_SHUFFLE_SERVICE_PORT_DEFAULT) +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientTransport.scala new file mode 100644 index 00000000..564bbcd1 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientTransport.scala @@ -0,0 +1,228 @@ +package org.apache.spark.shuffle.ucx + +// import org.apache.spark.SparkEnv +import org.apache.spark.network.util.TransportConf +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager} +import org.apache.spark.shuffle.utils.{UcxLogging, UnsafeUtils} +import org.apache.spark.shuffle.ucx.memory.UcxLimitedMemPool +import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils} +import org.apache.spark.storage.BlockManagerId +import org.openucx.jucx.ucp._ + +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, TimeUnit} + +class ExternalUcxClientTransport(clientConf: ExternalUcxClientConf, blockManagerId: BlockManagerId) +extends ExternalUcxTransport(clientConf) with UcxLogging { + private[spark] val executorId = blockManagerId.executorId.toLong + private[spark] val tcpServerPort = blockManagerId.port + private[spark] val ucxServerPort = clientConf.ucxServerPort + private[spark] val numWorkers = clientConf.numWorkers + private[spark] val timeoutMs = clientConf.getSparkConf.getTimeAsSeconds( + "spark.network.timeout", "120s") * 1000 + private[spark] val sparkTransportConf = SparkTransportConf.fromSparkConf( + clientConf.getSparkConf, "ucx-shuffle", numWorkers) + + private[ucx] val ucxServers = new ConcurrentHashMap[String, InetSocketAddress] + private[ucx] val localServerPortsDone = new CountDownLatch(1) + private[ucx] var localServerPortsBuffer: ByteBuffer = _ + private[ucx] var localServerPorts: Seq[Int] = _ + + private[ucx] lazy val currentWorkerId = new AtomicInteger() + private[ucx] lazy val workerLocal = new ThreadLocal[ExternalUcxClientWorker] + private[ucx] var allocatedWorker: Array[ExternalUcxClientWorker] = _ + + private[ucx] val scheduledLatch = new CountDownLatch(1) + + private var maxBlocksPerRequest = 0 + + override def estimateNumEps(): Int = numWorkers * + clientConf.sparkConf.getInt("spark.executor.instances", 1) + + override def init(): ByteBuffer = { + initContext() + initMemoryPool() + + if (clientConf.useWakeup) { + ucpWorkerParams.requestWakeupRX().requestWakeupTX().requestWakeupEdge() + } + + logInfo(s"Allocating ${numWorkers} client workers") + val appId = clientConf.sparkConf.getAppId + allocatedWorker = new Array[ExternalUcxClientWorker](numWorkers) + for (i <- 0 until numWorkers) { + ucpWorkerParams.setClientId((executorId << 32) | i.toLong) + val worker = ucxContext.newWorker(ucpWorkerParams) + val workerId = new UcxWorkerId(appId, executorId.toInt, i) + allocatedWorker(i) = new ExternalUcxClientWorker(worker, this, workerId) + } + + logInfo(s"Launching ${numWorkers} client workers") + allocatedWorker.foreach(_.start) + + val maxAmHeaderSize = allocatedWorker(0).worker.getMaxAmHeaderSize.toInt + maxBlocksPerRequest = clientConf.maxBlocksPerRequest.min( + (maxAmHeaderSize - UnsafeUtils.INT_SIZE) / UnsafeUtils.INT_SIZE) + + logInfo(s"Launching time-scheduled threads, period: $timeoutMs ms") + initTaskPool(1) + submit(() => progressTimeOut()) + + logInfo(s"Connecting server.") + + val shuffleServer = new InetSocketAddress(blockManagerId.host, ucxServerPort) + allocatedWorker(0).requestAddress(shuffleServer) + localServerPortsDone.await() + + logInfo(s"Connected server $shuffleServer") + + initialized = true + localServerPortsBuffer + } + + override def close(): Unit = { + if (initialized) { + running = false + + scheduledLatch.countDown() + + if (allocatedWorker != null) { + allocatedWorker.map(_.closing).foreach(_.get(5, TimeUnit.MILLISECONDS)) + } + + super.close() + } + } + + def getMaxBlocksPerRequest: Int = maxBlocksPerRequest + + // @inline + // def selectWorker(): ExternalUcxClientWorker = { + // allocatedWorker( + // (currentWorkerId.incrementAndGet() % allocatedWorker.length).abs) + // } + + @inline + def selectWorker(): ExternalUcxClientWorker = { + Option(workerLocal.get) match { + case Some(worker) => worker + case None => { + val worker = allocatedWorker( + (currentWorkerId.incrementAndGet() % allocatedWorker.length).abs) + workerLocal.set(worker) + worker + } + } + } + + @`inline` + def getServer(host: String): InetSocketAddress = { + ucxServers.computeIfAbsent(host, _ => { + logInfo(s"connecting $host with controller port") + new InetSocketAddress(host, ucxServerPort) + }) + } + + def connect(host: String, ports: Seq[Int]): Unit = { + val server = ucxServers.computeIfAbsent(host, _ => { + val id = executorId.toInt.abs % ports.length + new InetSocketAddress(host, ports(id)) + }) + allocatedWorker.foreach(_.connect(server)) + logDebug(s"connect $host $server") + } + + def connectAll(shuffleServerMap: Map[String, Seq[Int]]): Unit = { + val addressSet = shuffleServerMap.map(hostPorts => { + ucxServers.computeIfAbsent(hostPorts._1, _ => { + val id = executorId.toInt.abs % hostPorts._2.length + new InetSocketAddress(hostPorts._1, hostPorts._2(id)) + }) + }).toSeq + allocatedWorker.foreach(_.connectAll(addressSet)) + logDebug(s"connectAll $addressSet") + } + + def handleReplyAddress(msg: ByteBuffer): Unit = { + val num = msg.remaining() / UnsafeUtils.INT_SIZE + localServerPorts = (0 until num).map(_ => msg.getInt()) + localServerPortsBuffer = msg + + localServerPortsDone.countDown() + + connect(blockManagerId.host, localServerPorts) + } + + /** + * Batch version of [[ fetchBlocksByBlockIds ]]. + */ + def fetchBlocksByBlockIds(host: String, exeId: Int, + blockIds: Seq[BlockId], + callbacks: Seq[OperationCallback]): Unit = { + selectWorker.fetchBlocksByBlockIds(host, exeId, blockIds, callbacks) + } + + def fetchBlockByStream(host: String, exeId: Int, blockId: BlockId, + callback: OperationCallback): Unit = { + selectWorker.fetchBlockByStream(host, exeId, blockId, callback) + } + + def progressTimeOut(): Unit = { + while (!scheduledLatch.await(timeoutMs, TimeUnit.MILLISECONDS)) { + allocatedWorker.foreach(_.progressTimeOut) + } + } +} + +private[shuffle] class UcxFetchCallBack( + blockId: String, listener: BlockFetchingListener) + extends OperationCallback { + + override def onComplete(result: OperationResult): Unit = { + val memBlock = result.getData + val buffer = UnsafeUtils.getByteBufferView(memBlock.address, + memBlock.size.toInt) + listener.onBlockFetchSuccess(blockId, new NioManagedBuffer(buffer) { + override def release: ManagedBuffer = { + memBlock.close() + this + } + }) + } + + override def onError(result: OperationResult): Unit = { + listener.onBlockFetchFailure(blockId, result.getError) + } +} + +private[shuffle] class UcxDownloadCallBack( + blockId: String, listener: BlockFetchingListener, + downloadFileManager: DownloadFileManager, + transportConf: TransportConf) + extends OperationCallback { + + private[this] val targetFile = downloadFileManager.createTempFile( + transportConf) + private[this] val channel = targetFile.openForWriting(); + + override def onData(buffer: ByteBuffer): Unit = { + while (buffer.hasRemaining()) { + channel.write(buffer); + } + } + + override def onComplete(result: OperationResult): Unit = { + listener.onBlockFetchSuccess(blockId, channel.closeAndRead()); + if (!downloadFileManager.registerTempFileToClean(targetFile)) { + targetFile.delete(); + } + } + + override def onError(result: OperationResult): Unit = { + listener.onBlockFetchFailure(blockId, result.getError) + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientWorker.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientWorker.scala new file mode 100644 index 00000000..37f80fd7 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/client/ExternalUcxClientWorker.scala @@ -0,0 +1,543 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import java.io.Closeable +import java.nio.channels.ReadableByteChannel +import java.util.concurrent.{ConcurrentLinkedQueue, ConcurrentHashMap, Future, FutureTask} +import java.util.concurrent.atomic.AtomicInteger +import scala.collection.concurrent.TrieMap +import scala.collection.JavaConverters._ +import scala.util.Random +import org.openucx.jucx.ucp._ +import org.openucx.jucx.ucs.UcsConstants +import org.openucx.jucx.ucs.UcsConstants.MEMORY_TYPE +import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils} +import org.apache.spark.shuffle.ucx.memory.UcxSharedMemoryBlock +import org.apache.spark.shuffle.ucx.utils.SerializationUtils +import org.apache.spark.shuffle.utils.{UnsafeUtils, UcxLogging} +import org.apache.spark.unsafe.Platform + +import java.nio.ByteBuffer +import java.net.InetSocketAddress + +/** + * Worker per thread wrapper, that maintains connection and progress logic. + */ +case class ExternalUcxClientWorker(val worker: UcpWorker, + transport: ExternalUcxClientTransport, + workerId: UcxWorkerId) + extends Closeable with UcxLogging { + private[this] val tag = new AtomicInteger(Random.nextInt()) + private[this] val memPool = transport.hostBounceBufferMemoryPool(workerId.workerId) + private[this] val connectQueue = new ConcurrentLinkedQueue[InetSocketAddress] + private[this] val connectingServers = new ConcurrentHashMap[InetSocketAddress, (UcpEndpoint, UcpRequest)] + private[this] val shuffleServers = new ConcurrentHashMap[String, UcpEndpoint] + private[this] val executor = new UcxWorkerThread( + worker, transport.ucxShuffleConf.useWakeup) + private[this] lazy val requestData = new TrieMap[Int, UcxFetchState] + private[this] lazy val streamData = new TrieMap[Int, UcxStreamState] + private[this] lazy val sliceData = new TrieMap[Int, UcxSliceState] + private[this] var prevTag = 0 + + // Receive block data handler + worker.setAmRecvHandler(ExternalAmId.REPLY_SLICE, + (headerAddress: Long, headerSize: Long, ucpAmData: UcpAmData, + ep: UcpEndpoint) => { + val headerBuffer = UnsafeUtils.getByteBufferView(headerAddress, + headerSize.toInt) + val i = headerBuffer.getInt + val remaining = headerBuffer.getInt + val length = headerBuffer.getLong + val offset = headerBuffer.getLong + + handleReplySlice(i, offset, length, ucpAmData) + UcsConstants.STATUS.UCS_OK + }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG) + + worker.setAmRecvHandler(ExternalAmId.REPLY_STREAM, + (headerAddress: Long, headerSize: Long, ucpAmData: UcpAmData, + _: UcpEndpoint) => { + val headerBuffer = UnsafeUtils.getByteBufferView(headerAddress, + headerSize.toInt) + val i = headerBuffer.getInt + val remaining = headerBuffer.getInt + val length = headerBuffer.getLong + val offset = headerBuffer.getLong + + handleReplyStream(i, offset, length, ucpAmData) + UcsConstants.STATUS.UCS_OK + }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG) + + worker.setAmRecvHandler(ExternalAmId.REPLY_BLOCK, + (headerAddress: Long, headerSize: Long, ucpAmData: UcpAmData, _: UcpEndpoint) => { + val headerBuffer = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) + val i = headerBuffer.getInt + // Header contains tag followed by sizes of blocks + val numBlocks = headerBuffer.remaining() / UnsafeUtils.INT_SIZE + val blockSizes = (0 until numBlocks).map(_ => headerBuffer.getInt()) + + handleReplyBlock(i, blockSizes, ucpAmData) + if (ucpAmData.isDataValid) { + UcsConstants.STATUS.UCS_INPROGRESS + } else { + UcsConstants.STATUS.UCS_OK + } + }, UcpConstants.UCP_AM_FLAG_PERSISTENT_DATA | UcpConstants.UCP_AM_FLAG_WHOLE_MSG) + + worker.setAmRecvHandler(ExternalAmId.REPLY_ADDRESS, + (headerAddress: Long, headerSize: Long, _: UcpAmData, _: UcpEndpoint) => { + val headerBuffer = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) + val copiedBuffer = ByteBuffer.allocateDirect(headerBuffer.remaining()) + + copiedBuffer.put(headerBuffer) + copiedBuffer.rewind() + + transport.handleReplyAddress(copiedBuffer) + UcsConstants.STATUS.UCS_OK + }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG) + + private def handleReplySlice(i: Int, offset: Long, length: Long, + ucpAmData: UcpAmData): Unit = { + assert(!ucpAmData.isDataValid) + + val sliceState = sliceData.getOrElseUpdate(i, { + requestData.get(i) match { + case Some(data) => { + val mem = memPool.get(length) + val memRef = new UcxRefCountMemoryBlock(mem, 0, length, + new AtomicInteger(1)) + new UcxSliceState(data.callbacks(0), data.request, data.timestamp, + memRef, length) + } + case None => throw new UcxException(s"Slice tag $i context not found.") + } + }) + + val stats = sliceState.request.getStats.get.asInstanceOf[UcxStats] + stats.receiveSize += ucpAmData.getLength + stats.amHandleTime = System.nanoTime() + + val currentAddress = sliceState.mem.address + offset + worker.recvAmDataNonBlocking( + ucpAmData.getDataHandle, currentAddress, ucpAmData.getLength, + new UcxCallback() { + override def onSuccess(r: UcpRequest): Unit = { + stats.endTime = System.nanoTime() + logDebug(s"Slice receive rndv data size ${ucpAmData.getLength} " + + s"tag $i ($offset, $length) in ${stats.getElapsedTimeNs} ns " + + s"amHandle ${stats.endTime - stats.amHandleTime} ns") + receivedSlice(i, offset, length, ucpAmData.getLength, sliceState) + } + }, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + } + + private def handleReplyStream(i: Int, offset: Long, length: Long, + ucpAmData: UcpAmData): Unit = { + assert(!ucpAmData.isDataValid) + + val mem = memPool.get(ucpAmData.getLength) + val memRef = new UcxRefCountMemoryBlock(mem, 0, ucpAmData.getLength, + new AtomicInteger(1)) + + val data = streamData.get(i) + if (data.isEmpty) { + throw new UcxException(s"Stream tag $i context not found.") + } + + val streamState = data.get + if (streamState.remaining == Long.MaxValue) { + streamState.remaining = length + } + + val stats = streamState.request.getStats.get.asInstanceOf[UcxStats] + stats.receiveSize += ucpAmData.getLength + stats.amHandleTime = System.nanoTime() + worker.recvAmDataNonBlocking( + ucpAmData.getDataHandle, memRef.address, ucpAmData.getLength, + new UcxCallback() { + override def onSuccess(r: UcpRequest): Unit = { + stats.endTime = System.nanoTime() + logDebug(s"Stream receive rndv data ${ucpAmData.getLength} " + + s"tag $i ($offset, $length) in ${stats.getElapsedTimeNs} ns " + + s"amHandle ${stats.endTime - stats.amHandleTime} ns") + receivedStream(i, offset, length, memRef, streamState) + } + }, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + } + + private def handleReplyBlock( + i: Int, blockSizes: Seq[Int], ucpAmData: UcpAmData): Unit = { + val data = requestData.get(i) + + if (data.isEmpty) { + throw new UcxException(s"No data for tag $i.") + } + + val fetchState = data.get + val callbacks = fetchState.callbacks + val request = fetchState.request + val stats = request.getStats.get.asInstanceOf[UcxStats] + stats.receiveSize = ucpAmData.getLength + + val numBlocks = blockSizes.length + + var offset = 0 + val refCounts = new AtomicInteger(numBlocks) + if (ucpAmData.isDataValid) { + request.completed = true + stats.endTime = System.nanoTime() + logDebug(s"Received amData: $ucpAmData for tag $i " + + s"in ${stats.getElapsedTimeNs} ns") + + val closeCb = () => executor.post(() => ucpAmData.close()) + val address = ucpAmData.getDataAddress + for (b <- 0 until numBlocks) { + val blockSize = blockSizes(b) + if (callbacks(b) != null) { + val mem = new UcxSharedMemoryBlock(closeCb, refCounts, address + offset, + blockSize) + receivedChunk(i, b, mem, fetchState) + offset += blockSize + } + } + } else { + val mem = memPool.get(ucpAmData.getLength) + stats.amHandleTime = System.nanoTime() + request.setRequest(worker.recvAmDataNonBlocking(ucpAmData.getDataHandle, mem.address, ucpAmData.getLength, + new UcxCallback() { + override def onSuccess(r: UcpRequest): Unit = { + request.completed = true + stats.endTime = System.nanoTime() + logDebug(s"Received rndv data of size: ${ucpAmData.getLength} " + + s"for tag $i in ${stats.getElapsedTimeNs} ns " + + s"time from amHandle: ${System.nanoTime() - stats.amHandleTime} ns") + for (b <- 0 until numBlocks) { + val blockSize = blockSizes(b) + val memRef = new UcxRefCountMemoryBlock(mem, offset, blockSize, refCounts) + receivedChunk(i, b, memRef, fetchState) + offset += blockSize + } + + } + }, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)) + } + } + + private def receivedSlice(tag: Int, offset: Long, length: Long, msgLen: Long, + sliceState: UcxSliceState): Unit = { + if (!sliceState.recvSet.add(offset)) { + logWarning(s"Received duplicate slice: tag $tag offset $offset") + return + } + + logTrace(s"tag $tag $sliceState") + sliceState.remaining -= msgLen + if (sliceState.remaining != 0) { + return + } + + val stats = sliceState.request.getStats.get + val result = new UcxSucceedOperationResult(sliceState.mem, stats) + sliceState.callback.onComplete(result) + sliceData.remove(tag) + + val data = requestData.get(tag) + if (data.isEmpty) { + logWarning(s"No fetch found: tag $tag") + return + } + // block must be the last chunk + val fetchState = data.get + val chunkId = fetchState.callbacks.size - 1 + receivedChunk(tag, chunkId, null, fetchState) + } + + private def receivedStream(tag: Int, offset: Long, length: Long, mem: MemoryBlock, + streamState: UcxStreamState): Unit = { + if (streamState.recvMap.put(offset, mem) != null) { + logWarning(s"Received duplicate stream: tag $tag offset $offset") + return + } + + logTrace(s"tag $tag $streamState") + var received = length - streamState.remaining + var memNow = streamState.recvMap.get(received) + while (memNow != null) { + val buf = UnsafeUtils.getByteBufferView(memNow.address, memNow.size.toInt) + streamState.callback.onData(buf) + received += memNow.size + memNow.close() + memNow = streamState.recvMap.get(received) + } + + streamState.remaining = length - received + if (streamState.remaining != 0) { + return + } + + val stats = streamState.request.getStats.get + val result = new UcxSucceedOperationResult(null, stats) + streamState.callback.onComplete(result) + streamData.remove(tag) + } + + private def receivedChunk(tag: Int, chunkId: Int, mem: MemoryBlock, + fetchState: UcxFetchState): Unit = { + if (!fetchState.recvSet.add(chunkId)) { + logWarning(s"Received duplicate chunk: tag $tag chunk $chunkId") + return + } + + if (mem != null) { + val stats = fetchState.request.getStats.get + val result = new UcxSucceedOperationResult(mem, stats) + fetchState.callbacks(chunkId).onComplete(result) + } + + logTrace(s"tag $tag $fetchState") + if (fetchState.recvSet.size != fetchState.callbacks.size) { + return + } + + requestData.remove(tag) + } + + def start(): Unit = { + executor.start() + } + + override def close(): Unit = { + val closeConnecting = connectingServers.values.asScala.filterNot { + case (_, req) => req.isCompleted + }.map { + case (endpoint, _) => endpoint.closeNonBlockingForce() + } + val closeRequests = shuffleServers.asScala.map { + case (_, endpoint) => endpoint.closeNonBlockingForce() + } + while (!closeConnecting.forall(_.isCompleted)) { + progress() + } + while (!closeRequests.forall(_.isCompleted)) { + progress() + } + } + + def closing(): Future[Unit.type] = { + val cleanTask = new FutureTask(new Runnable { + override def run() = close() + }, Unit) + executor.close(cleanTask) + cleanTask + } + + /** + * The only place for worker progress + */ + def progress(): Int = worker.synchronized { + worker.progress() + } + + @`inline` + def requestAddress(localServer: InetSocketAddress): Unit = { + executor.post(() => shuffleServers.computeIfAbsent("0.0.0.0", _ => { + doConnect(localServer, ExternalAmId.ADDRESS)._1 + })) + } + + @`inline` + def connect(shuffleServer: InetSocketAddress): Unit = { + connectQueue.add(shuffleServer) + } + + @`inline` + def connectAll(addressSet: Seq[InetSocketAddress]): Unit = { + addressSet.foreach(connectQueue.add(_)) + } + + @`inline` + private def connectNext(): Unit = { + if (!connectQueue.isEmpty) { + executor.post(() => { + val shuffleServer = connectQueue.poll() + try { + startConnection(shuffleServer) + } catch { + // We cannot throw error here as current server might be down/rebooting. + case e: Throwable => + logWarning(s"Endpoint to $shuffleServer got an error: $e") + } + }) + } + } + + private def doConnect(shuffleServer: InetSocketAddress, + amId: Int): (UcpEndpoint, UcpRequest) = { + val endpointParams = new UcpEndpointParams().setPeerErrorHandlingMode() + .setSocketAddress(shuffleServer).sendClientId() + .setErrorHandler(new UcpEndpointErrorHandler() { + override def onError(ep: UcpEndpoint, status: Int, errorMsg: String): Unit = { + logError(s"Endpoint to $shuffleServer got an error: $errorMsg") + shuffleServers.remove(shuffleServer.getHostName()) + } + }).setName(s"Client to $shuffleServer") + + logDebug(s"$workerId connecting to external service $shuffleServer") + + val header = Platform.allocateDirectBuffer(workerId.serializedSize) + workerId.serialize(header) + header.rewind() + val workerAddress = worker.getAddress + + val ep = worker.newEndpoint(endpointParams) + val req = ep.sendAmNonBlocking( + amId, UcxUtils.getAddress(header), header.remaining(), + UcxUtils.getAddress(workerAddress), workerAddress.remaining(), + UcpConstants.UCP_AM_SEND_FLAG_EAGER | UcpConstants.UCP_AM_SEND_FLAG_REPLY, + new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + connectNext() + } + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"$workerId Sent connect to $shuffleServer failed: $errorMsg"); + connectNext() + header.clear() + workerAddress.clear() + } + }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + ep -> req + } + + private def startConnection(shuffleServer: InetSocketAddress): (UcpEndpoint, UcpRequest) = { + connectingServers.computeIfAbsent(shuffleServer, _ => + doConnect(shuffleServer, ExternalAmId.CONNECT)) + } + + private def getConnection(host: String): UcpEndpoint = { + val shuffleServer = transport.getServer(host) + shuffleServers.computeIfAbsent(shuffleServer.getAddress().getHostAddress(), _ => { + val (ep, req) = startConnection(shuffleServer) + if (!req.isCompleted) { + val deadline = System.currentTimeMillis() + transport.timeoutMs + do { + worker.progress() + if (System.currentTimeMillis() > deadline) { + throw new UcxException(s"connect $shuffleServer timeout") + } + } while (!req.isCompleted) + } + ep + }) + } + + def fetchBlocksByBlockIds(host: String, execId: Int, blockIds: Seq[BlockId], + callbacks: Seq[OperationCallback]): Unit = { + val startTime = System.nanoTime() + val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE + workerId.serializedSize + + val t = tag.incrementAndGet() + + val buffer = Platform.allocateDirectBuffer(headerSize + blockIds.map(_.serializedSize).sum) + workerId.serialize(buffer) + buffer.putInt(t) + buffer.putInt(execId) + blockIds.foreach(b => b.serialize(buffer)) + + val request = new UcxRequest(null, new UcxStats()) + requestData.put(t, new UcxFetchState(callbacks, request, startTime)) + + buffer.rewind() + val address = UnsafeUtils.getAdress(buffer) + val dataAddress = address + headerSize + + executor.post(() => { + val ep = getConnection(host) + ep.sendAmNonBlocking(ExternalAmId.FETCH_BLOCK, address, + headerSize, dataAddress, buffer.capacity() - headerSize, + UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + buffer.clear() + logDebug(s"Sent fetch to $host tag $t blocks ${blockIds.length} " + + s"in ${System.nanoTime() - startTime} ns") + } + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + val err = s"Sent fetch to $host tag $t failed: $errorMsg"; + logError(err) + throw new UcxException(err) + } + }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + }) + } + + def fetchBlockByStream(host: String, execId: Int, blockId: BlockId, + callback: OperationCallback): Unit = { + val startTime = System.nanoTime() + val headerSize = workerId.serializedSize + UnsafeUtils.INT_SIZE + + UnsafeUtils.INT_SIZE + blockId.serializedSize + + val t = tag.incrementAndGet() + + val buffer = Platform.allocateDirectBuffer(headerSize) + workerId.serialize(buffer) + buffer.putInt(t) + buffer.putInt(execId) + blockId.serialize(buffer) + + val request = new UcxRequest(null, new UcxStats()) + streamData.put(t, new UcxStreamState(callback, request, startTime, + Long.MaxValue)) + + val address = UnsafeUtils.getAdress(buffer) + + executor.post(() => { + val ep = getConnection(host) + ep.sendAmNonBlocking(ExternalAmId.FETCH_STREAM, address, headerSize, + address, 0, UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + buffer.clear() + logDebug(s"Sent stream to $host tag $t block $blockId " + + s"in ${System.nanoTime() - startTime} ns") + } + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + val err = s"Sent stream to $host tag $t failed: $errorMsg"; + logError(err) + throw new UcxException(err) + } + }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + }) + } + + def progressTimeOut(): Unit = { + val currTag = tag.get() + if (prevTag != currTag) { + prevTag = currTag + return + } + + val validTime = System.nanoTime - transport.timeoutMs * 1000000L + if (requestData.nonEmpty) { + requestData.filterNot { + case (_, request) => request.timestamp >= validTime + }.keys.foreach(requestData.remove(_).foreach(request => { + request.callbacks.foreach(_.onError(new UcxFailureOperationResult("timeout"))) + })) + } + if (streamData.nonEmpty) { + streamData.filterNot { + case (_, request) => request.timestamp >= validTime + }.keys.foreach(streamData.remove(_).foreach(request => { + request.callback.onError(new UcxFailureOperationResult("timeout")) + })) + } + if (sliceData.nonEmpty) { + sliceData.filterNot { + case (_, request) => request.timestamp >= validTime + }.keys.foreach(sliceData.remove(_).foreach(request => { + request.callback.onError(new UcxFailureOperationResult("timeout")) + })) + } + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerConf.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerConf.scala new file mode 100644 index 00000000..398236d5 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerConf.scala @@ -0,0 +1,87 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import org.apache.hadoop.conf.Configuration +import org.apache.spark.internal.config.ConfigBuilder +import org.apache.spark.network.util.ByteUnit +import org.apache.spark.SparkConf + +/** + * Plugin configuration properties. + */ +class ExternalUcxServerConf(val yarnConf: Configuration) extends ExternalUcxConf { + override lazy val preallocateBuffersMap: Map[Long, Int] = + ExternalUcxConf.preAllocateConfToMap( + yarnConf.get(ExternalUcxConf.PREALLOCATE_BUFFERS_KEY, + ExternalUcxConf.PREALLOCATE_BUFFERS_DEFAULT)) + + override lazy val memoryLimit: Boolean = yarnConf.getBoolean( + ExternalUcxConf.MEMORY_LIMIT_KEY, + ExternalUcxConf.MEMORY_LIMIT_DEFAULT) + + override lazy val memoryGroupSize: Int = yarnConf.getInt( + ExternalUcxConf.MEMORY_GROUP_SIZE_KEY, + ExternalUcxConf.MEMORY_GROUP_SIZE_DEFAULT) + + override lazy val minBufferSize: Long = yarnConf.getLong( + ExternalUcxConf.MIN_BUFFER_SIZE_KEY, + ExternalUcxConf.MIN_BUFFER_SIZE_DEFAULT) + + override lazy val maxBufferSize: Long = yarnConf.getLong( + ExternalUcxConf.MAX_BUFFER_SIZE_KEY, + ExternalUcxConf.MAX_BUFFER_SIZE_DEFAULT) + + override lazy val minRegistrationSize: Long = yarnConf.getLong( + ExternalUcxConf.MIN_REGISTRATION_SIZE_KEY, + ExternalUcxConf.MIN_REGISTRATION_SIZE_DEFAULT) + + override lazy val maxRegistrationSize: Long = yarnConf.getLong( + ExternalUcxConf.MAX_REGISTRATION_SIZE_KEY, + ExternalUcxConf.MAX_REGISTRATION_SIZE_DEFAULT) + + override lazy val numPools: Int = yarnConf.getInt( + ExternalUcxConf.NUM_POOLS_KEY, + ExternalUcxConf.NUM_POOLS_DEFAULT) + + override lazy val listenerAddress: String = yarnConf.get( + ExternalUcxConf.SOCKADDR_KEY, + ExternalUcxConf.SOCKADDR_DEFAULT) + + override lazy val useWakeup: Boolean = yarnConf.getBoolean( + ExternalUcxConf.WAKEUP_FEATURE_KEY, + ExternalUcxConf.WAKEUP_FEATURE_DEFAULT) + + override lazy val numWorkers: Int = yarnConf.getInt( + ExternalUcxConf.NUM_WORKERS_KEY, + yarnConf.getInt( + ExternalUcxConf.NUM_WORKERS_COMPAT_KEY, + ExternalUcxConf.NUM_WORKERS_DEFAULT)) + + override lazy val numThreads: Int = yarnConf.getInt( + ExternalUcxConf.NUM_THREADS_KEY, + yarnConf.getInt( + ExternalUcxConf.NUM_THREADS_COMPAT_KEY, + ExternalUcxConf.NUM_THREADS_DEFAULT)) + + override lazy val ucxServerPort: Int = yarnConf.getInt( + ExternalUcxConf.SPARK_UCX_SHUFFLE_SERVICE_PORT_KEY, + ExternalUcxConf.SPARK_UCX_SHUFFLE_SERVICE_PORT_DEFAULT) + + override lazy val maxReplySize: Long = yarnConf.getLong( + ExternalUcxConf.MAX_REPLY_SIZE_KEY, + ExternalUcxConf.MAX_REPLY_SIZE_DEFAULT) + + lazy val ucxEpsNum: Int = yarnConf.getInt( + ExternalUcxServerConf.SPARK_UCX_SHUFFLE_EPS_NUM_KEY, + ExternalUcxServerConf.SPARK_UCX_SHUFFLE_EPS_NUM_DEFAULT) +} + +object ExternalUcxServerConf { + lazy val SPARK_UCX_SHUFFLE_SERVICE_TCP_PORT_KEY = ExternalUcxConf.getUcxConf("service.tcp.port") + + lazy val SPARK_UCX_SHUFFLE_EPS_NUM_KEY = ExternalUcxConf.getUcxConf("eps.num") + lazy val SPARK_UCX_SHUFFLE_EPS_NUM_DEFAULT = 16777216 +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerTransport.scala new file mode 100644 index 00000000..ae12cc6f --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerTransport.scala @@ -0,0 +1,206 @@ +package org.apache.spark.shuffle.ucx + +// import org.apache.spark.SparkEnv +import org.apache.spark.shuffle.utils.{UcxLogging, UnsafeUtils} +import org.apache.spark.shuffle.ucx.utils.SerializationUtils +import org.apache.spark.network.buffer.FileSegmentManagedBuffer +import org.apache.spark.network.shuffle.ExternalUcxShuffleBlockResolver +import org.openucx.jucx.ucp._ +import org.openucx.jucx.ucs.UcsConstants +import org.openucx.jucx.ucs.UcsConstants.MEMORY_TYPE +import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils} + +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{ConcurrentHashMap, TimeUnit, CountDownLatch} +import java.nio.channels.FileChannel +import java.nio.file.StandardOpenOption +import scala.collection.concurrent.TrieMap + +class ExternalUcxServerTransport( + serverConf: ExternalUcxServerConf, blockManager: ExternalUcxShuffleBlockResolver) + extends ExternalUcxTransport(serverConf) + with UcxLogging { + private[ucx] val workerMap = new TrieMap[String, TrieMap[UcxWorkerId, Unit]] + private[ucx] val fileMap = new TrieMap[String, ConcurrentHashMap[UcxShuffleMapId, FileChannel]] + + private[ucx] var allocatedWorker: Array[ExternalUcxServerWorker] = _ + private[ucx] var globalWorker: ExternalUcxServerWorker = _ + private[ucx] var serverPorts: Seq[Int] = _ + + private[ucx] val scheduledLatch = new CountDownLatch(1) + + private[ucx] var maxReplySize: Long = 0 + + private var serverPortsBuffer: ByteBuffer = _ + + override def estimateNumEps(): Int = serverConf.ucxEpsNum + + override def init(): ByteBuffer = { + initContext() + initMemoryPool() + + if (serverConf.useWakeup) { + ucpWorkerParams.requestWakeupRX().requestWakeupTX().requestWakeupEdge() + } + + // additional 1 for mem pool report + initTaskPool(serverConf.numThreads + 1) + submit(() => scheduledReport()) + + logInfo(s"Allocating global worker") + val worker = ucxContext.newWorker(ucpWorkerParams) + globalWorker = new ExternalUcxServerWorker( + worker, this, new UcxWorkerId("Listener", 0, 0), serverConf.ucxServerPort) + + val maxAmHeaderSize = worker.getMaxAmHeaderSize + maxReplySize = serverConf.maxReplySize.min(serverConf.maxBufferSize - + maxAmHeaderSize) + + logInfo(s"Allocating ${serverConf.numWorkers} server workers") + + allocatedWorker = new Array[ExternalUcxServerWorker](serverConf.numWorkers) + for (i <- 0 until serverConf.numWorkers) { + val worker = ucxContext.newWorker(ucpWorkerParams) + val workerId = new UcxWorkerId("Server", 0, i) + allocatedWorker(i) = new ExternalUcxServerWorker(worker, this, workerId, 0) + } + serverPorts = allocatedWorker.map(_.getPort) + + serverPortsBuffer = ByteBuffer.allocateDirect( + serverPorts.length * UnsafeUtils.INT_SIZE) + serverPorts.foreach(serverPortsBuffer.putInt(_)) + serverPortsBuffer.rewind() + + logInfo(s"Launching ${serverConf.numWorkers} server workers") + allocatedWorker.foreach(_.start) + + logInfo(s"Launching global worker") + + globalWorker.start + + initialized = true + logInfo(s"Started listener on ${globalWorker.getAddress} ${serverPorts} maxReplySize $maxReplySize") + SerializationUtils.serializeInetAddress(globalWorker.getAddress) + } + + /** + * Close all transport resources + */ + override def close(): Unit = { + if (initialized) { + running = false + + if (globalWorker != null) { + globalWorker.closing().get(1, TimeUnit.MILLISECONDS) + } + + if (allocatedWorker != null) { + allocatedWorker.map(_.closing).foreach(_.get(5, TimeUnit.MILLISECONDS)) + } + + scheduledLatch.countDown() + + super.close() + + logInfo("UCX transport closed.") + } + } + + def applicationRemoved(appId: String): Unit = { + workerMap.remove(appId).foreach(clients => { + val clientIds = clients.keys.toSeq + allocatedWorker.foreach(_.disconnect(clientIds)) + globalWorker.disconnect(clientIds) + }) + fileMap.remove(appId).foreach(files => files.values.forEach(_.close)) + // allocatedWorker.foreach(_.debugClients()) + } + + def executorRemoved(executorId: String, appId: String): Unit = { + val exeId = executorId.toInt + workerMap.get(appId).map(clients => { + val clientIds = clients.filterKeys(_.exeId == exeId).keys.toSeq + allocatedWorker.foreach(_.disconnect(clientIds)) + globalWorker.disconnect(clientIds) + }) + } + + def getMaxReplySize(): Long = maxReplySize + + def getServerPortsBuffer(): ByteBuffer = { + serverPortsBuffer.duplicate() + } + + def handleConnect(handler: ExternalUcxServerWorker, + clientWorker: UcxWorkerId): Unit = { + workerMap.getOrElseUpdate(clientWorker.appId, { + new TrieMap[UcxWorkerId, Unit] + }).getOrElseUpdate(clientWorker, Unit) + } + + def handleFetchBlockRequest(handler: ExternalUcxServerWorker, + clientWorker: UcxWorkerId, exeId: Int, + replyTag: Int, blockIds: Seq[UcxShuffleBlockId]): + Unit = { + submit(new Runnable { + override def run(): Unit = { + var block: FileSegmentManagedBuffer = null + var blockInfos: Seq[(FileChannel, Long, Long)] = null + try { + blockInfos = blockIds.map(bid => { + block = blockManager.getBlockData(clientWorker.appId, exeId.toString, + bid.shuffleId, bid.mapId, + bid.reduceId).asInstanceOf[ + FileSegmentManagedBuffer] + (openBlock(clientWorker.appId, bid, block), block.getOffset, block.size) + }) + handler.handleFetchBlockRequest(clientWorker, replyTag, blockInfos) + } catch { + case ex: Throwable => + logError(s"Failed to reply fetch $clientWorker tag $replyTag files $blockInfos block $block $ex.") + } + } + }) + } + + def handleFetchBlockStream(handler: ExternalUcxServerWorker, + clientWorker: UcxWorkerId, exeId: Int, + replyTag: Int, bid: UcxShuffleBlockId): Unit = { + submit(new Runnable { + override def run(): Unit = { + var block: FileSegmentManagedBuffer = null + var blockInfo: (FileChannel, Long, Long) = null + try { + block = blockManager.getBlockData(clientWorker.appId, exeId.toString, + bid.shuffleId, bid.mapId, + bid.reduceId).asInstanceOf[ + FileSegmentManagedBuffer] + blockInfo = + (openBlock(clientWorker.appId, bid, block), block.getOffset, block.size) + handler.handleFetchBlockStream(clientWorker, replyTag, blockInfo) + } catch { + case ex: Throwable => + logError(s"Failed to reply stream $clientWorker tag $replyTag file $blockInfo block $block $ex.") + } + } + }) + } + + def openBlock(appId: String, bid: UcxShuffleBlockId, + blockData: FileSegmentManagedBuffer): FileChannel = { + fileMap.getOrElseUpdate(appId, { + new ConcurrentHashMap[UcxShuffleMapId, FileChannel] + }).computeIfAbsent( + UcxShuffleMapId(bid.shuffleId, bid.mapId), + _ => FileChannel.open(blockData.getFile().toPath(), StandardOpenOption.READ) + ) + } + + def scheduledReport(): Unit = { + while (!scheduledLatch.await(30, TimeUnit.SECONDS)) { + memPools.foreach(_.report) + } + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerWorker.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerWorker.scala new file mode 100644 index 00000000..0b1b96ab --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/server/ExternalUcxServerWorker.scala @@ -0,0 +1,377 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import java.io.Closeable +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.util.concurrent.{ConcurrentLinkedQueue, ConcurrentHashMap, CountDownLatch, Future, FutureTask} +import scala.collection.mutable +import scala.collection.JavaConverters._ +import org.openucx.jucx.ucp._ +import org.openucx.jucx.ucs.UcsConstants +import org.openucx.jucx.ucs.UcsConstants.MEMORY_TYPE +import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils} +import org.apache.spark.shuffle.ucx.memory.UcxLinkedMemBlock +import org.apache.spark.shuffle.utils.{UnsafeUtils, UcxLogging} +import java.net.InetSocketAddress + +class ExternalUcxEndpoint(val ucpEp: UcpEndpoint, var closed: Boolean) {} +/** + * Worker per thread wrapper, that maintains connection and progress logic. + */ +case class ExternalUcxServerWorker(val worker: UcpWorker, + transport: ExternalUcxServerTransport, + workerId: UcxWorkerId, + port: Int) + extends Closeable with UcxLogging { + private[this] val memPool = transport.hostBounceBufferMemoryPool(workerId.workerId) + private[this] val maxReplySize = transport.getMaxReplySize() + private[this] val shuffleClients = new ConcurrentHashMap[UcxWorkerId, ExternalUcxEndpoint] + private[ucx] val executor = new UcxWorkerThread( + worker, transport.ucxShuffleConf.useWakeup) + + private val emptyCallback = () => {} + private val endpoints = mutable.HashMap.empty[UcpEndpoint, () => Unit] + private val listener = worker.newListener( + new UcpListenerParams().setSockAddr(new InetSocketAddress("0.0.0.0", port)) + .setConnectionHandler((ucpConnectionRequest: UcpConnectionRequest) => { + val clientAddress = ucpConnectionRequest.getClientAddress() + try { + val ep = worker.newEndpoint( + new UcpEndpointParams().setConnectionRequest(ucpConnectionRequest) + .setPeerErrorHandlingMode().setErrorHandler(errorHandler) + .setName(s"Endpoint to $clientAddress")) + endpoints.getOrElseUpdate(ep, emptyCallback) + } catch { + case e: UcxException => logError(s"Accept $clientAddress fail: $e") + } + })) + + private val listenerAddress = listener.getAddress + private val errorHandler = new UcpEndpointErrorHandler { + override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = { + if (errorCode == UcsConstants.STATUS.UCS_ERR_CONNECTION_RESET) { + logInfo(s"Connection closed on ep: $ucpEndpoint") + } else { + logWarning(s"Ep $ucpEndpoint got an error: $errorString") + } + endpoints.remove(ucpEndpoint).foreach(_()) + ucpEndpoint.close() + } + } + + // Main RPC thread. Submit each RPC request to separate thread and send reply back from separate worker. + worker.setAmRecvHandler(ExternalAmId.FETCH_BLOCK, + (headerAddress: Long, headerSize: Long, amData: UcpAmData, _: UcpEndpoint) => { + val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) + val shuffleClient = UcxWorkerId.deserialize(header) + val replyTag = header.getInt + val exeId = header.getInt + val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, + amData.getLength.toInt) + val BlockNum = buffer.remaining() / UcxShuffleBlockId.serializedSize + val blockIds = (0 until BlockNum).map( + _ => UcxShuffleBlockId.deserialize(buffer)) + logTrace(s"${workerId.workerId} Recv fetch from $shuffleClient tag $replyTag.") + transport.handleFetchBlockRequest(this, shuffleClient, exeId, replyTag, blockIds) + UcsConstants.STATUS.UCS_OK + }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) + + // Main RPC thread. Submit each RPC request to separate thread and send stream back from separate worker. + worker.setAmRecvHandler(ExternalAmId.FETCH_STREAM, + (headerAddress: Long, headerSize: Long, amData: UcpAmData, _: UcpEndpoint) => { + val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) + val shuffleClient = UcxWorkerId.deserialize(header) + val replyTag = header.getInt + val exeId = header.getInt + val blockId = UcxShuffleBlockId.deserialize(header) + logTrace(s"${workerId.workerId} Recv stream from $shuffleClient tag $replyTag.") + transport.handleFetchBlockStream(this, shuffleClient, exeId, replyTag, blockId) + UcsConstants.STATUS.UCS_OK + }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) + + // AM to get worker address for client worker and connect server workers to it + worker.setAmRecvHandler(ExternalAmId.CONNECT, + (headerAddress: Long, headerSize: Long, amData: UcpAmData, ep: UcpEndpoint) => { + val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) + val shuffleClient = UcxWorkerId.deserialize(header) + val workerAddress = UnsafeUtils.getByteBufferView(amData.getDataAddress, + amData.getLength.toInt) + val copiedAddress = ByteBuffer.allocateDirect(workerAddress.remaining) + copiedAddress.put(workerAddress) + connected(shuffleClient, copiedAddress) + endpoints.put(ep, () => doDisconnect(shuffleClient)) + UcsConstants.STATUS.UCS_OK + }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) + // Main RPC thread. reply with ucpAddress. + + // AM to get worker address for client worker and connect server workers to it + worker.setAmRecvHandler(ExternalAmId.ADDRESS, + (headerAddress: Long, headerSize: Long, amData: UcpAmData, ep: UcpEndpoint) => { + handleAddress(ep) + UcsConstants.STATUS.UCS_OK + }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) + // Main RPC thread. reply with ucpAddress. + + def start(): Unit = { + executor.start() + } + + override def close(): Unit = { + if (!shuffleClients.isEmpty) { + logInfo(s"$workerId closing ${shuffleClients.size} clients") + shuffleClients.values.asScala.map( + _.ucpEp.closeNonBlockingFlush()).foreach(req => + while (!req.isCompleted){ + worker.progress() + } + ) + } + if (!endpoints.isEmpty) { + logInfo(s"$workerId closing ${endpoints.size} eps") + endpoints.keys.map( + _.closeNonBlockingForce()).foreach(req => + while (!req.isCompleted){ + worker.progress() + } + ) + } + listener.close() + } + + def closing(): Future[Unit.type] = { + val cleanTask = new FutureTask(new Runnable { + override def run() = close() + }, Unit) + executor.close(cleanTask) + cleanTask + } + + def getPort(): Int = { + listenerAddress.getPort() + } + + def getAddress(): InetSocketAddress = { + listenerAddress + } + + @`inline` + private def doDisconnect(shuffleClient: UcxWorkerId): Unit = { + try { + Option(shuffleClients.remove(shuffleClient)).foreach(ep => { + ep.ucpEp.closeNonBlockingFlush() + ep.closed = true + logDebug(s"Disconnect $shuffleClient") + }) + } catch { + case e: Throwable => logWarning(s"Disconnect $shuffleClient: $e") + } + } + + @`inline` + def isEpClosed(ep: UcpEndpoint): Boolean = { + ep.getNativeId() == null + } + + @`inline` + def disconnect(workerIds: Seq[UcxWorkerId]): Unit = { + executor.post(() => workerIds.foreach(doDisconnect(_))) + } + + @`inline` + def connected(shuffleClient: UcxWorkerId, workerAddress: ByteBuffer): Unit = { + logDebug(s"$workerId connecting back to $shuffleClient by worker address") + try { + shuffleClients.computeIfAbsent(shuffleClient, _ => { + val ucpEp = worker.newEndpoint(new UcpEndpointParams() + .setErrorHandler(new UcpEndpointErrorHandler { + override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = { + logInfo(s"Connection to $shuffleClient closed: $errorString") + shuffleClients.remove(shuffleClient) + ucpEndpoint.close() + } + }) + .setName(s"Server to $shuffleClient") + .setUcpAddress(workerAddress)) + new ExternalUcxEndpoint(ucpEp, false) + }) + } catch { + case e: UcxException => logWarning(s"Connection to $shuffleClient failed: $e") + } + } + + def awaitConnection(shuffleClient: UcxWorkerId): ExternalUcxEndpoint = { + shuffleClients.getOrDefault(shuffleClient, { + // wait until connected finished + val startTime = System.currentTimeMillis() + while (!shuffleClients.containsKey(shuffleClient)) { + if (System.currentTimeMillis() - startTime > 10000) { + throw new UcxException(s"Don't get a worker address for $shuffleClient") + } + Thread.`yield` + } + shuffleClients.get(shuffleClient) + }) + } + + def handleAddress(ep: UcpEndpoint) = { + val msg = transport.getServerPortsBuffer() + val header = UnsafeUtils.getAdress(msg) + ep.sendAmNonBlocking( + ExternalAmId.REPLY_ADDRESS, header, msg.remaining(), header, 0, + UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback { + override def onSuccess(request: UcpRequest): Unit = { + logTrace(s"$workerId sent to REPLY_ADDRESS to $ep") + msg.clear() + } + + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logWarning(s"$workerId sent to REPLY_ADDRESS to $ep: $errorMsg") + msg.clear() + } + }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + logTrace(s"$workerId sending to REPLY_ADDRESS $msg mem $header") + } + + def handleFetchBlockRequest( + clientWorker: UcxWorkerId, replyTag: Int, + blockInfos: Seq[(FileChannel, Long, Long)]): Unit = { + val blockSize = blockInfos.map(x => x._3).sum + if (blockSize <= maxReplySize) { + handleFetchBlockChunks(clientWorker, replyTag, blockInfos, blockSize) + return + } + // The size of last block could > maxBytesInFlight / 5 in spark. + val lastBlock = blockInfos.last + if (blockInfos.size > 1) { + val chunks = blockInfos.slice(0, blockInfos.size - 1) + val chunksSize = blockSize - lastBlock._3 + handleFetchBlockChunks(clientWorker, replyTag, chunks, chunksSize) + } + handleFetchBlockStream(clientWorker, replyTag, lastBlock, + ExternalAmId.REPLY_SLICE) + } + + def handleFetchBlockChunks( + clientWorker: UcxWorkerId, replyTag: Int, + blockInfos: Seq[(FileChannel, Long, Long)], blockSize: Long): Unit = { + val tagAndSizes = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE * blockInfos.length + val msgSize = tagAndSizes + blockSize + val resultMemory = memPool.get(msgSize).asInstanceOf[UcxLinkedMemBlock] + assert(resultMemory != null) + val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, msgSize) + + resultBuffer.putInt(replyTag) + val blocksRange = 0 until blockInfos.length + for (i <- blocksRange) { + resultBuffer.putInt(blockInfos(i)._3.toInt) + } + + for (i <- blocksRange) { + val (blockCh, blockOffset, blockSize) = blockInfos(i) + resultBuffer.limit(resultBuffer.position() + blockSize.toInt) + blockCh.read(resultBuffer, blockOffset) + } + + val ep = awaitConnection(clientWorker) + executor.post(new Runnable { + override def run(): Unit = { + if (ep.closed) { + resultMemory.close() + return + } + + val startTime = System.nanoTime() + val req = ep.ucpEp.sendAmNonBlocking(ExternalAmId.REPLY_BLOCK, + resultMemory.address, tagAndSizes, resultMemory.address + tagAndSizes, + msgSize - tagAndSizes, 0, new UcxCallback { + override def onSuccess(request: UcpRequest): Unit = { + resultMemory.close() + logTrace(s"${workerId.workerId} Sent to ${clientWorker} ${blockInfos.length} blocks of size: " + + s"${msgSize} tag $replyTag in ${System.nanoTime() - startTime} ns.") + } + + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + resultMemory.close() + logError(s"${workerId.workerId} Failed to reply fetch $clientWorker tag $replyTag $errorMsg.") + } + }, new UcpRequestParams().setMemoryType( + UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + .setMemoryHandle(resultMemory.memory)) + } + }) + logTrace(s"${workerId.workerId} Sending to $clientWorker tag $replyTag mem $resultMemory size $msgSize") + } + + def handleFetchBlockStream(clientWorker: UcxWorkerId, replyTag: Int, + blockInfo: (FileChannel, Long, Long), + amId: Int = ExternalAmId.REPLY_STREAM): Unit = { + // tag: Int + unsent replies: Int + total length: Long + offset now: Long + val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE + + UnsafeUtils.LONG_SIZE + UnsafeUtils.LONG_SIZE + val (blockCh, blockOffset, blockSize) = blockInfo + val blockSlice = (0L until blockSize by maxReplySize).toArray + // make sure the last one is not too small + if (blockSlice.size >= 2) { + val mid = (blockSlice(blockSlice.size - 2) + blockSize) / 2 + blockSlice(blockSlice.size - 1) = mid + } + + def send(workerWrapper: ExternalUcxServerWorker, currentId: Int): Unit = try { + val hashNext = (currentId + 1 != blockSlice.size) + val nextOffset = if (hashNext) blockSlice(currentId + 1) else blockSize + val currentOffset = blockSlice(currentId) + val currentSize = nextOffset - currentOffset + val unsent = blockSlice.length - currentId - 1 + val msgSize = headerSize + currentSize.toInt + val mem = memPool.get(msgSize).asInstanceOf[UcxLinkedMemBlock] + val buffer = mem.toByteBuffer() + + buffer.limit(msgSize) + buffer.putInt(replyTag) + buffer.putInt(unsent) + buffer.putLong(blockSize) + buffer.putLong(currentOffset) + blockCh.read(buffer, blockOffset + currentOffset) + + val ep = workerWrapper.awaitConnection(clientWorker) + workerWrapper.executor.post(new Runnable { + override def run(): Unit = { + if (ep.closed) { + mem.close() + return + } + + val startTime = System.nanoTime() + val req = ep.ucpEp.sendAmNonBlocking(amId, mem.address, headerSize, + mem.address + headerSize, currentSize, + UcpConstants.UCP_AM_SEND_FLAG_RNDV, new UcxCallback { + override def onSuccess(request: UcpRequest): Unit = { + mem.close() + logTrace(s"${workerId.workerId} Sent to ${clientWorker} size $currentSize tag $replyTag seg " + + s"$currentId in ${System.nanoTime() - startTime} ns.") + } + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + mem.close() + logError(s"${workerId.workerId} Failed to reply stream $clientWorker tag $replyTag $currentId $errorMsg.") + } + }, new UcpRequestParams() + .setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + .setMemoryHandle(mem.memory)) + } + }) + logTrace(s"${workerId.workerId} Sending to $clientWorker tag $replyTag $currentId mem $mem size $msgSize.") + if (hashNext) { + transport.submit(() => send(this, currentId + 1)) + } + } catch { + case ex: Throwable => + logError(s"${workerId.workerId} Failed to reply stream $clientWorker tag $replyTag $currentId $ex.") + } + + send(this, 0) + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/spark_2_4/CompatUtils.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/spark_2_4/CompatUtils.scala new file mode 100644 index 00000000..d8d01fdc --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/spark_2_4/CompatUtils.scala @@ -0,0 +1,90 @@ +package org.apache.spark.shuffle.ucx + +import org.apache.spark.shuffle.utils.UcxThreadFactory +import java.nio.ByteBuffer +import java.util.concurrent.{Executors, ExecutorService} +import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} + +case class UcxShuffleMapId(shuffleId: Int, mapId: Int) {} + +case class UcxShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { + override def serializedSize: Int = UcxShuffleBlockId.serializedSize + + override def serialize(byteBuffer: ByteBuffer): Unit = { + byteBuffer.putInt(shuffleId) + byteBuffer.putInt(reduceId) + byteBuffer.putInt(mapId) + } +} + +object UcxShuffleBlockId { + final val serializedSize = 12 + + def deserialize(byteBuffer: ByteBuffer): UcxShuffleBlockId = { + val shuffleId = byteBuffer.getInt + val reduceId = byteBuffer.getInt + val mapId = byteBuffer.getInt + UcxShuffleBlockId(shuffleId, mapId, reduceId) + } +} + +case class UcxWorkerId(appId: String, exeId: Int, workerId: Int) extends BlockId { + override def serializedSize: Int = 12 + appId.size + + override def serialize(byteBuffer: ByteBuffer): Unit = { + byteBuffer.putInt(exeId) + byteBuffer.putInt(workerId) + byteBuffer.putInt(appId.size) + byteBuffer.put(appId.getBytes) + } + + override def toString(): String = s"UcxWorkerId($appId, $exeId, $workerId)" +} + +object UcxWorkerId { + def deserialize(byteBuffer: ByteBuffer): UcxWorkerId = { + val exeId = byteBuffer.getInt + val workerId = byteBuffer.getInt + val appIdSize = byteBuffer.getInt + val appIdBytes = new Array[Byte](appIdSize) + byteBuffer.get(appIdBytes) + UcxWorkerId(new String(appIdBytes), exeId, workerId) + } + + @`inline` + def makeExeWorkerId(id: UcxWorkerId): Long = { + (id.workerId.toLong << 32) | id.exeId + } + + @`inline` + def extractExeId(exeWorkerId: Long): Int = { + exeWorkerId.toInt + } + + @`inline` + def extractWorkerId(exeWorkerId: Long): Int = { + (exeWorkerId >> 32).toInt + } + + def apply(appId: String, exeWorkerId: Long): UcxWorkerId = { + UcxWorkerId(appId, UcxWorkerId.extractExeId(exeWorkerId), + UcxWorkerId.extractWorkerId(exeWorkerId)) + } +} + +object UcxThreadUtils { + def newForkJoinPool(prefix: String, maxThreadNumber: Int): SForkJoinPool = { + val factory = new SForkJoinPool.ForkJoinWorkerThreadFactory { + override def newThread(pool: SForkJoinPool) = + new SForkJoinWorkerThread(pool) { + setName(s"${prefix}-${super.getName}") + } + } + new SForkJoinPool(maxThreadNumber, factory, null, false) + } + + def newFixedDaemonPool(prefix: String, maxThreadNumber: Int): ExecutorService = { + val factory = new UcxThreadFactory().setDaemon(true).setPrefix(prefix) + Executors.newFixedThreadPool(maxThreadNumber, factory) + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/external/spark_3_0/CompatUtils.scala b/src/main/scala/org/apache/spark/shuffle/ucx/external/spark_3_0/CompatUtils.scala new file mode 100644 index 00000000..0a1f80f4 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/external/spark_3_0/CompatUtils.scala @@ -0,0 +1,89 @@ +package org.apache.spark.shuffle.ucx + +import org.apache.spark.shuffle.utils.UcxThreadFactory +import java.nio.ByteBuffer +import java.util.concurrent.{Executors, ExecutorService, ForkJoinPool, ForkJoinWorkerThread} + +case class UcxShuffleMapId(shuffleId: Int, mapId: Long) {} + +case class UcxShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { + override def serializedSize: Int = UcxShuffleBlockId.serializedSize + + override def serialize(byteBuffer: ByteBuffer): Unit = { + byteBuffer.putInt(shuffleId) + byteBuffer.putInt(reduceId) + byteBuffer.putLong(mapId) + } +} + +object UcxShuffleBlockId { + final val serializedSize = 16 + + def deserialize(byteBuffer: ByteBuffer): UcxShuffleBlockId = { + val shuffleId = byteBuffer.getInt + val reduceId = byteBuffer.getInt + val mapId = byteBuffer.getLong + UcxShuffleBlockId(shuffleId, mapId, reduceId) + } +} + +case class UcxWorkerId(appId: String, exeId: Int, workerId: Int) extends BlockId { + override def serializedSize: Int = 12 + appId.size + + override def serialize(byteBuffer: ByteBuffer): Unit = { + byteBuffer.putInt(exeId) + byteBuffer.putInt(workerId) + byteBuffer.putInt(appId.size) + byteBuffer.put(appId.getBytes) + } + + override def toString(): String = s"UcxWorkerId($appId, $exeId, $workerId)" +} + +object UcxWorkerId { + def deserialize(byteBuffer: ByteBuffer): UcxWorkerId = { + val exeId = byteBuffer.getInt + val workerId = byteBuffer.getInt + val appIdSize = byteBuffer.getInt + val appIdBytes = new Array[Byte](appIdSize) + byteBuffer.get(appIdBytes) + UcxWorkerId(new String(appIdBytes), exeId, workerId) + } + + @`inline` + def makeExeWorkerId(id: UcxWorkerId): Long = { + (id.workerId.toLong << 32) | id.exeId + } + + @`inline` + def extractExeId(exeWorkerId: Long): Int = { + exeWorkerId.toInt + } + + @`inline` + def extractWorkerId(exeWorkerId: Long): Int = { + (exeWorkerId >> 32).toInt + } + + def apply(appId: String, exeWorkerId: Long): UcxWorkerId = { + UcxWorkerId(appId, UcxWorkerId.extractExeId(exeWorkerId), + UcxWorkerId.extractWorkerId(exeWorkerId)) + } +} + +object UcxThreadUtils { + def newForkJoinPool(prefix: String, maxThreadNumber: Int): ForkJoinPool = { + val factory = new ForkJoinPool.ForkJoinWorkerThreadFactory { + override def newThread(pool: ForkJoinPool) = + new ForkJoinWorkerThread(pool) { + setName(s"${prefix}-${super.getName}") + } + } + new ForkJoinPool(maxThreadNumber, factory, null, false) + } + + def newFixedDaemonPool(prefix: String, maxThreadNumber: Int): ExecutorService = { + val factory = new UcxThreadFactory().setDaemon(true).setPrefix(prefix) + Executors.newFixedThreadPool(maxThreadNumber, factory) + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxLimitedMemPool.scala b/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxLimitedMemPool.scala new file mode 100644 index 00000000..e0415b29 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxLimitedMemPool.scala @@ -0,0 +1,251 @@ +package org.apache.spark.shuffle.ucx.memory + +import java.io.Closeable +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedDeque, ConcurrentSkipListSet} +import java.util.concurrent.Semaphore +import java.nio.ByteBuffer +import scala.collection.JavaConverters._ + +import org.openucx.jucx.ucp.{UcpContext, UcpMemMapParams, UcpMemory} +import org.openucx.jucx.ucs.UcsConstants +import org.apache.spark.shuffle.utils.{SparkucxUtils, UcxLogging, UnsafeUtils} +import org.apache.spark.shuffle.ucx.MemoryBlock + +class UcxSharedMemoryBlock(val closeCb: () => Unit, val refCount: AtomicInteger, + override val address: Long, override val size: Long) + extends MemoryBlock(address, size, true) { + + override def close(): Unit = { + if (refCount.decrementAndGet() == 0) { + closeCb() + } + } +} + +class UcxMemBlock(private[ucx] val memory: UcpMemory, + private[ucx] val allocator: UcxMemoryAllocator, + override val address: Long, override val size: Long) + extends MemoryBlock(address, size,memory.getMemType == + UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) { + def toByteBuffer() = { + UnsafeUtils.getByteBufferView(address, size.min(Int.MaxValue).toInt) + } + + override def close(): Unit = { + allocator.deallocate(this) + } +} + +class UcxLinkedMemBlock(private[memory] val superMem: UcxLinkedMemBlock, + private[memory] var broMem: UcxLinkedMemBlock, + override private[ucx] val memory: UcpMemory, + override private[ucx] val allocator: UcxMemoryAllocator, + override val address: Long, override val size: Long) + extends UcxMemBlock(memory, allocator, address, size) with Comparable[UcxLinkedMemBlock] { + override def compareTo(o: UcxLinkedMemBlock): Int = { + return address.compareTo(o.address) + } +} + +trait UcxMemoryAllocator extends Closeable { + def allocate(): UcxMemBlock + def deallocate(mem: UcxMemBlock): Unit + def preallocate(numBuffers: Int): Unit = { + (0 until numBuffers).map(x => allocate()).foreach(_.close()) + } + def totalSize(): Long +} + +abstract class UcxBaseMemAllocator extends UcxMemoryAllocator with UcxLogging { + private[memory] val stack = new ConcurrentSkipListSet[UcxMemBlock] + private[memory] val numAllocs = new AtomicInteger(0) + private[memory] val memMapParams = new UcpMemMapParams().allocate() + .setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + + override def close(): Unit = { + var numBuffers = 0 + var length = 0L + stack.forEach(block => { + if (block.memory.getNativeId != null) { + length = block.size + block.memory.deregister() + } + numBuffers += 1 + }) + if (numBuffers != 0) { + logInfo(s"Closing $numBuffers buffers size $length allocations " + + s"${numAllocs.get()}. Total ${SparkucxUtils.bytesToString(length * numBuffers)}") + stack.clear() + } + } +} + +case class UcxLinkedMemAllocator(length: Long, minRegistrationSize: Long, + next: UcxLinkedMemAllocator, + ucxContext: UcpContext) + extends UcxBaseMemAllocator() with Closeable { + private[this] val registrationSize = length.max(minRegistrationSize) + private[this] val sliceRange = (0L until (registrationSize / length)) + private[this] var limit: Semaphore = _ + logInfo(s"Allocator stack size $length") + if (next == null) { + memMapParams.setLength(registrationSize) + } + + override def allocate(): UcxMemBlock = { + acquireLimit() + var result = stack.pollFirst() + if (result != null) { + result + } else if (next == null) { + logDebug(s"Allocating buffer of size $length.") + while (result == null) { + numAllocs.incrementAndGet() + val memory = ucxContext.memoryMap(memMapParams) + var address = memory.getAddress + for (i <- sliceRange) { + stack.add(new UcxLinkedMemBlock(null, null, memory, this, address, length)) + address += length + } + result = stack.pollFirst() + } + result + } else { + val superMem = next.allocate().asInstanceOf[UcxLinkedMemBlock] + val address1 = superMem.address + val address2 = address1 + length + val block1 = new UcxLinkedMemBlock(superMem, null, superMem.memory, this, + address1, length) + val block2 = new UcxLinkedMemBlock(superMem, null, superMem.memory, this, + address2, length) + block1.broMem = block2 + block2.broMem = block1 + stack.add(block2) + block1 + } + } + + override def deallocate(memBlock: UcxMemBlock): Unit = { + val block = memBlock.asInstanceOf[UcxLinkedMemBlock] + if (block.superMem == null) { + stack.add(block) + releaseLimit() + return + } + + var releaseNext = false + block.superMem.synchronized { + releaseNext = stack.remove(block.broMem) + if (!releaseNext) { + stack.add(block) + } + } + + if (releaseNext) { + next.deallocate(block.superMem) + } + releaseLimit() + } + + override def totalSize(): Long = registrationSize * numAllocs.get() + + def acquireLimit() = if (limit != null) { + limit.acquire(1) + } + + def releaseLimit() = if (limit != null) { + limit.release(1) + } + + def setLimit(num: Int): Unit = { + limit = new Semaphore(num) + } +} + +case class UcxLimitedMemPool(ucxContext: UcpContext) + extends Closeable with UcxLogging { + private[memory] val allocatorMap = new ConcurrentHashMap[Long, UcxMemoryAllocator]() + private[memory] var minBufferSize: Long = 4096L + private[memory] var maxBufferSize: Long = 2L * 1024 * 1024 * 1024 + private[memory] var minRegistrationSize: Long = 1024L * 1024 + private[memory] var maxRegistrationSize: Long = 16L * 1024 * 1024 * 1024 + + def get(size: Long): MemoryBlock = { + allocatorMap.get(roundUpToTheNextPowerOf2(size)).allocate() + } + + def report(): Unit = { + val memInfo = allocatorMap.asScala.map(allocator => + allocator._1 -> allocator._2.totalSize).filter(_._2 != 0) + + if (memInfo.nonEmpty) { + logInfo(s"Memory pool use: $memInfo") + } + } + + def init(minBufSize: Long, maxBufSize: Long, minRegSize: Long, maxRegSize: Long, + preAllocMap: Map[Long, Int], limit: Boolean, memGroupSize: Int = 3): + Unit = { + assert(memGroupSize > 2, s"Invalid memGroupSize. Expect > 2. Actual $memGroupSize") + val maxMemFactor = 1.0 - 1.0 / (1 << (memGroupSize - 1)) + minBufferSize = roundUpToTheNextPowerOf2(minBufSize) + maxBufferSize = roundUpToTheNextPowerOf2(maxBufSize) + minRegistrationSize = roundUpToTheNextPowerOf2(minRegSize) + maxRegistrationSize = roundUpToTheNextPowerOf2(maxRegSize) + + val memRange = (1 until 47).map(1L << _).filter(m => + (m >= minBufferSize) && (m <= maxBufferSize)).reverse + val minLimit = (maxRegistrationSize / maxBufferSize * maxMemFactor).toLong + logInfo(s"limit $limit buf ($minBufferSize, $maxBufferSize) reg " + + s"($minRegistrationSize, $maxRegistrationSize)") + + var shift = 0 + for (i <- 0 until memRange.length by memGroupSize) { + var superAllocator: UcxLinkedMemAllocator = null + for (j <- 0 until memGroupSize.min(memRange.length - i)) { + val memSize = memRange(i + j) + val current = new UcxLinkedMemAllocator(memSize, minRegistrationSize, + superAllocator, ucxContext) + // set limit to top allocator + if (limit && (superAllocator == null)) { + val memLimit = (maxRegistrationSize / memSize).min(minLimit << shift) + .max(1L) + .min(Int.MaxValue) + logInfo(s"mem $memSize limit $memLimit") + current.setLimit(memLimit.toInt) + shift += 1 + } + superAllocator = current + allocatorMap.put(memSize, current) + } + } + preAllocMap.foreach{ + case (size, count) => { + allocatorMap.get(roundUpToTheNextPowerOf2(size)).preallocate(count) + } + } + } + + protected def roundUpToTheNextPowerOf2(size: Long): Long = { + if (size < minBufferSize) { + minBufferSize + } else { + // Round up length to the nearest power of two + var length = size + length -= 1 + length |= length >> 1 + length |= length >> 2 + length |= length >> 4 + length |= length >> 8 + length |= length >> 16 + length += 1 + length + } + } + + override def close(): Unit = { + allocatorMap.values.forEach(allocator => allocator.close()) + allocatorMap.clear() + } + } \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/ExternalUcxDriverRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/ExternalUcxDriverRpcEndpoint.scala new file mode 100644 index 00000000..0d02e31b --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/ExternalUcxDriverRpcEndpoint.scala @@ -0,0 +1,45 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.rpc + +import java.net.InetSocketAddress + +import scala.collection.immutable.HashMap +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc._ +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{PushServiceAddress, PushAllServiceAddress} +import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils} + +class ExternalUcxDriverRpcEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + + private val endpoints = mutable.HashSet.empty[RpcEndpointRef] + private var shuffleServerMap = mutable.HashMap.empty[String, Seq[Int]] + + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case message@PushServiceAddress(host: String, ports: Seq[Int], endpoint: RpcEndpointRef) => { + // Driver receives a message from executor with it's workerAddress + logInfo(s"Received message $ports from ${context.senderAddress}") + // 1. Introduce existing members of a cluster + if (shuffleServerMap.nonEmpty) { + val msg = PushAllServiceAddress(shuffleServerMap.toMap) + logDebug(s"Replying $msg to $endpoint") + context.reply(msg) + } + // 2. For each existing member introduce newly joined executor. + if (!shuffleServerMap.contains(host)) { + shuffleServerMap += host -> ports + endpoints.foreach(ep => { + logDebug(s"Sending message $ports to $ep") + ep.send(message) + }) + } + // 3. Add ep to registered eps. + endpoints.add(endpoint) + } + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/ExternalUcxExecutorRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/ExternalUcxExecutorRpcEndpoint.scala new file mode 100644 index 00000000..065ce23c --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/ExternalUcxExecutorRpcEndpoint.scala @@ -0,0 +1,25 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.rpc + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.shuffle.ucx.ExternalUcxClientTransport +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{PushServiceAddress, PushAllServiceAddress} +import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer + +import java.util.concurrent.ExecutorService + +class ExternalUcxExecutorRpcEndpoint(override val rpcEnv: RpcEnv, transport: ExternalUcxClientTransport, + executorService: ExecutorService) + extends RpcEndpoint { + + override def receive: PartialFunction[Any, Unit] = { + case PushServiceAddress(host: String, ports: Seq[Int], _: RpcEndpointRef) => + transport.connect(host, ports) + case PushAllServiceAddress(shuffleServerMap: Map[String, Seq[Int]]) => + transport.connectAll(shuffleServerMap) + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala index 8c476236..27e84000 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala @@ -19,4 +19,8 @@ object UcxRpcMessages { * Reply from driver with all executors in the cluster with their worker addresses. */ case class IntroduceAllExecutors(executorIdToAddress: Map[Long, SerializableDirectBuffer]) + + case class PushServiceAddress(host: String, ports: Seq[Int], endpoint: RpcEndpointRef) + + case class PushAllServiceAddress(shuffleServerMap: Map[String, Seq[Int]]) } diff --git a/src/main/scala/org/apache/spark/shuffle/utils/SerializableDirectBuffer.scala b/src/main/scala/org/apache/spark/shuffle/utils/SerializableDirectBuffer.scala index b1f6e970..63074e80 100755 --- a/src/main/scala/org/apache/spark/shuffle/utils/SerializableDirectBuffer.scala +++ b/src/main/scala/org/apache/spark/shuffle/utils/SerializableDirectBuffer.scala @@ -10,19 +10,18 @@ import java.nio.ByteBuffer import java.nio.channels.Channels import java.nio.charset.StandardCharsets -import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.shuffle.utils.{SparkucxUtils, UcxLogging} /** * A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make * it easier to pass ByteBuffers in case class messages. */ class SerializableDirectBuffer(@transient var buffer: ByteBuffer) extends Serializable - with Logging { + with UcxLogging { def value: ByteBuffer = buffer - private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + private def readObject(in: ObjectInputStream): Unit = SparkucxUtils.tryOrIOException { val length = in.readInt() buffer = ByteBuffer.allocateDirect(length) var amountRead = 0 @@ -37,7 +36,7 @@ class SerializableDirectBuffer(@transient var buffer: ByteBuffer) extends Serial buffer.rewind() // Allow us to read it later } - private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + private def writeObject(out: ObjectOutputStream): Unit = SparkucxUtils.tryOrIOException { out.writeInt(buffer.limit()) buffer.rewind() while (buffer.position() < buffer.limit()) { @@ -48,11 +47,11 @@ class SerializableDirectBuffer(@transient var buffer: ByteBuffer) extends Serial } class DeserializableToExternalMemoryBuffer(@transient var buffer: ByteBuffer)() extends Serializable - with Logging { + with UcxLogging { def value: ByteBuffer = buffer - private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + private def readObject(in: ObjectInputStream): Unit = SparkucxUtils.tryOrIOException { val length = in.readInt() var amountRead = 0 val channel = Channels.newChannel(in) @@ -79,7 +78,7 @@ object SerializationUtils { } def serializeInetAddress(address: InetSocketAddress): ByteBuffer = { - val hostAddress = new InetSocketAddress(Utils.localCanonicalHostName(), address.getPort) + val hostAddress = new InetSocketAddress(address.getAddress.getCanonicalHostName, address.getPort) val hostString = hostAddress.getHostName.getBytes(StandardCharsets.UTF_8) val result = ByteBuffer.allocateDirect(hostString.length + 4) result.putInt(hostAddress.getPort) diff --git a/src/main/scala/org/apache/spark/shuffle/utils/SparkucxUtils.scala b/src/main/scala/org/apache/spark/shuffle/utils/SparkucxUtils.scala new file mode 100644 index 00000000..f28ccf0d --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/utils/SparkucxUtils.scala @@ -0,0 +1,85 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.utils + +import java.io.IOException +import java.math.{MathContext, RoundingMode} +import java.util.Locale +import java.util.concurrent.ThreadFactory +import scala.util.control.NonFatal + +object SparkucxUtils extends UcxLogging { + def tryOrIOException[T](block: => T): T = { + try { + block + } catch { + case e: IOException => + logError("Exception encountered", e) + throw e + case NonFatal(e) => + logError("Exception encountered", e) + throw new IOException(e) + } + } + + def bytesToString(size: Long): String = bytesToString(BigInt(size)) + + def bytesToString(size: BigInt): String = { + val EB = 1L << 60 + val PB = 1L << 50 + val TB = 1L << 40 + val GB = 1L << 30 + val MB = 1L << 20 + val KB = 1L << 10 + + if (size >= BigInt(1L << 11) * EB) { + // The number is too large, show it in scientific notation. + BigDecimal(size, new MathContext(3, RoundingMode.HALF_UP)).toString() + " B" + } else { + val (value, unit) = { + if (size >= 2 * EB) { + (BigDecimal(size) / EB, "EB") + } else if (size >= 2 * PB) { + (BigDecimal(size) / PB, "PB") + } else if (size >= 2 * TB) { + (BigDecimal(size) / TB, "TB") + } else if (size >= 2 * GB) { + (BigDecimal(size) / GB, "GB") + } else if (size >= 2 * MB) { + (BigDecimal(size) / MB, "MB") + } else if (size >= 2 * KB) { + (BigDecimal(size) / KB, "KB") + } else { + (BigDecimal(size), "B") + } + } + "%.1f %s".formatLocal(Locale.US, value, unit) + } + } +} + +class UcxThreadFactory extends ThreadFactory { + private var daemon: Boolean = true + private var prefix: String = "UCX" + + private class NamedThread(r: Runnable) extends Thread(r) { + setDaemon(daemon) + setName(s"${prefix}-${super.getName}") + } + + def setDaemon(isDaemon: Boolean): this.type = { + daemon = isDaemon + this + } + + def setPrefix(name: String): this.type = { + prefix = name + this + } + + def newThread(r: Runnable): Thread = { + new NamedThread(r) + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/utils/UcxLogging.scala b/src/main/scala/org/apache/spark/shuffle/utils/UcxLogging.scala new file mode 100644 index 00000000..527e7b86 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/utils/UcxLogging.scala @@ -0,0 +1,64 @@ +package org.apache.spark.shuffle.utils + +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import org.apache.log4j.LogManager + +trait UcxLogging { + @transient private var log_ : Logger = null + + // Method to get the logger name for this object + protected def logName = { + // Ignore trailing $'s in the class names for Scala objects + this.getClass.getName.stripSuffix("$") + } + + protected def log: Logger = { + if (log_ == null) { + log_ = LoggerFactory.getLogger(logName) + } + log_ + } + + // Log methods that take only a String + protected def logInfo(msg: => String) { + if (log.isInfoEnabled) log.info(msg) + } + + protected def logDebug(msg: => String) { + if (log.isDebugEnabled) log.debug(msg) + } + + protected def logTrace(msg: => String) { + if (log.isTraceEnabled) log.trace(msg) + } + + protected def logWarning(msg: => String) { + if (log.isWarnEnabled) log.warn(msg) + } + + protected def logError(msg: => String) { + if (log.isErrorEnabled) log.error(msg) + } + + // Log methods that take Throwables (Exceptions/Errors) too + protected def logInfo(msg: => String, throwable: Throwable) { + if (log.isInfoEnabled) log.info(msg, throwable) + } + + protected def logDebug(msg: => String, throwable: Throwable) { + if (log.isDebugEnabled) log.debug(msg, throwable) + } + + protected def logTrace(msg: => String, throwable: Throwable) { + if (log.isTraceEnabled) log.trace(msg, throwable) + } + + protected def logWarning(msg: => String, throwable: Throwable) { + if (log.isWarnEnabled) log.warn(msg, throwable) + } + + protected def logError(msg: => String, throwable: Throwable) { + if (log.isErrorEnabled) log.error(msg, throwable) + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/utils/UnsafeUtils.scala b/src/main/scala/org/apache/spark/shuffle/utils/UnsafeUtils.scala index 329fb01b..e4eab255 100755 --- a/src/main/scala/org/apache/spark/shuffle/utils/UnsafeUtils.scala +++ b/src/main/scala/org/apache/spark/shuffle/utils/UnsafeUtils.scala @@ -10,9 +10,8 @@ import java.nio.channels.FileChannel import org.openucx.jucx.UcxException import sun.nio.ch.{DirectBuffer, FileChannelImpl} -import org.apache.spark.internal.Logging -object UnsafeUtils extends Logging { +object UnsafeUtils extends UcxLogging { val INT_SIZE: Int = 4 val LONG_SIZE: Int = 8