Skip to content

[WIP] native nvcuvid backend #798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions benchmark_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env python3

import torch
from time import perf_counter_ns
import argparse
from pathlib import Path
from torchcodec.decoders import VideoDecoder
from joblib import Parallel, delayed
import os
from contextlib import contextmanager
import torchvision.io


@contextmanager
def with_cache(enabled=True):
"""Context manager to enable/disable NVDEC decoder cache."""
original_env_value = os.environ.get("TORCHCODEC_DISABLE_NVDEC_CACHE")
try:
if not enabled:
os.environ["TORCHCODEC_DISABLE_NVDEC_CACHE"] = "1"
elif "TORCHCODEC_DISABLE_NVDEC_CACHE" in os.environ:
del os.environ["TORCHCODEC_DISABLE_NVDEC_CACHE"]
yield
finally:
# Restore original environment variable state
if original_env_value is not None:
os.environ["TORCHCODEC_DISABLE_NVDEC_CACHE"] = original_env_value
elif "TORCHCODEC_DISABLE_NVDEC_CACHE" in os.environ:
del os.environ["TORCHCODEC_DISABLE_NVDEC_CACHE"]


def bench(f, *args, num_exp=100, warmup=0, **kwargs):
for _ in range(warmup):
f(*args, **kwargs)

times = []
for _ in range(num_exp):
start = perf_counter_ns()
f(*args, **kwargs)
end = perf_counter_ns()
times.append(end - start)
return torch.tensor(times).float()


def report_stats(times, unit="ms"):
mul = {
"ns": 1,
"µs": 1e-3,
"ms": 1e-6,
"s": 1e-9,
}[unit]
times = times * mul
std = times.std().item()
med = times.median().item()
print(f"{med = :.2f}{unit} +- {std:.2f}")
return med

# TODO call sync

def decode_videos_threaded(num_threads, decoder_implem):
assert decoder_implem in ["ffmpeg", "ours"], "Invalid decoder implementation"
device_variant = None if decoder_implem == "ffmpeg" else "custom_nvdec"
num_frames_to_decode = 10

def decode_one_video(video_path):
device = torch.device("cuda:0")
decoder = VideoDecoder(str(video_path), device=device, device_variant=device_variant, seek_mode="approximate")
indices = torch.linspace(0, len(decoder)-10, num_frames_to_decode, dtype=torch.int).tolist()
frames = decoder.get_frames_at(indices)
return frames.data.cpu() # Move to CPU for PNG saving

# Always collect and return all decoded frames
results = Parallel(n_jobs=num_threads, prefer="threads")(
delayed(decode_one_video)(video_path) for video_path in video_files
)
torch.cuda.synchronize()
return results


def validate_decode_correctness(video_path, num_threads=1):
"""Save decoded frames from different implementations for visual comparison."""
# Create results directory
results_dir = Path("results")
results_dir.mkdir(exist_ok=True)

# Test single video with different implementations
global video_files
original_files = video_files
video_files = [Path(video_path)] # Override for single video test

try:
# Get frames from each implementation (results is a list from joblib)
frames_ffmpeg = decode_videos_threaded(num_threads, "ffmpeg")[0] # First (and only) video

with with_cache(enabled=True):
frames_ours_cached = decode_videos_threaded(num_threads, "ours")[0]

with with_cache(enabled=False):
frames_ours_nocache = decode_videos_threaded(num_threads, "ours")[0]

# Frames are already uint8, no conversion needed
print(f"Frame shapes: ffmpeg={frames_ffmpeg.shape}, cached={frames_ours_cached.shape}, nocache={frames_ours_nocache.shape}")

# Save concatenated frames for visual comparison
num_frames = frames_ffmpeg.shape[0]
for i in range(min(5, num_frames)): # Save first 5 frames
# Frames are already [N, C, H, W], so just select frame i
frame_ffmpeg = frames_ffmpeg[i] # Shape: [C, H, W]
frame_cached = frames_ours_cached[i]
frame_nocache = frames_ours_nocache[i]

# Concatenate along width dimension (dim=2)
concat_frame = torch.cat([frame_ffmpeg, frame_cached, frame_nocache], dim=2)

output_path = results_dir / f"frame_{i:02d}_comparison.png"
torchvision.io.write_png(concat_frame, str(output_path))

finally:
video_files = original_files # Restore original file list


parser = argparse.ArgumentParser()
parser.add_argument("video_folder", help="Folder containing .h264 files")
parser.add_argument("--num-threads", type=int, help="Number of threads")
args = parser.parse_args()

video_files = list(Path(args.video_folder).glob("*.mp4"))
print(f"Decoder a few frames from {len(video_files)} video files in {args.video_folder} with {args.num_threads} threads")

# validate_decode_correctness(video_files[0], num_threads=args.num_threads)

print("=== Benchmarking FFmpeg backend ===")
times = bench(decode_videos_threaded, args.num_threads, decoder_implem="ffmpeg", warmup=0, num_exp=10)
report_stats(times)

print("\n=== Benchmarking our backend WITH cache ===")
with with_cache(enabled=True):
times = bench(decode_videos_threaded, args.num_threads, decoder_implem="ours", warmup=0, num_exp=10)
report_stats(times)

print("\n=== Benchmarking our backend WITHOUT cache ===")
with with_cache(enabled=False):
times = bench(decode_videos_threaded, args.num_threads, decoder_implem="ours", warmup=0, num_exp=10)
report_stats(times)
75 changes: 75 additions & 0 deletions benchmark_nvdec_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python3
"""
Simple multi-threaded NVDEC decoder cache benchmark.
"""

import argparse
from pathlib import Path
import torch
from time import perf_counter_ns
from torchcodec.decoders import VideoDecoder
from joblib import Parallel, delayed


def bench(f, *args, num_exp=100, warmup=0, **kwargs):
for _ in range(warmup):
f(*args, **kwargs)

times = []
for _ in range(num_exp):
start = perf_counter_ns()
f(*args, **kwargs)
end = perf_counter_ns()
times.append(end - start)
return torch.tensor(times).float()


def report_stats(times, unit="ms"):
mul = {
"ns": 1,
"µs": 1e-3,
"ms": 1e-6,
"s": 1e-9,
}[unit]
times = times * mul
std = times.std().item()
med = times.median().item()
print(f"{med = :.2f}{unit} +- {std:.2f}")
return med


def decode_videos(video_folder, num_threads, num_frames=10):
"""Decode frames from all .h264 files using multiple threads."""
video_files = list(Path(video_folder).glob("*.h264"))

def decode_single_video(video_path):
decoder = VideoDecoder(str(video_path), device=torch.device("cuda:0"), device_variant="custom_nvdec")
for i in range(min(num_frames, len(decoder))):
frame = decoder.get_frame_at(i)
return video_path.name

# Use joblib to run in parallel
Parallel(n_jobs=num_threads, backend="threading")(
delayed(decode_single_video)(video_path) for video_path in video_files
)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("video_folder", help="Folder with .h264 files")
parser.add_argument("--num-threads", type=int, default=4, help="Number of threads")
parser.add_argument("--num-frames", type=int, default=10, help="Frames per video")

args = parser.parse_args()

video_files = list(Path(args.video_folder).glob("*.h264"))
print(f"Found {len(video_files)} .h264 files")
print(f"Using {args.num_threads} threads, {args.num_frames} frames per video")
print("Benchmarking...")

times = bench(decode_videos, args.video_folder, args.num_threads, args.num_frames, warmup=0, num_exp=10)
report_stats(times, unit="ms")


if __name__ == "__main__":
main()
95 changes: 92 additions & 3 deletions src/torchcodec/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ endif()
if (WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4 ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}")
endif()


function(make_torchcodec_sublibrary
library_name
type
Expand Down Expand Up @@ -109,7 +109,7 @@ function(make_torchcodec_libraries
)

if(ENABLE_CUDA)
list(APPEND core_sources CudaDeviceInterface.cpp)
list(APPEND core_sources CudaDeviceInterface.cpp CustomNvdecDeviceInterface.cpp)
endif()

set(core_library_dependencies
Expand All @@ -122,6 +122,86 @@ function(make_torchcodec_libraries
${CUDA_nppi_LIBRARY}
${CUDA_nppicc_LIBRARY}
)

# Find NVIDIA Video Codec SDK library
find_library(CUDA_NVCUVID_LIBRARY
NAMES nvcuvid
PATHS
$ENV{NVCODEC_SDK_ROOT}/Lib/linux/stubs/x86_64
$ENV{NVCODEC_SDK_ROOT}/Lib/linux/x86_64
$ENV{NVCODEC_SDK_ROOT}/lib64
$ENV{NVCODEC_SDK_ROOT}/lib
/home/nicolashug/Downloads/Video_Codec_SDK_12.2.72/Lib/linux/stubs/x86_64
/usr/local/cuda/lib64
/usr/local/cuda/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
/opt/cuda/lib64
/opt/cuda/lib
$ENV{CUDA_PATH}/lib64
$ENV{CUDA_PATH}/lib
$ENV{CUDA_HOME}/lib64
$ENV{CUDA_HOME}/lib
/usr/lib64
/usr/lib
)
set(CUDA_NVCUVID_LIBRARY "/home/nicolashug/Downloads/Video_Codec_SDK_12.2.72/Lib/linux/stubs/x86_64/libnvcuvid.so")

# Find NVIDIA Video Codec SDK headers
find_path(NVCODEC_INCLUDE_DIR
NAMES cuviddec.h nvcuvid.h
PATHS
$ENV{NVCODEC_SDK_ROOT}/Interface
$ENV{NVCODEC_SDK_ROOT}/include
$ENV{HOME}/Downloads/Video_Codec_SDK_12.2.72/Interface
/usr/local/cuda/include
${CUDA_TOOLKIT_ROOT_DIR}/include
$ENV{CUDA_PATH}/include
$ENV{CUDA_HOME}/include
/opt/cuda/include
PATH_SUFFIXES
Interface
Video_Codec_SDK_12.2.72/Interface
Video_Codec_SDK/Interface
)

if(NOT CUDA_NVCUVID_LIBRARY)
message(FATAL_ERROR "Cannot find libnvcuvid, you may need to manually register and download at https://developer.nvidia.com/nvidia-video-codec-sdk. Then copy libnvcuvid to cuda_toolkit_root/lib64/")
endif()

if(NOT NVCODEC_INCLUDE_DIR)
message(FATAL_ERROR "Cannot find NVIDIA Video Codec SDK headers (cuviddec.h, nvcuvid.h). Please download the NVIDIA Video Codec SDK from https://developer.nvidia.com/nvidia-video-codec-sdk and copy the headers to your CUDA include directory.")
endif()

message(STATUS "Found NVIDIA Video Codec SDK library: ${CUDA_NVCUVID_LIBRARY}")
message(STATUS "Found NVIDIA Video Codec SDK headers: ${NVCODEC_INCLUDE_DIR}")

# Add CUDA Driver API library (needed for cuCtxGetCurrent, etc.)
find_library(CUDA_DRIVER_LIBRARY
NAMES cuda
PATHS
/usr/local/cuda/lib64
/usr/local/cuda/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
/opt/cuda/lib64
/opt/cuda/lib
$ENV{CUDA_PATH}/lib64
$ENV{CUDA_PATH}/lib
$ENV{CUDA_HOME}/lib64
$ENV{CUDA_HOME}/lib
/usr/lib64
/usr/lib
)

if(NOT CUDA_DRIVER_LIBRARY)
message(FATAL_ERROR "Cannot find CUDA Driver API library (libcuda.so)")
endif()

message(STATUS "Found CUDA Driver API library: ${CUDA_DRIVER_LIBRARY}")

# Add nvcuvid and cuda driver libraries to dependencies
list(APPEND core_library_dependencies ${CUDA_NVCUVID_LIBRARY} ${CUDA_DRIVER_LIBRARY})
endif()

make_torchcodec_sublibrary(
Expand All @@ -131,6 +211,15 @@ function(make_torchcodec_libraries
"${core_library_dependencies}"
)

# Add NVDEC include directories after target creation
if(ENABLE_CUDA AND NVCODEC_INCLUDE_DIR)
target_include_directories(${core_library_name}
PRIVATE
${NVCODEC_INCLUDE_DIR}
)
endif()


# 2. Create libtorchcodec_custom_opsN.{ext}.
set(custom_ops_library_name "libtorchcodec_custom_ops${ffmpeg_major_version}")
set(custom_ops_sources
Expand Down
Loading
Loading