diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 9bfea4e5..f50d2380 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -39,6 +39,9 @@ const int MAX_CUDA_GPUS = 128; const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; std::vector g_cached_hw_device_ctxs[MAX_CUDA_GPUS]; std::mutex g_cached_hw_device_mutexes[MAX_CUDA_GPUS]; +// NPP stream context cache, with up to MAX_CUDA_GPUS contexts per GPU. +std::map> g_cached_npp_stream_ctxs; +std::mutex g_cached_npp_stream_mutexes[MAX_CUDA_GPUS]; torch::DeviceIndex getFFMPEGCompatibleDeviceIndex(const torch::Device& device) { torch::DeviceIndex deviceIndex = device.index(); @@ -162,7 +165,7 @@ AVBufferRef* getCudaContext(const torch::Device& device) { #endif } -NppStreamContext createNppStreamContext(int deviceIndex) { +NppStreamContext* createNppStreamContext(int deviceIndex) { // From 12.9, NPP recommends using a user-created NppStreamContext and using // the `_Ctx()` calls: // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1 @@ -172,6 +175,7 @@ NppStreamContext createNppStreamContext(int deviceIndex) { // https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72 NppStreamContext nppCtx{}; + NppStreamContext* nppCtxPtr = &nppCtx; cudaDeviceProp prop{}; cudaError_t err = cudaGetDeviceProperties(&prop, deviceIndex); TORCH_CHECK( @@ -187,16 +191,44 @@ NppStreamContext createNppStreamContext(int deviceIndex) { nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major; nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor; - // TODO when implementing the cache logic, move these out. See other TODO - // below. - nppCtx.hStream = at::cuda::getCurrentCUDAStream(deviceIndex).stream(); - err = cudaStreamGetFlags(nppCtx.hStream, &nppCtx.nStreamFlags); - TORCH_CHECK( - err == cudaSuccess, - "cudaStreamGetFlags failed: ", - cudaGetErrorString(err)); + return nppCtxPtr; +} + +NppStreamContext* getNppStreamContextFromCache(const torch::Device& device) { + torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device); + std::scoped_lock lock(g_cached_npp_stream_mutexes[deviceIndex]); + if (g_cached_npp_stream_ctxs[deviceIndex].size() > 0) { + NppStreamContext* npp_stream_ctx = + g_cached_npp_stream_ctxs[deviceIndex].back(); + g_cached_npp_stream_ctxs[deviceIndex].pop_back(); + return npp_stream_ctx; + } else { + return nullptr; + } +} - return nppCtx; +void addNppStreamContextToCache( + const torch::Device& device, + NppStreamContext* ctx) { + torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device); + std::scoped_lock lock(g_cached_npp_stream_mutexes[deviceIndex]); + // Add to cache if cache has capacity + if (g_cached_npp_stream_ctxs[deviceIndex].size() < MAX_CUDA_GPUS) { + g_cached_npp_stream_ctxs[deviceIndex].push_back(ctx); + } +} + +NppStreamContext* getNppStreamContext(const torch::Device& device) { + // Return the cached NppStreamContext if it exists. + NppStreamContext* cached_npp_ctx_ptr = getNppStreamContextFromCache(device); + if (cached_npp_ctx_ptr != nullptr) { + return cached_npp_ctx_ptr; + } + // Create new NppStreamContext, and cache it + NppStreamContext* npp_ctx_ptr = createNppStreamContext( + static_cast(getFFMPEGCompatibleDeviceIndex(device))); + addNppStreamContextToCache(device, npp_ctx_ptr); + return npp_ctx_ptr; } } // namespace @@ -303,13 +335,15 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( dst = allocateEmptyHWCTensor(height, width, device_); } - // TODO cache the NppStreamContext! It currently gets re-recated for every - // single frame. The cache should be per-device, similar to the existing - // hw_device_ctx cache. When implementing the cache logic, the - // NppStreamContext hStream and nStreamFlags should not be part of the cache - // because they may change across calls. - NppStreamContext nppCtx = createNppStreamContext( - static_cast(getFFMPEGCompatibleDeviceIndex(device_))); + NppStreamContext nppCtx = *getNppStreamContext(device_); + + torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device_); + nppCtx.hStream = at::cuda::getCurrentCUDAStream(deviceIndex).stream(); + cudaError_t err = cudaStreamGetFlags(nppCtx.hStream, &nppCtx.nStreamFlags); + TORCH_CHECK( + err == cudaSuccess, + "cudaStreamGetFlags failed: ", + cudaGetErrorString(err)); NppiSize oSizeROI = {width, height}; Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};