diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 2e027da3..8174fd7d 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -198,6 +198,45 @@ int SingleStreamDecoder::getBestStreamIndex(AVMediaType mediaType) { // VIDEO METADATA QUERY API // -------------------------------------------------------------------------- +void SingleStreamDecoder::sortAllFrames() { + // Sort the allFrames and keyFrames vecs in each stream, and also sets + // additional fields of the FrameInfo entries like nextPts and frameIndex + // This is called at the end of a scan, or when setting a user-defined frame + // mapping. + for (auto& [streamIndex, streamInfo] : streamInfos_) { + std::sort( + streamInfo.keyFrames.begin(), + streamInfo.keyFrames.end(), + [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) { + return frameInfo1.pts < frameInfo2.pts; + }); + std::sort( + streamInfo.allFrames.begin(), + streamInfo.allFrames.end(), + [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) { + return frameInfo1.pts < frameInfo2.pts; + }); + + size_t keyFrameIndex = 0; + for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) { + streamInfo.allFrames[i].frameIndex = i; + if (streamInfo.allFrames[i].isKeyFrame) { + TORCH_CHECK( + keyFrameIndex < streamInfo.keyFrames.size(), + "The allFrames vec claims it has MORE keyFrames than the keyFrames vec. There's a bug in torchcodec."); + streamInfo.keyFrames[keyFrameIndex].frameIndex = i; + ++keyFrameIndex; + } + if (i + 1 < streamInfo.allFrames.size()) { + streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts; + } + } + TORCH_CHECK( + keyFrameIndex == streamInfo.keyFrames.size(), + "The allFrames vec claims it has LESS keyFrames than the keyFrames vec. There's a bug in torchcodec."); + } +} + void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { if (scannedAllStreams_) { return; @@ -283,40 +322,46 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { getFFMPEGErrorStringFromErrorCode(status)); // Sort all frames by their pts. - for (auto& [streamIndex, streamInfo] : streamInfos_) { - std::sort( - streamInfo.keyFrames.begin(), - streamInfo.keyFrames.end(), - [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) { - return frameInfo1.pts < frameInfo2.pts; - }); - std::sort( - streamInfo.allFrames.begin(), - streamInfo.allFrames.end(), - [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) { - return frameInfo1.pts < frameInfo2.pts; - }); + sortAllFrames(); + scannedAllStreams_ = true; +} - size_t keyFrameIndex = 0; - for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) { - streamInfo.allFrames[i].frameIndex = i; - if (streamInfo.allFrames[i].isKeyFrame) { - TORCH_CHECK( - keyFrameIndex < streamInfo.keyFrames.size(), - "The allFrames vec claims it has MORE keyFrames than the keyFrames vec. There's a bug in torchcodec."); - streamInfo.keyFrames[keyFrameIndex].frameIndex = i; - ++keyFrameIndex; - } - if (i + 1 < streamInfo.allFrames.size()) { - streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts; - } +void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( + int streamIndex, + FrameMappings customFrameMappings) { + auto& all_frames = customFrameMappings.all_frames; + auto& is_key_frame = customFrameMappings.is_key_frame; + auto& duration = customFrameMappings.duration; + TORCH_CHECK( + all_frames.size(0) == is_key_frame.size(0) && + is_key_frame.size(0) == duration.size(0), + "all_frames, is_key_frame, and duration from custom_frame_mappings were not same size."); + + auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; + + streamMetadata.beginStreamPtsFromContent = all_frames[0].item(); + streamMetadata.endStreamPtsFromContent = + all_frames[-1].item() + duration[-1].item(); + + auto avStream = formatContext_->streams[streamIndex]; + streamMetadata.beginStreamPtsSecondsFromContent = + *streamMetadata.beginStreamPtsFromContent * av_q2d(avStream->time_base); + + streamMetadata.endStreamPtsSecondsFromContent = + *streamMetadata.endStreamPtsFromContent * av_q2d(avStream->time_base); + + streamMetadata.numFramesFromContent = all_frames.size(0); + for (int64_t i = 0; i < all_frames.size(0); ++i) { + FrameInfo frameInfo; + frameInfo.pts = all_frames[i].item(); + frameInfo.isKeyFrame = is_key_frame[i].item(); + streamInfos_[streamIndex].allFrames.push_back(frameInfo); + if (frameInfo.isKeyFrame) { + streamInfos_[streamIndex].keyFrames.push_back(frameInfo); } - TORCH_CHECK( - keyFrameIndex == streamInfo.keyFrames.size(), - "The allFrames vec claims it has LESS keyFrames than the keyFrames vec. There's a bug in torchcodec."); } - - scannedAllStreams_ = true; + // Sort all frames by their pts + sortAllFrames(); } ContainerMetadata SingleStreamDecoder::getContainerMetadata() const { @@ -431,7 +476,8 @@ void SingleStreamDecoder::addStream( void SingleStreamDecoder::addVideoStream( int streamIndex, - const VideoStreamOptions& videoStreamOptions) { + const VideoStreamOptions& videoStreamOptions, + std::optional customFrameMappings) { addStream( streamIndex, AVMEDIA_TYPE_VIDEO, @@ -456,6 +502,14 @@ void SingleStreamDecoder::addVideoStream( streamMetadata.height = streamInfo.codecContext->height; streamMetadata.sampleAspectRatio = streamInfo.codecContext->sample_aspect_ratio; + + if (seekMode_ == SeekMode::custom_frame_mappings) { + TORCH_CHECK( + customFrameMappings.has_value(), + "Please provide frame mappings when using custom_frame_mappings seek mode."); + readCustomFrameMappingsUpdateMetadataAndIndex( + streamIndex, customFrameMappings.value()); + } } void SingleStreamDecoder::addAudioStream( @@ -1407,6 +1461,7 @@ int SingleStreamDecoder::getKeyFrameIndexForPtsUsingScannedIndex( int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) { auto& streamInfo = streamInfos_[activeStreamIndex_]; switch (seekMode_) { + case SeekMode::custom_frame_mappings: case SeekMode::exact: { auto frame = std::lower_bound( streamInfo.allFrames.begin(), @@ -1434,6 +1489,7 @@ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) { int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) { auto& streamInfo = streamInfos_[activeStreamIndex_]; switch (seekMode_) { + case SeekMode::custom_frame_mappings: case SeekMode::exact: { auto frame = std::upper_bound( streamInfo.allFrames.begin(), @@ -1461,6 +1517,7 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) { int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { auto& streamInfo = streamInfos_[activeStreamIndex_]; switch (seekMode_) { + case SeekMode::custom_frame_mappings: case SeekMode::exact: return streamInfo.allFrames[frameIndex].pts; case SeekMode::approximate: { @@ -1485,6 +1542,7 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { std::optional SingleStreamDecoder::getNumFrames( const StreamMetadata& streamMetadata) { switch (seekMode_) { + case SeekMode::custom_frame_mappings: case SeekMode::exact: return streamMetadata.numFramesFromContent.value(); case SeekMode::approximate: { @@ -1498,6 +1556,7 @@ std::optional SingleStreamDecoder::getNumFrames( double SingleStreamDecoder::getMinSeconds( const StreamMetadata& streamMetadata) { switch (seekMode_) { + case SeekMode::custom_frame_mappings: case SeekMode::exact: return streamMetadata.beginStreamPtsSecondsFromContent.value(); case SeekMode::approximate: @@ -1510,6 +1569,7 @@ double SingleStreamDecoder::getMinSeconds( std::optional SingleStreamDecoder::getMaxSeconds( const StreamMetadata& streamMetadata) { switch (seekMode_) { + case SeekMode::custom_frame_mappings: case SeekMode::exact: return streamMetadata.endStreamPtsSecondsFromContent.value(); case SeekMode::approximate: { @@ -1645,6 +1705,8 @@ SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) { return SingleStreamDecoder::SeekMode::exact; } else if (seekMode == "approximate") { return SingleStreamDecoder::SeekMode::approximate; + } else if (seekMode == "custom_frame_mappings") { + return SingleStreamDecoder::SeekMode::custom_frame_mappings; } else { TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode)); } diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index dec102d1..027f52fc 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -29,7 +29,7 @@ class SingleStreamDecoder { // CONSTRUCTION API // -------------------------------------------------------------------------- - enum class SeekMode { exact, approximate }; + enum class SeekMode { exact, approximate, custom_frame_mappings }; // Creates a SingleStreamDecoder from the video at videoFilePath. explicit SingleStreamDecoder( @@ -53,6 +53,9 @@ class SingleStreamDecoder { // the allFrames and keyFrames vectors. void scanFileAndUpdateMetadataAndIndex(); + // Sorts the keyFrames and allFrames vectors in each StreamInfo by pts. + void sortAllFrames(); + // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; @@ -60,13 +63,28 @@ class SingleStreamDecoder { // int64 values, where each value is the frame index for a key frame. torch::Tensor getKeyFrameIndices(); + // FrameMappings is used for the custom_frame_mappings seek mode to store + // metadata of frames in a stream. The size of all tensors in this struct must + // match. + // -------------------------------------------------------------------------- // ADDING STREAMS API // -------------------------------------------------------------------------- + struct FrameMappings { + // 1D tensor of int64, each value is the PTS of a frame in timebase units. + torch::Tensor all_frames; + // 1D tensor of bool, each value indicates if the corresponding frame in + // all_frames is a key frame. + torch::Tensor is_key_frame; + // 1D tensor of int64, each value is the duration of the corresponding frame + // in all_frames in timebase units. + torch::Tensor duration; + }; void addVideoStream( int streamIndex, - const VideoStreamOptions& videoStreamOptions = VideoStreamOptions()); + const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(), + std::optional customFrameMappings = std::nullopt); void addAudioStream( int streamIndex, const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); @@ -226,6 +244,13 @@ class SingleStreamDecoder { // -------------------------------------------------------------------------- void initializeDecoder(); + + // Reads the user provided frame index and updates each StreamInfo's index, + // i.e. the allFrames and keyFrames vectors, and + // endStreamPtsSecondsFromContent + void readCustomFrameMappingsUpdateMetadataAndIndex( + int streamIndex, + FrameMappings customFrameMappings); // -------------------------------------------------------------------------- // DECODING APIS AND RELATED UTILS // -------------------------------------------------------------------------- diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 4aa68a3b..2bae359a 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -36,9 +36,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def("_convert_to_tensor(int decoder_ptr) -> Tensor"); m.def( - "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? color_conversion_library=None) -> ()"); + "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()"); m.def( - "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None) -> ()"); + "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()"); m.def( "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()"); m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()"); @@ -105,6 +105,14 @@ OpsFrameOutput makeOpsFrameOutput(FrameOutput& frame) { torch::tensor(frame.durationSeconds, torch::dtype(torch::kFloat64))); } +SingleStreamDecoder::FrameMappings makeFrameMappings( + std::tuple custom_frame_mappings) { + return SingleStreamDecoder::FrameMappings{ + std::get<0>(custom_frame_mappings), + std::get<1>(custom_frame_mappings), + std::get<2>(custom_frame_mappings)}; +} + // All elements of this tuple are tensors of the same leading dimension. The // tuple represents the frames for N total frames, where N is the dimension of // each stacked tensor. The elments are: @@ -223,6 +231,8 @@ void _add_video_stream( std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, std::optional device = std::nullopt, + std::optional> + custom_frame_mappings = std::nullopt, std::optional color_conversion_library = std::nullopt) { VideoStreamOptions videoStreamOptions; videoStreamOptions.width = width; @@ -253,9 +263,13 @@ void _add_video_stream( if (device.has_value()) { videoStreamOptions.device = createTorchDevice(std::string(device.value())); } - + std::optional converted_mappings = + custom_frame_mappings.has_value() + ? std::make_optional(makeFrameMappings(custom_frame_mappings.value())) + : std::nullopt; auto videoDecoder = unwrapTensorToGetDecoder(decoder); - videoDecoder->addVideoStream(stream_index.value_or(-1), videoStreamOptions); + videoDecoder->addVideoStream( + stream_index.value_or(-1), videoStreamOptions, converted_mappings); } // Add a new video stream at `stream_index` using the provided options. @@ -266,7 +280,9 @@ void add_video_stream( std::optional num_threads = std::nullopt, std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, - std::optional device = std::nullopt) { + std::optional device = std::nullopt, + std::optional> + custom_frame_mappings = std::nullopt) { _add_video_stream( decoder, width, @@ -274,7 +290,8 @@ void add_video_stream( num_threads, dimension_order, stream_index, - device); + device, + custom_frame_mappings); } void add_audio_stream( diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index a68b51e2..21046a33 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -205,6 +205,9 @@ def _add_video_stream_abstract( dimension_order: Optional[str] = None, stream_index: Optional[int] = None, device: Optional[str] = None, + custom_frame_mappings: Optional[ + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ] = None, color_conversion_library: Optional[str] = None, ) -> None: return @@ -220,6 +223,9 @@ def add_video_stream_abstract( dimension_order: Optional[str] = None, stream_index: Optional[int] = None, device: Optional[str] = None, + custom_frame_mappings: Optional[ + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ] = None, ) -> None: return diff --git a/test/test_metadata.py b/test/test_metadata.py index 3de7d377..9f6b91ca 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -10,6 +10,7 @@ import pytest from torchcodec._core import ( + add_video_stream, AudioStreamMetadata, create_from_file, get_container_metadata, @@ -19,7 +20,7 @@ ) from torchcodec.decoders import AudioDecoder, VideoDecoder -from .utils import NASA_AUDIO_MP3, NASA_VIDEO +from .utils import get_ffmpeg_major_version, NASA_AUDIO_MP3, NASA_VIDEO # TODO: Expected values in these tests should be based on the assets's @@ -28,6 +29,17 @@ def _get_container_metadata(path, seek_mode): decoder = create_from_file(str(path), seek_mode=seek_mode) + + # For custom_frame_mappings seek mode, add a video stream to update metadata + if seek_mode == "custom_frame_mappings": + custom_frame_mappings = NASA_VIDEO.get_custom_frame_mappings() + + # Add the best video stream (index 3 for NASA_VIDEO) + add_video_stream( + decoder, + stream_index=NASA_VIDEO.default_stream_index, + custom_frame_mappings=custom_frame_mappings, + ) return get_container_metadata(decoder) @@ -37,18 +49,26 @@ def _get_container_metadata(path, seek_mode): get_container_metadata_from_header, functools.partial(_get_container_metadata, seek_mode="approximate"), functools.partial(_get_container_metadata, seek_mode="exact"), + functools.partial(_get_container_metadata, seek_mode="custom_frame_mappings"), ), ) def test_get_metadata(metadata_getter): + seek_mode = ( + metadata_getter.keywords["seek_mode"] + if isinstance(metadata_getter, functools.partial) + else None + ) + if (seek_mode == "custom_frame_mappings") and get_ffmpeg_major_version() in (4, 5): + pytest.skip(reason="ffprobe isn't accurate on ffmpeg 4 and 5") + with_added_video_stream = seek_mode == "custom_frame_mappings" + metadata = metadata_getter(NASA_VIDEO.path) + with_scan = ( - metadata_getter.keywords["seek_mode"] == "exact" + (seek_mode == "exact" or seek_mode == "custom_frame_mappings") if isinstance(metadata_getter, functools.partial) else False ) - metadata = metadata_getter(NASA_VIDEO.path) - # metadata = metadata_getter(NASA_VIDEO.path) - assert len(metadata.streams) == 6 assert metadata.best_video_stream_index == 3 assert metadata.best_audio_stream_index == 4 @@ -82,7 +102,9 @@ def test_get_metadata(metadata_getter): assert best_video_stream_metadata.begin_stream_seconds_from_header == 0 assert best_video_stream_metadata.bit_rate == 128783 assert best_video_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001) - assert best_video_stream_metadata.pixel_aspect_ratio is None + assert best_video_stream_metadata.pixel_aspect_ratio == ( + Fraction(1, 1) if with_added_video_stream else None + ) assert best_video_stream_metadata.codec == "h264" assert best_video_stream_metadata.num_frames_from_content == ( 390 if with_scan else None diff --git a/test/test_ops.py b/test/test_ops.py index 2f691615..2b0e7801 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -44,6 +44,7 @@ from .utils import ( assert_frames_equal, cpu_and_cuda, + get_ffmpeg_major_version, NASA_AUDIO, NASA_AUDIO_MP3, NASA_VIDEO, @@ -448,6 +449,73 @@ def test_frame_pts_equality(self): ) assert pts_is_equal + def test_seek_mode_custom_frame_mappings_fails(self): + decoder = create_from_file( + str(NASA_VIDEO.path), seek_mode="custom_frame_mappings" + ) + with pytest.raises( + RuntimeError, + match="Please provide frame mappings when using custom_frame_mappings seek mode.", + ): + add_video_stream(decoder, stream_index=0, custom_frame_mappings=None) + + decoder = create_from_file( + str(NASA_VIDEO.path), seek_mode="custom_frame_mappings" + ) + different_lengths = ( + torch.tensor([1, 2, 3]), + torch.tensor([1, 2]), + torch.tensor([1, 2, 3]), + ) + with pytest.raises( + RuntimeError, + match="all_frames, is_key_frame, and duration from custom_frame_mappings were not same size.", + ): + add_video_stream( + decoder, stream_index=0, custom_frame_mappings=different_lengths + ) + + @pytest.mark.skipif( + get_ffmpeg_major_version() in (4, 5), + reason="ffprobe isn't accurate on ffmpeg 4 and 5", + ) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_seek_mode_custom_frame_mappings(self, device): + stream_index = 3 # custom_frame_index seek mode requires a stream index + decoder = create_from_file( + str(NASA_VIDEO.path), seek_mode="custom_frame_mappings" + ) + add_video_stream( + decoder, + device=device, + stream_index=stream_index, + custom_frame_mappings=NASA_VIDEO.get_custom_frame_mappings( + stream_index=stream_index + ), + ) + + frame0, _, _ = get_next_frame(decoder) + reference_frame0 = NASA_VIDEO.get_frame_data_by_index( + 0, stream_index=stream_index + ) + assert_frames_equal(frame0, reference_frame0.to(device)) + + frame6, _, _ = get_frame_at_pts(decoder, 6.006) + reference_frame6 = NASA_VIDEO.get_frame_data_by_index( + INDEX_OF_FRAME_AT_6_SECONDS, stream_index=stream_index + ) + assert_frames_equal(frame6, reference_frame6.to(device)) + + frame6, _, _ = get_frame_at_index(decoder, frame_index=180) + reference_frame6 = NASA_VIDEO.get_frame_data_by_index( + INDEX_OF_FRAME_AT_6_SECONDS, stream_index=stream_index + ) + assert_frames_equal(frame6, reference_frame6.to(device)) + + ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range(0, 9) + bulk_frames0_9, *_ = get_frames_in_range(decoder, start=0, stop=9) + assert_frames_equal(bulk_frames0_9, ref_frames0_9.to(device)) + @pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale")) def test_color_conversion_library(self, color_conversion_library): decoder = create_from_file(str(NASA_VIDEO.path)) diff --git a/test/utils.py b/test/utils.py index e7ce12e5..76490fa0 100644 --- a/test/utils.py +++ b/test/utils.py @@ -2,6 +2,7 @@ import json import os import pathlib +import subprocess import sys from dataclasses import dataclass, field @@ -121,6 +122,9 @@ class TestContainerFile: default_stream_index: int stream_infos: Dict[int, Union[TestVideoStreamInfo, TestAudioStreamInfo]] frames: Dict[int, Dict[int, TestFrameInfo]] + _custom_frame_mappings_data: Dict[ + int, Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + ] = field(default_factory=dict) def __post_init__(self): # We load the .frames attribute from the checked-in json files, if needed. @@ -223,6 +227,48 @@ def get_frame_info( return self.frames[stream_index][idx] + # This function is used to get the frame mappings for the custom_frame_mappings seek mode. + def get_custom_frame_mappings( + self, stream_index: Optional[int] = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if stream_index is None: + stream_index = self.default_stream_index + if self._custom_frame_mappings_data.get(stream_index) is None: + self.generate_custom_frame_mappings(stream_index) + return self._custom_frame_mappings_data[stream_index] + + def generate_custom_frame_mappings(self, stream_index: int) -> None: + result = json.loads( + subprocess.run( + [ + "ffprobe", + "-i", + f"{self.path}", + "-select_streams", + f"{stream_index}", + "-show_frames", + "-of", + "json", + ], + check=True, + capture_output=True, + text=True, + ).stdout + ) + all_frames = torch.tensor([float(frame["pts"]) for frame in result["frames"]]) + is_key_frame = torch.tensor([frame["key_frame"] for frame in result["frames"]]) + duration = torch.tensor( + [float(frame["duration"]) for frame in result["frames"]] + ) + assert ( + len(all_frames) == len(is_key_frame) == len(duration) + ), "Mismatched lengths in frame index data" + self._custom_frame_mappings_data[stream_index] = ( + all_frames, + is_key_frame, + duration, + ) + @property def empty_pts_seconds(self) -> torch.Tensor: return torch.empty([0], dtype=torch.float64)