Skip to content

Convert negative frame indices in C++, convert slices in Python #746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,12 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
const auto& streamInfo = streamInfos_[activeStreamIndex_];
const auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];

std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
if (numFrames.has_value()) {
// If the frameIndex is negative, we convert it to a positive index
frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value();
}
Copy link
Contributor

@scotts scotts Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If numFrames does not have a value, then I think we need to assert that frameIndex is positive, and throw a std::range_error if it's not. Otherwise, we'll end up using a negative frameIndex further down.

Ah, now I see that we actually do that in validateFrameIndex(), which we call right after. 👍

validateFrameIndex(streamMetadata, frameIndex);

int64_t pts = getPts(frameIndex);
Expand Down Expand Up @@ -568,8 +574,6 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
auto indexInOutput = indicesAreSorted ? f : argsort[f];
auto indexInVideo = frameIndices[indexInOutput];

validateFrameIndex(streamMetadata, indexInVideo);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this call to validateFrameIndex is not needed, since getFrameAtIndexInternal will also call validateFrameIndex.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Yes, although I had to think about it. What gave me pause is the fact that we also use it in the if branch, but if indexInVideo == previousIndexInVideo then we must have checked it during the previous iteration so we're safe.

if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
// Avoid decoding the same frame twice
auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1];
Expand Down Expand Up @@ -1559,21 +1563,24 @@ void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) {
void SingleStreamDecoder::validateFrameIndex(
const StreamMetadata& streamMetadata,
int64_t frameIndex) {
TORCH_CHECK(
frameIndex >= 0,
"Invalid frame index=" + std::to_string(frameIndex) +
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
"; must be greater than or equal to 0");
if (frameIndex < 0) {
throw std::out_of_range(
"Invalid frame index=" + std::to_string(frameIndex) +
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
"; negative indices must have an absolute value less than the number of frames, "
"and the number of frames must be known.");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could end up being wordy, but we might want to mention that we can't receive any negative index if we don't know the number of frames. Maybe something like:

throw std::out_of_range(
        "Invalid frame index=" + std::to_string(frameIndex) +
        " for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
        "; negative indices must have an absolute value that is less than the number of frames, "
        "and the number of frames must be known.");
  }


// Note that if we do not have the number of frames available in our metadata,
// then we assume that the frameIndex is valid.
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
if (numFrames.has_value()) {
TORCH_CHECK(
frameIndex < numFrames.value(),
"Invalid frame index=" + std::to_string(frameIndex) +
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
"; must be less than " + std::to_string(numFrames.value()));
if (frameIndex >= numFrames.value()) {
throw std::out_of_range(
"Invalid frame index=" + std::to_string(frameIndex) +
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
"; must be less than " + std::to_string(numFrames.value()));
}
}
}

Expand Down
30 changes: 2 additions & 28 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,6 @@ def __len__(self) -> int:
def _getitem_int(self, key: int) -> Tensor:
assert isinstance(key, int)

if key < 0:
key += self._num_frames
if key >= self._num_frames or key < 0:
raise IndexError(
f"Index {key} is out of bounds; length is {self._num_frames}"
)

frame_data, *_ = core.get_frame_at_index(self._decoder, frame_index=key)
return frame_data

Expand Down Expand Up @@ -195,13 +188,6 @@ def get_frame_at(self, index: int) -> Frame:
Returns:
Frame: The frame at the given index.
"""
if index < 0:
index += self._num_frames

if not 0 <= index < self._num_frames:
raise IndexError(
f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})."
)
data, pts_seconds, duration_seconds = core.get_frame_at_index(
self._decoder, frame_index=index
)
Expand All @@ -220,10 +206,6 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
Returns:
FrameBatch: The frames at the given indices.
"""
indices = [
index if index >= 0 else index + self._num_frames for index in indices
]

data, pts_seconds, duration_seconds = core.get_frames_at_indices(
self._decoder, frame_indices=indices
)
Expand All @@ -247,16 +229,8 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc
Returns:
FrameBatch: The frames within the specified range.
"""
if not 0 <= start < self._num_frames:
raise IndexError(
f"Start index {start} is out of bounds; must be in the range [0, {self._num_frames})."
)
if stop < start:
raise IndexError(
f"Stop index ({stop}) must not be less than the start index ({start})."
)
if not step > 0:
raise IndexError(f"Step ({step}) must be greater than 0.")
# Adjust start / stop indices to enable indexing semantics, ex. [-10, 1000] returns the last 10 frames
start, stop, step = slice(start, stop, step).indices(self._num_frames)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how slices are handled in VideoDecoder._getitem_slice, using it here ensures consistent behavior.

frames = core.get_frames_in_range(
self._decoder,
start=start,
Expand Down
69 changes: 62 additions & 7 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,10 @@ def test_device_instance(self):
def test_getitem_fails(self, device, seek_mode):
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)

with pytest.raises(IndexError, match="out of bounds"):
with pytest.raises(IndexError, match="Invalid frame index"):
frame = decoder[1000] # noqa

with pytest.raises(IndexError, match="out of bounds"):
with pytest.raises(IndexError, match="Invalid frame index"):
frame = decoder[-1000] # noqa

with pytest.raises(TypeError, match="Unsupported key type"):
Expand Down Expand Up @@ -487,10 +487,13 @@ def test_get_frame_at_tuple_unpacking(self, device):
def test_get_frame_at_fails(self, device, seek_mode):
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)

with pytest.raises(IndexError, match="out of bounds"):
with pytest.raises(
IndexError,
match="negative indices must have an absolute value less than the number of frames",
):
Copy link
Contributor Author

@Dan-Flores Dan-Flores Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous error check in get_frame_at was removed from Python. Now, this error comes from validateFrameIndex.

I removed the single index error check in Python for consistency between the get_frame(s)... functions, but I am open to added them back if erroring earlier is preferred, please let me know if you have any preference @scotts @NicolasHug

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to do the error checking once, on the C++ side. But we still have the error checking on the Python side for get_frames_in_range() - is that now also redundant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll remove those checks since .indices() will prevent most of them. It does not alter a negative step, but we check and error on that case in SingleStreamDecoder::getFramesInRange.

frame = decoder.get_frame_at(-10000) # noqa

with pytest.raises(IndexError, match="out of bounds"):
with pytest.raises(IndexError, match="must be less than"):
frame = decoder.get_frame_at(10000) # noqa

@pytest.mark.parametrize("device", cpu_and_cuda())
Expand Down Expand Up @@ -549,13 +552,13 @@ def test_get_frames_at(self, device, seek_mode):
def test_get_frames_at_fails(self, device, seek_mode):
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)

expected_converted_index = -10000 + len(decoder)
with pytest.raises(
RuntimeError, match=f"Invalid frame index={expected_converted_index}"
IndexError,
match="negative indices must have an absolute value less than the number of frames",
):
decoder.get_frames_at([-10000])

with pytest.raises(RuntimeError, match="Invalid frame index=390"):
with pytest.raises(IndexError, match="Invalid frame index=390"):
decoder.get_frames_at([390])

with pytest.raises(RuntimeError, match="Expected a value of type"):
Expand Down Expand Up @@ -772,6 +775,58 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds
)

@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode):
decoder = VideoDecoder(
NASA_VIDEO.path,
stream_index=3,
device=device,
seek_mode=seek_mode,
)

# high range ends get capped to num_frames
frames387_389 = decoder.get_frames_in_range(start=387, stop=1000)
assert frames387_389.data.shape == torch.Size(
[
3,
NASA_VIDEO.get_num_color_channels(stream_index=3),
NASA_VIDEO.get_height(stream_index=3),
NASA_VIDEO.get_width(stream_index=3),
]
)
ref_frame387_389 = NASA_VIDEO.get_frame_data_by_range(
start=387, stop=390, stream_index=3
).to(device)
assert_frames_equal(frames387_389.data, ref_frame387_389)

# negative indices are converted
frames387_389 = decoder.get_frames_in_range(start=-3, stop=1000)
assert frames387_389.data.shape == torch.Size(
[
3,
NASA_VIDEO.get_num_color_channels(stream_index=3),
NASA_VIDEO.get_height(stream_index=3),
NASA_VIDEO.get_width(stream_index=3),
]
)
assert_frames_equal(frames387_389.data, ref_frame387_389)

# "None" as stop is treated as end of the video
frames387_None = decoder.get_frames_in_range(start=-3, stop=None)
assert frames387_None.data.shape == torch.Size(
[
3,
NASA_VIDEO.get_num_color_channels(stream_index=3),
NASA_VIDEO.get_height(stream_index=3),
NASA_VIDEO.get_width(stream_index=3),
]
)
reference_frame387_389 = NASA_VIDEO.get_frame_data_by_range(
start=387, stop=390, stream_index=3
).to(device)
assert_frames_equal(frames387_None.data, reference_frame387_389)

@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
@patch("torchcodec._core._metadata._get_stream_json_metadata")
Expand Down
30 changes: 30 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def test_get_frame_at_index(self, device):
INDEX_OF_FRAME_AT_6_SECONDS
)
assert_frames_equal(frame6, reference_frame6.to(device))
# Negative indices are supported
frame389 = get_frame_at_index(decoder, frame_index=-1)
reference_frame389 = NASA_VIDEO.get_frame_data_by_index(389)
assert_frames_equal(frame389[0], reference_frame389.to(device))

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_get_frame_with_info_at_index(self, device):
Expand Down Expand Up @@ -177,6 +181,32 @@ def test_get_frames_at_indices_unsorted_indices(self, device):
with pytest.raises(AssertionError):
assert_frames_equal(frames[0], frames[-1])

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_get_frames_at_indices_negative_indices(self, device):
decoder = create_from_file(str(NASA_VIDEO.path))
add_video_stream(decoder, device=device)
frames389and387and1, *_ = get_frames_at_indices(
decoder, frame_indices=[-1, -3, -389]
)
reference_frame389 = NASA_VIDEO.get_frame_data_by_index(389)
reference_frame387 = NASA_VIDEO.get_frame_data_by_index(387)
reference_frame1 = NASA_VIDEO.get_frame_data_by_index(1)
assert_frames_equal(frames389and387and1[0], reference_frame389.to(device))
assert_frames_equal(frames389and387and1[1], reference_frame387.to(device))
assert_frames_equal(frames389and387and1[2], reference_frame1.to(device))

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_get_frames_at_indices_fail_on_invalid_negative_indices(self, device):
decoder = create_from_file(str(NASA_VIDEO.path))
add_video_stream(decoder, device=device)
with pytest.raises(
IndexError,
match="negative indices must have an absolute value less than the number of frames",
):
invalid_frames, *_ = get_frames_at_indices(
decoder, frame_indices=[-10000, -3000]
)

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_get_frames_by_pts(self, device):
decoder = create_from_file(str(NASA_VIDEO.path))
Expand Down
Loading