diff --git a/benchmarks/decoders/benchmark_decoders.py b/benchmarks/decoders/benchmark_decoders.py index 57669d12b..636326066 100644 --- a/benchmarks/decoders/benchmark_decoders.py +++ b/benchmarks/decoders/benchmark_decoders.py @@ -8,66 +8,18 @@ import importlib.resources import os import platform -import typing -from dataclasses import dataclass, field from pathlib import Path import torch from benchmark_decoders_library import ( - AbstractDecoder, - DecordAccurate, - DecordAccurateBatch, - OpenCVDecoder, + decoder_registry, plot_data, run_benchmarks, - TorchAudioDecoder, - TorchCodecCore, - TorchCodecCoreBatch, - TorchCodecCoreCompiled, - TorchCodecCoreNonBatch, - TorchCodecPublic, - TorchCodecPublicNonBatch, - TorchVision, + verify_outputs, ) -@dataclass -class DecoderKind: - display_name: str - kind: typing.Type[AbstractDecoder] - default_options: dict[str, str] = field(default_factory=dict) - - -decoder_registry = { - "decord": DecoderKind("DecordAccurate", DecordAccurate), - "decord_batch": DecoderKind("DecordAccurateBatch", DecordAccurateBatch), - "torchcodec_core": DecoderKind("TorchCodecCore", TorchCodecCore), - "torchcodec_core_batch": DecoderKind("TorchCodecCoreBatch", TorchCodecCoreBatch), - "torchcodec_core_nonbatch": DecoderKind( - "TorchCodecCoreNonBatch", TorchCodecCoreNonBatch - ), - "torchcodec_core_compiled": DecoderKind( - "TorchCodecCoreCompiled", TorchCodecCoreCompiled - ), - "torchcodec_public": DecoderKind("TorchCodecPublic", TorchCodecPublic), - "torchcodec_public_nonbatch": DecoderKind( - "TorchCodecPublicNonBatch", TorchCodecPublicNonBatch - ), - "torchvision": DecoderKind( - # We don't compare against TorchVision's "pyav" backend because it doesn't support - # accurate seeks. - "TorchVision[backend=video_reader]", - TorchVision, - {"backend": "video_reader"}, - ), - "torchaudio": DecoderKind("TorchAudio", TorchAudioDecoder), - "opencv": DecoderKind( - "OpenCV[backend=FFMPEG]", OpenCVDecoder, {"backend": "FFMPEG"} - ), -} - - def in_fbcode() -> bool: return "FB_PAR_RUNTIME_FILES" in os.environ @@ -148,6 +100,12 @@ def main() -> None: type=str, default="benchmarks.png", ) + parser.add_argument( + "--verify-outputs", + help="Verify that the outputs of the decoders are the same", + default=False, + action=argparse.BooleanOptionalAction, + ) args = parser.parse_args() specified_decoders = set(args.decoders.split(",")) @@ -177,29 +135,32 @@ def main() -> None: if entry.is_file() and entry.name.endswith(".mp4"): video_paths.append(entry.path) - results = run_benchmarks( - decoders_to_run, - video_paths, - num_uniform_samples, - num_sequential_frames_from_start=[1, 10, 100], - min_runtime_seconds=args.min_run_seconds, - benchmark_video_creation=args.bm_video_creation, - ) - data = { - "experiments": results, - "system_metadata": { - "cpu_count": os.cpu_count(), - "system": platform.system(), - "machine": platform.machine(), - "python_version": str(platform.python_version()), - "cuda": ( - torch.cuda.get_device_properties(0).name - if torch.cuda.is_available() - else "not available" - ), - }, - } - plot_data(data, args.plot_path) + if args.verify_outputs: + verify_outputs(decoders_to_run, video_paths, num_uniform_samples) + else: + results = run_benchmarks( + decoders_to_run, + video_paths, + num_uniform_samples, + num_sequential_frames_from_start=[1, 10, 100], + min_runtime_seconds=args.min_run_seconds, + benchmark_video_creation=args.bm_video_creation, + ) + data = { + "experiments": results, + "system_metadata": { + "cpu_count": os.cpu_count(), + "system": platform.system(), + "machine": platform.machine(), + "python_version": str(platform.python_version()), + "cuda": ( + torch.cuda.get_device_properties(0).name + if torch.cuda.is_available() + else "not available" + ), + }, + } + plot_data(data, args.plot_path) if __name__ == "__main__": diff --git a/benchmarks/decoders/benchmark_decoders_library.py b/benchmarks/decoders/benchmark_decoders_library.py index 9be263305..bd1c2f1c4 100644 --- a/benchmarks/decoders/benchmark_decoders_library.py +++ b/benchmarks/decoders/benchmark_decoders_library.py @@ -1,9 +1,10 @@ import abc import json import subprocess +import typing import urllib.request from concurrent.futures import ThreadPoolExecutor, wait -from dataclasses import dataclass +from dataclasses import dataclass, field from itertools import product from pathlib import Path @@ -23,6 +24,7 @@ get_next_frame, seek_to_pts, ) +from torchcodec._frame import FrameBatch from torchcodec.decoders import VideoDecoder, VideoStreamMetadata torch._dynamo.config.cache_size_limit = 100 @@ -824,6 +826,42 @@ def convert_result_to_df_item( return df_item +@dataclass +class DecoderKind: + display_name: str + kind: typing.Type[AbstractDecoder] + default_options: dict[str, str] = field(default_factory=dict) + + +decoder_registry = { + "decord": DecoderKind("DecordAccurate", DecordAccurate), + "decord_batch": DecoderKind("DecordAccurateBatch", DecordAccurateBatch), + "torchcodec_core": DecoderKind("TorchCodecCore", TorchCodecCore), + "torchcodec_core_batch": DecoderKind("TorchCodecCoreBatch", TorchCodecCoreBatch), + "torchcodec_core_nonbatch": DecoderKind( + "TorchCodecCoreNonBatch", TorchCodecCoreNonBatch + ), + "torchcodec_core_compiled": DecoderKind( + "TorchCodecCoreCompiled", TorchCodecCoreCompiled + ), + "torchcodec_public": DecoderKind("TorchCodecPublic", TorchCodecPublic), + "torchcodec_public_nonbatch": DecoderKind( + "TorchCodecPublicNonBatch", TorchCodecPublicNonBatch + ), + "torchvision": DecoderKind( + # We don't compare against TorchVision's "pyav" backend because it doesn't support + # accurate seeks. + "TorchVision[backend=video_reader]", + TorchVision, + {"backend": "video_reader"}, + ), + "torchaudio": DecoderKind("TorchAudio", TorchAudioDecoder), + "opencv": DecoderKind( + "OpenCV[backend=FFMPEG]", OpenCVDecoder, {"backend": "FFMPEG"} + ), +} + + def run_benchmarks( decoder_dict: dict[str, AbstractDecoder], video_files_paths: list[Path], @@ -986,3 +1024,77 @@ def run_benchmarks( compare = benchmark.Compare(results) compare.print() return df_data + + +def verify_outputs(decoders_to_run, video_paths, num_samples): + # Reuse TorchCodecPublic decoder stream_index option, if provided. + options = decoder_registry["torchcodec_public"].default_options + if torchcodec_decoder := next( + ( + decoder + for name, decoder in decoders_to_run.items() + if "TorchCodecPublic" in name + ), + None, + ): + options["stream_index"] = ( + str(torchcodec_decoder._stream_index) + if torchcodec_decoder._stream_index is not None + else "" + ) + # Create default TorchCodecPublic decoder to use as a baseline + torchcodec_public_decoder = TorchCodecPublic(**options) + + # Get frames using each decoder + for video_file_path in video_paths: + metadata = get_metadata(video_file_path) + metadata_label = f"{metadata.codec} {metadata.width}x{metadata.height}, {metadata.duration_seconds}s {metadata.average_fps}fps" + print(f"{metadata_label=}") + + # Generate uniformly spaced PTS + duration = metadata.duration_seconds + pts_list = [i * duration / num_samples for i in range(num_samples)] + + # Get the frames from TorchCodecPublic as the baseline + torchcodec_public_results = decode_and_adjust_frames( + torchcodec_public_decoder, + video_file_path, + num_samples=num_samples, + pts_list=pts_list, + ) + + for decoder_name, decoder in decoders_to_run.items(): + print(f"video={video_file_path}, decoder={decoder_name}") + + frames = decode_and_adjust_frames( + decoder, + video_file_path, + num_samples=num_samples, + pts_list=pts_list, + ) + for f1, f2 in zip(torchcodec_public_results, frames): + torch.testing.assert_close(f1, f2) + print(f"Results of baseline TorchCodecPublic and {decoder_name} match!") + + +def decode_and_adjust_frames( + decoder, video_file_path, *, num_samples: int, pts_list: list[float] +): + frames = [] + # Decode non-sequential frames using decode_frames function + non_seq_frames = decoder.decode_frames(video_file_path, pts_list) + # TorchCodec's batch APIs return a FrameBatch, so we need to extract the frames + if isinstance(non_seq_frames, FrameBatch): + non_seq_frames = non_seq_frames.data + frames.extend(non_seq_frames) + + # Decode sequential frames using decode_first_n_frames function + seq_frames = decoder.decode_first_n_frames(video_file_path, num_samples) + if isinstance(seq_frames, FrameBatch): + seq_frames = seq_frames.data + frames.extend(seq_frames) + + # Check and convert frames to C,H,W for consistency with other decoders. + if frames[0].shape[-1] == 3: + frames = [frame.permute(-1, *range(frame.dim() - 1)) for frame in frames] + return frames