Skip to content

Commit d327b10

Browse files
author
Daniel Flores
committed
Convert negative indices in C++, handle slices in python
1 parent 8b6c9a4 commit d327b10

File tree

4 files changed

+93
-75
lines changed

4 files changed

+93
-75
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,12 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
526526
const auto& streamInfo = streamInfos_[activeStreamIndex_];
527527
const auto& streamMetadata =
528528
containerMetadata_.allStreamMetadata[activeStreamIndex_];
529+
530+
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
531+
if (numFrames.has_value()) {
532+
// If the frameIndex is negative, we convert it to a positive index
533+
frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value();
534+
}
529535
validateFrameIndex(streamMetadata, frameIndex);
530536

531537
int64_t pts = getPts(frameIndex);
@@ -568,8 +574,6 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
568574
auto indexInOutput = indicesAreSorted ? f : argsort[f];
569575
auto indexInVideo = frameIndices[indexInOutput];
570576

571-
validateFrameIndex(streamMetadata, indexInVideo);
572-
573577
if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
574578
// Avoid decoding the same frame twice
575579
auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1];
@@ -1559,21 +1563,23 @@ void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) {
15591563
void SingleStreamDecoder::validateFrameIndex(
15601564
const StreamMetadata& streamMetadata,
15611565
int64_t frameIndex) {
1562-
TORCH_CHECK(
1563-
frameIndex >= 0,
1564-
"Invalid frame index=" + std::to_string(frameIndex) +
1565-
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
1566-
"; must be greater than or equal to 0");
1566+
if (frameIndex < 0) {
1567+
throw std::out_of_range(
1568+
"Invalid frame index=" + std::to_string(frameIndex) +
1569+
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
1570+
"; must be greater than or equal to 0, or be a valid negative index");
1571+
}
15671572

15681573
// Note that if we do not have the number of frames available in our metadata,
15691574
// then we assume that the frameIndex is valid.
15701575
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
15711576
if (numFrames.has_value()) {
1572-
TORCH_CHECK(
1573-
frameIndex < numFrames.value(),
1574-
"Invalid frame index=" + std::to_string(frameIndex) +
1575-
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
1576-
"; must be less than " + std::to_string(numFrames.value()));
1577+
if (frameIndex >= numFrames.value()) {
1578+
throw std::out_of_range(
1579+
"Invalid frame index=" + std::to_string(frameIndex) +
1580+
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
1581+
"; must be less than " + std::to_string(numFrames.value()));
1582+
}
15771583
}
15781584
}
15791585

src/torchcodec/decoders/_video_decoder.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,6 @@ def get_frame_at(self, index: int) -> Frame:
195195
Returns:
196196
Frame: The frame at the given index.
197197
"""
198-
if index < 0:
199-
index += self._num_frames
200-
201-
if not 0 <= index < self._num_frames:
202-
raise IndexError(
203-
f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})."
204-
)
205198
data, pts_seconds, duration_seconds = core.get_frame_at_index(
206199
self._decoder, frame_index=index
207200
)
@@ -220,15 +213,6 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
220213
Returns:
221214
FrameBatch: The frames at the given indices.
222215
"""
223-
for i, index in enumerate(indices):
224-
index = index if index >= 0 else index + self._num_frames
225-
if not 0 <= index < self._num_frames:
226-
raise IndexError(
227-
f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})."
228-
)
229-
else:
230-
indices[i] = index
231-
232216
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
233217
self._decoder, frame_indices=indices
234218
)
@@ -252,8 +236,8 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc
252236
Returns:
253237
FrameBatch: The frames within the specified range.
254238
"""
255-
start = start if start >= 0 else start + self._num_frames
256-
stop = min(stop if stop >= 0 else stop + self._num_frames, self._num_frames)
239+
# Adjust start / stop indices to enable indexing semantics, ex. [-10, 1000] returns the last 10 frames
240+
start, stop, step = slice(start, stop, step).indices(self._num_frames)
257241
if not 0 <= start < self._num_frames:
258242
raise IndexError(
259243
f"Start index {start} is out of bounds; must be in the range [0, {self._num_frames})."

test/test_decoders.py

Lines changed: 43 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -487,10 +487,13 @@ def test_get_frame_at_tuple_unpacking(self, device):
487487
def test_get_frame_at_fails(self, device, seek_mode):
488488
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
489489

490-
with pytest.raises(IndexError, match="out of bounds"):
490+
with pytest.raises(
491+
IndexError,
492+
match="must be greater than or equal to 0, or be a valid negative index",
493+
):
491494
frame = decoder.get_frame_at(-10000) # noqa
492495

493-
with pytest.raises(IndexError, match="out of bounds"):
496+
with pytest.raises(IndexError, match="must be less than"):
494497
frame = decoder.get_frame_at(10000) # noqa
495498

496499
@pytest.mark.parametrize("device", cpu_and_cuda())
@@ -549,10 +552,13 @@ def test_get_frames_at(self, device, seek_mode):
549552
def test_get_frames_at_fails(self, device, seek_mode):
550553
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
551554

552-
with pytest.raises(IndexError, match="Index -\\d+ is out of bounds"):
555+
with pytest.raises(
556+
IndexError,
557+
match="must be greater than or equal to 0, or be a valid negative index",
558+
):
553559
decoder.get_frames_at([-10000])
554560

555-
with pytest.raises(IndexError, match="Index 390 is out of bounds"):
561+
with pytest.raises(IndexError, match="Invalid frame index=390"):
556562
decoder.get_frames_at([390])
557563

558564
with pytest.raises(RuntimeError, match="Expected a value of type"):
@@ -770,64 +776,56 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
770776
)
771777

772778
@pytest.mark.parametrize("device", cpu_and_cuda())
773-
@pytest.mark.parametrize("stream_index", [3, None])
774779
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
775-
def test_get_frames_in_range_tensor_index_semantics(
776-
self, stream_index, device, seek_mode
777-
):
780+
def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode):
778781
decoder = VideoDecoder(
779782
NASA_VIDEO.path,
780-
stream_index=stream_index,
783+
stream_index=3,
781784
device=device,
782785
seek_mode=seek_mode,
783786
)
784-
# slices with upper bound greater than len(decoder) are supported
785-
ref_frames387_389 = NASA_VIDEO.get_frame_data_by_range(
786-
start=387, stop=390, stream_index=stream_index
787-
).to(device)
787+
788+
# high range ends get capped to num_frames
788789
frames387_389 = decoder.get_frames_in_range(start=387, stop=1000)
789-
print(f"{frames387_389.data.shape=}")
790790
assert frames387_389.data.shape == torch.Size(
791791
[
792792
3,
793-
NASA_VIDEO.get_num_color_channels(stream_index=stream_index),
794-
NASA_VIDEO.get_height(stream_index=stream_index),
795-
NASA_VIDEO.get_width(stream_index=stream_index),
793+
NASA_VIDEO.get_num_color_channels(stream_index=3),
794+
NASA_VIDEO.get_height(stream_index=3),
795+
NASA_VIDEO.get_width(stream_index=3),
796796
]
797797
)
798-
assert_frames_equal(ref_frames387_389, frames387_389.data)
799-
800-
# test that negative values in the range are supported
801-
ref_frames386_389 = NASA_VIDEO.get_frame_data_by_range(
802-
start=386, stop=390, stream_index=stream_index
798+
ref_frame387_389 = NASA_VIDEO.get_frame_data_by_range(
799+
start=387, stop=390, stream_index=3
803800
).to(device)
804-
frames386_389 = decoder.get_frames_in_range(start=-4, stop=1000)
805-
assert frames386_389.data.shape == torch.Size(
801+
assert_frames_equal(frames387_389.data, ref_frame387_389)
802+
803+
# negative indices are converted
804+
frames387_389 = decoder.get_frames_in_range(start=-3, stop=1000)
805+
assert frames387_389.data.shape == torch.Size(
806806
[
807-
4,
808-
NASA_VIDEO.get_num_color_channels(stream_index=stream_index),
809-
NASA_VIDEO.get_height(stream_index=stream_index),
810-
NASA_VIDEO.get_width(stream_index=stream_index),
807+
3,
808+
NASA_VIDEO.get_num_color_channels(stream_index=3),
809+
NASA_VIDEO.get_height(stream_index=3),
810+
NASA_VIDEO.get_width(stream_index=3),
811811
]
812812
)
813-
assert_frames_equal(ref_frames386_389, frames386_389.data)
814-
815-
@pytest.mark.parametrize("device", cpu_and_cuda())
816-
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
817-
def test_get_frames_in_range_fails(self, device, seek_mode):
818-
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
819-
820-
with pytest.raises(IndexError, match="Start index 1000 is out of bounds"):
821-
decoder.get_frames_in_range(start=1000, stop=10)
813+
assert_frames_equal(frames387_389.data, ref_frame387_389)
822814

823-
with pytest.raises(IndexError, match="Start index -\\d+ is out of bounds"):
824-
decoder.get_frames_in_range(start=-1000, stop=10)
825-
826-
with pytest.raises(
827-
IndexError,
828-
match="Stop index \\(-\\d+\\) must not be less than the start index",
829-
):
830-
decoder.get_frames_in_range(start=0, stop=-1000)
815+
# "None" as stop is treated as end of the video
816+
frames387_None = decoder.get_frames_in_range(start=-3, stop=None)
817+
assert frames387_None.data.shape == torch.Size(
818+
[
819+
3,
820+
NASA_VIDEO.get_num_color_channels(stream_index=3),
821+
NASA_VIDEO.get_height(stream_index=3),
822+
NASA_VIDEO.get_width(stream_index=3),
823+
]
824+
)
825+
reference_frame387_389 = NASA_VIDEO.get_frame_data_by_range(
826+
start=387, stop=390, stream_index=3
827+
).to(device)
828+
assert_frames_equal(frames387_None.data, reference_frame387_389)
831829

832830
@pytest.mark.parametrize("device", cpu_and_cuda())
833831
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))

test/test_ops.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ def test_get_frame_at_index(self, device):
125125
INDEX_OF_FRAME_AT_6_SECONDS
126126
)
127127
assert_frames_equal(frame6, reference_frame6.to(device))
128+
# Negative indices are supported
129+
frame389 = get_frame_at_index(decoder, frame_index=-1)
130+
reference_frame389 = NASA_VIDEO.get_frame_data_by_index(389)
131+
assert_frames_equal(frame389[0], reference_frame389.to(device))
128132

129133
@pytest.mark.parametrize("device", cpu_and_cuda())
130134
def test_get_frame_with_info_at_index(self, device):
@@ -177,6 +181,32 @@ def test_get_frames_at_indices_unsorted_indices(self, device):
177181
with pytest.raises(AssertionError):
178182
assert_frames_equal(frames[0], frames[-1])
179183

184+
@pytest.mark.parametrize("device", cpu_and_cuda())
185+
def test_get_frames_at_indices_negative_indices(self, device):
186+
decoder = create_from_file(str(NASA_VIDEO.path))
187+
add_video_stream(decoder, device=device)
188+
frames389and387and1, *_ = get_frames_at_indices(
189+
decoder, frame_indices=[-1, -3, -389]
190+
)
191+
reference_frame389 = NASA_VIDEO.get_frame_data_by_index(389)
192+
reference_frame387 = NASA_VIDEO.get_frame_data_by_index(387)
193+
reference_frame1 = NASA_VIDEO.get_frame_data_by_index(1)
194+
assert_frames_equal(frames389and387and1[0], reference_frame389.to(device))
195+
assert_frames_equal(frames389and387and1[1], reference_frame387.to(device))
196+
assert_frames_equal(frames389and387and1[2], reference_frame1.to(device))
197+
198+
@pytest.mark.parametrize("device", cpu_and_cuda())
199+
def test_get_frames_at_indices_fail_on_invalid_negative_indices(self, device):
200+
decoder = create_from_file(str(NASA_VIDEO.path))
201+
add_video_stream(decoder, device=device)
202+
with pytest.raises(
203+
IndexError,
204+
match="must be greater than or equal to 0, or be a valid negative index",
205+
):
206+
invalid_frames, *_ = get_frames_at_indices(
207+
decoder, frame_indices=[-10000, -3000]
208+
)
209+
180210
@pytest.mark.parametrize("device", cpu_and_cuda())
181211
def test_get_frames_by_pts(self, device):
182212
decoder = create_from_file(str(NASA_VIDEO.path))

0 commit comments

Comments
 (0)