diff --git a/benchmark_cache.py b/benchmark_cache.py new file mode 100644 index 00000000..15310090 --- /dev/null +++ b/benchmark_cache.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 + +import torch +from time import perf_counter_ns +import argparse +from pathlib import Path +from torchcodec.decoders import VideoDecoder +from joblib import Parallel, delayed +import os +from contextlib import contextmanager +import torchvision.io + + +@contextmanager +def with_cache(enabled=True): + """Context manager to enable/disable NVDEC decoder cache.""" + original_env_value = os.environ.get("TORCHCODEC_DISABLE_NVDEC_CACHE") + try: + if not enabled: + os.environ["TORCHCODEC_DISABLE_NVDEC_CACHE"] = "1" + elif "TORCHCODEC_DISABLE_NVDEC_CACHE" in os.environ: + del os.environ["TORCHCODEC_DISABLE_NVDEC_CACHE"] + yield + finally: + # Restore original environment variable state + if original_env_value is not None: + os.environ["TORCHCODEC_DISABLE_NVDEC_CACHE"] = original_env_value + elif "TORCHCODEC_DISABLE_NVDEC_CACHE" in os.environ: + del os.environ["TORCHCODEC_DISABLE_NVDEC_CACHE"] + + +def bench(f, *args, num_exp=100, warmup=0, **kwargs): + for _ in range(warmup): + f(*args, **kwargs) + + times = [] + for _ in range(num_exp): + start = perf_counter_ns() + f(*args, **kwargs) + end = perf_counter_ns() + times.append(end - start) + return torch.tensor(times).float() + + +def report_stats(times, unit="ms"): + mul = { + "ns": 1, + "µs": 1e-3, + "ms": 1e-6, + "s": 1e-9, + }[unit] + times = times * mul + std = times.std().item() + med = times.median().item() + print(f"{med = :.2f}{unit} +- {std:.2f}") + return med + +# TODO call sync + +def decode_videos_threaded(num_threads, decoder_implem): + assert decoder_implem in ["ffmpeg", "ours"], "Invalid decoder implementation" + device_variant = None if decoder_implem == "ffmpeg" else "custom_nvdec" + num_frames_to_decode = 10 + + def decode_one_video(video_path): + device = torch.device("cuda:0") + decoder = VideoDecoder(str(video_path), device=device, device_variant=device_variant, seek_mode="approximate") + indices = torch.linspace(0, len(decoder)-10, num_frames_to_decode, dtype=torch.int).tolist() + frames = decoder.get_frames_at(indices) + return frames.data.cpu() # Move to CPU for PNG saving + + # Always collect and return all decoded frames + results = Parallel(n_jobs=num_threads, prefer="threads")( + delayed(decode_one_video)(video_path) for video_path in video_files + ) + torch.cuda.synchronize() + return results + + +def validate_decode_correctness(video_path, num_threads=1): + """Save decoded frames from different implementations for visual comparison.""" + # Create results directory + results_dir = Path("results") + results_dir.mkdir(exist_ok=True) + + # Test single video with different implementations + global video_files + original_files = video_files + video_files = [Path(video_path)] # Override for single video test + + try: + # Get frames from each implementation (results is a list from joblib) + frames_ffmpeg = decode_videos_threaded(num_threads, "ffmpeg")[0] # First (and only) video + + with with_cache(enabled=True): + frames_ours_cached = decode_videos_threaded(num_threads, "ours")[0] + + with with_cache(enabled=False): + frames_ours_nocache = decode_videos_threaded(num_threads, "ours")[0] + + # Frames are already uint8, no conversion needed + print(f"Frame shapes: ffmpeg={frames_ffmpeg.shape}, cached={frames_ours_cached.shape}, nocache={frames_ours_nocache.shape}") + + # Save concatenated frames for visual comparison + num_frames = frames_ffmpeg.shape[0] + for i in range(min(5, num_frames)): # Save first 5 frames + # Frames are already [N, C, H, W], so just select frame i + frame_ffmpeg = frames_ffmpeg[i] # Shape: [C, H, W] + frame_cached = frames_ours_cached[i] + frame_nocache = frames_ours_nocache[i] + + # Concatenate along width dimension (dim=2) + concat_frame = torch.cat([frame_ffmpeg, frame_cached, frame_nocache], dim=2) + + output_path = results_dir / f"frame_{i:02d}_comparison.png" + torchvision.io.write_png(concat_frame, str(output_path)) + + finally: + video_files = original_files # Restore original file list + + +parser = argparse.ArgumentParser() +parser.add_argument("video_folder", help="Folder containing .h264 files") +parser.add_argument("--num-threads", type=int, help="Number of threads") +args = parser.parse_args() + +video_files = list(Path(args.video_folder).glob("*.mp4")) +print(f"Decoder a few frames from {len(video_files)} video files in {args.video_folder} with {args.num_threads} threads") + +# validate_decode_correctness(video_files[0], num_threads=args.num_threads) + +print("=== Benchmarking FFmpeg backend ===") +times = bench(decode_videos_threaded, args.num_threads, decoder_implem="ffmpeg", warmup=0, num_exp=10) +report_stats(times) + +print("\n=== Benchmarking our backend WITH cache ===") +with with_cache(enabled=True): + times = bench(decode_videos_threaded, args.num_threads, decoder_implem="ours", warmup=0, num_exp=10) + report_stats(times) + +print("\n=== Benchmarking our backend WITHOUT cache ===") +with with_cache(enabled=False): + times = bench(decode_videos_threaded, args.num_threads, decoder_implem="ours", warmup=0, num_exp=10) + report_stats(times) \ No newline at end of file diff --git a/benchmark_nvdec_simple.py b/benchmark_nvdec_simple.py new file mode 100644 index 00000000..bf24378b --- /dev/null +++ b/benchmark_nvdec_simple.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +""" +Simple multi-threaded NVDEC decoder cache benchmark. +""" + +import argparse +from pathlib import Path +import torch +from time import perf_counter_ns +from torchcodec.decoders import VideoDecoder +from joblib import Parallel, delayed + + +def bench(f, *args, num_exp=100, warmup=0, **kwargs): + for _ in range(warmup): + f(*args, **kwargs) + + times = [] + for _ in range(num_exp): + start = perf_counter_ns() + f(*args, **kwargs) + end = perf_counter_ns() + times.append(end - start) + return torch.tensor(times).float() + + +def report_stats(times, unit="ms"): + mul = { + "ns": 1, + "µs": 1e-3, + "ms": 1e-6, + "s": 1e-9, + }[unit] + times = times * mul + std = times.std().item() + med = times.median().item() + print(f"{med = :.2f}{unit} +- {std:.2f}") + return med + + +def decode_videos(video_folder, num_threads, num_frames=10): + """Decode frames from all .h264 files using multiple threads.""" + video_files = list(Path(video_folder).glob("*.h264")) + + def decode_single_video(video_path): + decoder = VideoDecoder(str(video_path), device=torch.device("cuda:0"), device_variant="custom_nvdec") + for i in range(min(num_frames, len(decoder))): + frame = decoder.get_frame_at(i) + return video_path.name + + # Use joblib to run in parallel + Parallel(n_jobs=num_threads, backend="threading")( + delayed(decode_single_video)(video_path) for video_path in video_files + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("video_folder", help="Folder with .h264 files") + parser.add_argument("--num-threads", type=int, default=4, help="Number of threads") + parser.add_argument("--num-frames", type=int, default=10, help="Frames per video") + + args = parser.parse_args() + + video_files = list(Path(args.video_folder).glob("*.h264")) + print(f"Found {len(video_files)} .h264 files") + print(f"Using {args.num_threads} threads, {args.num_frames} frames per video") + print("Benchmarking...") + + times = bench(decode_videos, args.video_folder, args.num_threads, args.num_frames, warmup=0, num_exp=10) + report_stats(times, unit="ms") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 1e6d2ec8..c37a22ca 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -31,10 +31,10 @@ endif() if (WIN32) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4 ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}") else() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}") + # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}") endif() - function(make_torchcodec_sublibrary library_name type @@ -109,7 +109,7 @@ function(make_torchcodec_libraries ) if(ENABLE_CUDA) - list(APPEND core_sources CudaDeviceInterface.cpp) + list(APPEND core_sources CudaDeviceInterface.cpp CustomNvdecDeviceInterface.cpp) endif() set(core_library_dependencies @@ -122,6 +122,86 @@ function(make_torchcodec_libraries ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} ) + + # Find NVIDIA Video Codec SDK library + find_library(CUDA_NVCUVID_LIBRARY + NAMES nvcuvid + PATHS + $ENV{NVCODEC_SDK_ROOT}/Lib/linux/stubs/x86_64 + $ENV{NVCODEC_SDK_ROOT}/Lib/linux/x86_64 + $ENV{NVCODEC_SDK_ROOT}/lib64 + $ENV{NVCODEC_SDK_ROOT}/lib + /home/nicolashug/Downloads/Video_Codec_SDK_12.2.72/Lib/linux/stubs/x86_64 + /usr/local/cuda/lib64 + /usr/local/cuda/lib + ${CUDA_TOOLKIT_ROOT_DIR}/lib64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib + /opt/cuda/lib64 + /opt/cuda/lib + $ENV{CUDA_PATH}/lib64 + $ENV{CUDA_PATH}/lib + $ENV{CUDA_HOME}/lib64 + $ENV{CUDA_HOME}/lib + /usr/lib64 + /usr/lib + ) + set(CUDA_NVCUVID_LIBRARY "/home/nicolashug/Downloads/Video_Codec_SDK_12.2.72/Lib/linux/stubs/x86_64/libnvcuvid.so") + + # Find NVIDIA Video Codec SDK headers + find_path(NVCODEC_INCLUDE_DIR + NAMES cuviddec.h nvcuvid.h + PATHS + $ENV{NVCODEC_SDK_ROOT}/Interface + $ENV{NVCODEC_SDK_ROOT}/include + $ENV{HOME}/Downloads/Video_Codec_SDK_12.2.72/Interface + /usr/local/cuda/include + ${CUDA_TOOLKIT_ROOT_DIR}/include + $ENV{CUDA_PATH}/include + $ENV{CUDA_HOME}/include + /opt/cuda/include + PATH_SUFFIXES + Interface + Video_Codec_SDK_12.2.72/Interface + Video_Codec_SDK/Interface + ) + + if(NOT CUDA_NVCUVID_LIBRARY) + message(FATAL_ERROR "Cannot find libnvcuvid, you may need to manually register and download at https://developer.nvidia.com/nvidia-video-codec-sdk. Then copy libnvcuvid to cuda_toolkit_root/lib64/") + endif() + + if(NOT NVCODEC_INCLUDE_DIR) + message(FATAL_ERROR "Cannot find NVIDIA Video Codec SDK headers (cuviddec.h, nvcuvid.h). Please download the NVIDIA Video Codec SDK from https://developer.nvidia.com/nvidia-video-codec-sdk and copy the headers to your CUDA include directory.") + endif() + + message(STATUS "Found NVIDIA Video Codec SDK library: ${CUDA_NVCUVID_LIBRARY}") + message(STATUS "Found NVIDIA Video Codec SDK headers: ${NVCODEC_INCLUDE_DIR}") + + # Add CUDA Driver API library (needed for cuCtxGetCurrent, etc.) + find_library(CUDA_DRIVER_LIBRARY + NAMES cuda + PATHS + /usr/local/cuda/lib64 + /usr/local/cuda/lib + ${CUDA_TOOLKIT_ROOT_DIR}/lib64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib + /opt/cuda/lib64 + /opt/cuda/lib + $ENV{CUDA_PATH}/lib64 + $ENV{CUDA_PATH}/lib + $ENV{CUDA_HOME}/lib64 + $ENV{CUDA_HOME}/lib + /usr/lib64 + /usr/lib + ) + + if(NOT CUDA_DRIVER_LIBRARY) + message(FATAL_ERROR "Cannot find CUDA Driver API library (libcuda.so)") + endif() + + message(STATUS "Found CUDA Driver API library: ${CUDA_DRIVER_LIBRARY}") + + # Add nvcuvid and cuda driver libraries to dependencies + list(APPEND core_library_dependencies ${CUDA_NVCUVID_LIBRARY} ${CUDA_DRIVER_LIBRARY}) endif() make_torchcodec_sublibrary( @@ -131,6 +211,15 @@ function(make_torchcodec_libraries "${core_library_dependencies}" ) + # Add NVDEC include directories after target creation + if(ENABLE_CUDA AND NVCODEC_INCLUDE_DIR) + target_include_directories(${core_library_name} + PRIVATE + ${NVCODEC_INCLUDE_DIR} + ) + endif() + + # 2. Create libtorchcodec_custom_opsN.{ext}. set(custom_ops_library_name "libtorchcodec_custom_ops${ffmpeg_major_version}") set(custom_ops_sources diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 9bfea4e5..cfe94491 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -15,7 +15,7 @@ extern "C" { namespace facebook::torchcodec { namespace { -static bool g_cuda = +static bool g_cuda_default = registerDeviceInterface(torch::kCUDA, [](const torch::Device& device) { return new CudaDeviceInterface(device); }); @@ -203,7 +203,7 @@ NppStreamContext createNppStreamContext(int deviceIndex) { CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device) : DeviceInterface(device) { - TORCH_CHECK(g_cuda, "CudaDeviceInterface was not registered!"); + TORCH_CHECK(g_cuda_default, "CudaDeviceInterface was not registered!"); TORCH_CHECK( device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); } @@ -258,29 +258,29 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( return; } - // Above we checked that the AVFrame was on GPU, but that's not enough, we - // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits), - // because this is what the NPP color conversion routines expect. - // TODO: we should investigate how to can perform color conversion for - // non-8bit videos. This is supported on CPU. - TORCH_CHECK( - avFrame->hw_frames_ctx != nullptr, - "The AVFrame does not have a hw_frames_ctx. " - "That's unexpected, please report this to the TorchCodec repo."); - - auto hwFramesCtx = - reinterpret_cast(avFrame->hw_frames_ctx->data); - AVPixelFormat actualFormat = hwFramesCtx->sw_format; - TORCH_CHECK( - actualFormat == AV_PIX_FMT_NV12, - "The AVFrame is ", - (av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat) - : "unknown"), - ", but we expected AV_PIX_FMT_NV12. This typically happens when " - "the video isn't 8bit, which is not supported on CUDA at the moment. " - "Try using the CPU device instead. " - "If the video is 10bit, we are tracking 10bit support in " - "https://github.com/pytorch/torchcodec/issues/776"); + // // Above we checked that the AVFrame was on GPU, but that's not enough, we + // // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits), + // // because this is what the NPP color conversion routines expect. + // // TODO: we should investigate how to can perform color conversion for + // // non-8bit videos. This is supported on CPU. + // TORCH_CHECK( + // avFrame->hw_frames_ctx != nullptr, + // "The AVFrame does not have a hw_frames_ctx. " + // "That's unexpected, please report this to the TorchCodec repo."); + + // auto hwFramesCtx = + // reinterpret_cast(avFrame->hw_frames_ctx->data); + // AVPixelFormat actualFormat = hwFramesCtx->sw_format; + // TORCH_CHECK( + // actualFormat == AV_PIX_FMT_NV12, + // "The AVFrame is ", + // (av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat) + // : "unknown"), + // ", but we expected AV_PIX_FMT_NV12. This typically happens when " + // "the video isn't 8bit, which is not supported on CUDA at the moment. " + // "Try using the CPU device instead. " + // "If the video is 10bit, we are tracking 10bit support in " + // "https://github.com/pytorch/torchcodec/issues/776"); auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); diff --git a/src/torchcodec/_core/CustomNvdecDeviceInterface.cpp b/src/torchcodec/_core/CustomNvdecDeviceInterface.cpp new file mode 100644 index 00000000..2d8ff963 --- /dev/null +++ b/src/torchcodec/_core/CustomNvdecDeviceInterface.cpp @@ -0,0 +1,553 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include + +#include "src/torchcodec/_core/CustomNvdecDeviceInterface.h" +#include "src/torchcodec/_core/DeviceInterface.h" +#include "src/torchcodec/_core/FFMPEGCommon.h" + +// Include NVIDIA Video Codec SDK headers +#include +#include + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { + +namespace { + +// Check if caching is disabled via environment variable +bool isCacheDisabled() { + const char* envVar = std::getenv("TORCHCODEC_DISABLE_NVDEC_CACHE"); + return envVar && std::string(envVar) == "1"; +} + +// Simple decoder cache for reusing CUvideodecoder objects +struct CachedDecoder { + CUvideodecoder decoder = nullptr; + CUcontext context = nullptr; + CUVIDDECODECREATEINFO createInfo = {}; + + // For LRU eviction + std::chrono::steady_clock::time_point lastUsed; +}; + +class DecoderCache { + public: + static constexpr size_t MAX_CACHE_SIZE = 10; + + static CUvideodecoder tryGetCachedDecoder( + CUcontext currentContext, + const CUVIDDECODECREATEINFO& requestedCreateInfo) { + std::lock_guard lock(mutex_); + + if (isCacheDisabled()) { + // If caching is disabled, clear any existing cache and return null + for (auto& entry : cached_) { + if (entry.decoder) { + cuvidDestroyDecoder(entry.decoder); + entry.decoder = nullptr; + } + } + cached_.clear(); + return nullptr; + } + + // Find matching decoder + for (auto& entry : cached_) { + // TODO probably need additional or different cache key checks + if (entry.decoder && + entry.context == currentContext && + isCreateInfoCompatible(entry.createInfo, requestedCreateInfo)) { + auto decoder = entry.decoder; + entry.decoder = nullptr; // Remove from cache + entry.lastUsed = std::chrono::steady_clock::now(); + return decoder; + } + } + return nullptr; + } + + static void cacheDecoder( + CUvideodecoder decoder, + CUcontext context, + const CUVIDDECODECREATEINFO& createInfo) { + if (isCacheDisabled()) { + // If caching is disabled, destroy the decoder immediately + cuvidDestroyDecoder(decoder); + return; + } + + std::lock_guard lock(mutex_); + + // Find empty slot or oldest entry + auto now = std::chrono::steady_clock::now(); + size_t targetIndex = 0; + bool foundEmpty = false; + + for (size_t i = 0; i < cached_.size(); ++i) { + if (!cached_[i].decoder) { + targetIndex = i; + foundEmpty = true; + break; + } + if (cached_[i].lastUsed < cached_[targetIndex].lastUsed) { + targetIndex = i; + } + } + + // If cache is full and no empty slot, destroy the oldest + if (!foundEmpty && targetIndex < cached_.size()) { + if (cached_[targetIndex].decoder) { + cuvidDestroyDecoder(cached_[targetIndex].decoder); + } + } + + // Add to cache (ensure we don't exceed MAX_CACHE_SIZE) + if (cached_.size() < MAX_CACHE_SIZE && targetIndex >= cached_.size()) { + cached_.resize(targetIndex + 1); + } else if (cached_.size() >= MAX_CACHE_SIZE && targetIndex >= MAX_CACHE_SIZE) { + targetIndex = 0; // Use first slot if we somehow got here + } + + cached_[targetIndex] = {decoder, context, createInfo, now}; + } + + // Clean up cache on program exit + static void clearCache() { + std::lock_guard lock(mutex_); + for (auto& entry : cached_) { + if (entry.decoder) { + cuvidDestroyDecoder(entry.decoder); + entry.decoder = nullptr; + } + } + cached_.clear(); + } + + private: + static bool isCreateInfoCompatible( + const CUVIDDECODECREATEINFO& cached, + const CUVIDDECODECREATEINFO& requested) { + return cached.CodecType == requested.CodecType && + cached.ulWidth == requested.ulWidth && + cached.ulHeight == requested.ulHeight && + cached.ChromaFormat == requested.ChromaFormat && + cached.OutputFormat == requested.OutputFormat && + cached.bitDepthMinus8 == requested.bitDepthMinus8; + } + + static std::mutex mutex_; + static std::vector cached_; +}; + +std::mutex DecoderCache::mutex_; +std::vector DecoderCache::cached_; + +// Global cleanup helper +struct CacheCleanup { + ~CacheCleanup() { + DecoderCache::clearCache(); + } +}; +static CacheCleanup g_cache_cleanup; + +// Register the custom NVDEC device interface with 'custom_nvdec' variant +static bool g_cuda_custom_nvdec = registerDeviceInterface( + DeviceInterfaceKey(torch::kCUDA, "custom_nvdec"), + [](const torch::Device& device) { + return new CustomNvdecDeviceInterface(device); + }); + +// NVDEC callback functions +static int CUDAAPI +HandleVideoSequence(void* pUserData, CUVIDEOFORMAT* pVideoFormat) { + // printf("Static HandleVideoSequence called\n"); + CustomNvdecDeviceInterface* decoder = + static_cast(pUserData); + return decoder->handleVideoSequence(pVideoFormat); +} + +static int CUDAAPI +HandlePictureDecode(void* pUserData, CUVIDPICPARAMS* pPicParams) { + // printf("Static HandlePictureDecode called\n"); + CustomNvdecDeviceInterface* decoder = + static_cast(pUserData); + return decoder->handlePictureDecode(pPicParams); +} + +static int CUDAAPI +HandlePictureDisplay(void* pUserData, CUVIDPARSERDISPINFO* pDispInfo) { + // printf("Static HandlePictureDisplay called\n"); + CustomNvdecDeviceInterface* decoder = + static_cast(pUserData); + return decoder->handlePictureDisplay(pDispInfo); +} + +} // namespace + +CustomNvdecDeviceInterface::CustomNvdecDeviceInterface( + const torch::Device& device) + : DeviceInterface(device) { + TORCH_CHECK( + g_cuda_custom_nvdec, "CustomNvdecDeviceInterface was not registered!"); + TORCH_CHECK( + device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); +} + +CustomNvdecDeviceInterface::~CustomNvdecDeviceInterface() { + // Clean up any remaining frames in the queue + { + std::lock_guard lock(frameQueueMutex_); + while (!frameQueue_.empty()) { + FrameData frameData = frameQueue_.front(); + frameQueue_.pop(); + + // Unmap the frame if it's still mapped + if (decoder_ && frameData.framePtr != 0) { + cuvidUnmapVideoFrame(decoder_, frameData.framePtr); + } + } + } + + // Cache decoder instead of destroying it + if (decoder_ && context_) { + DecoderCache::cacheDecoder(decoder_, context_, createInfo_); + decoder_ = nullptr; + } + + // Clean up video parser + if (videoParser_) { + cuvidDestroyVideoParser(videoParser_); + videoParser_ = nullptr; + } + + isInitialized_ = false; + parserInitialized_ = false; +} + +std::optional CustomNvdecDeviceInterface::findCodec( + const AVCodecID& codecId) { + // For custom NVDEC, we bypass FFmpeg codec selection entirely + // We'll handle the codec selection in our own NVDEC initialization + (void)codecId; // Suppress unused parameter warning + return std::nullopt; +} + +void CustomNvdecDeviceInterface::initializeContext( + AVCodecContext* codecContext) { + // Don't set hw_device_ctx - we handle decoding directly with NVDEC SDK + // Just ensure CUDA context exists for PyTorch tensors + torch::Tensor dummyTensor = torch::empty( + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); + + // Initialize our custom NVDEC decoder + initializeNvdecDecoder(codecContext->codec_id); + + // Initialize video parser with the codec ID and extradata + initializeVideoParser(codecContext->codec_id, codecContext->extradata, codecContext->extradata_size); +} + +void CustomNvdecDeviceInterface::initializeNvdecDecoder(AVCodecID codecId) { + if (isInitialized_) { + return; // Already initialized + } + + // Store the codec ID for later use + currentCodecId_ = codecId; + + // Convert AVCodecID to NVDEC codec type + cudaVideoCodec nvCodec; + switch (codecId) { + case AV_CODEC_ID_H264: + nvCodec = cudaVideoCodec_H264; + break; + case AV_CODEC_ID_HEVC: + nvCodec = cudaVideoCodec_HEVC; + break; + case AV_CODEC_ID_AV1: + nvCodec = cudaVideoCodec_AV1; + break; + case AV_CODEC_ID_VP8: + nvCodec = cudaVideoCodec_VP8; + break; + case AV_CODEC_ID_VP9: + nvCodec = cudaVideoCodec_VP9; + break; + default: + TORCH_CHECK( + false, + "Unsupported codec for custom NVDEC: ", + avcodec_get_name(codecId)); + } + + // Initialize video format structure (decoder will be created in + // handleVideoSequence) + memset(&videoFormat_, 0, sizeof(videoFormat_)); + videoFormat_.codec = nvCodec; + videoFormat_.coded_width = 0; // Will be set when we get the first frame + videoFormat_.coded_height = 0; // Will be set when we get the first frame + videoFormat_.chroma_format = cudaVideoChromaFormat_420; + videoFormat_.bit_depth_luma_minus8 = 0; + videoFormat_.bit_depth_chroma_minus8 = 0; + + isInitialized_ = true; +} + +void CustomNvdecDeviceInterface::initializeVideoParser(AVCodecID codecId, uint8_t* extradata, int extradata_size) { + if (parserInitialized_) { + return; + } + + // printf("Initializing NVDEC video parser for codec\n"); + + // Set up video parser parameters + CUVIDPARSERPARAMS parserParams = {}; + parserParams.CodecType = videoFormat_.codec; + parserParams.ulMaxNumDecodeSurfaces = 1; + parserParams.ulClockRate = 1000; + parserParams.ulErrorThreshold = 0; + parserParams.ulMaxDisplayDelay = 1; + parserParams.pUserData = this; + parserParams.pfnSequenceCallback = HandleVideoSequence; + parserParams.pfnDecodePicture = HandlePictureDecode; + parserParams.pfnDisplayPicture = HandlePictureDisplay; + + // printf("Parser params: pUserData=%p, pfnSequenceCallback=%p, pfnDecodePicture=%p, pfnDisplayPicture=%p\n", + // parserParams.pUserData, (void*)parserParams.pfnSequenceCallback, + // (void*)parserParams.pfnDecodePicture, (void*)parserParams.pfnDisplayPicture); + + CUresult result = cuvidCreateVideoParser(&videoParser_, &parserParams); + TORCH_CHECK( + result == CUDA_SUCCESS, "Failed to create video parser: ", result); + + parserInitialized_ = true; +} + +int CustomNvdecDeviceInterface::handleVideoSequence( + CUVIDEOFORMAT* pVideoFormat) { + // printf("In CustomNvdecDeviceInterface::handleVideoSequence\n"); + TORCH_CHECK(pVideoFormat != nullptr, "Invalid video format"); + + // Store video format + videoFormat_ = *pVideoFormat; + + // Ensure we have a CUDA context - create a tensor to force PyTorch to set it up + torch::Tensor contextTensor = torch::empty( + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); + + // Force CUDA operations to establish context properly + contextTensor = contextTensor + 1; // TODO ugh? Same with tensor creation above. + c10::cuda::device_synchronize(); // TODO is this needed? multi-threading was buggy without it + + // Now get the current CUDA context + CUresult cuResult = cuCtxGetCurrent(&context_); + if (cuResult != CUDA_SUCCESS || context_ == nullptr) { + TORCH_CHECK(false, "Failed to get CUDA context after tensor operation"); + } + + // Create decoder with the video format + createInfo_ = {}; + createInfo_.CodecType = pVideoFormat->codec; + createInfo_.ulWidth = pVideoFormat->coded_width; + createInfo_.ulHeight = pVideoFormat->coded_height; + createInfo_.ulNumDecodeSurfaces = 4; + createInfo_.ChromaFormat = pVideoFormat->chroma_format; + createInfo_.OutputFormat = cudaVideoSurfaceFormat_NV12; + createInfo_.bitDepthMinus8 = pVideoFormat->bit_depth_luma_minus8; + createInfo_.ulTargetWidth = pVideoFormat->coded_width; + createInfo_.ulTargetHeight = pVideoFormat->coded_height; + createInfo_.ulNumOutputSurfaces = 2; + createInfo_.ulCreationFlags = cudaVideoCreate_PreferCUVID; + createInfo_.vidLock = nullptr; + + // Try to get a cached decoder first + decoder_ = DecoderCache::tryGetCachedDecoder(context_, createInfo_); + + if (!decoder_) { + // No suitable cached decoder found, create a new one + CUresult result = cuvidCreateDecoder(&decoder_, &createInfo_); + if (result != CUDA_SUCCESS) { + TORCH_CHECK(false, "Failed to create NVDEC decoder: ", result); + } + } + + return 1; // Success +} + +int CustomNvdecDeviceInterface::handlePictureDecode( + CUVIDPICPARAMS* pPicParams) { + TORCH_CHECK(pPicParams != nullptr, "Invalid picture parameters"); + // printf("In CustomNvdecDeviceInterface::handlePictureDecode\n"); + + if (!decoder_) { + return 0; // No decoder available + } + + CUresult result = cuvidDecodePicture(decoder_, pPicParams); + return (result == CUDA_SUCCESS) ? 1 : 0; +} + +int CustomNvdecDeviceInterface::handlePictureDisplay( + CUVIDPARSERDISPINFO* pDispInfo) { + TORCH_CHECK(pDispInfo != nullptr, "Invalid display info"); + + // Queue the frame for later retrieval + std::lock_guard lock(frameQueueMutex_); + + // Map the decoded frame + CUdeviceptr framePtr = 0; + unsigned int pitch = 0; + CUVIDPROCPARAMS procParams = {}; + procParams.progressive_frame = pDispInfo->progressive_frame; + procParams.top_field_first = pDispInfo->top_field_first; + procParams.unpaired_field = pDispInfo->repeat_first_field < 0; + + CUresult result = cuvidMapVideoFrame( + decoder_, + pDispInfo->picture_index, + &framePtr, + &pitch, + &procParams); + if (result == CUDA_SUCCESS) { + FrameData frameData = {framePtr, pitch, *pDispInfo}; + frameQueue_.push(frameData); + } + + return 1; +} + +UniqueAVFrame CustomNvdecDeviceInterface::decodePacketDirectly( + ReferenceAVPacket& packet) { + TORCH_CHECK(isInitialized_, "NVDEC decoder not initialized"); + + // Extract compressed data from AVPacket + uint8_t* compressedData = packet->data; + int size = packet->size; + int64_t pts = packet->pts; + + TORCH_CHECK(compressedData != nullptr && size > 0, "Invalid packet data"); + + // Video parser should already be initialized from initializeContext + TORCH_CHECK(parserInitialized_, "Video parser not initialized"); + + // Parse the packet data (now already in Annex B format from bitstream filter) + // printf("About to parse packet: size=%d, pts=%lld\n", size, pts); + // printf("First 8 bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", + // compressedData[0], compressedData[1], compressedData[2], compressedData[3], + // compressedData[4], compressedData[5], compressedData[6], compressedData[7]); + + CUVIDSOURCEDATAPACKET cudaPacket = {0}; // Initialize all fields to 0 + cudaPacket.payload = compressedData; + cudaPacket.payload_size = size; + cudaPacket.flags = CUVID_PKT_TIMESTAMP; + cudaPacket.timestamp = pts; + + CUresult result = cuvidParseVideoData(videoParser_, &cudaPacket); + // printf("Parse result: %d\n", result); + TORCH_CHECK(result == CUDA_SUCCESS, "Failed to parse video data: ", result); + + // Check if we have any decoded frames available + std::lock_guard lock(frameQueueMutex_); + if (frameQueue_.empty()) { + // No frame ready yet (async decoding) + return UniqueAVFrame(nullptr); + } + + // Get the first available frame + FrameData frameData = frameQueue_.front(); + frameQueue_.pop(); + + // Convert the NVDEC frame to AVFrame + UniqueAVFrame avFrame = convertCudaFrameToAVFrame(frameData.framePtr, frameData.pitch, frameData.dispInfo); + + // Unmap the frame + cuvidUnmapVideoFrame(decoder_, frameData.framePtr); + + return avFrame; +} + + +UniqueAVFrame CustomNvdecDeviceInterface::convertCudaFrameToAVFrame( + CUdeviceptr framePtr, + unsigned int pitch, + const CUVIDPARSERDISPINFO& dispInfo) { + TORCH_CHECK(framePtr != 0, "Invalid CUDA frame pointer"); + + // Get frame dimensions from video format + int width = videoFormat_.coded_width; + int height = videoFormat_.coded_height; + + TORCH_CHECK(width > 0 && height > 0, "Invalid frame dimensions"); + TORCH_CHECK(pitch >= width, "Pitch must be >= width"); + + // printf("Frame conversion: width=%d, height=%d, pitch=%u\n", width, height, pitch); + + // Allocate AVFrame + UniqueAVFrame avFrame(av_frame_alloc()); + TORCH_CHECK(avFrame.get() != nullptr, "Failed to allocate AVFrame"); + + // Set frame properties + avFrame->width = width; + avFrame->height = height; + avFrame->format = AV_PIX_FMT_CUDA; // Indicate this is GPU data + avFrame->pts = dispInfo.timestamp; + avFrame->duration = 0; // Will be set by caller if needed + + // For NVDEC output in NV12 format, we need to set up the data pointers + // The framePtr points to the beginning of the NV12 data + avFrame->data[0] = reinterpret_cast(framePtr); // Y plane + avFrame->data[1] = reinterpret_cast(framePtr + (pitch * height)); // UV plane (using pitch, not width) + avFrame->data[2] = nullptr; + avFrame->data[3] = nullptr; + + // Set line sizes for NV12 format using the actual NVDEC pitch + avFrame->linesize[0] = pitch; // Y plane stride (use actual pitch from NVDEC) + avFrame->linesize[1] = pitch; // UV plane stride (use actual pitch from NVDEC) + avFrame->linesize[2] = 0; + avFrame->linesize[3] = 0; + + return avFrame; +} + +void CustomNvdecDeviceInterface::convertAVFrameToFrameOutput( + const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase, + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor) { + // For custom NVDEC, the frame should already be on GPU + // We need to convert from NVDEC's output format (typically NV12) to RGB + + TORCH_CHECK( + avFrame->format == AV_PIX_FMT_CUDA, + "Expected CUDA format frame from custom NVDEC decoder"); + + auto cpuDevice = torch::Device(torch::kCUDA); + auto cpuInterface = createDeviceInterface(cpuDevice); + + FrameOutput cpuFrameOutput; + cpuInterface->convertAVFrameToFrameOutput( + videoStreamOptions, + timeBase, + avFrame, + cpuFrameOutput, + preAllocatedOutputTensor); + + frameOutput.data = cpuFrameOutput.data.to(device_); +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CustomNvdecDeviceInterface.h b/src/torchcodec/_core/CustomNvdecDeviceInterface.h new file mode 100644 index 00000000..201a3c79 --- /dev/null +++ b/src/torchcodec/_core/CustomNvdecDeviceInterface.h @@ -0,0 +1,96 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "src/torchcodec/_core/DeviceInterface.h" + +#include +#include +#include +#include + +// Include NVIDIA Video Codec SDK headers +#include +#include + +namespace facebook::torchcodec { + + +// Custom NVDEC device interface that provides direct control over NVDEC +// while keeping FFmpeg for demuxing +class CustomNvdecDeviceInterface : public DeviceInterface { + public: + CustomNvdecDeviceInterface(const torch::Device& device); + + virtual ~CustomNvdecDeviceInterface(); + + std::optional findCodec(const AVCodecID& codecId) override; + + void initializeContext(AVCodecContext* codecContext) override; + + void convertAVFrameToFrameOutput( + const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase, + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = + std::nullopt) override; + + // Extension point overrides for direct packet decoding + bool canDecodePacketDirectly() const override { + return true; + } + + UniqueAVFrame decodePacketDirectly(ReferenceAVPacket& packet) override; + + // Legacy method name - kept for compatibility + UniqueAVFrame decodePacket(ReferenceAVPacket& packet) { + return decodePacketDirectly(packet); + } + + public: + // NVDEC callback functions (must be public for C callbacks) + int handleVideoSequence(CUVIDEOFORMAT* pVideoFormat); + int handlePictureDecode(CUVIDPICPARAMS* pPicParams); + int handlePictureDisplay(CUVIDPARSERDISPINFO* pDispInfo); + + private: + // NVDEC decoder context and parser + CUvideoparser videoParser_ = nullptr; + CUvideodecoder decoder_ = nullptr; + CUcontext context_ = nullptr; + + // Video format info + CUVIDEOFORMAT videoFormat_; + CUVIDDECODECREATEINFO createInfo_; + AVCodecID currentCodecId_ = AV_CODEC_ID_NONE; + bool isInitialized_ = false; + bool parserInitialized_ = false; + + // Frame queue for async decoding - stores frame pointer, pitch, and display info + struct FrameData { + CUdeviceptr framePtr; + unsigned int pitch; + CUVIDPARSERDISPINFO dispInfo; + }; + std::queue frameQueue_; + std::mutex frameQueueMutex_; + + // Custom context initialization for direct NVDEC usage + void initializeNvdecDecoder(AVCodecID codecId); + + // Initialize video parser + void initializeVideoParser(AVCodecID codecId, uint8_t* extradata, int extradata_size); + + // Convert CUDA frame pointer to AVFrame + UniqueAVFrame convertCudaFrameToAVFrame( + CUdeviceptr framePtr, + unsigned int pitch, + const CUVIDPARSERDISPINFO& dispInfo); +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.cpp b/src/torchcodec/_core/DeviceInterface.cpp index 70b00fb6..20cf845a 100644 --- a/src/torchcodec/_core/DeviceInterface.cpp +++ b/src/torchcodec/_core/DeviceInterface.cpp @@ -11,7 +11,8 @@ namespace facebook::torchcodec { namespace { -using DeviceInterfaceMap = std::map; +using DeviceInterfaceMap = + std::map; static std::mutex g_interface_mutex; DeviceInterfaceMap& getDeviceMap() { @@ -29,51 +30,128 @@ std::string getDeviceType(const std::string& device) { } // namespace +// New registration function with variant support bool registerDeviceInterface( - torch::DeviceType deviceType, + const DeviceInterfaceKey& key, CreateDeviceInterfaceFn createInterface) { std::scoped_lock lock(g_interface_mutex); DeviceInterfaceMap& deviceMap = getDeviceMap(); TORCH_CHECK( - deviceMap.find(deviceType) == deviceMap.end(), - "Device interface already registered for ", - deviceType); - deviceMap.insert({deviceType, createInterface}); + deviceMap.find(key) == deviceMap.end(), + "Device interface already registered for device type ", + key.deviceType, + " variant '", + key.variant, + "'"); + deviceMap.insert({key, createInterface}); return true; } +// Backward-compatible registration function +bool registerDeviceInterface( + torch::DeviceType deviceType, + CreateDeviceInterfaceFn createInterface) { + return registerDeviceInterface( + DeviceInterfaceKey(deviceType), createInterface); +} + torch::Device createTorchDevice(const std::string device) { std::scoped_lock lock(g_interface_mutex); - std::string deviceType = getDeviceType(device); + + // Parse device string: "device_type:index:variant" or "device_type:index" or + // "device_type" + std::string deviceType; + std::string variant = "default"; + std::string torchDeviceString = + device; // What we'll pass to torch::Device constructor + + size_t firstColon = device.find(':'); + if (firstColon == std::string::npos) { + // Just device type (e.g., "cpu") + deviceType = device; + } else { + deviceType = device.substr(0, firstColon); + + // Check for second colon (variant) + size_t secondColon = device.find(':', firstColon + 1); + if (secondColon != std::string::npos) { + // Format: "device_type:index:variant" + variant = device.substr(secondColon + 1); + torchDeviceString = device.substr(0, secondColon); // Remove variant part + } + // else: Format: "device_type:index" (no variant) + } + DeviceInterfaceMap& deviceMap = getDeviceMap(); + // Find device interface that matches device type and variant + torch::DeviceType deviceTypeEnum = torch::Device(deviceType).type(); + auto deviceInterface = std::find_if( deviceMap.begin(), deviceMap.end(), - [&](const std::pair& arg) { - return device.rfind( - torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0; + [&](const std::pair& arg) { + return arg.first.deviceType == deviceTypeEnum && + arg.first.variant == variant; }); - TORCH_CHECK( - deviceInterface != deviceMap.end(), "Unsupported device: ", device); - return torch::Device(device); + // If variant-specific interface not found, try default variant + if (deviceInterface == deviceMap.end() && variant != "default") { + deviceInterface = std::find_if( + deviceMap.begin(), + deviceMap.end(), + [&](const std::pair& arg) { + return arg.first.deviceType == deviceTypeEnum && + arg.first.variant == "default"; + }); + } + + TORCH_CHECK( + deviceInterface != deviceMap.end(), + "Unsupported device: ", + device, + " (device type: ", + deviceType, + ", variant: ", + variant, + ")"); + + // Return torch::Device with just device type and index (no variant) + return torch::Device(torchDeviceString); } +// Creation function with variant support (default = "default" for backward +// compatibility) std::unique_ptr createDeviceInterface( - const torch::Device& device) { - auto deviceType = device.type(); + const torch::Device& device, + const std::string& variant) { + DeviceInterfaceKey key(device.type(), variant); std::scoped_lock lock(g_interface_mutex); DeviceInterfaceMap& deviceMap = getDeviceMap(); - TORCH_CHECK( - deviceMap.find(deviceType) != deviceMap.end(), - "Unsupported device: ", - device); + auto it = deviceMap.find(key); + if (it != deviceMap.end()) { + return std::unique_ptr(it->second(device)); + } - return std::unique_ptr(deviceMap[deviceType](device)); + // Fallback to default variant if specific variant not found + if (variant != "default") { + key.variant = "default"; + it = deviceMap.find(key); + if (it != deviceMap.end()) { + return std::unique_ptr(it->second(device)); + } + } + + TORCH_CHECK( + false, + "No device interface found for device type: ", + device.type(), + " variant: '", + variant, + "'"); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 11a73b65..3c7ca9ae 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -17,6 +17,25 @@ namespace facebook::torchcodec { +// Key for device interface registration with device type + variant support +struct DeviceInterfaceKey { + torch::DeviceType deviceType; + std::string variant = "default"; // e.g., "default", "custom_nvdec", etc. + + bool operator<(const DeviceInterfaceKey& other) const { + if (deviceType != other.deviceType) { + return deviceType < other.deviceType; + } + return variant < other.variant; + } + + // Convenience constructors + DeviceInterfaceKey(torch::DeviceType type) : deviceType(type) {} + + DeviceInterfaceKey(torch::DeviceType type, const std::string& var) + : deviceType(type), variant(var) {} +}; + // Note that all these device functions should only be called if the device is // not a CPU device. CPU device functions are already implemented in the // SingleStreamDecoder implementation. @@ -48,6 +67,22 @@ class DeviceInterface { FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt) = 0; + // Extension points for custom decoding paths + // Override to return true if this device interface can decode packets + // directly + virtual bool canDecodePacketDirectly() const { + return false; + } + + // Override to decode AVPacket directly (bypassing FFmpeg codec) + // Only called if canDecodePacketDirectly() returns true + virtual UniqueAVFrame decodePacketDirectly(ReferenceAVPacket& /* packet */) { + TORCH_CHECK( + false, + "Direct packet decoding not implemented for this device interface"); + return UniqueAVFrame(nullptr); + } + protected: torch::Device device_; }; @@ -55,13 +90,22 @@ class DeviceInterface { using CreateDeviceInterfaceFn = std::function; +// New registration function with variant support +bool registerDeviceInterface( + const DeviceInterfaceKey& key, + const CreateDeviceInterfaceFn createInterface); + +// Backward-compatible registration function bool registerDeviceInterface( torch::DeviceType deviceType, const CreateDeviceInterfaceFn createInterface); torch::Device createTorchDevice(const std::string device); +// Creation function with variant support (default = "default" for backward +// compatibility) std::unique_ptr createDeviceInterface( - const torch::Device& device); + const torch::Device& device, + const std::string& variant = "default"); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 38a0c099..540f0e95 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -12,6 +12,7 @@ extern "C" { #include +#include #include #include #include @@ -76,6 +77,8 @@ using UniqueSwrContext = std::unique_ptr>; using UniqueAVAudioFifo = std:: unique_ptr>; +using UniqueAVBSFContext = std:: + unique_ptr>; // These 2 classes share the same underlying AVPacket object. They are meant to // be used in tandem, like so: diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 9cfc652a..a5d08819 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -12,6 +12,7 @@ #include #include #include +#include "src/torchcodec/_core/CustomNvdecDeviceInterface.h" #include "torch/types.h" namespace facebook::torchcodec { @@ -391,7 +392,8 @@ void SingleStreamDecoder::addStream( int streamIndex, AVMediaType mediaType, const torch::Device& device, - std::optional ffmpegThreadCount) { + std::optional ffmpegThreadCount, + const std::string& deviceVariant) { TORCH_CHECK( activeStreamIndex_ == NO_ACTIVE_STREAM, "Can only add one single stream."); @@ -419,7 +421,8 @@ void SingleStreamDecoder::addStream( streamInfo.stream = formatContext_->streams[activeStreamIndex_]; streamInfo.avMediaType = mediaType; - deviceInterface_ = createDeviceInterface(device); + deviceVariant_ = deviceVariant; + deviceInterface_ = createDeviceInterface(device, deviceVariant); // This should never happen, checking just to be safe. TORCH_CHECK( @@ -460,6 +463,43 @@ void SingleStreamDecoder::addStream( TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal)); codecContext->time_base = streamInfo.stream->time_base; + + // Initialize bitstream filter for H.264 MP4 containers + if (mediaType == AVMEDIA_TYPE_VIDEO && + codecContext->codec_id == AV_CODEC_ID_H264) { + + // Check if this is an MP4-style container that needs bitstream filtering + const char* formatName = formatContext_->iformat->long_name; + bool isMP4Container = (strcmp(formatName, "QuickTime / MOV") == 0 || + strcmp(formatName, "FLV (Flash Video)") == 0 || + strcmp(formatName, "Matroska / WebM") == 0); + + if (isMP4Container) { + // printf("Initializing H.264 MP4 to Annex B bitstream filter for %s container\n", formatName); + + const AVBitStreamFilter* bsf = av_bsf_get_by_name("h264_mp4toannexb"); + TORCH_CHECK(bsf != nullptr, "Failed to find h264_mp4toannexb bitstream filter"); + + AVBSFContext* rawBsf = nullptr; + retVal = av_bsf_alloc(bsf, &rawBsf); + TORCH_CHECK(retVal >= AVSUCCESS, "Failed to allocate bitstream filter: ", + getFFMPEGErrorStringFromErrorCode(retVal)); + + streamInfo.bitstreamFilter.reset(rawBsf); + + retVal = avcodec_parameters_copy(streamInfo.bitstreamFilter->par_in, + streamInfo.stream->codecpar); + TORCH_CHECK(retVal >= AVSUCCESS, "Failed to copy codec parameters: ", + getFFMPEGErrorStringFromErrorCode(retVal)); + + retVal = av_bsf_init(streamInfo.bitstreamFilter.get()); + TORCH_CHECK(retVal >= AVSUCCESS, "Failed to initialize bitstream filter: ", + getFFMPEGErrorStringFromErrorCode(retVal)); + + streamInfo.needsBitstreamFiltering = true; + // printf("Successfully initialized bitstream filter\n"); + } + } containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName = std::string(avcodec_get_name(codecContext->codec_id)); @@ -469,7 +509,7 @@ void SingleStreamDecoder::addStream( // important to discard/demux correctly in the inner decoding loop. for (unsigned int i = 0; i < formatContext_->nb_streams; ++i) { if (i != static_cast(activeStreamIndex_)) { - formatContext_->streams[i]->discard = AVDISCARD_ALL; + // formatContext_->streams[i]->discard = AVDISCARD_ALL; } } } @@ -477,12 +517,14 @@ void SingleStreamDecoder::addStream( void SingleStreamDecoder::addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions, - std::optional customFrameMappings) { + std::optional customFrameMappings, + const std::string& deviceVariant) { addStream( streamIndex, AVMEDIA_TYPE_VIDEO, videoStreamOptions.device, - videoStreamOptions.ffmpegThreadCount); + videoStreamOptions.ffmpegThreadCount, + deviceVariant); auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; @@ -1200,15 +1242,52 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( continue; } - // We got a valid packet. Send it to the decoder, and we'll receive it in - // the next iteration. - status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get()); - TORCH_CHECK( - status >= AVSUCCESS, - "Could not push packet to decoder: ", - getFFMPEGErrorStringFromErrorCode(status)); + // Check if device interface can handle packet decoding directly + if (deviceInterface_ && deviceInterface_->canDecodePacketDirectly()) { + ReferenceAVPacket* packetToSend = &packet; + AutoAVPacket filteredAutoPacket; + ReferenceAVPacket filteredPacket(filteredAutoPacket); + + // Apply bitstream filtering if needed + if (streamInfo.needsBitstreamFiltering && streamInfo.bitstreamFilter) { + // printf("Applying bitstream filter to packet\n"); + + // Send packet to bitstream filter + int retVal = av_bsf_send_packet(streamInfo.bitstreamFilter.get(), packet.get()); + TORCH_CHECK(retVal >= AVSUCCESS, "Failed to send packet to bitstream filter: ", + getFFMPEGErrorStringFromErrorCode(retVal)); + + // Receive filtered packet + retVal = av_bsf_receive_packet(streamInfo.bitstreamFilter.get(), filteredPacket.get()); + TORCH_CHECK(retVal >= AVSUCCESS, "Failed to receive packet from bitstream filter: ", + getFFMPEGErrorStringFromErrorCode(retVal)); + + packetToSend = &filteredPacket; + // printf("Bitstream filtering complete: original size=%d, filtered size=%d\n", + // packet->size, filteredPacket->size); + } + + // Use custom packet decoding (e.g., direct NVDEC) + UniqueAVFrame decodedFrame = + deviceInterface_->decodePacketDirectly(*packetToSend); + if (decodedFrame && filterFunction(decodedFrame)) { + // We got the frame we're looking for from direct decoding + avFrame = std::move(decodedFrame); + decodeStats_.numPacketsSentToDecoder++; + break; + } + // If custom decoding didn't produce the desired frame, continue the loop + decodeStats_.numPacketsSentToDecoder++; + } else { + // Use standard FFmpeg decoding path + status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get()); + TORCH_CHECK( + status >= AVSUCCESS, + "Could not push packet to decoder: ", + getFFMPEGErrorStringFromErrorCode(status)); - decodeStats_.numPacketsSentToDecoder++; + decodeStats_.numPacketsSentToDecoder++; + } } if (status < AVSUCCESS) { diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 027f52fc..1d5bde92 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -84,7 +84,8 @@ class SingleStreamDecoder { void addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(), - std::optional customFrameMappings = std::nullopt); + std::optional customFrameMappings = std::nullopt, + const std::string& deviceVariant = "default"); void addAudioStream( int streamIndex, const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); @@ -237,6 +238,10 @@ class SingleStreamDecoder { // color-conversion fields. Only one of FilterGraphContext and // UniqueSwsContext should be non-null. UniqueSwrContext swrContext; + + // Bitstream filter for MP4 to Annex B conversion + UniqueAVBSFContext bitstreamFilter; + bool needsBitstreamFiltering = false; }; // -------------------------------------------------------------------------- @@ -318,7 +323,8 @@ class SingleStreamDecoder { int streamIndex, AVMediaType mediaType, const torch::Device& device = torch::kCPU, - std::optional ffmpegThreadCount = std::nullopt); + std::optional ffmpegThreadCount = std::nullopt, + const std::string& deviceVariant = "default"); // Returns the "best" stream index for a given media type. The "best" is // determined by various heuristics in FFMPEG. @@ -352,6 +358,7 @@ class SingleStreamDecoder { ContainerMetadata containerMetadata_; UniqueDecodingAVFormatContext formatContext_; std::unique_ptr deviceInterface_; + std::string deviceVariant_ = "default"; std::map streamInfos_; const int NO_ACTIVE_STREAM = -2; int activeStreamIndex_ = NO_ACTIVE_STREAM; diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index daec2010..8901f45d 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -260,8 +260,19 @@ void _add_video_stream( ". color_conversion_library must be either filtergraph or swscale."); } } + std::string deviceVariant = "default"; if (device.has_value()) { - videoStreamOptions.device = createTorchDevice(std::string(device.value())); + std::string deviceStr = std::string(device.value()); + videoStreamOptions.device = createTorchDevice(deviceStr); + + // Extract variant from device string (format: "device_type:index:variant") + size_t firstColon = deviceStr.find(':'); + if (firstColon != std::string::npos) { + size_t secondColon = deviceStr.find(':', firstColon + 1); + if (secondColon != std::string::npos) { + deviceVariant = deviceStr.substr(secondColon + 1); + } + } } std::optional converted_mappings = custom_frame_mappings.has_value() @@ -269,7 +280,10 @@ void _add_video_stream( : std::nullopt; auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->addVideoStream( - stream_index.value_or(-1), videoStreamOptions, converted_mappings); + stream_index.value_or(-1), + videoStreamOptions, + converted_mappings, + deviceVariant); } // Add a new video stream at `stream_index` using the provided options. diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index bb252fd6..57501419 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -53,6 +53,9 @@ class VideoDecoder: Passing 0 lets FFmpeg decide on the number of threads. Default: 1. device (str or torch.device, optional): The device to use for decoding. Default: "cpu". + device_variant (str, optional): The device interface variant to use. For CUDA devices, + specify "custom_nvdec" to use the custom NVDEC implementation instead of the default + FFmpeg NVDEC decoder. Default: None (uses default device interface). seek_mode (str, optional): Determines if frame access will be "exact" or "approximate". Exact guarantees that requesting frame i will always return frame i, but doing so requires an initial :term:`scan` of the @@ -78,6 +81,7 @@ def __init__( dimension_order: Literal["NCHW", "NHWC"] = "NCHW", num_ffmpeg_threads: int = 1, device: Optional[Union[str, torch_device]] = "cpu", + device_variant: Optional[str] = None, seek_mode: Literal["exact", "approximate"] = "exact", ): allowed_seek_modes = ("exact", "approximate") @@ -102,6 +106,15 @@ def __init__( if isinstance(device, torch_device): device = str(device) + # Handle device variant by extending device string + if device_variant is not None: + if device == "cpu": + # For CPU, variant format is "cpu:variant" + device = f"cpu:{device_variant}" + else: + # For other devices (e.g., "cuda:0"), append variant + device = f"{device}:{device_variant}" + core.add_video_stream( self._decoder, stream_index=stream_index,