Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/sparkucx-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
build-sparkucx:
strategy:
matrix:
spark_version: ["2.1", "2.4", "3.0"]
spark_version: ["2.1", "2.4", "3.0", "3.1"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/sparkucx-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
release:
strategy:
matrix:
spark_version: ["2.1", "2.4", "3.0"]
spark_version: ["2.1", "2.4", "3.0", "3.1"]
runs-on: ubuntu-latest
steps:
- name: Checkout code
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ that are supported by [UCX](https://github.com/openucx/ucx#supported-transports)
This open-source project is developed, maintained and supported by the [UCF consortium](http://www.ucfconsortium.org/).

## Runtime requirements
* Apache Spark 2.3/2.4/3.0
* Apache Spark 2.3/2.4/3.0/3.1
* Java 8+
* Installed UCX of version 1.10+, and [UCX supported transport hardware](https://github.com/openucx/ucx#supported-transports).

Expand Down Expand Up @@ -34,9 +34,9 @@ to spark (e.g. in $SPARK_HOME/conf/spark-defaults.conf):
```
spark.shuffle.manager org.apache.spark.shuffle.UcxShuffleManager
```
For spark-3.0 version add SparkUCX ShuffleIO plugin:
For spark-3.0 or spark-3.1 versions add SparkUCX ShuffleIO plugin:
```
spark.shuffle.sort.io.plugin.class org.apache.spark.shuffle.compat.spark_3_0.UcxLocalDiskShuffleDataIO
spark.shuffle.sort.io.plugin.class org.apache.spark.shuffle.compat.spark_(3_0|3_1).UcxLocalDiskShuffleDataIO
```

### Build
Expand Down
49 changes: 45 additions & 4 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ See file LICENSE for terms.
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/spark_3_1/**</exclude>
<exclude>**/spark_3_0/**</exclude>
<exclude>**/spark_2_4/**</exclude>
</excludes>
Expand All @@ -53,6 +54,7 @@ See file LICENSE for terms.
<artifactId>scala-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/spark_3_1/**</exclude>
<exclude>**/spark_3_0/**</exclude>
<exclude>**/spark_2_4/**</exclude>
</excludes>
Expand All @@ -62,7 +64,7 @@ See file LICENSE for terms.
</build>
<properties>
<spark.version>2.1.0</spark.version>
<sonar.exclusions>**/spark_3_0/**, **/spark_2_4/**</sonar.exclusions>
<sonar.exclusions>**/spark_3_1/**, **/spark_3_0/**, **/spark_2_4/**</sonar.exclusions>
<scala.version>2.11.12</scala.version>
<scala.compat.version>2.11</scala.compat.version>
</properties>
Expand All @@ -76,6 +78,7 @@ See file LICENSE for terms.
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/spark_3_1/**</exclude>
<exclude>**/spark_3_0/**</exclude>
<exclude>**/spark_2_1/**</exclude>
</excludes>
Expand All @@ -86,6 +89,7 @@ See file LICENSE for terms.
<artifactId>scala-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/spark_3_1/**</exclude>
<exclude>**/spark_2_1/**</exclude>
<exclude>**/spark_3_0/**</exclude>
</excludes>
Expand All @@ -95,13 +99,48 @@ See file LICENSE for terms.
</build>
<properties>
<spark.version>2.4.0</spark.version>
<sonar.exclusions>**/spark_3_0/**, **/spark_2_1/**</sonar.exclusions>
<sonar.exclusions>**/spark_3_1/**, **/spark_3_0/**, **/spark_2_1/**</sonar.exclusions>
<scala.version>2.11.12</scala.version>
<scala.compat.version>2.11</scala.compat.version>
</properties>
</profile>
<profile>
<id>spark-3.0</id>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/spark_3_1/**</exclude>
<exclude>**/spark_2_1/**</exclude>
<exclude>**/spark_2_4/**</exclude>
</excludes>
</configuration>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/spark_3_1/**</exclude>
<exclude>**/spark_2_1/**</exclude>
<exclude>**/spark_2_4/**</exclude>
</excludes>
</configuration>
</plugin>
</plugins>
</build>
<properties>
<spark.version>3.0.1</spark.version>
<scala.version>2.12.10</scala.version>
<scala.compat.version>2.12</scala.compat.version>
<sonar.exclusions>**/spark_3_1/**, **/spark_2_1/**, **/spark_2_4/**</sonar.exclusions>
</properties>
</profile>
<profile>
<id>spark-3.1</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
Expand All @@ -112,6 +151,7 @@ See file LICENSE for terms.
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/spark_3_0/**</exclude>
<exclude>**/spark_2_1/**</exclude>
<exclude>**/spark_2_4/**</exclude>
</excludes>
Expand All @@ -122,6 +162,7 @@ See file LICENSE for terms.
<artifactId>scala-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/spark_3_0/**</exclude>
<exclude>**/spark_2_1/**</exclude>
<exclude>**/spark_2_4/**</exclude>
</excludes>
Expand All @@ -130,10 +171,10 @@ See file LICENSE for terms.
</plugins>
</build>
<properties>
<spark.version>3.0.1</spark.version>
<spark.version>3.1.2</spark.version>
<scala.version>2.12.10</scala.version>
<scala.compat.version>2.12</scala.compat.version>
<sonar.exclusions>**/spark_2_1/**, **/spark_2_4/**</sonar.exclusions>
<sonar.exclusions>**/spark_3_0/**, **/spark_2_1/**, **/spark_2_4/**</sonar.exclusions>
</properties>
</profile>
</profiles>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx.reducer.compat.spark_3_1;

import org.apache.spark.network.shuffle.BlockFetchingListener;
import org.apache.spark.shuffle.UcxWorkerWrapper;
import org.apache.spark.shuffle.ucx.UnsafeUtils;
import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
import org.apache.spark.shuffle.ucx.reducer.ReducerCallback;
import org.apache.spark.shuffle.ucx.reducer.OnBlocksFetchCallback;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.ShuffleBlockBatchId;
import org.apache.spark.storage.ShuffleBlockId;
import org.openucx.jucx.UcxUtils;
import org.openucx.jucx.ucp.UcpEndpoint;
import org.openucx.jucx.ucp.UcpRemoteKey;
import org.openucx.jucx.ucp.UcpRequest;

import java.nio.ByteBuffer;
import java.util.Map;

/**
* Callback, called when got all offsets for blocks
*/
public class OnOffsetsFetchCallback extends ReducerCallback {
private final RegisteredMemory offsetMemory;
private final long[] dataAddresses;
private Map<Integer, UcpRemoteKey> dataRkeysCache;
private final Map<Long, Integer> mapId2PartitionId;

public OnOffsetsFetchCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener,
RegisteredMemory offsetMemory, long[] dataAddresses,
Map<Integer, UcpRemoteKey> dataRkeysCache,
Map<Long, Integer> mapId2PartitionId) {
super(blockIds, endpoint, listener);
this.offsetMemory = offsetMemory;
this.dataAddresses = dataAddresses;
this.dataRkeysCache = dataRkeysCache;
this.mapId2PartitionId = mapId2PartitionId;
}

@Override
public void onSuccess(UcpRequest request) {
ByteBuffer resultOffset = offsetMemory.getBuffer();
long totalSize = 0;
int[] sizes = new int[blockIds.length];
int offset = 0;
long blockOffset;
long blockLength;
int offsetSize = UnsafeUtils.LONG_SIZE;
for (int i = 0; i < blockIds.length; i++) {
// Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd |
if (blockIds[i] instanceof ShuffleBlockBatchId) {
ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId) blockIds[i];
int blocksInBatch = blockBatchId.endReduceId() - blockBatchId.startReduceId();
blockOffset = resultOffset.getLong(offset * 2 * offsetSize);
blockLength = resultOffset.getLong(offset * 2 * offsetSize + offsetSize * blocksInBatch)
- blockOffset;
offset += blocksInBatch;
} else {
blockOffset = resultOffset.getLong(offset * 16);
blockLength = resultOffset.getLong(offset * 16 + 8) - blockOffset;
offset++;
}

assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE);
sizes[i] = (int) blockLength;
totalSize += blockLength;
dataAddresses[i] += blockOffset;
}

assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
mempool.put(offsetMemory);
RegisteredMemory blocksMemory = mempool.get((int) totalSize);

offset = 0;
// Submits N fetch blocks requests
for (int i = 0; i < blockIds.length; i++) {
int mapPartitionId = (blockIds[i] instanceof ShuffleBlockId) ?
mapId2PartitionId.get(((ShuffleBlockId)blockIds[i]).mapId()) :
mapId2PartitionId.get(((ShuffleBlockBatchId)blockIds[i]).mapId());
endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(mapPartitionId),
UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]);
offset += sizes[i];
}

// Process blocks when all fetched.
// Flush guarantees that callback would invoke when all fetch requests will completed.
endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx.reducer.compat.spark_3_1;

import org.apache.spark.SparkEnv;
import org.apache.spark.executor.TempShuffleReadMetrics;
import org.apache.spark.network.shuffle.BlockFetchingListener;
import org.apache.spark.network.shuffle.BlockStoreClient;
import org.apache.spark.network.shuffle.DownloadFileManager;
import org.apache.spark.shuffle.DriverMetadata;
import org.apache.spark.shuffle.UcxShuffleManager;
import org.apache.spark.shuffle.UcxWorkerWrapper;
import org.apache.spark.shuffle.ucx.UnsafeUtils;
import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
import org.apache.spark.storage.*;
import org.openucx.jucx.UcxUtils;
import org.openucx.jucx.ucp.UcpEndpoint;
import org.openucx.jucx.ucp.UcpRemoteKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;


import java.util.HashMap;
import java.util.Map;

public class UcxShuffleClient extends BlockStoreClient {
private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class);
private final UcxWorkerWrapper workerWrapper;
private final Map<Long, Integer> mapId2PartitionId;
private final TempShuffleReadMetrics shuffleReadMetrics;
private final int shuffleId;
final HashMap<Integer, UcpRemoteKey> offsetRkeysCache = new HashMap<>();
final HashMap<Integer, UcpRemoteKey> dataRkeysCache = new HashMap<>();


public UcxShuffleClient(int shuffleId, UcxWorkerWrapper workerWrapper,
Map<Long, Integer> mapId2PartitionId, TempShuffleReadMetrics shuffleReadMetrics) {
this.workerWrapper = workerWrapper;
this.shuffleId = shuffleId;
this.mapId2PartitionId = mapId2PartitionId;
this.shuffleReadMetrics = shuffleReadMetrics;
}

/**
* Submits n non blocking fetch offsets to get needed offsets for n blocks.
*/
private void submitFetchOffsets(UcpEndpoint endpoint, BlockId[] blockIds,
RegisteredMemory offsetMemory,
long[] dataAddresses) {
DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(shuffleId);
long offset = 0;
int startReduceId;
long size;

for (int i = 0; i < blockIds.length; i++) {
BlockId blockId = blockIds[i];
int mapIdpartition;

if (blockId instanceof ShuffleBlockId) {
ShuffleBlockId shuffleBlockId = (ShuffleBlockId) blockId;
mapIdpartition = mapId2PartitionId.get(shuffleBlockId.mapId());
size = 2L * UnsafeUtils.LONG_SIZE;
startReduceId = shuffleBlockId.reduceId();
} else {
ShuffleBlockBatchId shuffleBlockBatchId = (ShuffleBlockBatchId) blockId;
mapIdpartition = mapId2PartitionId.get(shuffleBlockBatchId.mapId());
size = (shuffleBlockBatchId.endReduceId() - shuffleBlockBatchId.startReduceId())
* 2L * UnsafeUtils.LONG_SIZE;
startReduceId = shuffleBlockBatchId.startReduceId();
}

long offsetAddress = driverMetadata.offsetAddress(mapIdpartition);
dataAddresses[i] = driverMetadata.dataAddress(mapIdpartition);

offsetRkeysCache.computeIfAbsent(mapIdpartition, mapId ->
endpoint.unpackRemoteKey(driverMetadata.offsetRkey(mapIdpartition)));

dataRkeysCache.computeIfAbsent(mapIdpartition, mapId ->
endpoint.unpackRemoteKey(driverMetadata.dataRkey(mapIdpartition)));

endpoint.getNonBlockingImplicit(
offsetAddress + startReduceId * UnsafeUtils.LONG_SIZE,
offsetRkeysCache.get(mapIdpartition),
UcxUtils.getAddress(offsetMemory.getBuffer()) + offset,
size);

offset += size;
}
}

@Override
public void fetchBlocks(String host, int port, String execId, String[] blockIds, BlockFetchingListener listener,
DownloadFileManager downloadFileManager) {
long startTime = System.currentTimeMillis();
BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty());
UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId);
long[] dataAddresses = new long[blockIds.length];
int totalBlocks = 0;

BlockId[] blocks = new BlockId[blockIds.length];

for (int i = 0; i < blockIds.length; i++) {
blocks[i] = BlockId.apply(blockIds[i]);
if (blocks[i] instanceof ShuffleBlockId) {
totalBlocks += 1;
} else {
ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId)blocks[i];
totalBlocks += (blockBatchId.endReduceId() - blockBatchId.startReduceId());
}
}

RegisteredMemory offsetMemory = ((UcxShuffleManager)SparkEnv.get().shuffleManager())
.ucxNode().getMemoryPool().get(totalBlocks * 2 * UnsafeUtils.LONG_SIZE);
// Submits N implicit get requests without callback
submitFetchOffsets(endpoint, blocks, offsetMemory, dataAddresses);

// flush guarantees that all that requests completes when callback is called.
// TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush.
workerWrapper.worker().flushNonBlocking(
new OnOffsetsFetchCallback(blocks, endpoint, listener, offsetMemory,
dataAddresses, dataRkeysCache, mapId2PartitionId));

shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime);
}

@Override
public void close() {
offsetRkeysCache.values().forEach(UcpRemoteKey::close);
dataRkeysCache.values().forEach(UcpRemoteKey::close);
logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime());
}

}
Loading