Skip to content

[wip] Add npp_stream_ctx cache #800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 51 additions & 17 deletions src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ const int MAX_CUDA_GPUS = 128;
const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
std::vector<AVBufferRef*> g_cached_hw_device_ctxs[MAX_CUDA_GPUS];
std::mutex g_cached_hw_device_mutexes[MAX_CUDA_GPUS];
// NPP stream context cache, with up to MAX_CUDA_GPUS contexts per GPU.
std::map<int, std::vector<NppStreamContext*>> g_cached_npp_stream_ctxs;
std::mutex g_cached_npp_stream_mutexes[MAX_CUDA_GPUS];

torch::DeviceIndex getFFMPEGCompatibleDeviceIndex(const torch::Device& device) {
torch::DeviceIndex deviceIndex = device.index();
Expand Down Expand Up @@ -162,7 +165,7 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
#endif
}

NppStreamContext createNppStreamContext(int deviceIndex) {
NppStreamContext* createNppStreamContext(int deviceIndex) {
// From 12.9, NPP recommends using a user-created NppStreamContext and using
// the `_Ctx()` calls:
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1
Expand All @@ -172,6 +175,7 @@ NppStreamContext createNppStreamContext(int deviceIndex) {
// https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72

NppStreamContext nppCtx{};
NppStreamContext* nppCtxPtr = &nppCtx;
cudaDeviceProp prop{};
cudaError_t err = cudaGetDeviceProperties(&prop, deviceIndex);
TORCH_CHECK(
Expand All @@ -187,16 +191,44 @@ NppStreamContext createNppStreamContext(int deviceIndex) {
nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major;
nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor;

// TODO when implementing the cache logic, move these out. See other TODO
// below.
nppCtx.hStream = at::cuda::getCurrentCUDAStream(deviceIndex).stream();
err = cudaStreamGetFlags(nppCtx.hStream, &nppCtx.nStreamFlags);
TORCH_CHECK(
err == cudaSuccess,
"cudaStreamGetFlags failed: ",
cudaGetErrorString(err));
return nppCtxPtr;
Copy link
Member

@NicolasHug NicolasHug Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I think this is the problem: we're returning a pointer to nppCtx, which was allocated on the stack, not on the heap. So the pointer's value (and what it points to) is only valid within the createNppStreamContext function, not after it returns.

}

NppStreamContext* getNppStreamContextFromCache(const torch::Device& device) {
torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device);
std::scoped_lock lock(g_cached_npp_stream_mutexes[deviceIndex]);
if (g_cached_npp_stream_ctxs[deviceIndex].size() > 0) {
NppStreamContext* npp_stream_ctx =
g_cached_npp_stream_ctxs[deviceIndex].back();
g_cached_npp_stream_ctxs[deviceIndex].pop_back();
return npp_stream_ctx;
} else {
return nullptr;
}
}

return nppCtx;
void addNppStreamContextToCache(
const torch::Device& device,
NppStreamContext* ctx) {
torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device);
std::scoped_lock lock(g_cached_npp_stream_mutexes[deviceIndex]);
// Add to cache if cache has capacity
if (g_cached_npp_stream_ctxs[deviceIndex].size() < MAX_CUDA_GPUS) {
g_cached_npp_stream_ctxs[deviceIndex].push_back(ctx);
}
}

NppStreamContext* getNppStreamContext(const torch::Device& device) {
// Return the cached NppStreamContext if it exists.
NppStreamContext* cached_npp_ctx_ptr = getNppStreamContextFromCache(device);
if (cached_npp_ctx_ptr != nullptr) {
return cached_npp_ctx_ptr;
}
// Create new NppStreamContext, and cache it
NppStreamContext* npp_ctx_ptr = createNppStreamContext(
static_cast<int>(getFFMPEGCompatibleDeviceIndex(device)));
addNppStreamContextToCache(device, npp_ctx_ptr);
return npp_ctx_ptr;
}

} // namespace
Expand Down Expand Up @@ -303,13 +335,15 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
dst = allocateEmptyHWCTensor(height, width, device_);
}

// TODO cache the NppStreamContext! It currently gets re-recated for every
// single frame. The cache should be per-device, similar to the existing
// hw_device_ctx cache. When implementing the cache logic, the
// NppStreamContext hStream and nStreamFlags should not be part of the cache
// because they may change across calls.
NppStreamContext nppCtx = createNppStreamContext(
static_cast<int>(getFFMPEGCompatibleDeviceIndex(device_)));
NppStreamContext nppCtx = *getNppStreamContext(device_);

torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device_);
nppCtx.hStream = at::cuda::getCurrentCUDAStream(deviceIndex).stream();
cudaError_t err = cudaStreamGetFlags(nppCtx.hStream, &nppCtx.nStreamFlags);
TORCH_CHECK(
err == cudaSuccess,
"cudaStreamGetFlags failed: ",
cudaGetErrorString(err));

NppiSize oSizeROI = {width, height};
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};
Expand Down
Loading