diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index fff0fe9f..548f59c3 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -195,6 +195,8 @@ def get_frame_at(self, index: int) -> Frame: Returns: Frame: The frame at the given index. """ + if index < 0: + index += self._num_frames if not 0 <= index < self._num_frames: raise IndexError( @@ -218,6 +220,9 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch: Returns: FrameBatch: The frames at the given indices. """ + indices = [ + index if index >= 0 else index + self._num_frames for index in indices + ] data, pts_seconds, duration_seconds = core.get_frames_at_indices( self._decoder, frame_indices=indices diff --git a/test/test_decoders.py b/test/test_decoders.py index fbe55290..dcf9a158 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -328,6 +328,19 @@ def test_getitem_slice(self, device, seek_mode): ) assert_frames_equal(ref386_389, slice386_389) + # slices with upper bound greater than len(decoder) are supported + slice387_389 = decoder[-3:10000].to(device) + assert slice387_389.shape == torch.Size( + [ + 3, + NASA_VIDEO.num_color_channels, + NASA_VIDEO.height, + NASA_VIDEO.width, + ] + ) + ref387_389 = NASA_VIDEO.get_frame_data_by_range(387, 390).to(device) + assert_frames_equal(ref387_389, slice387_389) + # an empty range is valid! empty_frame = decoder[5:5] assert_frames_equal(empty_frame, NASA_VIDEO.empty_chw_tensor.to(device)) @@ -437,6 +450,11 @@ def test_get_frame_at(self, device, seek_mode): expected_frame_info.duration_seconds, rel=1e-3 ) + # test negative frame index + frame_minus1 = decoder.get_frame_at(-1) + ref_frame_minus1 = NASA_VIDEO.get_frame_data_by_index(389).to(device) + assert_frames_equal(ref_frame_minus1, frame_minus1.data) + # test numpy.int64 frame9 = decoder.get_frame_at(numpy.int64(9)) assert_frames_equal(ref_frame9, frame9.data) @@ -470,7 +488,7 @@ def test_get_frame_at_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) with pytest.raises(IndexError, match="out of bounds"): - frame = decoder.get_frame_at(-1) # noqa + frame = decoder.get_frame_at(-10000) # noqa with pytest.raises(IndexError, match="out of bounds"): frame = decoder.get_frame_at(10000) # noqa @@ -480,7 +498,8 @@ def test_get_frame_at_fails(self, device, seek_mode): def test_get_frames_at(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) - frames = decoder.get_frames_at([35, 25]) + # test positive and negative frame index + frames = decoder.get_frames_at([35, 25, -1, -2]) assert isinstance(frames, FrameBatch) @@ -490,12 +509,20 @@ def test_get_frames_at(self, device, seek_mode): assert_frames_equal( frames[1].data, NASA_VIDEO.get_frame_data_by_index(25).to(device) ) + assert_frames_equal( + frames[2].data, NASA_VIDEO.get_frame_data_by_index(389).to(device) + ) + assert_frames_equal( + frames[3].data, NASA_VIDEO.get_frame_data_by_index(388).to(device) + ) assert frames.pts_seconds.device.type == "cpu" expected_pts_seconds = torch.tensor( [ NASA_VIDEO.get_frame_info(35).pts_seconds, NASA_VIDEO.get_frame_info(25).pts_seconds, + NASA_VIDEO.get_frame_info(389).pts_seconds, + NASA_VIDEO.get_frame_info(388).pts_seconds, ], dtype=torch.float64, ) @@ -508,6 +535,8 @@ def test_get_frames_at(self, device, seek_mode): [ NASA_VIDEO.get_frame_info(35).duration_seconds, NASA_VIDEO.get_frame_info(25).duration_seconds, + NASA_VIDEO.get_frame_info(389).duration_seconds, + NASA_VIDEO.get_frame_info(388).duration_seconds, ], dtype=torch.float64, ) @@ -520,8 +549,11 @@ def test_get_frames_at(self, device, seek_mode): def test_get_frames_at_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) - with pytest.raises(RuntimeError, match="Invalid frame index=-1"): - decoder.get_frames_at([-1]) + expected_converted_index = -10000 + len(decoder) + with pytest.raises( + RuntimeError, match=f"Invalid frame index={expected_converted_index}" + ): + decoder.get_frames_at([-10000]) with pytest.raises(RuntimeError, match="Invalid frame index=390"): decoder.get_frames_at([390])