Skip to content

Support negative index in SimpleVideoDecoder #743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ def get_frame_at(self, index: int) -> Frame:
Returns:
Frame: The frame at the given index.
"""
if index < 0:
index += self._num_frames

if not 0 <= index < self._num_frames:
raise IndexError(
Expand All @@ -218,6 +220,9 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
Returns:
FrameBatch: The frames at the given indices.
"""
indices = [
index if index >= 0 else index + self._num_frames for index in indices
]

data, pts_seconds, duration_seconds = core.get_frames_at_indices(
self._decoder, frame_indices=indices
Expand Down
40 changes: 36 additions & 4 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,19 @@ def test_getitem_slice(self, device, seek_mode):
)
assert_frames_equal(ref386_389, slice386_389)

# slices with upper bound greater than len(decoder) are supported
slice387_389 = decoder[-3:10000].to(device)
assert slice387_389.shape == torch.Size(
[
3,
NASA_VIDEO.num_color_channels,
NASA_VIDEO.height,
NASA_VIDEO.width,
]
)
ref387_389 = NASA_VIDEO.get_frame_data_by_range(387, 390).to(device)
assert_frames_equal(ref387_389, slice387_389)

# an empty range is valid!
empty_frame = decoder[5:5]
assert_frames_equal(empty_frame, NASA_VIDEO.empty_chw_tensor.to(device))
Expand Down Expand Up @@ -437,6 +450,11 @@ def test_get_frame_at(self, device, seek_mode):
expected_frame_info.duration_seconds, rel=1e-3
)

# test negative frame index
frame_minus1 = decoder.get_frame_at(-1)
ref_frame_minus1 = NASA_VIDEO.get_frame_data_by_index(389).to(device)
assert_frames_equal(ref_frame_minus1, frame_minus1.data)

# test numpy.int64
frame9 = decoder.get_frame_at(numpy.int64(9))
assert_frames_equal(ref_frame9, frame9.data)
Expand Down Expand Up @@ -470,7 +488,7 @@ def test_get_frame_at_fails(self, device, seek_mode):
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)

with pytest.raises(IndexError, match="out of bounds"):
frame = decoder.get_frame_at(-1) # noqa
frame = decoder.get_frame_at(-10000) # noqa

with pytest.raises(IndexError, match="out of bounds"):
frame = decoder.get_frame_at(10000) # noqa
Expand All @@ -480,7 +498,8 @@ def test_get_frame_at_fails(self, device, seek_mode):
def test_get_frames_at(self, device, seek_mode):
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)

frames = decoder.get_frames_at([35, 25])
# test positive and negative frame index
frames = decoder.get_frames_at([35, 25, -1, -2])

assert isinstance(frames, FrameBatch)

Expand All @@ -490,12 +509,20 @@ def test_get_frames_at(self, device, seek_mode):
assert_frames_equal(
frames[1].data, NASA_VIDEO.get_frame_data_by_index(25).to(device)
)
assert_frames_equal(
frames[2].data, NASA_VIDEO.get_frame_data_by_index(389).to(device)
)
assert_frames_equal(
frames[3].data, NASA_VIDEO.get_frame_data_by_index(388).to(device)
)

assert frames.pts_seconds.device.type == "cpu"
expected_pts_seconds = torch.tensor(
[
NASA_VIDEO.get_frame_info(35).pts_seconds,
NASA_VIDEO.get_frame_info(25).pts_seconds,
NASA_VIDEO.get_frame_info(389).pts_seconds,
NASA_VIDEO.get_frame_info(388).pts_seconds,
],
dtype=torch.float64,
)
Expand All @@ -508,6 +535,8 @@ def test_get_frames_at(self, device, seek_mode):
[
NASA_VIDEO.get_frame_info(35).duration_seconds,
NASA_VIDEO.get_frame_info(25).duration_seconds,
NASA_VIDEO.get_frame_info(389).duration_seconds,
NASA_VIDEO.get_frame_info(388).duration_seconds,
],
dtype=torch.float64,
)
Expand All @@ -520,8 +549,11 @@ def test_get_frames_at(self, device, seek_mode):
def test_get_frames_at_fails(self, device, seek_mode):
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)

with pytest.raises(RuntimeError, match="Invalid frame index=-1"):
decoder.get_frames_at([-1])
expected_converted_index = -10000 + len(decoder)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not necessary to get the expected index here, but I added it to match the other tests. Alternatively, the test can simply match the text: "Invalid frame index".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine to do, but I usually default to just matching the text that doesn't change.

with pytest.raises(
RuntimeError, match=f"Invalid frame index={expected_converted_index}"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error is from SingleStreamDecoder.cpp. Should we also be checking valid indices at the Python level, in _video_decoder.py?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good observation - at the Python level, we currently don't check the input when we get a list of indices (get_frames_at()) or a list of timestamps (get_frames_played_at()). The weak rationale was that we'd have to walk the list to do the check, and that seemed wasteful. Now that we're walking the list in get_frames_at() to deal with negative indices, we might as well do error checking at the same time.

The question, then, is what to do in get_frames_played_at()? For completeness, I think it does make sense to do error checking at the Python level there, but let's do that in a separate PR. We can also do some refactoring there, and make _getitem_slice() call VideoDecoder.get_frames_in_range() in order to get that error checking at the Python level as well.

):
decoder.get_frames_at([-10000])

with pytest.raises(RuntimeError, match="Invalid frame index=390"):
decoder.get_frames_at([390])
Expand Down
Loading