diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 2e027da3..61ffa39a 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -526,6 +526,12 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal( const auto& streamInfo = streamInfos_[activeStreamIndex_]; const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; + + std::optional numFrames = getNumFrames(streamMetadata); + if (numFrames.has_value()) { + // If the frameIndex is negative, we convert it to a positive index + frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value(); + } validateFrameIndex(streamMetadata, frameIndex); int64_t pts = getPts(frameIndex); @@ -568,8 +574,6 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( auto indexInOutput = indicesAreSorted ? f : argsort[f]; auto indexInVideo = frameIndices[indexInOutput]; - validateFrameIndex(streamMetadata, indexInVideo); - if ((f > 0) && (indexInVideo == previousIndexInVideo)) { // Avoid decoding the same frame twice auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1]; @@ -1559,21 +1563,24 @@ void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) { void SingleStreamDecoder::validateFrameIndex( const StreamMetadata& streamMetadata, int64_t frameIndex) { - TORCH_CHECK( - frameIndex >= 0, - "Invalid frame index=" + std::to_string(frameIndex) + - " for streamIndex=" + std::to_string(streamMetadata.streamIndex) + - "; must be greater than or equal to 0"); + if (frameIndex < 0) { + throw std::out_of_range( + "Invalid frame index=" + std::to_string(frameIndex) + + " for streamIndex=" + std::to_string(streamMetadata.streamIndex) + + "; negative indices must have an absolute value less than the number of frames, " + "and the number of frames must be known."); + } // Note that if we do not have the number of frames available in our metadata, // then we assume that the frameIndex is valid. std::optional numFrames = getNumFrames(streamMetadata); if (numFrames.has_value()) { - TORCH_CHECK( - frameIndex < numFrames.value(), - "Invalid frame index=" + std::to_string(frameIndex) + - " for streamIndex=" + std::to_string(streamMetadata.streamIndex) + - "; must be less than " + std::to_string(numFrames.value())); + if (frameIndex >= numFrames.value()) { + throw std::out_of_range( + "Invalid frame index=" + std::to_string(frameIndex) + + " for streamIndex=" + std::to_string(streamMetadata.streamIndex) + + "; must be less than " + std::to_string(numFrames.value())); + } } } diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index d7bd7a04..0ae4c2b8 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -125,13 +125,6 @@ def __len__(self) -> int: def _getitem_int(self, key: int) -> Tensor: assert isinstance(key, int) - if key < 0: - key += self._num_frames - if key >= self._num_frames or key < 0: - raise IndexError( - f"Index {key} is out of bounds; length is {self._num_frames}" - ) - frame_data, *_ = core.get_frame_at_index(self._decoder, frame_index=key) return frame_data @@ -195,13 +188,6 @@ def get_frame_at(self, index: int) -> Frame: Returns: Frame: The frame at the given index. """ - if index < 0: - index += self._num_frames - - if not 0 <= index < self._num_frames: - raise IndexError( - f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})." - ) data, pts_seconds, duration_seconds = core.get_frame_at_index( self._decoder, frame_index=index ) @@ -220,10 +206,6 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch: Returns: FrameBatch: The frames at the given indices. """ - indices = [ - index if index >= 0 else index + self._num_frames for index in indices - ] - data, pts_seconds, duration_seconds = core.get_frames_at_indices( self._decoder, frame_indices=indices ) @@ -247,16 +229,8 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc Returns: FrameBatch: The frames within the specified range. """ - if not 0 <= start < self._num_frames: - raise IndexError( - f"Start index {start} is out of bounds; must be in the range [0, {self._num_frames})." - ) - if stop < start: - raise IndexError( - f"Stop index ({stop}) must not be less than the start index ({start})." - ) - if not step > 0: - raise IndexError(f"Step ({step}) must be greater than 0.") + # Adjust start / stop indices to enable indexing semantics, ex. [-10, 1000] returns the last 10 frames + start, stop, step = slice(start, stop, step).indices(self._num_frames) frames = core.get_frames_in_range( self._decoder, start=start, diff --git a/test/test_decoders.py b/test/test_decoders.py index dcf9a158..f1a32d0a 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -375,10 +375,10 @@ def test_device_instance(self): def test_getitem_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) - with pytest.raises(IndexError, match="out of bounds"): + with pytest.raises(IndexError, match="Invalid frame index"): frame = decoder[1000] # noqa - with pytest.raises(IndexError, match="out of bounds"): + with pytest.raises(IndexError, match="Invalid frame index"): frame = decoder[-1000] # noqa with pytest.raises(TypeError, match="Unsupported key type"): @@ -487,10 +487,13 @@ def test_get_frame_at_tuple_unpacking(self, device): def test_get_frame_at_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) - with pytest.raises(IndexError, match="out of bounds"): + with pytest.raises( + IndexError, + match="negative indices must have an absolute value less than the number of frames", + ): frame = decoder.get_frame_at(-10000) # noqa - with pytest.raises(IndexError, match="out of bounds"): + with pytest.raises(IndexError, match="must be less than"): frame = decoder.get_frame_at(10000) # noqa @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -549,13 +552,13 @@ def test_get_frames_at(self, device, seek_mode): def test_get_frames_at_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) - expected_converted_index = -10000 + len(decoder) with pytest.raises( - RuntimeError, match=f"Invalid frame index={expected_converted_index}" + IndexError, + match="negative indices must have an absolute value less than the number of frames", ): decoder.get_frames_at([-10000]) - with pytest.raises(RuntimeError, match="Invalid frame index=390"): + with pytest.raises(IndexError, match="Invalid frame index=390"): decoder.get_frames_at([390]) with pytest.raises(RuntimeError, match="Expected a value of type"): @@ -772,6 +775,58 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode): empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds ) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode): + decoder = VideoDecoder( + NASA_VIDEO.path, + stream_index=3, + device=device, + seek_mode=seek_mode, + ) + + # high range ends get capped to num_frames + frames387_389 = decoder.get_frames_in_range(start=387, stop=1000) + assert frames387_389.data.shape == torch.Size( + [ + 3, + NASA_VIDEO.get_num_color_channels(stream_index=3), + NASA_VIDEO.get_height(stream_index=3), + NASA_VIDEO.get_width(stream_index=3), + ] + ) + ref_frame387_389 = NASA_VIDEO.get_frame_data_by_range( + start=387, stop=390, stream_index=3 + ).to(device) + assert_frames_equal(frames387_389.data, ref_frame387_389) + + # negative indices are converted + frames387_389 = decoder.get_frames_in_range(start=-3, stop=1000) + assert frames387_389.data.shape == torch.Size( + [ + 3, + NASA_VIDEO.get_num_color_channels(stream_index=3), + NASA_VIDEO.get_height(stream_index=3), + NASA_VIDEO.get_width(stream_index=3), + ] + ) + assert_frames_equal(frames387_389.data, ref_frame387_389) + + # "None" as stop is treated as end of the video + frames387_None = decoder.get_frames_in_range(start=-3, stop=None) + assert frames387_None.data.shape == torch.Size( + [ + 3, + NASA_VIDEO.get_num_color_channels(stream_index=3), + NASA_VIDEO.get_height(stream_index=3), + NASA_VIDEO.get_width(stream_index=3), + ] + ) + reference_frame387_389 = NASA_VIDEO.get_frame_data_by_range( + start=387, stop=390, stream_index=3 + ).to(device) + assert_frames_equal(frames387_None.data, reference_frame387_389) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) @patch("torchcodec._core._metadata._get_stream_json_metadata") diff --git a/test/test_ops.py b/test/test_ops.py index 2f691615..24742124 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -125,6 +125,10 @@ def test_get_frame_at_index(self, device): INDEX_OF_FRAME_AT_6_SECONDS ) assert_frames_equal(frame6, reference_frame6.to(device)) + # Negative indices are supported + frame389 = get_frame_at_index(decoder, frame_index=-1) + reference_frame389 = NASA_VIDEO.get_frame_data_by_index(389) + assert_frames_equal(frame389[0], reference_frame389.to(device)) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_with_info_at_index(self, device): @@ -177,6 +181,32 @@ def test_get_frames_at_indices_unsorted_indices(self, device): with pytest.raises(AssertionError): assert_frames_equal(frames[0], frames[-1]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frames_at_indices_negative_indices(self, device): + decoder = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder, device=device) + frames389and387and1, *_ = get_frames_at_indices( + decoder, frame_indices=[-1, -3, -389] + ) + reference_frame389 = NASA_VIDEO.get_frame_data_by_index(389) + reference_frame387 = NASA_VIDEO.get_frame_data_by_index(387) + reference_frame1 = NASA_VIDEO.get_frame_data_by_index(1) + assert_frames_equal(frames389and387and1[0], reference_frame389.to(device)) + assert_frames_equal(frames389and387and1[1], reference_frame387.to(device)) + assert_frames_equal(frames389and387and1[2], reference_frame1.to(device)) + + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frames_at_indices_fail_on_invalid_negative_indices(self, device): + decoder = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder, device=device) + with pytest.raises( + IndexError, + match="negative indices must have an absolute value less than the number of frames", + ): + invalid_frames, *_ = get_frames_at_indices( + decoder, frame_indices=[-10000, -3000] + ) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_by_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path))