diff --git a/README.md b/README.md index 0e38ddb..81117f7 100644 --- a/README.md +++ b/README.md @@ -3,12 +3,138 @@ CUDA Stream Compaction **University of Pennsylvania, CIS 565: GPU Programming and Architecture, Project 2** -* (TODO) YOUR NAME HERE - * (TODO) [LinkedIn](), [personal website](), [twitter](), etc. -* Tested on: (TODO) Windows 22, i7-2222 @ 2.22GHz 22GB, GTX 222 222MB (Moore 2222 Lab) +* Siyu Zheng +* Tested on: Windows 10, i7-8750 @ 2.20GHz 16GB, GTX 1060 6GB, Visual Studio 2015, CUDA 8.0(Personal Laptop) -### (TODO: Your README) +## Description -Include analysis, etc. (Remember, this is public, so don't put -anything here that you don't want to share with the world.) +### Stream Compation +![](img/streamCompaction.png) +The goal of stream compaction is that, given an array of elements, we create a new array with elements that meet a certain criteria, e.g. non null and preserve order. It's used in path tracing, collision detection, sparse matrix compression, etc. + +* Step 1: Compute temporary array +* Step 2: Run exclusive scan on temporary array +* Step 3: Scatter + +### CPU Scan +Use for loop to compute an exclusive prefix sum. +![](img/cpu.png) + +Number of add: O(n) + +### Naive GPU Scan + +![](img/naive.png) + +Use double-buffer to scan two array. First do exclusive scan, then do shift right to get inclusive scan array. + +Number of add: O(nlog2(n)) + +### Work-Efficient GPU Scan +Up-Sweep (Reduce) Phase: + +![](img/upsweep.png) + + In the reduce phase, we traverse the tree from leaves to root computing partial sums at internal nodes of the tree. + +Down-Sweep Phase: + +![](img/downsweep.png) + +In the down-sweep phase, we traverse back down the tree from the root, using the partial sums from the reduce phase to build the scan in place on the array. We start by inserting zero at the root of the tree, and on each step, each node at the current level passes its own value to its left child, and the sum of its value and the former value of its left child to its right child. + +### Thrust's Implementation + +Wraps a call to the Thrust library function thrust::exclusive_scan(first, last, result). + +## Performance Analysis + +* Roughly optimize the block sizes of each of your implementations for minimal run time on your GPU. +Array Size 1<<15 + + +| Block Size | 128 | 256 | 512 |1024 | +| ------------- |-------------|-------| -----|----- | +| naive | 0.16784 | 0.132096 | 0.157504|0.155584 | +| coherent | 1639.7 | 1534.2 | 0.094048| 0.096736 | + +In my experiment, the performance for different block size is quite closed. I chose 1024 for my further tests. + + +* Compare all of these GPU Scan implementations (Naive, Work-Efficient, and Thrust) to the serial CPU version of Scan. Plot a graph of the comparison (with array size on the independent axis). +![](img/pot1.png) +![](img/pot2.png) +![](img/npot1.png) +![](img/npot2.png) + +* Write a brief explanation of the phenomena you see here. + +At first, I used non-optimized efficient GPU scan which is slower than CPU approach. Then I optimized it with resizable blockPerGrid, so that in each level of depth in scanning we can terminate idle threads. In upSweep and downSweep stage, modify the array index to maintain correctness. As a result, in test of array size larger than 16, the effiecient GPU approach has better performance than CPU approach. + +Compare these four implementation, we can see that when the array size is small, the CPU approach has the best performance. Effiecient GPU approach is better than naive approach. After array size larger than 16, GPU implementation has better performance than CPU. For thrust approach, when array size is large, it has the best performance and the as the size grows, the running time doesn't increase much so it's quite stable. + +I checked timeline when array size is 1 << 15. The function call of thrust::exclusive_scan is about one half of each kernel sweep call. So in the thrust implementation most expense is on memory allocation and copy. I guess the base cost for memory operation is quite big in thrust, but as the array size grows, since it might has some kind of memory access optimization like contiguous memory access, the memory operation cost might not increase a lot. As a result, in larger array, thrust implementation has the best performance. + +The performance bottleneck for naive approach is mainly the algorithm. For non-optimized efficient scan, too many idle threads is the bottleneck. For optimized efficient GPU approach, the bottleneck is mainly memory I/O. If we switch to shared memory, the performance will increase a lot. + +## Result + +Array size = 1<<15 + +``` + +**************** +** SCAN TESTS ** +**************** + [ 10 39 41 0 14 37 18 40 1 42 27 21 10 ... 14 0 ] +==== cpu scan, power-of-two ==== + elapsed time: 0.10945ms (std::chrono Measured) + [ 0 10 49 90 90 104 141 159 199 200 242 269 290 ... 803563 803577 ] +==== cpu scan, non-power-of-two ==== + elapsed time: 0.116406ms (std::chrono Measured) + [ 0 10 49 90 90 104 141 159 199 200 242 269 290 ... 803493 803514 ] + passed +==== naive scan, power-of-two ==== + elapsed time: 0.235072ms (CUDA Measured) + passed +==== naive scan, non-power-of-two ==== + elapsed time: 0.197024ms (CUDA Measured) + passed +==== work-efficient scan, power-of-two ==== + elapsed time: 0.147424ms (CUDA Measured) + passed +==== work-efficient scan, non-power-of-two ==== + elapsed time: 0.119808ms (CUDA Measured) + passed +==== thrust scan, power-of-two ==== + elapsed time: 0.299008ms (CUDA Measured) + passed +==== thrust scan, non-power-of-two ==== + elapsed time: 0.253952ms (CUDA Measured) + passed + +***************************** +** STREAM COMPACTION TESTS ** +***************************** + [ 0 3 0 2 1 3 2 3 3 1 3 0 0 ... 1 0 ] +==== cpu compact without scan, power-of-two ==== + elapsed time: 0.121971ms (std::chrono Measured) + [ 3 2 1 3 2 3 3 1 3 2 1 2 1 ... 2 1 ] + passed +==== cpu compact without scan, non-power-of-two ==== + elapsed time: 0.139594ms (std::chrono Measured) + [ 3 2 1 3 2 3 3 1 3 2 1 2 1 ... 1 1 ] + passed +==== cpu compact with scan ==== + elapsed time: 0.552812ms (std::chrono Measured) + [ 3 2 1 3 2 3 3 1 3 2 1 2 1 ... 2 1 ] + passed +==== work-efficient compact, power-of-two ==== + elapsed time: 0.141568ms (CUDA Measured) + passed +==== work-efficient compact, non-power-of-two ==== + elapsed time: 0.243712ms (CUDA Measured) + passed + +``` diff --git a/img/cpu.png b/img/cpu.png new file mode 100644 index 0000000..5128ff6 Binary files /dev/null and b/img/cpu.png differ diff --git a/img/downsweep.png b/img/downsweep.png new file mode 100644 index 0000000..b683da1 Binary files /dev/null and b/img/downsweep.png differ diff --git a/img/naive.png b/img/naive.png new file mode 100644 index 0000000..2c0e7f8 Binary files /dev/null and b/img/naive.png differ diff --git a/img/npot1.png b/img/npot1.png new file mode 100644 index 0000000..e7b263f Binary files /dev/null and b/img/npot1.png differ diff --git a/img/npot2.png b/img/npot2.png new file mode 100644 index 0000000..6866563 Binary files /dev/null and b/img/npot2.png differ diff --git a/img/pot1.png b/img/pot1.png new file mode 100644 index 0000000..f863b19 Binary files /dev/null and b/img/pot1.png differ diff --git a/img/pot2.png b/img/pot2.png new file mode 100644 index 0000000..72dde6f Binary files /dev/null and b/img/pot2.png differ diff --git a/img/streamCompaction.png b/img/streamCompaction.png new file mode 100644 index 0000000..0dfa4d2 Binary files /dev/null and b/img/streamCompaction.png differ diff --git a/img/upsweep.png b/img/upsweep.png new file mode 100644 index 0000000..845d34c Binary files /dev/null and b/img/upsweep.png differ diff --git a/src/main.cpp b/src/main.cpp index 1850161..e490452 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -13,7 +13,7 @@ #include #include "testing_helpers.hpp" -const int SIZE = 1 << 8; // feel free to change the size of array +const int SIZE = 1 << 15; // feel free to change the size of array const int NPOT = SIZE - 3; // Non-Power-Of-Two int *a = new int[SIZE]; int *b = new int[SIZE]; @@ -81,6 +81,20 @@ int main(int argc, char* argv[]) { //printArray(NPOT, c, true); printCmpResult(NPOT, b, c); + //zeroArray(SIZE, c); + //printDesc("work-efficient scan, shared memory, power-of-two"); + //StreamCompaction::Efficient::sharedMemoryScan(SIZE, c, a); + //printElapsedTime(StreamCompaction::Efficient::timer().getGpuElapsedTimeForPreviousOperation(), "(CUDA Measured)"); + ////printArray(SIZE, c, true); + //printCmpResult(SIZE, b, c); + + //zeroArray(SIZE, c); + //printDesc("work-efficient scan, shared memory, non-power-of-two"); + //StreamCompaction::Efficient::sharedMemoryScan(NPOT, c, a); + //printElapsedTime(StreamCompaction::Efficient::timer().getGpuElapsedTimeForPreviousOperation(), "(CUDA Measured)"); + ////printArray(NPOT, c, true); + //printCmpResult(NPOT, b, c); + zeroArray(SIZE, c); printDesc("thrust scan, power-of-two"); StreamCompaction::Thrust::scan(SIZE, c, a); diff --git a/stream_compaction/CMakeLists.txt b/stream_compaction/CMakeLists.txt index cdbef77..e31ca3c 100644 --- a/stream_compaction/CMakeLists.txt +++ b/stream_compaction/CMakeLists.txt @@ -13,5 +13,5 @@ set(SOURCE_FILES cuda_add_library(stream_compaction ${SOURCE_FILES} - OPTIONS -arch=sm_20 + OPTIONS -arch=sm_30 ) diff --git a/stream_compaction/common.cu b/stream_compaction/common.cu index 8fc0211..a1b286b 100644 --- a/stream_compaction/common.cu +++ b/stream_compaction/common.cu @@ -24,6 +24,11 @@ namespace StreamCompaction { */ __global__ void kernMapToBoolean(int n, int *bools, const int *idata) { // TODO + int index = blockDim.x * blockIdx.x + threadIdx.x; + if (index >= n) { + return; + } + bools[index] = idata[index] == 0 ? 0 : 1; } /** @@ -33,6 +38,13 @@ namespace StreamCompaction { __global__ void kernScatter(int n, int *odata, const int *idata, const int *bools, const int *indices) { // TODO + int index = blockDim.x * blockIdx.x + threadIdx.x; + if (index >= n) { + return; + } + if (bools[index] == 1) { + odata[indices[index]] = idata[index]; + } } } diff --git a/stream_compaction/cpu.cu b/stream_compaction/cpu.cu index 05ce667..c81e96e 100644 --- a/stream_compaction/cpu.cu +++ b/stream_compaction/cpu.cu @@ -1,15 +1,16 @@ #include #include "cpu.h" -#include "common.h" +#include "common.h" + namespace StreamCompaction { namespace CPU { - using StreamCompaction::Common::PerformanceTimer; - PerformanceTimer& timer() - { - static PerformanceTimer timer; - return timer; + using StreamCompaction::Common::PerformanceTimer; + PerformanceTimer& timer() + { + static PerformanceTimer timer; + return timer; } /** @@ -18,9 +19,13 @@ namespace StreamCompaction { * (Optional) For better understanding before starting moving to GPU, you can simulate your GPU scan in this function first. */ void scan(int n, int *odata, const int *idata) { - timer().startCpuTimer(); + //timer().startCpuTimer(); // TODO - timer().endCpuTimer(); + odata[0] = 0; + for (int i = 1; i < n; i++) { + odata[i] = odata[i - 1] + idata[i - 1]; + } + //timer().endCpuTimer(); } /** @@ -31,11 +36,18 @@ namespace StreamCompaction { int compactWithoutScan(int n, int *odata, const int *idata) { timer().startCpuTimer(); // TODO + int count = 0; + for (int i = 0; i < n; i++) { + if (idata[i] == 0) { + continue; + } + odata[count++] = idata[i]; + } timer().endCpuTimer(); - return -1; + return count; } - /** + /**s * CPU stream compaction using scan and scatter, like the parallel version. * * @returns the number of elements remaining after compaction. @@ -43,8 +55,27 @@ namespace StreamCompaction { int compactWithScan(int n, int *odata, const int *idata) { timer().startCpuTimer(); // TODO + int *mdata = new int[n]; + int *sdata = new int[n]; + int count = 0; + for (int i = 0; i < n; i++) { + if (idata[i] == 0) { + mdata[i] = 0; + } + else { + mdata[i] = 1; + } + } + scan(n, sdata, mdata); + for (int i = 0; i < n; i++) { + if (mdata[i] != 0) { + odata[sdata[i]] = idata[i]; + count++; + } + } + delete[] mdata, sdata; timer().endCpuTimer(); - return -1; + return count; } } } diff --git a/stream_compaction/efficient.cu b/stream_compaction/efficient.cu index 36c5ef2..0d2462a 100644 --- a/stream_compaction/efficient.cu +++ b/stream_compaction/efficient.cu @@ -2,25 +2,277 @@ #include #include "common.h" #include "efficient.h" +#include "device_launch_parameters.h" +#include +#include +#include +#include "thrust.h" + +#define NUM_BANKS 16 +#define LOG_NUM_BANKS 4 +#define CONFLICT_FREE_OFFSET(n) \ + ((n) >> NUM_BANKS + (n) >> (2 * LOG_NUM_BANKS)) namespace StreamCompaction { namespace Efficient { - using StreamCompaction::Common::PerformanceTimer; - PerformanceTimer& timer() - { - static PerformanceTimer timer; - return timer; + using StreamCompaction::Common::PerformanceTimer; + PerformanceTimer& timer() + { + static PerformanceTimer timer; + return timer; } + __global__ void kernUpSweep(int n, int d, int *odata) { + int index = (blockDim.x * blockIdx.x) + threadIdx.x; + if (index >= n) { + return; + } + int step = 1 << (d + 1); + if (index % step == 0) { + odata[index + step - 1] += odata[index + (1 << d) - 1]; + } + } + + __global__ void kernDownSweep(int n, int d, int *odata) { + int index = (blockDim.x * blockIdx.x) + threadIdx.x; + if (index >= n) { + return; + } + int step = 1 << (d + 1); + if (index % step == 0) { + int t = odata[index + (1 << d) - 1]; + odata[index + (1 << d) - 1] = odata[index + step - 1]; + odata[index + step - 1] += t; + } + } + + + /** * Performs prefix-sum (aka scan) on idata, storing the result into odata. */ - void scan(int n, int *odata, const int *idata) { - timer().startGpuTimer(); + void nonoptscan(int n, int *odata, const int *idata) { // TODO - timer().endGpuTimer(); + int upLimit = ilog2ceil(n); + int len = 1 << upLimit; + + int blockSize = 1024; + dim3 threadsPerBlock(blockSize); + dim3 blocksPerGrid((len + blockSize - 1) / blockSize); + + int *dev_data; + cudaMalloc((void**)&dev_data, len * sizeof(int)); + checkCUDAError("cudaMalloc dev_data failed!"); + cudaMemcpy(dev_data, idata, len * sizeof(int), cudaMemcpyHostToDevice); + checkCUDAError("cudaMemcpy dev_data failed!"); + + + + timer().startGpuTimer(); + for (int d = 0; d <= upLimit - 1; d++) { + kernUpSweep << > > (len, d, dev_data); + checkCUDAError("kernUpSweep failed!"); + } + + cudaMemset(&dev_data[len - 1], 0, sizeof(int)); + checkCUDAError("cudaMemcpy set last one to be zero failed!"); + + for (int d = upLimit - 1; d >= 0; d--) { + kernDownSweep << > > (len, d, dev_data); + checkCUDAError("kernDownSweep failed!"); + } + timer().endGpuTimer(); + cudaMemcpy(odata, dev_data, len * sizeof(int), cudaMemcpyDeviceToHost); + checkCUDAError("cudaMemcpy dev_data failed!"); + cudaFree(dev_data); + cudaDeviceSynchronize(); } + __global__ void kernOptUpSweep(int n, int d, int *odata) { + int index = (blockDim.x * blockIdx.x) + threadIdx.x; + if (index >= n) { + return; + } + int step = 1 << (d + 1); + odata[index * step + step - 1] += odata[index * step + (1 << d) - 1]; + } + + __global__ void kernOptDownSweep(int n, int d, int *odata) { + int index = (blockDim.x * blockIdx.x) + threadIdx.x; + if (index >= n) { + return; + } + int step = 1 << (d + 1); + int t = odata[index*step + (1 << d) - 1]; + odata[index * step + (1 << d) - 1] = odata[index * step + step - 1]; + odata[index * step + step - 1] += t; + } + + void scan(int n, int *odata, const int *idata) { + int upLimit = ilog2ceil(n); + int len = 1 << upLimit; + + int blockSize = 1024; + dim3 threadsPerBlock(blockSize); + dim3 blocksPerGrid((len + blockSize - 1) / blockSize); + + int *dev_data; + cudaMalloc((void**)&dev_data, len * sizeof(int)); + checkCUDAError("cudaMalloc dev_data failed!"); + cudaMemcpy(dev_data, idata, len * sizeof(int), cudaMemcpyHostToDevice); + checkCUDAError("cudaMemcpy dev_data failed!"); + + + + timer().startGpuTimer(); + for (int d = 0; d <= upLimit - 1; d++) { + int step = 1 << (d + 1); + int tempLen = len / step; + blocksPerGrid = dim3((tempLen + blockSize) / blockSize); + kernOptUpSweep << > > (tempLen, d, dev_data); + checkCUDAError("kernUpSweep failed!"); + } + + cudaMemset(&dev_data[len - 1], 0, sizeof(int)); + checkCUDAError("cudaMemcpy set last one to be zero failed!"); + + for (int d = upLimit - 1; d >= 0; d--) { + int step = 1 << (d + 1); + int tempLen = len / step; + blocksPerGrid = dim3((tempLen + blockSize) / blockSize); + kernOptDownSweep << > > (tempLen, d, dev_data); + checkCUDAError("kernDownSweep failed!"); + } + timer().endGpuTimer(); + cudaMemcpy(odata, dev_data, len * sizeof(int), cudaMemcpyDeviceToHost); + checkCUDAError("cudaMemcpy dev_data failed!"); + cudaFree(dev_data); + cudaDeviceSynchronize(); + } + + void gpuScan(int n, int *data) { + int upLimit = ilog2ceil(n); + int blockSize = 1024; + dim3 threadsPerBlock(blockSize); + dim3 blocksPerGrid((n + blockSize - 1) / blockSize); + + for (int d = 0; d <= upLimit - 1; d++) { + kernUpSweep << > > (n, d, data); + checkCUDAError("kernUpSweep failed!"); + } + + cudaMemset(&data[n - 1], 0, sizeof(int)); + checkCUDAError("cudaMemcpy set last one to be zero failed!"); + + for (int d = upLimit - 1; d >= 0; d--) { + kernDownSweep << > > (n, d, data); + checkCUDAError("kernDownSweep failed!"); + } + } + + __global__ void kernSharedMemoryScan(int n, int *odata, const int *idata, int *blockData) { + int index = (blockDim.x * blockIdx.x) + threadIdx.x; + if (index >= n) { + return; + } + extern __shared__ int temp[]; + int thid = threadIdx.x; + int offset = 1; + int stride = blockIdx.x * blockDim.x; + int ai = thid; + int bi = thid + (n / 2); + int bankOffsetA = CONFLICT_FREE_OFFSET(ai); + int bankOffsetB = CONFLICT_FREE_OFFSET(bi); + temp[ai + bankOffsetA] = idata[ai + stride]; + temp[bi + bankOffsetB] = idata[bi + stride]; + for (int d = n >> 1; d > 0; d >>= 1) + { + __syncthreads(); + if (thid < d) + { + int ai = offset*(2 * thid + 1) - 1; + int bi = offset*(2 * thid + 2) - 1; + ai += CONFLICT_FREE_OFFSET(ai); + bi += CONFLICT_FREE_OFFSET(bi); + temp[bi] += temp[ai]; + } + offset *= 2; + } + if (thid == 0) { + blockData[blockIdx.x] = temp[n - 1 + CONFLICT_FREE_OFFSET(n - 1)]; + temp[n - 1 + CONFLICT_FREE_OFFSET(n - 1)] = 0; + } + for (int d = 1; d < n; d *= 2) // traverse down tree & build scan + { + offset >>= 1; + __syncthreads(); + if (thid < d) + { + int ai = offset*(2 * thid + 1) - 1; + int bi = offset*(2 * thid + 2) - 1; + ai += CONFLICT_FREE_OFFSET(ai); + bi += CONFLICT_FREE_OFFSET(bi); + float t = temp[ai]; + temp[ai] = temp[bi]; + temp[bi] += t; + } + } + odata[ai + stride] = temp[ai + bankOffsetA]; + odata[bi + stride] = temp[bi + bankOffsetB]; + } + + __global__ void kernAddBlockData(int n, int *odata, const int *idata, const int *blockData) { + int index = (blockDim.x * blockIdx.x) + threadIdx.x; + if (index >= n) { + return; + } + odata[index] = idata[index] + blockData[blockIdx.x]; + } + + void sharedMemoryScan(int n, int *odata, const int *idata) { + int upLimit = ilog2ceil(n); + int len = 1 << upLimit; + + int blockSize = 64; + dim3 threadsPerBlock(blockSize); + dim3 blocksPerGrid((len + blockSize - 1) / blockSize); + int partSize = (len + blockSize - 1) / blockSize; + int *dev_idata; + int *dev_odata; + int *dev_blockData; + int *dev_blockSum; + cudaMalloc((void**)&dev_idata, len * sizeof(int)); + checkCUDAError("cudaMalloc dev_idata failed!"); + cudaMalloc((void**)&dev_odata, len * sizeof(int)); + checkCUDAError("cudaMalloc dev_odata failed!"); + cudaMalloc((void**)&dev_blockData, partSize * sizeof(int)); + checkCUDAError("cudaMalloc dev_blockData failed!"); + cudaMalloc((void**)&dev_blockSum, partSize * sizeof(int)); + checkCUDAError("cudaMalloc dev_blockSum failed!"); + cudaMemcpy(dev_idata, idata, n * sizeof(int), cudaMemcpyHostToDevice); + checkCUDAError("cudaMemcpy dev_idata failed!"); + + + + timer().startGpuTimer(); + kernSharedMemoryScan << > > (len, dev_odata, dev_idata, dev_blockData); + checkCUDAError("kernSharedMemoryScan failed!"); + //thrust::inclusive_scan(dev_blockData, dev_blockData + n, dev_blockSum); + //kernAddBlockData << > > (len, dev_odata, dev_idata, dev_blockSum); + //checkCUDAError("kernSharedMemoryScan failed!"); + timer().endGpuTimer(); + cudaMemcpy(odata, dev_odata, n * sizeof(int), cudaMemcpyDeviceToHost); + checkCUDAError("cudaMemcpy dev_odata failed!"); + + cudaFree(dev_idata); + cudaFree(dev_odata); + cudaFree(dev_blockData); + cudaFree(dev_blockSum); + cudaDeviceSynchronize(); + } + + + /** * Performs stream compaction on idata, storing the result into odata. * All zeroes are discarded. @@ -31,10 +283,52 @@ namespace StreamCompaction { * @returns The number of elements remaining after compaction. */ int compact(int n, int *odata, const int *idata) { + int upLimit = ilog2ceil(n); + int len = 1 << upLimit; + int blockSize = 1024; + dim3 threadsPerBlock(blockSize); + dim3 blocksPerGrid((len + blockSize - 1) / blockSize); + int *dev_odata; + int *dev_idata; + int *dev_bools; + int *dev_indices; + + cudaMalloc((void**)&dev_odata, sizeof(int) * len); + checkCUDAError("cudaMalloc dev_odata failed!"); + cudaMalloc((void**)&dev_idata, sizeof(int) * len); + checkCUDAError("cudaMalloc dev_idata failed!"); + cudaMalloc((void**)&dev_bools, sizeof(int) * len); + checkCUDAError("cudaMalloc dev_bools failed!"); + cudaMalloc((void**)&dev_indices, sizeof(int) * len); + checkCUDAError("cudaMalloc dev_indices failed!"); + cudaMemcpy(dev_idata, idata, sizeof(int) * n, cudaMemcpyHostToDevice); + checkCUDAError("cudaMemcpy dev_idata failed!"); + timer().startGpuTimer(); - // TODO - timer().endGpuTimer(); - return -1; + + StreamCompaction::Common::kernMapToBoolean << > > (len, dev_bools, dev_idata); + checkCUDAError("kernMaptoBoolean failed!"); + + cudaMemcpy(dev_indices, dev_bools, sizeof(int) * len, cudaMemcpyDeviceToDevice); + cudaMemset(&dev_odata[n], 0, sizeof(int) * (len - n)); + gpuScan(len, dev_indices); + + StreamCompaction::Common::kernScatter << > > (len, dev_odata, dev_idata, dev_bools, dev_indices); + checkCUDAError("kernScatter failed!"); + + timer().endGpuTimer(); + cudaMemcpy(odata, dev_odata, sizeof(int) * len, cudaMemcpyDeviceToHost); + checkCUDAError("cudaMemcpy odata failed!"); + + cudaFree(dev_odata); + cudaFree(dev_idata); + cudaFree(dev_indices); + cudaFree(dev_bools); + int count = 0; + while (odata[count] != 0) { + count++; + } + return count; } } } diff --git a/stream_compaction/efficient.h b/stream_compaction/efficient.h index 803cb4f..1ac55d8 100644 --- a/stream_compaction/efficient.h +++ b/stream_compaction/efficient.h @@ -8,6 +8,8 @@ namespace StreamCompaction { void scan(int n, int *odata, const int *idata); + void sharedMemoryScan(int n, int *odata, const int *idata); + int compact(int n, int *odata, const int *idata); } } diff --git a/stream_compaction/naive.cu b/stream_compaction/naive.cu index 9218f8e..181baeb 100644 --- a/stream_compaction/naive.cu +++ b/stream_compaction/naive.cu @@ -5,21 +5,75 @@ namespace StreamCompaction { namespace Naive { - using StreamCompaction::Common::PerformanceTimer; - PerformanceTimer& timer() - { - static PerformanceTimer timer; - return timer; + using StreamCompaction::Common::PerformanceTimer; + PerformanceTimer& timer() + { + static PerformanceTimer timer; + return timer; } // TODO: __global__ + __global__ void kernNaiveScan(int n, int bound, int *odata, const int *idata) { + int index = (blockDim.x * blockIdx.x) + threadIdx.x; + if (index >= n) { + return; + } + if (index >= bound) { + odata[index] = idata[index - bound] + idata[index]; + } + else { + odata[index] = idata[index]; + } + } + + __global__ void kernInclusiveToExclusive(int n, int *odata, const int *idata) { + int index = (blockDim.x * blockIdx.x) + threadIdx.x; + if (index >= n) { + return; + } + odata[index] = index == 0 ? 0 : idata[index - 1]; + } + /** * Performs prefix-sum (aka scan) on idata, storing the result into odata. */ void scan(int n, int *odata, const int *idata) { - timer().startGpuTimer(); // TODO - timer().endGpuTimer(); + int blockSize = 1024; + dim3 threadsPerBlock(blockSize); + dim3 blocksPerGrid((n + blockSize - 1) / blockSize); + + // allocate memory + int *dev_idata; + int *dev_odata; + + cudaMalloc((void**)&dev_idata, n * sizeof(int)); + checkCUDAError("cudaMalloc dev_idata failed!"); + cudaMalloc((void**)&dev_odata, n * sizeof(int)); + checkCUDAError("cudaMalloc dev_odata failed!"); + + + cudaMemcpy(dev_idata, idata, n * sizeof(int), cudaMemcpyHostToDevice); + checkCUDAError("cudaMemcpy dev_idata failed!"); + timer().startGpuTimer(); + int depth = ilog2ceil(n); + int bound = 1; + for (int d = 1; d <= depth; d++) { + kernNaiveScan << > > (n, bound, dev_odata, dev_idata); + checkCUDAError("kernNaiveScan failed!"); + std::swap(dev_odata, dev_idata); + bound *= 2; + } + kernInclusiveToExclusive << > > (n, dev_odata, dev_idata); + checkCUDAError("kernInclusiveToExclusive failed!"); + cudaMemcpy(odata, dev_odata, n * sizeof(int), cudaMemcpyDeviceToHost); + checkCUDAError("cudaMemcpy dev_odata failed!"); + timer().endGpuTimer(); + cudaFree(dev_odata); + cudaFree(dev_idata); + cudaDeviceSynchronize(); + } + } } diff --git a/stream_compaction/thrust.cu b/stream_compaction/thrust.cu index 36b732d..468b85f 100644 --- a/stream_compaction/thrust.cu +++ b/stream_compaction/thrust.cu @@ -8,21 +8,29 @@ namespace StreamCompaction { namespace Thrust { - using StreamCompaction::Common::PerformanceTimer; - PerformanceTimer& timer() - { - static PerformanceTimer timer; - return timer; + using StreamCompaction::Common::PerformanceTimer; + PerformanceTimer& timer() + { + static PerformanceTimer timer; + return timer; } /** * Performs prefix-sum (aka scan) on idata, storing the result into odata. */ void scan(int n, int *odata, const int *idata) { - timer().startGpuTimer(); + // TODO use `thrust::exclusive_scan` // example: for device_vectors dv_in and dv_out: // thrust::exclusive_scan(dv_in.begin(), dv_in.end(), dv_out.begin()); + + thrust::device_vector dev_idata(idata, idata + n); + thrust::device_vector dev_odata(odata, odata + n); + + timer().startGpuTimer(); + //thrust::exclusive_scan(idata, idata + n, odata); + thrust::exclusive_scan(dev_idata.begin(), dev_idata.end(), dev_odata.begin()); timer().endGpuTimer(); + thrust::copy(dev_odata.begin(), dev_odata.end(), odata); } } }