From 88cacb9967145ddb5d763511cf651d68d8647f45 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Thu, 10 Jul 2025 00:05:55 -0400 Subject: [PATCH 01/26] Add stream_index seek mode, read frame index and update metadata --- src/torchcodec/_core/SingleStreamDecoder.cpp | 55 ++++++++++++++++++-- src/torchcodec/_core/SingleStreamDecoder.h | 11 +++- src/torchcodec/_core/custom_ops.cpp | 19 ++++--- src/torchcodec/_core/ops.py | 3 ++ 4 files changed, 76 insertions(+), 12 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 2e027da3..69da51f3 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -319,6 +319,41 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { scannedAllStreams_ = true; } + +void SingleStreamDecoder::readFrameIndexUpdateMetadataAndIndex(int streamIndex, std::tuple frameIndex){ + if (readFrameIndex_) { + return; + } + auto& all_frames = std::get<0>(frameIndex); + auto& key_frames = std::get<1>(frameIndex); + auto& duration = std::get<2>(frameIndex); + + auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; + + // Get the last index for key_frames and duration + auto last_idx = all_frames.size(0) - 1; + streamMetadata.beginStreamPtsFromContent = all_frames[0].item(); + streamMetadata.endStreamPtsFromContent = all_frames[last_idx].item() + duration[last_idx].item(); + + auto avStream = formatContext_->streams[streamIndex]; + streamMetadata.beginStreamPtsSecondsFromContent = + *streamMetadata.beginStreamPtsFromContent * + av_q2d(avStream->time_base); + + streamMetadata.endStreamPtsSecondsFromContent = + *streamMetadata.endStreamPtsFromContent * av_q2d(avStream->time_base); + + streamMetadata.numFramesFromContent = all_frames.size(0); + for (int64_t i = 0; i < all_frames.size(0); ++i) { + // FrameInfo struct utilizes PTS + FrameInfo frameInfo = {all_frames[i].item()}; + frameInfo.isKeyFrame = (i < key_frames.size(0) && key_frames[i].item() == 1); + frameInfo.nextPts = (i + 1 < all_frames.size(0)) ? all_frames[i + 1].item() : INT64_MAX; + streamInfos_[streamIndex].allFrames.push_back(frameInfo); + } + readFrameIndex_ = true; +} + ContainerMetadata SingleStreamDecoder::getContainerMetadata() const { return containerMetadata_; } @@ -431,7 +466,8 @@ void SingleStreamDecoder::addStream( void SingleStreamDecoder::addVideoStream( int streamIndex, - const VideoStreamOptions& videoStreamOptions) { + const VideoStreamOptions& videoStreamOptions, + std::optional> frameIndex) { addStream( streamIndex, AVMEDIA_TYPE_VIDEO, @@ -456,6 +492,11 @@ void SingleStreamDecoder::addVideoStream( streamMetadata.height = streamInfo.codecContext->height; streamMetadata.sampleAspectRatio = streamInfo.codecContext->sample_aspect_ratio; + + if (seekMode_ == SeekMode::frame_index) { + TORCH_CHECK(frameIndex.has_value(), "Please provide a frame index when using frame_index seek mode."); + readFrameIndexUpdateMetadataAndIndex(streamIndex, frameIndex.value()); + } } void SingleStreamDecoder::addAudioStream( @@ -597,7 +638,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange( int64_t stop, int64_t step) { validateActiveStream(AVMEDIA_TYPE_VIDEO); - + const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; const auto& streamInfo = streamInfos_[activeStreamIndex_]; @@ -1407,6 +1448,7 @@ int SingleStreamDecoder::getKeyFrameIndexForPtsUsingScannedIndex( int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) { auto& streamInfo = streamInfos_[activeStreamIndex_]; switch (seekMode_) { + case SeekMode::frame_index: case SeekMode::exact: { auto frame = std::lower_bound( streamInfo.allFrames.begin(), @@ -1434,6 +1476,7 @@ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) { int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) { auto& streamInfo = streamInfos_[activeStreamIndex_]; switch (seekMode_) { + case SeekMode::frame_index: case SeekMode::exact: { auto frame = std::upper_bound( streamInfo.allFrames.begin(), @@ -1444,7 +1487,7 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) { }); return frame - streamInfo.allFrames.begin(); - } + } case SeekMode::approximate: { auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; @@ -1461,6 +1504,7 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) { int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { auto& streamInfo = streamInfos_[activeStreamIndex_]; switch (seekMode_) { + case SeekMode::frame_index: case SeekMode::exact: return streamInfo.allFrames[frameIndex].pts; case SeekMode::approximate: { @@ -1485,6 +1529,7 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { std::optional SingleStreamDecoder::getNumFrames( const StreamMetadata& streamMetadata) { switch (seekMode_) { + case SeekMode::frame_index: case SeekMode::exact: return streamMetadata.numFramesFromContent.value(); case SeekMode::approximate: { @@ -1498,6 +1543,7 @@ std::optional SingleStreamDecoder::getNumFrames( double SingleStreamDecoder::getMinSeconds( const StreamMetadata& streamMetadata) { switch (seekMode_) { + case SeekMode::frame_index: case SeekMode::exact: return streamMetadata.beginStreamPtsSecondsFromContent.value(); case SeekMode::approximate: @@ -1510,6 +1556,7 @@ double SingleStreamDecoder::getMinSeconds( std::optional SingleStreamDecoder::getMaxSeconds( const StreamMetadata& streamMetadata) { switch (seekMode_) { + case SeekMode::frame_index: case SeekMode::exact: return streamMetadata.endStreamPtsSecondsFromContent.value(); case SeekMode::approximate: { @@ -1645,6 +1692,8 @@ SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) { return SingleStreamDecoder::SeekMode::exact; } else if (seekMode == "approximate") { return SingleStreamDecoder::SeekMode::approximate; + } else if (seekMode == "frame_index") { + return SingleStreamDecoder::SeekMode::frame_index; } else { TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode)); } diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index dec102d1..58e0d2a5 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -29,7 +29,7 @@ class SingleStreamDecoder { // CONSTRUCTION API // -------------------------------------------------------------------------- - enum class SeekMode { exact, approximate }; + enum class SeekMode { exact, approximate, frame_index }; // Creates a SingleStreamDecoder from the video at videoFilePath. explicit SingleStreamDecoder( @@ -53,6 +53,10 @@ class SingleStreamDecoder { // the allFrames and keyFrames vectors. void scanFileAndUpdateMetadataAndIndex(); + // Reads the user provided frame index and updates each StreamInfo's index, i.e. + // the allFrames and keyFrames vectors, and the endStreamPtsSecondsFromContent + void readFrameIndexUpdateMetadataAndIndex(int streamIndex, std::tuple frameIndex); + // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; @@ -66,7 +70,8 @@ class SingleStreamDecoder { void addVideoStream( int streamIndex, - const VideoStreamOptions& videoStreamOptions = VideoStreamOptions()); + const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(), + std::optional> frameIndex = std::nullopt); void addAudioStream( int streamIndex, const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); @@ -343,6 +348,8 @@ class SingleStreamDecoder { bool scannedAllStreams_ = false; // Tracks that we've already been initialized. bool initialized_ = false; + // Tracks that frame index has been ingested + bool readFrameIndex_ = false; }; // Prints the SingleStreamDecoder::DecodeStats to the ostream. diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 4aa68a3b..c6475ed5 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -36,9 +36,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def("_convert_to_tensor(int decoder_ptr) -> Tensor"); m.def( - "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? color_conversion_library=None) -> ()"); + "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? color_conversion_library=None, (Tensor, Tensor, Tensor)? frame_index=None) -> ()"); m.def( - "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None) -> ()"); + "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? color_conversion_library=None, (Tensor, Tensor, Tensor)? frame_index=None) -> ()"); m.def( "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()"); m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()"); @@ -180,7 +180,7 @@ at::Tensor create_from_file( if (seek_mode.has_value()) { realSeek = seekModeFromString(seek_mode.value()); } - + std::unique_ptr uniqueDecoder = std::make_unique(filenameStr, realSeek); @@ -223,7 +223,8 @@ void _add_video_stream( std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, std::optional device = std::nullopt, - std::optional color_conversion_library = std::nullopt) { + std::optional color_conversion_library = std::nullopt, + std::optional> frame_index = std::nullopt) { VideoStreamOptions videoStreamOptions; videoStreamOptions.width = width; videoStreamOptions.height = height; @@ -255,7 +256,7 @@ void _add_video_stream( } auto videoDecoder = unwrapTensorToGetDecoder(decoder); - videoDecoder->addVideoStream(stream_index.value_or(-1), videoStreamOptions); + videoDecoder->addVideoStream(stream_index.value_or(-1), videoStreamOptions, frame_index); } // Add a new video stream at `stream_index` using the provided options. @@ -266,7 +267,9 @@ void add_video_stream( std::optional num_threads = std::nullopt, std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, - std::optional device = std::nullopt) { + std::optional device = std::nullopt, + std::optional color_conversion_library = std::nullopt, + std::optional> frame_index = std::nullopt) { _add_video_stream( decoder, width, @@ -274,7 +277,9 @@ void add_video_stream( num_threads, dimension_order, stream_index, - device); + device, + color_conversion_library, + frame_index); } void add_audio_stream( diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index a68b51e2..a1189d71 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -206,6 +206,7 @@ def _add_video_stream_abstract( stream_index: Optional[int] = None, device: Optional[str] = None, color_conversion_library: Optional[str] = None, + frame_index: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None ) -> None: return @@ -220,6 +221,8 @@ def add_video_stream_abstract( dimension_order: Optional[str] = None, stream_index: Optional[int] = None, device: Optional[str] = None, + color_conversion_library: Optional[str] = None, + frame_index: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None ) -> None: return From 4dfc581de3d621e157b6cfa91705f53341eedb52 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Thu, 10 Jul 2025 00:06:31 -0400 Subject: [PATCH 02/26] add get_frame tests for new seek mode --- test/test_metadata.py | 27 ++++++++++++--------------- test/test_ops.py | 34 ++++++++++++++++++++++++++++++++++ test/utils.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 15 deletions(-) diff --git a/test/test_metadata.py b/test/test_metadata.py index 3de7d377..196c24c0 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -32,22 +32,19 @@ def _get_container_metadata(path, seek_mode): @pytest.mark.parametrize( - "metadata_getter", - ( - get_container_metadata_from_header, - functools.partial(_get_container_metadata, seek_mode="approximate"), - functools.partial(_get_container_metadata, seek_mode="exact"), - ), + "seek_mode", + ["approximate", "exact", "frame_index"] ) -def test_get_metadata(metadata_getter): - with_scan = ( - metadata_getter.keywords["seek_mode"] == "exact" - if isinstance(metadata_getter, functools.partial) - else False - ) +def test_get_metadata(seek_mode): + from torchcodec._core import add_video_stream + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode=seek_mode) + # For frame_index seek mode, add a video stream to update metadata + frame_index = NASA_VIDEO.frame_index if seek_mode == "frame_index" else None + # Add the best video stream (index 3 for NASA_VIDEO) + add_video_stream(decoder, stream_index=NASA_VIDEO.default_stream_index, frame_index=frame_index) + metadata = get_container_metadata(decoder) - metadata = metadata_getter(NASA_VIDEO.path) - # metadata = metadata_getter(NASA_VIDEO.path) + with_scan = seek_mode == "exact" or seek_mode == "frame_index" assert len(metadata.streams) == 6 assert metadata.best_video_stream_index == 3 @@ -82,7 +79,7 @@ def test_get_metadata(metadata_getter): assert best_video_stream_metadata.begin_stream_seconds_from_header == 0 assert best_video_stream_metadata.bit_rate == 128783 assert best_video_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001) - assert best_video_stream_metadata.pixel_aspect_ratio is None + assert best_video_stream_metadata.pixel_aspect_ratio == Fraction(1, 1) assert best_video_stream_metadata.codec == "h264" assert best_video_stream_metadata.num_frames_from_content == ( 390 if with_scan else None diff --git a/test/test_ops.py b/test/test_ops.py index 2f691615..a4094ddd 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -448,6 +448,40 @@ def test_frame_pts_equality(self): ) assert pts_is_equal + def test_seek_mode_frame_index_fails(self): + decoder = create_from_file(str(NASA_VIDEO.path), "frame_index") + with pytest.raises( + RuntimeError, + match="Please provide a frame index when using frame_index seek mode.", + ): + add_video_stream(decoder, stream_index=0, frame_index=None) + + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_seek_mode_frame_index(self, device): + stream_index = 3 # frame index seek mode requires a stream index + decoder = create_from_file(str(NASA_VIDEO.path), "frame_index") + add_video_stream(decoder, device=device, stream_index=stream_index, frame_index=NASA_VIDEO.frame_index) + + frame0, _, _ = get_next_frame(decoder) + reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0, stream_index=stream_index) + assert_frames_equal(frame0, reference_frame0.to(device)) + + frame6, _, _ = get_frame_at_pts(decoder, 6.006) + reference_frame6 = NASA_VIDEO.get_frame_data_by_index( + INDEX_OF_FRAME_AT_6_SECONDS, stream_index=stream_index + ) + assert_frames_equal(frame6, reference_frame6.to(device)) + + frame6, _, _ = get_frame_at_index(decoder, frame_index=180) + reference_frame6 = NASA_VIDEO.get_frame_data_by_index( + INDEX_OF_FRAME_AT_6_SECONDS, stream_index=stream_index + ) + assert_frames_equal(frame6, reference_frame6.to(device)) + + ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range(0, 9) + bulk_frames0_9, *_ = get_frames_in_range(decoder, start=0, stop=9) + assert_frames_equal(bulk_frames0_9, ref_frames0_9.to(device)) + @pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale")) def test_color_conversion_library(self, color_conversion_library): decoder = create_from_file(str(NASA_VIDEO.path)) diff --git a/test/utils.py b/test/utils.py index e7ce12e5..3b1cb031 100644 --- a/test/utils.py +++ b/test/utils.py @@ -2,6 +2,7 @@ import json import os import pathlib +import subprocess import sys from dataclasses import dataclass, field @@ -121,6 +122,7 @@ class TestContainerFile: default_stream_index: int stream_infos: Dict[int, Union[TestVideoStreamInfo, TestAudioStreamInfo]] frames: Dict[int, Dict[int, TestFrameInfo]] + frame_index_data: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None def __post_init__(self): # We load the .frames attribute from the checked-in json files, if needed. @@ -222,6 +224,37 @@ def get_frame_info( stream_index = self.default_stream_index return self.frames[stream_index][idx] + + # This property is used to get the frame index metadata for the frame_index seek mode. + @property + def frame_index(self, stream_index: Optional[int] = None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if stream_index is None: + stream_index = self.default_stream_index + if self.frame_index_data is None: + self.get_frame_index(stream_index) + return self.frame_index_data + + def get_frame_index(self, stream_index: int) -> None: + show_frames_result = json.loads( + subprocess.run( + ["ffprobe", "-i", f"{self.path}","-select_streams",f"{stream_index}","-show_frames","-of","json"], + check=True, capture_output=True, text=True, + ).stdout + ) + frame_index_data = ([], [], []) + frames = show_frames_result["frames"] + for frame in frames: + frame_index_data[0].append(float(frame["pts"])) + frame_index_data[1].append(frame["key_frame"]) + frame_index_data[2].append(float(frame["duration"])) + + (pts_list, key_frame_list, duration_list) = frame_index_data + # Zip the lists together, sort by pts, then unzip + assert len(pts_list) == len(key_frame_list) == len(duration_list), "Mismatched lengths in frame index data" + combined = list(zip(pts_list, key_frame_list, duration_list)) + combined.sort(key=lambda x: x[0]) + pts_sorted, key_frame_sorted, duration_sorted = zip(*combined) + self.frame_index_data = (torch.tensor(pts_sorted), torch.tensor(key_frame_sorted), torch.tensor(duration_sorted)) @property def empty_pts_seconds(self) -> torch.Tensor: From ed3fdec4467e7ee400b28f19fb147af6b3ab8678 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Thu, 10 Jul 2025 00:36:27 -0400 Subject: [PATCH 03/26] lints --- src/torchcodec/_core/SingleStreamDecoder.cpp | 28 ++++++++++------ src/torchcodec/_core/SingleStreamDecoder.h | 12 ++++--- src/torchcodec/_core/custom_ops.cpp | 11 +++--- src/torchcodec/_core/ops.py | 4 +-- test/test_metadata.py | 10 +++--- test/test_ops.py | 17 +++++++--- test/utils.py | 35 +++++++++++++++----- 7 files changed, 78 insertions(+), 39 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 69da51f3..bcd979a0 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -319,8 +319,9 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { scannedAllStreams_ = true; } - -void SingleStreamDecoder::readFrameIndexUpdateMetadataAndIndex(int streamIndex, std::tuple frameIndex){ +void SingleStreamDecoder::readFrameIndexUpdateMetadataAndIndex( + int streamIndex, + std::tuple frameIndex) { if (readFrameIndex_) { return; } @@ -333,12 +334,12 @@ void SingleStreamDecoder::readFrameIndexUpdateMetadataAndIndex(int streamIndex, // Get the last index for key_frames and duration auto last_idx = all_frames.size(0) - 1; streamMetadata.beginStreamPtsFromContent = all_frames[0].item(); - streamMetadata.endStreamPtsFromContent = all_frames[last_idx].item() + duration[last_idx].item(); + streamMetadata.endStreamPtsFromContent = + all_frames[last_idx].item() + duration[last_idx].item(); auto avStream = formatContext_->streams[streamIndex]; streamMetadata.beginStreamPtsSecondsFromContent = - *streamMetadata.beginStreamPtsFromContent * - av_q2d(avStream->time_base); + *streamMetadata.beginStreamPtsFromContent * av_q2d(avStream->time_base); streamMetadata.endStreamPtsSecondsFromContent = *streamMetadata.endStreamPtsFromContent * av_q2d(avStream->time_base); @@ -347,8 +348,11 @@ void SingleStreamDecoder::readFrameIndexUpdateMetadataAndIndex(int streamIndex, for (int64_t i = 0; i < all_frames.size(0); ++i) { // FrameInfo struct utilizes PTS FrameInfo frameInfo = {all_frames[i].item()}; - frameInfo.isKeyFrame = (i < key_frames.size(0) && key_frames[i].item() == 1); - frameInfo.nextPts = (i + 1 < all_frames.size(0)) ? all_frames[i + 1].item() : INT64_MAX; + frameInfo.isKeyFrame = + (i < key_frames.size(0) && key_frames[i].item() == 1); + frameInfo.nextPts = (i + 1 < all_frames.size(0)) + ? all_frames[i + 1].item() + : INT64_MAX; streamInfos_[streamIndex].allFrames.push_back(frameInfo); } readFrameIndex_ = true; @@ -466,7 +470,7 @@ void SingleStreamDecoder::addStream( void SingleStreamDecoder::addVideoStream( int streamIndex, - const VideoStreamOptions& videoStreamOptions, + const VideoStreamOptions& videoStreamOptions, std::optional> frameIndex) { addStream( streamIndex, @@ -494,7 +498,9 @@ void SingleStreamDecoder::addVideoStream( streamInfo.codecContext->sample_aspect_ratio; if (seekMode_ == SeekMode::frame_index) { - TORCH_CHECK(frameIndex.has_value(), "Please provide a frame index when using frame_index seek mode."); + TORCH_CHECK( + frameIndex.has_value(), + "Please provide a frame index when using frame_index seek mode."); readFrameIndexUpdateMetadataAndIndex(streamIndex, frameIndex.value()); } } @@ -638,7 +644,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange( int64_t stop, int64_t step) { validateActiveStream(AVMEDIA_TYPE_VIDEO); - + const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; const auto& streamInfo = streamInfos_[activeStreamIndex_]; @@ -1487,7 +1493,7 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) { }); return frame - streamInfo.allFrames.begin(); - } + } case SeekMode::approximate: { auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 58e0d2a5..c814f89e 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -53,9 +53,12 @@ class SingleStreamDecoder { // the allFrames and keyFrames vectors. void scanFileAndUpdateMetadataAndIndex(); - // Reads the user provided frame index and updates each StreamInfo's index, i.e. - // the allFrames and keyFrames vectors, and the endStreamPtsSecondsFromContent - void readFrameIndexUpdateMetadataAndIndex(int streamIndex, std::tuple frameIndex); + // Reads the user provided frame index and updates each StreamInfo's index, + // i.e. the allFrames and keyFrames vectors, and the + // endStreamPtsSecondsFromContent + void readFrameIndexUpdateMetadataAndIndex( + int streamIndex, + std::tuple frameIndex); // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; @@ -71,7 +74,8 @@ class SingleStreamDecoder { void addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(), - std::optional> frameIndex = std::nullopt); + std::optional> frameIndex = + std::nullopt); void addAudioStream( int streamIndex, const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index c6475ed5..7af770bc 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -180,7 +180,7 @@ at::Tensor create_from_file( if (seek_mode.has_value()) { realSeek = seekModeFromString(seek_mode.value()); } - + std::unique_ptr uniqueDecoder = std::make_unique(filenameStr, realSeek); @@ -224,7 +224,8 @@ void _add_video_stream( std::optional stream_index = std::nullopt, std::optional device = std::nullopt, std::optional color_conversion_library = std::nullopt, - std::optional> frame_index = std::nullopt) { + std::optional> frame_index = + std::nullopt) { VideoStreamOptions videoStreamOptions; videoStreamOptions.width = width; videoStreamOptions.height = height; @@ -256,7 +257,8 @@ void _add_video_stream( } auto videoDecoder = unwrapTensorToGetDecoder(decoder); - videoDecoder->addVideoStream(stream_index.value_or(-1), videoStreamOptions, frame_index); + videoDecoder->addVideoStream( + stream_index.value_or(-1), videoStreamOptions, frame_index); } // Add a new video stream at `stream_index` using the provided options. @@ -269,7 +271,8 @@ void add_video_stream( std::optional stream_index = std::nullopt, std::optional device = std::nullopt, std::optional color_conversion_library = std::nullopt, - std::optional> frame_index = std::nullopt) { + std::optional> frame_index = + std::nullopt) { _add_video_stream( decoder, width, diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index a1189d71..7041fa49 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -206,7 +206,7 @@ def _add_video_stream_abstract( stream_index: Optional[int] = None, device: Optional[str] = None, color_conversion_library: Optional[str] = None, - frame_index: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None + frame_index: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, ) -> None: return @@ -222,7 +222,7 @@ def add_video_stream_abstract( stream_index: Optional[int] = None, device: Optional[str] = None, color_conversion_library: Optional[str] = None, - frame_index: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None + frame_index: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, ) -> None: return diff --git a/test/test_metadata.py b/test/test_metadata.py index 196c24c0..56bae1d7 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -31,17 +31,17 @@ def _get_container_metadata(path, seek_mode): return get_container_metadata(decoder) -@pytest.mark.parametrize( - "seek_mode", - ["approximate", "exact", "frame_index"] -) +@pytest.mark.parametrize("seek_mode", ["approximate", "exact", "frame_index"]) def test_get_metadata(seek_mode): from torchcodec._core import add_video_stream + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode=seek_mode) # For frame_index seek mode, add a video stream to update metadata frame_index = NASA_VIDEO.frame_index if seek_mode == "frame_index" else None # Add the best video stream (index 3 for NASA_VIDEO) - add_video_stream(decoder, stream_index=NASA_VIDEO.default_stream_index, frame_index=frame_index) + add_video_stream( + decoder, stream_index=NASA_VIDEO.default_stream_index, frame_index=frame_index + ) metadata = get_container_metadata(decoder) with_scan = seek_mode == "exact" or seek_mode == "frame_index" diff --git a/test/test_ops.py b/test/test_ops.py index a4094ddd..4370c7d2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -458,12 +458,19 @@ def test_seek_mode_frame_index_fails(self): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_mode_frame_index(self, device): - stream_index = 3 # frame index seek mode requires a stream index + stream_index = 3 # frame index seek mode requires a stream index decoder = create_from_file(str(NASA_VIDEO.path), "frame_index") - add_video_stream(decoder, device=device, stream_index=stream_index, frame_index=NASA_VIDEO.frame_index) - + add_video_stream( + decoder, + device=device, + stream_index=stream_index, + frame_index=NASA_VIDEO.frame_index, + ) + frame0, _, _ = get_next_frame(decoder) - reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0, stream_index=stream_index) + reference_frame0 = NASA_VIDEO.get_frame_data_by_index( + 0, stream_index=stream_index + ) assert_frames_equal(frame0, reference_frame0.to(device)) frame6, _, _ = get_frame_at_pts(decoder, 6.006) @@ -481,7 +488,7 @@ def test_seek_mode_frame_index(self, device): ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range(0, 9) bulk_frames0_9, *_ = get_frames_in_range(decoder, start=0, stop=9) assert_frames_equal(bulk_frames0_9, ref_frames0_9.to(device)) - + @pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale")) def test_color_conversion_library(self, color_conversion_library): decoder = create_from_file(str(NASA_VIDEO.path)) diff --git a/test/utils.py b/test/utils.py index 3b1cb031..9ccf0570 100644 --- a/test/utils.py +++ b/test/utils.py @@ -224,21 +224,34 @@ def get_frame_info( stream_index = self.default_stream_index return self.frames[stream_index][idx] - + # This property is used to get the frame index metadata for the frame_index seek mode. - @property - def frame_index(self, stream_index: Optional[int] = None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + @property + def frame_index( + self, stream_index: Optional[int] = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if stream_index is None: stream_index = self.default_stream_index if self.frame_index_data is None: self.get_frame_index(stream_index) return self.frame_index_data - + def get_frame_index(self, stream_index: int) -> None: show_frames_result = json.loads( subprocess.run( - ["ffprobe", "-i", f"{self.path}","-select_streams",f"{stream_index}","-show_frames","-of","json"], - check=True, capture_output=True, text=True, + [ + "ffprobe", + "-i", + f"{self.path}", + "-select_streams", + f"{stream_index}", + "-show_frames", + "-of", + "json", + ], + check=True, + capture_output=True, + text=True, ).stdout ) frame_index_data = ([], [], []) @@ -250,11 +263,17 @@ def get_frame_index(self, stream_index: int) -> None: (pts_list, key_frame_list, duration_list) = frame_index_data # Zip the lists together, sort by pts, then unzip - assert len(pts_list) == len(key_frame_list) == len(duration_list), "Mismatched lengths in frame index data" + assert ( + len(pts_list) == len(key_frame_list) == len(duration_list) + ), "Mismatched lengths in frame index data" combined = list(zip(pts_list, key_frame_list, duration_list)) combined.sort(key=lambda x: x[0]) pts_sorted, key_frame_sorted, duration_sorted = zip(*combined) - self.frame_index_data = (torch.tensor(pts_sorted), torch.tensor(key_frame_sorted), torch.tensor(duration_sorted)) + self.frame_index_data = ( + torch.tensor(pts_sorted), + torch.tensor(key_frame_sorted), + torch.tensor(duration_sorted), + ) @property def empty_pts_seconds(self) -> torch.Tensor: From 6030f9e4c9f07cee6e67a445383d008c2ac6f08c Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Thu, 10 Jul 2025 08:52:23 -0400 Subject: [PATCH 04/26] replace union syntax with optional --- test/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.py b/test/utils.py index 9ccf0570..1fc4332e 100644 --- a/test/utils.py +++ b/test/utils.py @@ -122,7 +122,7 @@ class TestContainerFile: default_stream_index: int stream_infos: Dict[int, Union[TestVideoStreamInfo, TestAudioStreamInfo]] frames: Dict[int, Dict[int, TestFrameInfo]] - frame_index_data: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None + frame_index_data: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None def __post_init__(self): # We load the .frames attribute from the checked-in json files, if needed. From 1361c0d670094d6514ab86e4b53f80812adda597 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Thu, 10 Jul 2025 11:08:15 -0400 Subject: [PATCH 05/26] remove color_conversion changes --- src/torchcodec/_core/custom_ops.cpp | 10 ++++------ src/torchcodec/_core/ops.py | 3 +-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 7af770bc..01db3994 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -36,9 +36,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def("_convert_to_tensor(int decoder_ptr) -> Tensor"); m.def( - "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? color_conversion_library=None, (Tensor, Tensor, Tensor)? frame_index=None) -> ()"); + "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? frame_index=None, str? color_conversion_library=None) -> ()"); m.def( - "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? color_conversion_library=None, (Tensor, Tensor, Tensor)? frame_index=None) -> ()"); + "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? frame_index=None) -> ()"); m.def( "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()"); m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()"); @@ -223,9 +223,9 @@ void _add_video_stream( std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, std::optional device = std::nullopt, - std::optional color_conversion_library = std::nullopt, std::optional> frame_index = - std::nullopt) { + std::nullopt, + std::optional color_conversion_library = std::nullopt) { VideoStreamOptions videoStreamOptions; videoStreamOptions.width = width; videoStreamOptions.height = height; @@ -270,7 +270,6 @@ void add_video_stream( std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, std::optional device = std::nullopt, - std::optional color_conversion_library = std::nullopt, std::optional> frame_index = std::nullopt) { _add_video_stream( @@ -281,7 +280,6 @@ void add_video_stream( dimension_order, stream_index, device, - color_conversion_library, frame_index); } diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 7041fa49..82a7d1ef 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -205,8 +205,8 @@ def _add_video_stream_abstract( dimension_order: Optional[str] = None, stream_index: Optional[int] = None, device: Optional[str] = None, - color_conversion_library: Optional[str] = None, frame_index: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, + color_conversion_library: Optional[str] = None, ) -> None: return @@ -221,7 +221,6 @@ def add_video_stream_abstract( dimension_order: Optional[str] = None, stream_index: Optional[int] = None, device: Optional[str] = None, - color_conversion_library: Optional[str] = None, frame_index: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, ) -> None: return From 5bb23c20bc5e6ad498f25ae40014610e92c0302b Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Thu, 10 Jul 2025 11:42:35 -0400 Subject: [PATCH 06/26] rename is_key_frame, remove readFrameIndex_, check tensor lengths are equal --- src/torchcodec/_core/SingleStreamDecoder.cpp | 15 ++++++--------- src/torchcodec/_core/SingleStreamDecoder.h | 2 -- test/test_ops.py | 8 ++++++++ 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index bcd979a0..d65baaf4 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -322,20 +322,18 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { void SingleStreamDecoder::readFrameIndexUpdateMetadataAndIndex( int streamIndex, std::tuple frameIndex) { - if (readFrameIndex_) { - return; - } auto& all_frames = std::get<0>(frameIndex); - auto& key_frames = std::get<1>(frameIndex); + auto& is_key_frame = std::get<1>(frameIndex); auto& duration = std::get<2>(frameIndex); + TORCH_CHECK( + all_frames.size(0) == is_key_frame.size(0) && is_key_frame.size(0) == duration.size(0), + "all_frames, is_key_frame, and duration from custom_frame_mappings were not same size."); auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; - // Get the last index for key_frames and duration - auto last_idx = all_frames.size(0) - 1; streamMetadata.beginStreamPtsFromContent = all_frames[0].item(); streamMetadata.endStreamPtsFromContent = - all_frames[last_idx].item() + duration[last_idx].item(); + all_frames[-1].item() + duration[-1].item(); auto avStream = formatContext_->streams[streamIndex]; streamMetadata.beginStreamPtsSecondsFromContent = @@ -349,13 +347,12 @@ void SingleStreamDecoder::readFrameIndexUpdateMetadataAndIndex( // FrameInfo struct utilizes PTS FrameInfo frameInfo = {all_frames[i].item()}; frameInfo.isKeyFrame = - (i < key_frames.size(0) && key_frames[i].item() == 1); + (is_key_frame[i].item() == 1); frameInfo.nextPts = (i + 1 < all_frames.size(0)) ? all_frames[i + 1].item() : INT64_MAX; streamInfos_[streamIndex].allFrames.push_back(frameInfo); } - readFrameIndex_ = true; } ContainerMetadata SingleStreamDecoder::getContainerMetadata() const { diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index c814f89e..03118c40 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -352,8 +352,6 @@ class SingleStreamDecoder { bool scannedAllStreams_ = false; // Tracks that we've already been initialized. bool initialized_ = false; - // Tracks that frame index has been ingested - bool readFrameIndex_ = false; }; // Prints the SingleStreamDecoder::DecodeStats to the ostream. diff --git a/test/test_ops.py b/test/test_ops.py index 4370c7d2..64e40de7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -456,6 +456,14 @@ def test_seek_mode_frame_index_fails(self): ): add_video_stream(decoder, stream_index=0, frame_index=None) + decoder = create_from_file(str(NASA_VIDEO.path), "frame_index") + different_lengths = ((torch.tensor([1, 2, 3]), torch.tensor([1, 2]), torch.tensor([1, 2, 3]))) + with pytest.raises( + RuntimeError, + match="all_frames, is_key_frame, and duration from custom_frame_mappings were not same size.", + ): + add_video_stream(decoder, stream_index=0, frame_index=different_lengths) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_mode_frame_index(self, device): stream_index = 3 # frame index seek mode requires a stream index From 657394360337a49d1f9223a0be864cac9ac29df2 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Thu, 10 Jul 2025 12:12:17 -0400 Subject: [PATCH 07/26] rename frame_index to custom_frame_mappings --- src/torchcodec/_core/SingleStreamDecoder.cpp | 36 ++++++++++---------- src/torchcodec/_core/SingleStreamDecoder.h | 8 ++--- src/torchcodec/_core/custom_ops.cpp | 12 +++---- src/torchcodec/_core/ops.py | 4 +-- test/test_metadata.py | 10 +++--- test/test_ops.py | 18 +++++----- test/utils.py | 30 ++++++++-------- 7 files changed, 59 insertions(+), 59 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index d65baaf4..1b45347c 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -319,12 +319,12 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { scannedAllStreams_ = true; } -void SingleStreamDecoder::readFrameIndexUpdateMetadataAndIndex( +void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( int streamIndex, - std::tuple frameIndex) { - auto& all_frames = std::get<0>(frameIndex); - auto& is_key_frame = std::get<1>(frameIndex); - auto& duration = std::get<2>(frameIndex); + std::tuple customFrameMappings) { + auto& all_frames = std::get<0>(customFrameMappings); + auto& is_key_frame = std::get<1>(customFrameMappings); + auto& duration = std::get<2>(customFrameMappings); TORCH_CHECK( all_frames.size(0) == is_key_frame.size(0) && is_key_frame.size(0) == duration.size(0), "all_frames, is_key_frame, and duration from custom_frame_mappings were not same size."); @@ -468,7 +468,7 @@ void SingleStreamDecoder::addStream( void SingleStreamDecoder::addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions, - std::optional> frameIndex) { + std::optional> customFrameMappings) { addStream( streamIndex, AVMEDIA_TYPE_VIDEO, @@ -494,11 +494,11 @@ void SingleStreamDecoder::addVideoStream( streamMetadata.sampleAspectRatio = streamInfo.codecContext->sample_aspect_ratio; - if (seekMode_ == SeekMode::frame_index) { + if (seekMode_ == SeekMode::custom_frame_mappings) { TORCH_CHECK( - frameIndex.has_value(), - "Please provide a frame index when using frame_index seek mode."); - readFrameIndexUpdateMetadataAndIndex(streamIndex, frameIndex.value()); + customFrameMappings.has_value(), + "Please provide frame mappings when using custom_frame_mappings seek mode."); + readCustomFrameMappingsUpdateMetadataAndIndex(streamIndex, customFrameMappings.value()); } } @@ -1451,7 +1451,7 @@ int SingleStreamDecoder::getKeyFrameIndexForPtsUsingScannedIndex( int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) { auto& streamInfo = streamInfos_[activeStreamIndex_]; switch (seekMode_) { - case SeekMode::frame_index: + case SeekMode::custom_frame_mappings: case SeekMode::exact: { auto frame = std::lower_bound( streamInfo.allFrames.begin(), @@ -1479,7 +1479,7 @@ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) { int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) { auto& streamInfo = streamInfos_[activeStreamIndex_]; switch (seekMode_) { - case SeekMode::frame_index: + case SeekMode::custom_frame_mappings: case SeekMode::exact: { auto frame = std::upper_bound( streamInfo.allFrames.begin(), @@ -1507,7 +1507,7 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) { int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { auto& streamInfo = streamInfos_[activeStreamIndex_]; switch (seekMode_) { - case SeekMode::frame_index: + case SeekMode::custom_frame_mappings: case SeekMode::exact: return streamInfo.allFrames[frameIndex].pts; case SeekMode::approximate: { @@ -1532,7 +1532,7 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { std::optional SingleStreamDecoder::getNumFrames( const StreamMetadata& streamMetadata) { switch (seekMode_) { - case SeekMode::frame_index: + case SeekMode::custom_frame_mappings: case SeekMode::exact: return streamMetadata.numFramesFromContent.value(); case SeekMode::approximate: { @@ -1546,7 +1546,7 @@ std::optional SingleStreamDecoder::getNumFrames( double SingleStreamDecoder::getMinSeconds( const StreamMetadata& streamMetadata) { switch (seekMode_) { - case SeekMode::frame_index: + case SeekMode::custom_frame_mappings: case SeekMode::exact: return streamMetadata.beginStreamPtsSecondsFromContent.value(); case SeekMode::approximate: @@ -1559,7 +1559,7 @@ double SingleStreamDecoder::getMinSeconds( std::optional SingleStreamDecoder::getMaxSeconds( const StreamMetadata& streamMetadata) { switch (seekMode_) { - case SeekMode::frame_index: + case SeekMode::custom_frame_mappings: case SeekMode::exact: return streamMetadata.endStreamPtsSecondsFromContent.value(); case SeekMode::approximate: { @@ -1695,8 +1695,8 @@ SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) { return SingleStreamDecoder::SeekMode::exact; } else if (seekMode == "approximate") { return SingleStreamDecoder::SeekMode::approximate; - } else if (seekMode == "frame_index") { - return SingleStreamDecoder::SeekMode::frame_index; + } else if (seekMode == "custom_frame_mappings") { + return SingleStreamDecoder::SeekMode::custom_frame_mappings; } else { TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode)); } diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 03118c40..849d2484 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -29,7 +29,7 @@ class SingleStreamDecoder { // CONSTRUCTION API // -------------------------------------------------------------------------- - enum class SeekMode { exact, approximate, frame_index }; + enum class SeekMode { exact, approximate, custom_frame_mappings }; // Creates a SingleStreamDecoder from the video at videoFilePath. explicit SingleStreamDecoder( @@ -56,9 +56,9 @@ class SingleStreamDecoder { // Reads the user provided frame index and updates each StreamInfo's index, // i.e. the allFrames and keyFrames vectors, and the // endStreamPtsSecondsFromContent - void readFrameIndexUpdateMetadataAndIndex( + void readCustomFrameMappingsUpdateMetadataAndIndex( int streamIndex, - std::tuple frameIndex); + std::tuple customFrameMappings); // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; @@ -74,7 +74,7 @@ class SingleStreamDecoder { void addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(), - std::optional> frameIndex = + std::optional> customFrameMappings = std::nullopt); void addAudioStream( int streamIndex, diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 01db3994..39b0a4fc 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -36,9 +36,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def("_convert_to_tensor(int decoder_ptr) -> Tensor"); m.def( - "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? frame_index=None, str? color_conversion_library=None) -> ()"); + "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()"); m.def( - "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? frame_index=None) -> ()"); + "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()"); m.def( "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()"); m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()"); @@ -223,7 +223,7 @@ void _add_video_stream( std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, std::optional device = std::nullopt, - std::optional> frame_index = + std::optional> custom_frame_mappings = std::nullopt, std::optional color_conversion_library = std::nullopt) { VideoStreamOptions videoStreamOptions; @@ -258,7 +258,7 @@ void _add_video_stream( auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->addVideoStream( - stream_index.value_or(-1), videoStreamOptions, frame_index); + stream_index.value_or(-1), videoStreamOptions, custom_frame_mappings); } // Add a new video stream at `stream_index` using the provided options. @@ -270,7 +270,7 @@ void add_video_stream( std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, std::optional device = std::nullopt, - std::optional> frame_index = + std::optional> custom_frame_mappings = std::nullopt) { _add_video_stream( decoder, @@ -280,7 +280,7 @@ void add_video_stream( dimension_order, stream_index, device, - frame_index); + custom_frame_mappings); } void add_audio_stream( diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 82a7d1ef..3daa4370 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -205,7 +205,7 @@ def _add_video_stream_abstract( dimension_order: Optional[str] = None, stream_index: Optional[int] = None, device: Optional[str] = None, - frame_index: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, + custom_frame_mappings: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, color_conversion_library: Optional[str] = None, ) -> None: return @@ -221,7 +221,7 @@ def add_video_stream_abstract( dimension_order: Optional[str] = None, stream_index: Optional[int] = None, device: Optional[str] = None, - frame_index: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, + custom_frame_mappings: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, ) -> None: return diff --git a/test/test_metadata.py b/test/test_metadata.py index 56bae1d7..86d77612 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -31,20 +31,20 @@ def _get_container_metadata(path, seek_mode): return get_container_metadata(decoder) -@pytest.mark.parametrize("seek_mode", ["approximate", "exact", "frame_index"]) +@pytest.mark.parametrize("seek_mode", ["approximate", "exact", "custom_frame_mappings"]) def test_get_metadata(seek_mode): from torchcodec._core import add_video_stream decoder = create_from_file(str(NASA_VIDEO.path), seek_mode=seek_mode) - # For frame_index seek mode, add a video stream to update metadata - frame_index = NASA_VIDEO.frame_index if seek_mode == "frame_index" else None + # For custom_frame_mappings seek mode, add a video stream to update metadata + custom_frame_mappings = NASA_VIDEO.custom_frame_mappings if seek_mode == "custom_frame_mappings" else None # Add the best video stream (index 3 for NASA_VIDEO) add_video_stream( - decoder, stream_index=NASA_VIDEO.default_stream_index, frame_index=frame_index + decoder, stream_index=NASA_VIDEO.default_stream_index, custom_frame_mappings=custom_frame_mappings ) metadata = get_container_metadata(decoder) - with_scan = seek_mode == "exact" or seek_mode == "frame_index" + with_scan = seek_mode == "exact" or seek_mode == "custom_frame_mappings" assert len(metadata.streams) == 6 assert metadata.best_video_stream_index == 3 diff --git a/test/test_ops.py b/test/test_ops.py index 64e40de7..93946496 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -448,31 +448,31 @@ def test_frame_pts_equality(self): ) assert pts_is_equal - def test_seek_mode_frame_index_fails(self): - decoder = create_from_file(str(NASA_VIDEO.path), "frame_index") + def test_seek_mode_custom_frame_mappings_fails(self): + decoder = create_from_file(str(NASA_VIDEO.path), "custom_frame_mappings") with pytest.raises( RuntimeError, - match="Please provide a frame index when using frame_index seek mode.", + match="Please provide frame mappings when using custom_frame_mappings seek mode.", ): - add_video_stream(decoder, stream_index=0, frame_index=None) + add_video_stream(decoder, stream_index=0, custom_frame_mappings=None) - decoder = create_from_file(str(NASA_VIDEO.path), "frame_index") + decoder = create_from_file(str(NASA_VIDEO.path), "custom_frame_mappings") different_lengths = ((torch.tensor([1, 2, 3]), torch.tensor([1, 2]), torch.tensor([1, 2, 3]))) with pytest.raises( RuntimeError, match="all_frames, is_key_frame, and duration from custom_frame_mappings were not same size.", ): - add_video_stream(decoder, stream_index=0, frame_index=different_lengths) + add_video_stream(decoder, stream_index=0, custom_frame_mappings=different_lengths) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_seek_mode_frame_index(self, device): + def test_seek_mode_custom_frame_mappings(self, device): stream_index = 3 # frame index seek mode requires a stream index - decoder = create_from_file(str(NASA_VIDEO.path), "frame_index") + decoder = create_from_file(str(NASA_VIDEO.path), "custom_frame_mappings") add_video_stream( decoder, device=device, stream_index=stream_index, - frame_index=NASA_VIDEO.frame_index, + custom_frame_mappings=NASA_VIDEO.custom_frame_mappings, ) frame0, _, _ = get_next_frame(decoder) diff --git a/test/utils.py b/test/utils.py index 1fc4332e..7ee679c3 100644 --- a/test/utils.py +++ b/test/utils.py @@ -122,7 +122,7 @@ class TestContainerFile: default_stream_index: int stream_infos: Dict[int, Union[TestVideoStreamInfo, TestAudioStreamInfo]] frames: Dict[int, Dict[int, TestFrameInfo]] - frame_index_data: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None + custom_frame_mappings_data: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None def __post_init__(self): # We load the .frames attribute from the checked-in json files, if needed. @@ -225,18 +225,18 @@ def get_frame_info( return self.frames[stream_index][idx] - # This property is used to get the frame index metadata for the frame_index seek mode. + # This property is used to get the frame mappings for the custom_frame_mappings seek mode. @property - def frame_index( + def custom_frame_mappings( self, stream_index: Optional[int] = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if stream_index is None: stream_index = self.default_stream_index - if self.frame_index_data is None: - self.get_frame_index(stream_index) - return self.frame_index_data + if self.custom_frame_mappings_data is None: + self.get_custom_frame_mappings(stream_index) + return self.custom_frame_mappings_data - def get_frame_index(self, stream_index: int) -> None: + def get_custom_frame_mappings(self, stream_index: int) -> None: show_frames_result = json.loads( subprocess.run( [ @@ -254,24 +254,24 @@ def get_frame_index(self, stream_index: int) -> None: text=True, ).stdout ) - frame_index_data = ([], [], []) + custom_frame_mappings_data = ([], [], []) frames = show_frames_result["frames"] for frame in frames: - frame_index_data[0].append(float(frame["pts"])) - frame_index_data[1].append(frame["key_frame"]) - frame_index_data[2].append(float(frame["duration"])) + custom_frame_mappings_data[0].append(float(frame["pts"])) + custom_frame_mappings_data[1].append(frame["key_frame"]) + custom_frame_mappings_data[2].append(float(frame["duration"])) - (pts_list, key_frame_list, duration_list) = frame_index_data + (pts_list, key_frame_list, duration_list) = custom_frame_mappings_data # Zip the lists together, sort by pts, then unzip assert ( len(pts_list) == len(key_frame_list) == len(duration_list) ), "Mismatched lengths in frame index data" combined = list(zip(pts_list, key_frame_list, duration_list)) combined.sort(key=lambda x: x[0]) - pts_sorted, key_frame_sorted, duration_sorted = zip(*combined) - self.frame_index_data = ( + pts_sorted, is_key_frame_sorted, duration_sorted = zip(*combined) + self.custom_frame_mappings_data = ( torch.tensor(pts_sorted), - torch.tensor(key_frame_sorted), + torch.tensor(is_key_frame_sorted), torch.tensor(duration_sorted), ) From 3341dd835fcbde30e9de46a472c84410ba34f531 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Thu, 10 Jul 2025 12:17:39 -0400 Subject: [PATCH 08/26] lints --- src/torchcodec/_core/SingleStreamDecoder.cpp | 12 +++++++----- src/torchcodec/_core/SingleStreamDecoder.h | 4 ++-- src/torchcodec/_core/custom_ops.cpp | 8 ++++---- src/torchcodec/_core/ops.py | 8 ++++++-- test/test_metadata.py | 10 ++++++++-- test/test_ops.py | 10 ++++++++-- test/utils.py | 4 +++- 7 files changed, 38 insertions(+), 18 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 1b45347c..ef218145 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -326,7 +326,8 @@ void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( auto& is_key_frame = std::get<1>(customFrameMappings); auto& duration = std::get<2>(customFrameMappings); TORCH_CHECK( - all_frames.size(0) == is_key_frame.size(0) && is_key_frame.size(0) == duration.size(0), + all_frames.size(0) == is_key_frame.size(0) && + is_key_frame.size(0) == duration.size(0), "all_frames, is_key_frame, and duration from custom_frame_mappings were not same size."); auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; @@ -346,8 +347,7 @@ void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( for (int64_t i = 0; i < all_frames.size(0); ++i) { // FrameInfo struct utilizes PTS FrameInfo frameInfo = {all_frames[i].item()}; - frameInfo.isKeyFrame = - (is_key_frame[i].item() == 1); + frameInfo.isKeyFrame = (is_key_frame[i].item() == 1); frameInfo.nextPts = (i + 1 < all_frames.size(0)) ? all_frames[i + 1].item() : INT64_MAX; @@ -468,7 +468,8 @@ void SingleStreamDecoder::addStream( void SingleStreamDecoder::addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions, - std::optional> customFrameMappings) { + std::optional> + customFrameMappings) { addStream( streamIndex, AVMEDIA_TYPE_VIDEO, @@ -498,7 +499,8 @@ void SingleStreamDecoder::addVideoStream( TORCH_CHECK( customFrameMappings.has_value(), "Please provide frame mappings when using custom_frame_mappings seek mode."); - readCustomFrameMappingsUpdateMetadataAndIndex(streamIndex, customFrameMappings.value()); + readCustomFrameMappingsUpdateMetadataAndIndex( + streamIndex, customFrameMappings.value()); } } diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 849d2484..5854287c 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -74,8 +74,8 @@ class SingleStreamDecoder { void addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(), - std::optional> customFrameMappings = - std::nullopt); + std::optional> + customFrameMappings = std::nullopt); void addAudioStream( int streamIndex, const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 39b0a4fc..e1db50f2 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -223,8 +223,8 @@ void _add_video_stream( std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, std::optional device = std::nullopt, - std::optional> custom_frame_mappings = - std::nullopt, + std::optional> + custom_frame_mappings = std::nullopt, std::optional color_conversion_library = std::nullopt) { VideoStreamOptions videoStreamOptions; videoStreamOptions.width = width; @@ -270,8 +270,8 @@ void add_video_stream( std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, std::optional device = std::nullopt, - std::optional> custom_frame_mappings = - std::nullopt) { + std::optional> + custom_frame_mappings = std::nullopt) { _add_video_stream( decoder, width, diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 3daa4370..21046a33 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -205,7 +205,9 @@ def _add_video_stream_abstract( dimension_order: Optional[str] = None, stream_index: Optional[int] = None, device: Optional[str] = None, - custom_frame_mappings: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, + custom_frame_mappings: Optional[ + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ] = None, color_conversion_library: Optional[str] = None, ) -> None: return @@ -221,7 +223,9 @@ def add_video_stream_abstract( dimension_order: Optional[str] = None, stream_index: Optional[int] = None, device: Optional[str] = None, - custom_frame_mappings: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, + custom_frame_mappings: Optional[ + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ] = None, ) -> None: return diff --git a/test/test_metadata.py b/test/test_metadata.py index 86d77612..4f2a0eb5 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -37,10 +37,16 @@ def test_get_metadata(seek_mode): decoder = create_from_file(str(NASA_VIDEO.path), seek_mode=seek_mode) # For custom_frame_mappings seek mode, add a video stream to update metadata - custom_frame_mappings = NASA_VIDEO.custom_frame_mappings if seek_mode == "custom_frame_mappings" else None + custom_frame_mappings = ( + NASA_VIDEO.custom_frame_mappings + if seek_mode == "custom_frame_mappings" + else None + ) # Add the best video stream (index 3 for NASA_VIDEO) add_video_stream( - decoder, stream_index=NASA_VIDEO.default_stream_index, custom_frame_mappings=custom_frame_mappings + decoder, + stream_index=NASA_VIDEO.default_stream_index, + custom_frame_mappings=custom_frame_mappings, ) metadata = get_container_metadata(decoder) diff --git a/test/test_ops.py b/test/test_ops.py index 93946496..19d6638c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -457,12 +457,18 @@ def test_seek_mode_custom_frame_mappings_fails(self): add_video_stream(decoder, stream_index=0, custom_frame_mappings=None) decoder = create_from_file(str(NASA_VIDEO.path), "custom_frame_mappings") - different_lengths = ((torch.tensor([1, 2, 3]), torch.tensor([1, 2]), torch.tensor([1, 2, 3]))) + different_lengths = ( + torch.tensor([1, 2, 3]), + torch.tensor([1, 2]), + torch.tensor([1, 2, 3]), + ) with pytest.raises( RuntimeError, match="all_frames, is_key_frame, and duration from custom_frame_mappings were not same size.", ): - add_video_stream(decoder, stream_index=0, custom_frame_mappings=different_lengths) + add_video_stream( + decoder, stream_index=0, custom_frame_mappings=different_lengths + ) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_mode_custom_frame_mappings(self, device): diff --git a/test/utils.py b/test/utils.py index 7ee679c3..1a0b960d 100644 --- a/test/utils.py +++ b/test/utils.py @@ -122,7 +122,9 @@ class TestContainerFile: default_stream_index: int stream_infos: Dict[int, Union[TestVideoStreamInfo, TestAudioStreamInfo]] frames: Dict[int, Dict[int, TestFrameInfo]] - custom_frame_mappings_data: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None + custom_frame_mappings_data: Optional[ + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ] = None def __post_init__(self): # We load the .frames attribute from the checked-in json files, if needed. From e91416003e1909171ecd2fa08aa725369ca81b18 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Thu, 10 Jul 2025 15:41:00 -0400 Subject: [PATCH 09/26] Use CustomFrameMappings struct --- src/torchcodec/_core/SingleStreamDecoder.cpp | 13 ++++----- src/torchcodec/_core/SingleStreamDecoder.h | 30 ++++++++++++++------ src/torchcodec/_core/custom_ops.cpp | 9 +++++- 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index ef218145..3f623784 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -321,10 +321,10 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( int streamIndex, - std::tuple customFrameMappings) { - auto& all_frames = std::get<0>(customFrameMappings); - auto& is_key_frame = std::get<1>(customFrameMappings); - auto& duration = std::get<2>(customFrameMappings); + CustomFrameMappings customFrameMappings) { + auto& all_frames = customFrameMappings.all_frames; + auto& is_key_frame = customFrameMappings.is_key_frame; + auto& duration = customFrameMappings.duration; TORCH_CHECK( all_frames.size(0) == is_key_frame.size(0) && is_key_frame.size(0) == duration.size(0), @@ -347,7 +347,7 @@ void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( for (int64_t i = 0; i < all_frames.size(0); ++i) { // FrameInfo struct utilizes PTS FrameInfo frameInfo = {all_frames[i].item()}; - frameInfo.isKeyFrame = (is_key_frame[i].item() == 1); + frameInfo.isKeyFrame = (is_key_frame[i].item() == true); frameInfo.nextPts = (i + 1 < all_frames.size(0)) ? all_frames[i + 1].item() : INT64_MAX; @@ -468,8 +468,7 @@ void SingleStreamDecoder::addStream( void SingleStreamDecoder::addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions, - std::optional> - customFrameMappings) { + std::optional customFrameMappings) { addStream( streamIndex, AVMEDIA_TYPE_VIDEO, diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 5854287c..abf53018 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -53,13 +53,6 @@ class SingleStreamDecoder { // the allFrames and keyFrames vectors. void scanFileAndUpdateMetadataAndIndex(); - // Reads the user provided frame index and updates each StreamInfo's index, - // i.e. the allFrames and keyFrames vectors, and the - // endStreamPtsSecondsFromContent - void readCustomFrameMappingsUpdateMetadataAndIndex( - int streamIndex, - std::tuple customFrameMappings); - // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; @@ -67,6 +60,26 @@ class SingleStreamDecoder { // int64 values, where each value is the frame index for a key frame. torch::Tensor getKeyFrameIndices(); + struct CustomFrameMappings { + // all_frames is a 1D tensor of int64 values, where each value is the PTS + // for a frame in the stream. + // The size of all tensors in this struct must match. + torch::Tensor all_frames; + // is_key_frame is a 1D tensor of bool values, and indicates + // whether the corresponding frame in all_frames is a key frame. + torch::Tensor is_key_frame; + // duration is a 1D tensor of int64 values, where each value is the duration + // of the corresponding frame in all_frames. + torch::Tensor duration; + }; + + // Reads the user provided frame index and updates each StreamInfo's index, + // i.e. the allFrames and keyFrames vectors, and + // endStreamPtsSecondsFromContent + void readCustomFrameMappingsUpdateMetadataAndIndex( + int streamIndex, + CustomFrameMappings customFrameMappings); + // -------------------------------------------------------------------------- // ADDING STREAMS API // -------------------------------------------------------------------------- @@ -74,8 +87,7 @@ class SingleStreamDecoder { void addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(), - std::optional> - customFrameMappings = std::nullopt); + std::optional customFrameMappings = std::nullopt); void addAudioStream( int streamIndex, const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index e1db50f2..d083c7d0 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -255,10 +255,17 @@ void _add_video_stream( if (device.has_value()) { videoStreamOptions.device = createTorchDevice(std::string(device.value())); } + std::optional converted_mappings = + custom_frame_mappings.has_value() + ? std::make_optional(SingleStreamDecoder::CustomFrameMappings{ + std::get<0>(custom_frame_mappings.value()), + std::get<1>(custom_frame_mappings.value()), + std::get<2>(custom_frame_mappings.value())}) + : std::nullopt; auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->addVideoStream( - stream_index.value_or(-1), videoStreamOptions, custom_frame_mappings); + stream_index.value_or(-1), videoStreamOptions, converted_mappings); } // Add a new video stream at `stream_index` using the provided options. From 394ac700cdf362c392a18ab3cc03241dd9ec13a8 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Thu, 10 Jul 2025 15:45:47 -0400 Subject: [PATCH 10/26] use seek_mode keyword in test_ops --- test/test_ops.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 19d6638c..382f957f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -449,14 +449,18 @@ def test_frame_pts_equality(self): assert pts_is_equal def test_seek_mode_custom_frame_mappings_fails(self): - decoder = create_from_file(str(NASA_VIDEO.path), "custom_frame_mappings") + decoder = create_from_file( + str(NASA_VIDEO.path), seek_mode="custom_frame_mappings" + ) with pytest.raises( RuntimeError, match="Please provide frame mappings when using custom_frame_mappings seek mode.", ): add_video_stream(decoder, stream_index=0, custom_frame_mappings=None) - decoder = create_from_file(str(NASA_VIDEO.path), "custom_frame_mappings") + decoder = create_from_file( + str(NASA_VIDEO.path), seek_mode="custom_frame_mappings" + ) different_lengths = ( torch.tensor([1, 2, 3]), torch.tensor([1, 2]), @@ -473,7 +477,9 @@ def test_seek_mode_custom_frame_mappings_fails(self): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_mode_custom_frame_mappings(self, device): stream_index = 3 # frame index seek mode requires a stream index - decoder = create_from_file(str(NASA_VIDEO.path), "custom_frame_mappings") + decoder = create_from_file( + str(NASA_VIDEO.path), seek_mode="custom_frame_mappings" + ) add_video_stream( decoder, device=device, From f13433c8e986ac0a9b691ae5f5cef0b9dbe406f7 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Thu, 10 Jul 2025 16:26:49 -0400 Subject: [PATCH 11/26] Turn custom_frame_mappings_data into dict, use list comprehension --- test/test_metadata.py | 2 +- test/test_ops.py | 4 +++- test/utils.py | 38 ++++++++++++++++---------------------- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/test/test_metadata.py b/test/test_metadata.py index 4f2a0eb5..47ed80c5 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -38,7 +38,7 @@ def test_get_metadata(seek_mode): decoder = create_from_file(str(NASA_VIDEO.path), seek_mode=seek_mode) # For custom_frame_mappings seek mode, add a video stream to update metadata custom_frame_mappings = ( - NASA_VIDEO.custom_frame_mappings + NASA_VIDEO.get_custom_frame_mappings() if seek_mode == "custom_frame_mappings" else None ) diff --git a/test/test_ops.py b/test/test_ops.py index 382f957f..09d70f85 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -484,7 +484,9 @@ def test_seek_mode_custom_frame_mappings(self, device): decoder, device=device, stream_index=stream_index, - custom_frame_mappings=NASA_VIDEO.custom_frame_mappings, + custom_frame_mappings=NASA_VIDEO.get_custom_frame_mappings( + stream_index=stream_index + ), ) frame0, _, _ = get_next_frame(decoder) diff --git a/test/utils.py b/test/utils.py index 1a0b960d..cd466983 100644 --- a/test/utils.py +++ b/test/utils.py @@ -122,9 +122,9 @@ class TestContainerFile: default_stream_index: int stream_infos: Dict[int, Union[TestVideoStreamInfo, TestAudioStreamInfo]] frames: Dict[int, Dict[int, TestFrameInfo]] - custom_frame_mappings_data: Optional[ - tuple[torch.Tensor, torch.Tensor, torch.Tensor] - ] = None + custom_frame_mappings_data: Dict[ + int, Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + ] = field(default_factory=dict) def __post_init__(self): # We load the .frames attribute from the checked-in json files, if needed. @@ -227,19 +227,18 @@ def get_frame_info( return self.frames[stream_index][idx] - # This property is used to get the frame mappings for the custom_frame_mappings seek mode. - @property - def custom_frame_mappings( + # This function is used to get the frame mappings for the custom_frame_mappings seek mode. + def get_custom_frame_mappings( self, stream_index: Optional[int] = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if stream_index is None: stream_index = self.default_stream_index - if self.custom_frame_mappings_data is None: - self.get_custom_frame_mappings(stream_index) - return self.custom_frame_mappings_data + if self.custom_frame_mappings_data.get(stream_index) is None: + self.generate_custom_frame_mappings(stream_index) + return self.custom_frame_mappings_data[stream_index] - def get_custom_frame_mappings(self, stream_index: int) -> None: - show_frames_result = json.loads( + def generate_custom_frame_mappings(self, stream_index: int) -> None: + result = json.loads( subprocess.run( [ "ffprobe", @@ -256,22 +255,17 @@ def get_custom_frame_mappings(self, stream_index: int) -> None: text=True, ).stdout ) - custom_frame_mappings_data = ([], [], []) - frames = show_frames_result["frames"] - for frame in frames: - custom_frame_mappings_data[0].append(float(frame["pts"])) - custom_frame_mappings_data[1].append(frame["key_frame"]) - custom_frame_mappings_data[2].append(float(frame["duration"])) - - (pts_list, key_frame_list, duration_list) = custom_frame_mappings_data + pts_list = [float(frame["pts"]) for frame in result["frames"]] + is_key_frame_list = [frame["key_frame"] for frame in result["frames"]] + duration_list = [float(frame["duration"]) for frame in result["frames"]] # Zip the lists together, sort by pts, then unzip assert ( - len(pts_list) == len(key_frame_list) == len(duration_list) + len(pts_list) == len(is_key_frame_list) == len(duration_list) ), "Mismatched lengths in frame index data" - combined = list(zip(pts_list, key_frame_list, duration_list)) + combined = list(zip(pts_list, is_key_frame_list, duration_list)) combined.sort(key=lambda x: x[0]) pts_sorted, is_key_frame_sorted, duration_sorted = zip(*combined) - self.custom_frame_mappings_data = ( + self.custom_frame_mappings_data[stream_index] = ( torch.tensor(pts_sorted), torch.tensor(is_key_frame_sorted), torch.tensor(duration_sorted), From 8971bd6afa3349bf08e44e1de4bbd07430944617 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Thu, 10 Jul 2025 16:59:24 -0400 Subject: [PATCH 12/26] Remove c++ struct CustomFrameMappings --- src/torchcodec/_core/SingleStreamDecoder.cpp | 11 +++---- src/torchcodec/_core/SingleStreamDecoder.h | 30 ++++++-------------- src/torchcodec/_core/custom_ops.cpp | 10 +------ 3 files changed, 16 insertions(+), 35 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 3f623784..83132d90 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -321,10 +321,10 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( int streamIndex, - CustomFrameMappings customFrameMappings) { - auto& all_frames = customFrameMappings.all_frames; - auto& is_key_frame = customFrameMappings.is_key_frame; - auto& duration = customFrameMappings.duration; + std::tuple customFrameMappings) { + auto& all_frames = std::get<0>(customFrameMappings); + auto& is_key_frame = std::get<1>(customFrameMappings); + auto& duration = std::get<2>(customFrameMappings); TORCH_CHECK( all_frames.size(0) == is_key_frame.size(0) && is_key_frame.size(0) == duration.size(0), @@ -468,7 +468,8 @@ void SingleStreamDecoder::addStream( void SingleStreamDecoder::addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions, - std::optional customFrameMappings) { + std::optional> + customFrameMappings) { addStream( streamIndex, AVMEDIA_TYPE_VIDEO, diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index abf53018..622428cb 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -53,6 +53,13 @@ class SingleStreamDecoder { // the allFrames and keyFrames vectors. void scanFileAndUpdateMetadataAndIndex(); + // Reads the user provided frame index and updates each StreamInfo's index, + // i.e. the allFrames and keyFrames vectors, and + // endStreamPtsSecondsFromContent + void readCustomFrameMappingsUpdateMetadataAndIndex( + int streamIndex, + std::tuple customFrameMappings); + // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; @@ -60,26 +67,6 @@ class SingleStreamDecoder { // int64 values, where each value is the frame index for a key frame. torch::Tensor getKeyFrameIndices(); - struct CustomFrameMappings { - // all_frames is a 1D tensor of int64 values, where each value is the PTS - // for a frame in the stream. - // The size of all tensors in this struct must match. - torch::Tensor all_frames; - // is_key_frame is a 1D tensor of bool values, and indicates - // whether the corresponding frame in all_frames is a key frame. - torch::Tensor is_key_frame; - // duration is a 1D tensor of int64 values, where each value is the duration - // of the corresponding frame in all_frames. - torch::Tensor duration; - }; - - // Reads the user provided frame index and updates each StreamInfo's index, - // i.e. the allFrames and keyFrames vectors, and - // endStreamPtsSecondsFromContent - void readCustomFrameMappingsUpdateMetadataAndIndex( - int streamIndex, - CustomFrameMappings customFrameMappings); - // -------------------------------------------------------------------------- // ADDING STREAMS API // -------------------------------------------------------------------------- @@ -87,7 +74,8 @@ class SingleStreamDecoder { void addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(), - std::optional customFrameMappings = std::nullopt); + std::optional> + customFrameMappings = std::nullopt); void addAudioStream( int streamIndex, const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index d083c7d0..d394d8f0 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -255,17 +255,9 @@ void _add_video_stream( if (device.has_value()) { videoStreamOptions.device = createTorchDevice(std::string(device.value())); } - std::optional converted_mappings = - custom_frame_mappings.has_value() - ? std::make_optional(SingleStreamDecoder::CustomFrameMappings{ - std::get<0>(custom_frame_mappings.value()), - std::get<1>(custom_frame_mappings.value()), - std::get<2>(custom_frame_mappings.value())}) - : std::nullopt; - auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->addVideoStream( - stream_index.value_or(-1), videoStreamOptions, converted_mappings); + stream_index.value_or(-1), videoStreamOptions, custom_frame_mappings); } // Add a new video stream at `stream_index` using the provided options. From 5d66ef1b00ce10e2a2d43ea66002b0a7806327dd Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 14 Jul 2025 11:19:29 -0400 Subject: [PATCH 13/26] underscore custom_frame_mappings_data field --- test/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/utils.py b/test/utils.py index cd466983..232e045c 100644 --- a/test/utils.py +++ b/test/utils.py @@ -122,7 +122,7 @@ class TestContainerFile: default_stream_index: int stream_infos: Dict[int, Union[TestVideoStreamInfo, TestAudioStreamInfo]] frames: Dict[int, Dict[int, TestFrameInfo]] - custom_frame_mappings_data: Dict[ + _custom_frame_mappings_data: Dict[ int, Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ] = field(default_factory=dict) @@ -233,9 +233,9 @@ def get_custom_frame_mappings( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if stream_index is None: stream_index = self.default_stream_index - if self.custom_frame_mappings_data.get(stream_index) is None: + if self._custom_frame_mappings_data.get(stream_index) is None: self.generate_custom_frame_mappings(stream_index) - return self.custom_frame_mappings_data[stream_index] + return self._custom_frame_mappings_data[stream_index] def generate_custom_frame_mappings(self, stream_index: int) -> None: result = json.loads( @@ -265,7 +265,7 @@ def generate_custom_frame_mappings(self, stream_index: int) -> None: combined = list(zip(pts_list, is_key_frame_list, duration_list)) combined.sort(key=lambda x: x[0]) pts_sorted, is_key_frame_sorted, duration_sorted = zip(*combined) - self.custom_frame_mappings_data[stream_index] = ( + self._custom_frame_mappings_data[stream_index] = ( torch.tensor(pts_sorted), torch.tensor(is_key_frame_sorted), torch.tensor(duration_sorted), From 4907e6d179a7ccdf11c5a3665724eb38ef03033d Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 14 Jul 2025 11:38:11 -0400 Subject: [PATCH 14/26] set keyFrames in readCustomFrameMappingsUpdateMetadataAndIndex --- src/torchcodec/_core/SingleStreamDecoder.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 83132d90..95a82d30 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -346,12 +346,15 @@ void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( streamMetadata.numFramesFromContent = all_frames.size(0); for (int64_t i = 0; i < all_frames.size(0); ++i) { // FrameInfo struct utilizes PTS - FrameInfo frameInfo = {all_frames[i].item()}; + FrameInfo frameInfo = {.pts=all_frames[i].item()}; frameInfo.isKeyFrame = (is_key_frame[i].item() == true); frameInfo.nextPts = (i + 1 < all_frames.size(0)) ? all_frames[i + 1].item() : INT64_MAX; streamInfos_[streamIndex].allFrames.push_back(frameInfo); + if (frameInfo.isKeyFrame) { + streamInfos_[streamIndex].keyFrames.push_back(frameInfo); + } } } From d9cec941314edf4cb47f94c7c078e87bd6a524a4 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 14 Jul 2025 14:56:18 -0400 Subject: [PATCH 15/26] Restore parameterized metadata_getter --- src/torchcodec/_core/SingleStreamDecoder.cpp | 4 +- test/test_metadata.py | 53 +++++++++++++------- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 95a82d30..b107cc87 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -346,7 +346,7 @@ void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( streamMetadata.numFramesFromContent = all_frames.size(0); for (int64_t i = 0; i < all_frames.size(0); ++i) { // FrameInfo struct utilizes PTS - FrameInfo frameInfo = {.pts=all_frames[i].item()}; + FrameInfo frameInfo = {.pts = all_frames[i].item()}; frameInfo.isKeyFrame = (is_key_frame[i].item() == true); frameInfo.nextPts = (i + 1 < all_frames.size(0)) ? all_frames[i + 1].item() @@ -354,7 +354,7 @@ void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( streamInfos_[streamIndex].allFrames.push_back(frameInfo); if (frameInfo.isKeyFrame) { streamInfos_[streamIndex].keyFrames.push_back(frameInfo); - } + } } } diff --git a/test/test_metadata.py b/test/test_metadata.py index 47ed80c5..8d045dc4 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -10,6 +10,7 @@ import pytest from torchcodec._core import ( + add_video_stream, AudioStreamMetadata, create_from_file, get_container_metadata, @@ -28,29 +29,43 @@ def _get_container_metadata(path, seek_mode): decoder = create_from_file(str(path), seek_mode=seek_mode) - return get_container_metadata(decoder) + # For custom_frame_mappings seek mode, add a video stream to update metadata + if seek_mode == "custom_frame_mappings": + custom_frame_mappings = NASA_VIDEO.get_custom_frame_mappings() + + # Add the best video stream (index 3 for NASA_VIDEO) + add_video_stream( + decoder, + stream_index=NASA_VIDEO.default_stream_index, + custom_frame_mappings=custom_frame_mappings, + ) + return get_container_metadata(decoder) -@pytest.mark.parametrize("seek_mode", ["approximate", "exact", "custom_frame_mappings"]) -def test_get_metadata(seek_mode): - from torchcodec._core import add_video_stream - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode=seek_mode) - # For custom_frame_mappings seek mode, add a video stream to update metadata - custom_frame_mappings = ( - NASA_VIDEO.get_custom_frame_mappings() - if seek_mode == "custom_frame_mappings" +@pytest.mark.parametrize( + "metadata_getter", + ( + get_container_metadata_from_header, + functools.partial(_get_container_metadata, seek_mode="approximate"), + functools.partial(_get_container_metadata, seek_mode="exact"), + functools.partial(_get_container_metadata, seek_mode="custom_frame_mappings"), + ), +) +def test_get_metadata(metadata_getter): + seek_mode = ( + metadata_getter.keywords["seek_mode"] + if isinstance(metadata_getter, functools.partial) else None ) - # Add the best video stream (index 3 for NASA_VIDEO) - add_video_stream( - decoder, - stream_index=NASA_VIDEO.default_stream_index, - custom_frame_mappings=custom_frame_mappings, - ) - metadata = get_container_metadata(decoder) + with_added_video_stream = seek_mode == "custom_frame_mappings" + metadata = metadata_getter(NASA_VIDEO.path) - with_scan = seek_mode == "exact" or seek_mode == "custom_frame_mappings" + with_scan = ( + (seek_mode == "exact" or seek_mode == "custom_frame_mappings") + if isinstance(metadata_getter, functools.partial) + else False + ) assert len(metadata.streams) == 6 assert metadata.best_video_stream_index == 3 @@ -85,7 +100,9 @@ def test_get_metadata(seek_mode): assert best_video_stream_metadata.begin_stream_seconds_from_header == 0 assert best_video_stream_metadata.bit_rate == 128783 assert best_video_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001) - assert best_video_stream_metadata.pixel_aspect_ratio == Fraction(1, 1) + assert best_video_stream_metadata.pixel_aspect_ratio == ( + Fraction(1, 1) if with_added_video_stream else None + ) assert best_video_stream_metadata.codec == "h264" assert best_video_stream_metadata.num_frames_from_content == ( 390 if with_scan else None From 72ae22dead509c5838398b569781e740e042665a Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 14 Jul 2025 15:48:44 -0400 Subject: [PATCH 16/26] Add frameMappings struct to wrap tensor tuple --- src/torchcodec/_core/Frame.h | 14 ++++++++++++++ src/torchcodec/_core/SingleStreamDecoder.cpp | 11 +++++------ src/torchcodec/_core/SingleStreamDecoder.h | 5 ++--- src/torchcodec/_core/custom_ops.cpp | 14 +++++++++++++- 4 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/torchcodec/_core/Frame.h b/src/torchcodec/_core/Frame.h index 84ccc728..014c954d 100644 --- a/src/torchcodec/_core/Frame.h +++ b/src/torchcodec/_core/Frame.h @@ -45,6 +45,20 @@ struct AudioFramesOutput { double ptsSeconds; }; +// FrameMappings is used for the custom_frame_mappings seek mode to store +// metadata of frames in a stream. The size of all tensors in this struct must +// match. +struct FrameMappings { + // 1D tensor of int64, each value is the PTS of a frame in timebase units. + torch::Tensor all_frames; + // 1D tensor of bool, each value indicates if the corresponding frame in + // all_frames is a key frame. + torch::Tensor is_key_frame; + // 1D tensor of int64, each value is the duration of the corresponding frame + // in all_frames in timebase units. + torch::Tensor duration; +}; + // -------------------------------------------------------------------------- // FRAME TENSOR ALLOCATION APIs // -------------------------------------------------------------------------- diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index b107cc87..f835740d 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -321,10 +321,10 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( int streamIndex, - std::tuple customFrameMappings) { - auto& all_frames = std::get<0>(customFrameMappings); - auto& is_key_frame = std::get<1>(customFrameMappings); - auto& duration = std::get<2>(customFrameMappings); + FrameMappings customFrameMappings) { + auto& all_frames = customFrameMappings.all_frames; + auto& is_key_frame = customFrameMappings.is_key_frame; + auto& duration = customFrameMappings.duration; TORCH_CHECK( all_frames.size(0) == is_key_frame.size(0) && is_key_frame.size(0) == duration.size(0), @@ -471,8 +471,7 @@ void SingleStreamDecoder::addStream( void SingleStreamDecoder::addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions, - std::optional> - customFrameMappings) { + std::optional customFrameMappings) { addStream( streamIndex, AVMEDIA_TYPE_VIDEO, diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 622428cb..453c8d68 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -58,7 +58,7 @@ class SingleStreamDecoder { // endStreamPtsSecondsFromContent void readCustomFrameMappingsUpdateMetadataAndIndex( int streamIndex, - std::tuple customFrameMappings); + FrameMappings customFrameMappings); // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; @@ -74,8 +74,7 @@ class SingleStreamDecoder { void addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(), - std::optional> - customFrameMappings = std::nullopt); + std::optional customFrameMappings = std::nullopt); void addAudioStream( int streamIndex, const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index d394d8f0..62637e6f 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -105,6 +105,14 @@ OpsFrameOutput makeOpsFrameOutput(FrameOutput& frame) { torch::tensor(frame.durationSeconds, torch::dtype(torch::kFloat64))); } +FrameMappings makeFrameMappings( + std::tuple custom_frame_mappings) { + return FrameMappings{ + std::get<0>(custom_frame_mappings), + std::get<1>(custom_frame_mappings), + std::get<2>(custom_frame_mappings)}; +} + // All elements of this tuple are tensors of the same leading dimension. The // tuple represents the frames for N total frames, where N is the dimension of // each stacked tensor. The elments are: @@ -255,9 +263,13 @@ void _add_video_stream( if (device.has_value()) { videoStreamOptions.device = createTorchDevice(std::string(device.value())); } + std::optional converted_mappings = + custom_frame_mappings.has_value() + ? std::make_optional(makeFrameMappings(custom_frame_mappings.value())) + : std::nullopt; auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->addVideoStream( - stream_index.value_or(-1), videoStreamOptions, custom_frame_mappings); + stream_index.value_or(-1), videoStreamOptions, converted_mappings); } // Add a new video stream at `stream_index` using the provided options. From 34eb7b44ff121e4586fc7e1b141a55dc6e99db84 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 14 Jul 2025 19:51:32 -0400 Subject: [PATCH 17/26] Extract sorting code to function --- src/torchcodec/_core/SingleStreamDecoder.cpp | 71 +++++++++++--------- src/torchcodec/_core/SingleStreamDecoder.h | 3 + test/test_ops.py | 2 +- test/utils.py | 20 +++--- 4 files changed, 52 insertions(+), 44 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index f835740d..abc6b803 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -198,6 +198,41 @@ int SingleStreamDecoder::getBestStreamIndex(AVMediaType mediaType) { // VIDEO METADATA QUERY API // -------------------------------------------------------------------------- +void SingleStreamDecoder::sortAllFrames() { + for (auto& [streamIndex, streamInfo] : streamInfos_) { + std::sort( + streamInfo.keyFrames.begin(), + streamInfo.keyFrames.end(), + [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) { + return frameInfo1.pts < frameInfo2.pts; + }); + std::sort( + streamInfo.allFrames.begin(), + streamInfo.allFrames.end(), + [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) { + return frameInfo1.pts < frameInfo2.pts; + }); + + size_t keyFrameIndex = 0; + for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) { + streamInfo.allFrames[i].frameIndex = i; + if (streamInfo.allFrames[i].isKeyFrame) { + TORCH_CHECK( + keyFrameIndex < streamInfo.keyFrames.size(), + "The allFrames vec claims it has MORE keyFrames than the keyFrames vec. There's a bug in torchcodec."); + streamInfo.keyFrames[keyFrameIndex].frameIndex = i; + ++keyFrameIndex; + } + if (i + 1 < streamInfo.allFrames.size()) { + streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts; + } + } + TORCH_CHECK( + keyFrameIndex == streamInfo.keyFrames.size(), + "The allFrames vec claims it has LESS keyFrames than the keyFrames vec. There's a bug in torchcodec."); + } +} + void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { if (scannedAllStreams_) { return; @@ -283,39 +318,7 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { getFFMPEGErrorStringFromErrorCode(status)); // Sort all frames by their pts. - for (auto& [streamIndex, streamInfo] : streamInfos_) { - std::sort( - streamInfo.keyFrames.begin(), - streamInfo.keyFrames.end(), - [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) { - return frameInfo1.pts < frameInfo2.pts; - }); - std::sort( - streamInfo.allFrames.begin(), - streamInfo.allFrames.end(), - [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) { - return frameInfo1.pts < frameInfo2.pts; - }); - - size_t keyFrameIndex = 0; - for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) { - streamInfo.allFrames[i].frameIndex = i; - if (streamInfo.allFrames[i].isKeyFrame) { - TORCH_CHECK( - keyFrameIndex < streamInfo.keyFrames.size(), - "The allFrames vec claims it has MORE keyFrames than the keyFrames vec. There's a bug in torchcodec."); - streamInfo.keyFrames[keyFrameIndex].frameIndex = i; - ++keyFrameIndex; - } - if (i + 1 < streamInfo.allFrames.size()) { - streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts; - } - } - TORCH_CHECK( - keyFrameIndex == streamInfo.keyFrames.size(), - "The allFrames vec claims it has LESS keyFrames than the keyFrames vec. There's a bug in torchcodec."); - } - + sortAllFrames(); scannedAllStreams_ = true; } @@ -356,6 +359,8 @@ void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( streamInfos_[streamIndex].keyFrames.push_back(frameInfo); } } + // Sort all frames by their pts + sortAllFrames(); } ContainerMetadata SingleStreamDecoder::getContainerMetadata() const { diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 453c8d68..21925528 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -60,6 +60,9 @@ class SingleStreamDecoder { int streamIndex, FrameMappings customFrameMappings); + // Sorts the keyFrames and allFrames vectors in each StreamInfo by pts. + void sortAllFrames(); + // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; diff --git a/test/test_ops.py b/test/test_ops.py index 09d70f85..407e7c9f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -476,7 +476,7 @@ def test_seek_mode_custom_frame_mappings_fails(self): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_mode_custom_frame_mappings(self, device): - stream_index = 3 # frame index seek mode requires a stream index + stream_index = 3 # custom_frame_index seek mode requires a stream index decoder = create_from_file( str(NASA_VIDEO.path), seek_mode="custom_frame_mappings" ) diff --git a/test/utils.py b/test/utils.py index 232e045c..43502c5c 100644 --- a/test/utils.py +++ b/test/utils.py @@ -255,20 +255,20 @@ def generate_custom_frame_mappings(self, stream_index: int) -> None: text=True, ).stdout ) - pts_list = [float(frame["pts"]) for frame in result["frames"]] - is_key_frame_list = [frame["key_frame"] for frame in result["frames"]] - duration_list = [float(frame["duration"]) for frame in result["frames"]] - # Zip the lists together, sort by pts, then unzip + pts_list = torch.tensor([float(frame["pts"]) for frame in result["frames"]]) + is_key_frame_list = torch.tensor( + [frame["key_frame"] for frame in result["frames"]] + ) + duration_list = torch.tensor( + [float(frame["duration"]) for frame in result["frames"]] + ) assert ( len(pts_list) == len(is_key_frame_list) == len(duration_list) ), "Mismatched lengths in frame index data" - combined = list(zip(pts_list, is_key_frame_list, duration_list)) - combined.sort(key=lambda x: x[0]) - pts_sorted, is_key_frame_sorted, duration_sorted = zip(*combined) self._custom_frame_mappings_data[stream_index] = ( - torch.tensor(pts_sorted), - torch.tensor(is_key_frame_sorted), - torch.tensor(duration_sorted), + pts_list, + is_key_frame_list, + duration_list, ) @property From 9a343b3b316bbc5df0c4ed95d3d89602d1e9bcd1 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Tue, 15 Jul 2025 09:26:25 -0400 Subject: [PATCH 18/26] move FrameMappings struct to singlestreamdecoder.h --- src/torchcodec/_core/Frame.h | 14 ----------- src/torchcodec/_core/SingleStreamDecoder.h | 28 ++++++++++++++++------ src/torchcodec/_core/custom_ops.cpp | 6 ++--- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/torchcodec/_core/Frame.h b/src/torchcodec/_core/Frame.h index 014c954d..84ccc728 100644 --- a/src/torchcodec/_core/Frame.h +++ b/src/torchcodec/_core/Frame.h @@ -45,20 +45,6 @@ struct AudioFramesOutput { double ptsSeconds; }; -// FrameMappings is used for the custom_frame_mappings seek mode to store -// metadata of frames in a stream. The size of all tensors in this struct must -// match. -struct FrameMappings { - // 1D tensor of int64, each value is the PTS of a frame in timebase units. - torch::Tensor all_frames; - // 1D tensor of bool, each value indicates if the corresponding frame in - // all_frames is a key frame. - torch::Tensor is_key_frame; - // 1D tensor of int64, each value is the duration of the corresponding frame - // in all_frames in timebase units. - torch::Tensor duration; -}; - // -------------------------------------------------------------------------- // FRAME TENSOR ALLOCATION APIs // -------------------------------------------------------------------------- diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 21925528..39d294fe 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -53,13 +53,6 @@ class SingleStreamDecoder { // the allFrames and keyFrames vectors. void scanFileAndUpdateMetadataAndIndex(); - // Reads the user provided frame index and updates each StreamInfo's index, - // i.e. the allFrames and keyFrames vectors, and - // endStreamPtsSecondsFromContent - void readCustomFrameMappingsUpdateMetadataAndIndex( - int streamIndex, - FrameMappings customFrameMappings); - // Sorts the keyFrames and allFrames vectors in each StreamInfo by pts. void sortAllFrames(); @@ -70,6 +63,27 @@ class SingleStreamDecoder { // int64 values, where each value is the frame index for a key frame. torch::Tensor getKeyFrameIndices(); +// FrameMappings is used for the custom_frame_mappings seek mode to store +// metadata of frames in a stream. The size of all tensors in this struct must +// match. +struct FrameMappings { + // 1D tensor of int64, each value is the PTS of a frame in timebase units. + torch::Tensor all_frames; + // 1D tensor of bool, each value indicates if the corresponding frame in + // all_frames is a key frame. + torch::Tensor is_key_frame; + // 1D tensor of int64, each value is the duration of the corresponding frame + // in all_frames in timebase units. + torch::Tensor duration; +}; + + // Reads the user provided frame index and updates each StreamInfo's index, + // i.e. the allFrames and keyFrames vectors, and + // endStreamPtsSecondsFromContent + void readCustomFrameMappingsUpdateMetadataAndIndex( + int streamIndex, + FrameMappings customFrameMappings); + // -------------------------------------------------------------------------- // ADDING STREAMS API // -------------------------------------------------------------------------- diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 62637e6f..2bae359a 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -105,9 +105,9 @@ OpsFrameOutput makeOpsFrameOutput(FrameOutput& frame) { torch::tensor(frame.durationSeconds, torch::dtype(torch::kFloat64))); } -FrameMappings makeFrameMappings( +SingleStreamDecoder::FrameMappings makeFrameMappings( std::tuple custom_frame_mappings) { - return FrameMappings{ + return SingleStreamDecoder::FrameMappings{ std::get<0>(custom_frame_mappings), std::get<1>(custom_frame_mappings), std::get<2>(custom_frame_mappings)}; @@ -263,7 +263,7 @@ void _add_video_stream( if (device.has_value()) { videoStreamOptions.device = createTorchDevice(std::string(device.value())); } - std::optional converted_mappings = + std::optional converted_mappings = custom_frame_mappings.has_value() ? std::make_optional(makeFrameMappings(custom_frame_mappings.value())) : std::nullopt; From 7728895cd1bd3fb1bed84ee6962d441aff5610d8 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Tue, 15 Jul 2025 09:40:01 -0400 Subject: [PATCH 19/26] Rename variables in generate_custom_frame_mappings --- src/torchcodec/_core/SingleStreamDecoder.h | 26 +++++++++++----------- test/utils.py | 16 ++++++------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 39d294fe..83dc92ae 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -63,19 +63,19 @@ class SingleStreamDecoder { // int64 values, where each value is the frame index for a key frame. torch::Tensor getKeyFrameIndices(); -// FrameMappings is used for the custom_frame_mappings seek mode to store -// metadata of frames in a stream. The size of all tensors in this struct must -// match. -struct FrameMappings { - // 1D tensor of int64, each value is the PTS of a frame in timebase units. - torch::Tensor all_frames; - // 1D tensor of bool, each value indicates if the corresponding frame in - // all_frames is a key frame. - torch::Tensor is_key_frame; - // 1D tensor of int64, each value is the duration of the corresponding frame - // in all_frames in timebase units. - torch::Tensor duration; -}; + // FrameMappings is used for the custom_frame_mappings seek mode to store + // metadata of frames in a stream. The size of all tensors in this struct must + // match. + struct FrameMappings { + // 1D tensor of int64, each value is the PTS of a frame in timebase units. + torch::Tensor all_frames; + // 1D tensor of bool, each value indicates if the corresponding frame in + // all_frames is a key frame. + torch::Tensor is_key_frame; + // 1D tensor of int64, each value is the duration of the corresponding frame + // in all_frames in timebase units. + torch::Tensor duration; + }; // Reads the user provided frame index and updates each StreamInfo's index, // i.e. the allFrames and keyFrames vectors, and diff --git a/test/utils.py b/test/utils.py index 43502c5c..76490fa0 100644 --- a/test/utils.py +++ b/test/utils.py @@ -255,20 +255,18 @@ def generate_custom_frame_mappings(self, stream_index: int) -> None: text=True, ).stdout ) - pts_list = torch.tensor([float(frame["pts"]) for frame in result["frames"]]) - is_key_frame_list = torch.tensor( - [frame["key_frame"] for frame in result["frames"]] - ) - duration_list = torch.tensor( + all_frames = torch.tensor([float(frame["pts"]) for frame in result["frames"]]) + is_key_frame = torch.tensor([frame["key_frame"] for frame in result["frames"]]) + duration = torch.tensor( [float(frame["duration"]) for frame in result["frames"]] ) assert ( - len(pts_list) == len(is_key_frame_list) == len(duration_list) + len(all_frames) == len(is_key_frame) == len(duration) ), "Mismatched lengths in frame index data" self._custom_frame_mappings_data[stream_index] = ( - pts_list, - is_key_frame_list, - duration_list, + all_frames, + is_key_frame, + duration, ) @property From dacd826fcf2d60cc43970eafa12d37ba36496ece Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Tue, 15 Jul 2025 09:47:07 -0400 Subject: [PATCH 20/26] remove duplicate logic --- src/torchcodec/_core/SingleStreamDecoder.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index abc6b803..7325207b 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -351,9 +351,6 @@ void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( // FrameInfo struct utilizes PTS FrameInfo frameInfo = {.pts = all_frames[i].item()}; frameInfo.isKeyFrame = (is_key_frame[i].item() == true); - frameInfo.nextPts = (i + 1 < all_frames.size(0)) - ? all_frames[i + 1].item() - : INT64_MAX; streamInfos_[streamIndex].allFrames.push_back(frameInfo); if (frameInfo.isKeyFrame) { streamInfos_[streamIndex].keyFrames.push_back(frameInfo); From f90718d0912aee022accd6a2640155f8696f184d Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Tue, 15 Jul 2025 10:02:12 -0400 Subject: [PATCH 21/26] remove designated initializer --- src/torchcodec/_core/SingleStreamDecoder.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 7325207b..3c6c06f0 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -348,8 +348,8 @@ void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( streamMetadata.numFramesFromContent = all_frames.size(0); for (int64_t i = 0; i < all_frames.size(0); ++i) { - // FrameInfo struct utilizes PTS - FrameInfo frameInfo = {.pts = all_frames[i].item()}; + FrameInfo frameInfo; + frameInfo.pts = all_frames[i].item(); frameInfo.isKeyFrame = (is_key_frame[i].item() == true); streamInfos_[streamIndex].allFrames.push_back(frameInfo); if (frameInfo.isKeyFrame) { From a24c79cd506b1edc5c49ab80192aeee2dfd251c6 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Tue, 15 Jul 2025 10:41:13 -0400 Subject: [PATCH 22/26] Check before accessing dict keys --- test/utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test/utils.py b/test/utils.py index 76490fa0..744bee98 100644 --- a/test/utils.py +++ b/test/utils.py @@ -255,10 +255,18 @@ def generate_custom_frame_mappings(self, stream_index: int) -> None: text=True, ).stdout ) - all_frames = torch.tensor([float(frame["pts"]) for frame in result["frames"]]) - is_key_frame = torch.tensor([frame["key_frame"] for frame in result["frames"]]) + all_frames = torch.tensor( + [float(frame["pts"]) for frame in result["frames"] if "pts" in frame] + ) + is_key_frame = torch.tensor( + [frame["key_frame"] for frame in result["frames"] if "key_frame" in frame] + ) duration = torch.tensor( - [float(frame["duration"]) for frame in result["frames"]] + [ + float(frame["duration"]) + for frame in result["frames"] + if "duration" in frame + ] ) assert ( len(all_frames) == len(is_key_frame) == len(duration) From ee6e8675cc983807c334d026b90a6af578e3b251 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Tue, 15 Jul 2025 15:00:33 -0400 Subject: [PATCH 23/26] add show_entries arg to ffprobe --- test/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/utils.py b/test/utils.py index 744bee98..85b61b42 100644 --- a/test/utils.py +++ b/test/utils.py @@ -246,6 +246,8 @@ def generate_custom_frame_mappings(self, stream_index: int) -> None: f"{self.path}", "-select_streams", f"{stream_index}", + "-show_entries", + "frame=pts,duration,stream_index,key_frame", "-show_frames", "-of", "json", From 011c8c5e2846ebfc7ec6b5b4cf2378cbdb668e70 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Tue, 15 Jul 2025 23:51:54 -0400 Subject: [PATCH 24/26] skip test if ffmpeg 4 or 5, remove past debugging attempts --- test/utils.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/test/utils.py b/test/utils.py index 85b61b42..37c16d51 100644 --- a/test/utils.py +++ b/test/utils.py @@ -231,6 +231,11 @@ def get_frame_info( def get_custom_frame_mappings( self, stream_index: Optional[int] = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Ensure all tests using this function are skipped if the FFmpeg version is 4 or 5 + # FFprobe on FFmpeg 4 and 5 does not return complete metadata + if get_ffmpeg_major_version() == 4 or get_ffmpeg_major_version() == 5: + pytest.skip("FFprobe on FFmpeg 4 and 5 does not return complete metadata") + if stream_index is None: stream_index = self.default_stream_index if self._custom_frame_mappings_data.get(stream_index) is None: @@ -246,8 +251,6 @@ def generate_custom_frame_mappings(self, stream_index: int) -> None: f"{self.path}", "-select_streams", f"{stream_index}", - "-show_entries", - "frame=pts,duration,stream_index,key_frame", "-show_frames", "-of", "json", @@ -257,18 +260,10 @@ def generate_custom_frame_mappings(self, stream_index: int) -> None: text=True, ).stdout ) - all_frames = torch.tensor( - [float(frame["pts"]) for frame in result["frames"] if "pts" in frame] - ) - is_key_frame = torch.tensor( - [frame["key_frame"] for frame in result["frames"] if "key_frame" in frame] - ) + all_frames = torch.tensor([float(frame["pts"]) for frame in result["frames"]]) + is_key_frame = torch.tensor([frame["key_frame"] for frame in result["frames"]]) duration = torch.tensor( - [ - float(frame["duration"]) - for frame in result["frames"] - if "duration" in frame - ] + [float(frame["duration"]) for frame in result["frames"]] ) assert ( len(all_frames) == len(is_key_frame) == len(duration) From c5c5cd249d949b5bfbbe63c36589997b45f2c018 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Wed, 16 Jul 2025 09:20:58 -0400 Subject: [PATCH 25/26] Use pytest skip in tests --- test/test_metadata.py | 4 +++- test/test_ops.py | 2 ++ test/utils.py | 5 ----- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/test/test_metadata.py b/test/test_metadata.py index 8d045dc4..a78a7c94 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -20,7 +20,7 @@ ) from torchcodec.decoders import AudioDecoder, VideoDecoder -from .utils import NASA_AUDIO_MP3, NASA_VIDEO +from .utils import NASA_AUDIO_MP3, NASA_VIDEO, get_ffmpeg_major_version # TODO: Expected values in these tests should be based on the assets's @@ -58,6 +58,8 @@ def test_get_metadata(metadata_getter): if isinstance(metadata_getter, functools.partial) else None ) + if (seek_mode == "custom_frame_mappings") and get_ffmpeg_major_version() in (4, 5): + pytest.skip(reason="ffprobe isn't accurate on ffmpeg 4 and 5") with_added_video_stream = seek_mode == "custom_frame_mappings" metadata = metadata_getter(NASA_VIDEO.path) diff --git a/test/test_ops.py b/test/test_ops.py index 407e7c9f..0ef10dac 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -47,6 +47,7 @@ NASA_AUDIO, NASA_AUDIO_MP3, NASA_VIDEO, + get_ffmpeg_major_version, needs_cuda, SINE_MONO_S32, SINE_MONO_S32_44100, @@ -474,6 +475,7 @@ def test_seek_mode_custom_frame_mappings_fails(self): decoder, stream_index=0, custom_frame_mappings=different_lengths ) + @pytest.mark.skipif(get_ffmpeg_major_version() in (4, 5), reason="ffprobe isn't accurate on ffmpeg 4 and 5") @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_mode_custom_frame_mappings(self, device): stream_index = 3 # custom_frame_index seek mode requires a stream index diff --git a/test/utils.py b/test/utils.py index 37c16d51..76490fa0 100644 --- a/test/utils.py +++ b/test/utils.py @@ -231,11 +231,6 @@ def get_frame_info( def get_custom_frame_mappings( self, stream_index: Optional[int] = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # Ensure all tests using this function are skipped if the FFmpeg version is 4 or 5 - # FFprobe on FFmpeg 4 and 5 does not return complete metadata - if get_ffmpeg_major_version() == 4 or get_ffmpeg_major_version() == 5: - pytest.skip("FFprobe on FFmpeg 4 and 5 does not return complete metadata") - if stream_index is None: stream_index = self.default_stream_index if self._custom_frame_mappings_data.get(stream_index) is None: From 4c6aadd7447775139cb75b27068eb896d90668ee Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Wed, 16 Jul 2025 09:46:05 -0400 Subject: [PATCH 26/26] Reflect commented suggestions --- src/torchcodec/_core/SingleStreamDecoder.cpp | 6 +++++- src/torchcodec/_core/SingleStreamDecoder.h | 22 ++++++++++---------- test/test_metadata.py | 2 +- test/test_ops.py | 7 +++++-- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 3c6c06f0..8174fd7d 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -199,6 +199,10 @@ int SingleStreamDecoder::getBestStreamIndex(AVMediaType mediaType) { // -------------------------------------------------------------------------- void SingleStreamDecoder::sortAllFrames() { + // Sort the allFrames and keyFrames vecs in each stream, and also sets + // additional fields of the FrameInfo entries like nextPts and frameIndex + // This is called at the end of a scan, or when setting a user-defined frame + // mapping. for (auto& [streamIndex, streamInfo] : streamInfos_) { std::sort( streamInfo.keyFrames.begin(), @@ -350,7 +354,7 @@ void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( for (int64_t i = 0; i < all_frames.size(0); ++i) { FrameInfo frameInfo; frameInfo.pts = all_frames[i].item(); - frameInfo.isKeyFrame = (is_key_frame[i].item() == true); + frameInfo.isKeyFrame = is_key_frame[i].item(); streamInfos_[streamIndex].allFrames.push_back(frameInfo); if (frameInfo.isKeyFrame) { streamInfos_[streamIndex].keyFrames.push_back(frameInfo); diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 83dc92ae..027f52fc 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -66,6 +66,10 @@ class SingleStreamDecoder { // FrameMappings is used for the custom_frame_mappings seek mode to store // metadata of frames in a stream. The size of all tensors in this struct must // match. + + // -------------------------------------------------------------------------- + // ADDING STREAMS API + // -------------------------------------------------------------------------- struct FrameMappings { // 1D tensor of int64, each value is the PTS of a frame in timebase units. torch::Tensor all_frames; @@ -77,17 +81,6 @@ class SingleStreamDecoder { torch::Tensor duration; }; - // Reads the user provided frame index and updates each StreamInfo's index, - // i.e. the allFrames and keyFrames vectors, and - // endStreamPtsSecondsFromContent - void readCustomFrameMappingsUpdateMetadataAndIndex( - int streamIndex, - FrameMappings customFrameMappings); - - // -------------------------------------------------------------------------- - // ADDING STREAMS API - // -------------------------------------------------------------------------- - void addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(), @@ -251,6 +244,13 @@ class SingleStreamDecoder { // -------------------------------------------------------------------------- void initializeDecoder(); + + // Reads the user provided frame index and updates each StreamInfo's index, + // i.e. the allFrames and keyFrames vectors, and + // endStreamPtsSecondsFromContent + void readCustomFrameMappingsUpdateMetadataAndIndex( + int streamIndex, + FrameMappings customFrameMappings); // -------------------------------------------------------------------------- // DECODING APIS AND RELATED UTILS // -------------------------------------------------------------------------- diff --git a/test/test_metadata.py b/test/test_metadata.py index a78a7c94..9f6b91ca 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -20,7 +20,7 @@ ) from torchcodec.decoders import AudioDecoder, VideoDecoder -from .utils import NASA_AUDIO_MP3, NASA_VIDEO, get_ffmpeg_major_version +from .utils import get_ffmpeg_major_version, NASA_AUDIO_MP3, NASA_VIDEO # TODO: Expected values in these tests should be based on the assets's diff --git a/test/test_ops.py b/test/test_ops.py index 0ef10dac..2b0e7801 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -44,10 +44,10 @@ from .utils import ( assert_frames_equal, cpu_and_cuda, + get_ffmpeg_major_version, NASA_AUDIO, NASA_AUDIO_MP3, NASA_VIDEO, - get_ffmpeg_major_version, needs_cuda, SINE_MONO_S32, SINE_MONO_S32_44100, @@ -475,7 +475,10 @@ def test_seek_mode_custom_frame_mappings_fails(self): decoder, stream_index=0, custom_frame_mappings=different_lengths ) - @pytest.mark.skipif(get_ffmpeg_major_version() in (4, 5), reason="ffprobe isn't accurate on ffmpeg 4 and 5") + @pytest.mark.skipif( + get_ffmpeg_major_version() in (4, 5), + reason="ffprobe isn't accurate on ffmpeg 4 and 5", + ) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_mode_custom_frame_mappings(self, device): stream_index = 3 # custom_frame_index seek mode requires a stream index