-
Notifications
You must be signed in to change notification settings - Fork 48
Convert negative frame indices in C++, convert slices in Python #746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ef0ae7b
8b6c9a4
d327b10
16e5713
93b8548
a00c68a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -526,6 +526,12 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal( | |
const auto& streamInfo = streamInfos_[activeStreamIndex_]; | ||
const auto& streamMetadata = | ||
containerMetadata_.allStreamMetadata[activeStreamIndex_]; | ||
|
||
std::optional<int64_t> numFrames = getNumFrames(streamMetadata); | ||
if (numFrames.has_value()) { | ||
// If the frameIndex is negative, we convert it to a positive index | ||
frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value(); | ||
} | ||
validateFrameIndex(streamMetadata, frameIndex); | ||
|
||
int64_t pts = getPts(frameIndex); | ||
|
@@ -568,8 +574,6 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( | |
auto indexInOutput = indicesAreSorted ? f : argsort[f]; | ||
auto indexInVideo = frameIndices[indexInOutput]; | ||
|
||
validateFrameIndex(streamMetadata, indexInVideo); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm. Yes, although I had to think about it. What gave me pause is the fact that we also use it in the |
||
if ((f > 0) && (indexInVideo == previousIndexInVideo)) { | ||
// Avoid decoding the same frame twice | ||
auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1]; | ||
|
@@ -1559,21 +1563,24 @@ void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) { | |
void SingleStreamDecoder::validateFrameIndex( | ||
const StreamMetadata& streamMetadata, | ||
int64_t frameIndex) { | ||
TORCH_CHECK( | ||
frameIndex >= 0, | ||
"Invalid frame index=" + std::to_string(frameIndex) + | ||
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) + | ||
"; must be greater than or equal to 0"); | ||
if (frameIndex < 0) { | ||
throw std::out_of_range( | ||
"Invalid frame index=" + std::to_string(frameIndex) + | ||
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) + | ||
"; negative indices must have an absolute value less than the number of frames, " | ||
"and the number of frames must be known."); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It could end up being wordy, but we might want to mention that we can't receive any negative index if we don't know the number of frames. Maybe something like: throw std::out_of_range(
"Invalid frame index=" + std::to_string(frameIndex) +
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
"; negative indices must have an absolute value that is less than the number of frames, "
"and the number of frames must be known.");
} |
||
|
||
// Note that if we do not have the number of frames available in our metadata, | ||
// then we assume that the frameIndex is valid. | ||
std::optional<int64_t> numFrames = getNumFrames(streamMetadata); | ||
if (numFrames.has_value()) { | ||
TORCH_CHECK( | ||
frameIndex < numFrames.value(), | ||
"Invalid frame index=" + std::to_string(frameIndex) + | ||
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) + | ||
"; must be less than " + std::to_string(numFrames.value())); | ||
if (frameIndex >= numFrames.value()) { | ||
throw std::out_of_range( | ||
"Invalid frame index=" + std::to_string(frameIndex) + | ||
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) + | ||
"; must be less than " + std::to_string(numFrames.value())); | ||
} | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -125,13 +125,6 @@ def __len__(self) -> int: | |
def _getitem_int(self, key: int) -> Tensor: | ||
assert isinstance(key, int) | ||
|
||
if key < 0: | ||
key += self._num_frames | ||
if key >= self._num_frames or key < 0: | ||
raise IndexError( | ||
f"Index {key} is out of bounds; length is {self._num_frames}" | ||
) | ||
|
||
frame_data, *_ = core.get_frame_at_index(self._decoder, frame_index=key) | ||
return frame_data | ||
|
||
|
@@ -195,13 +188,6 @@ def get_frame_at(self, index: int) -> Frame: | |
Returns: | ||
Frame: The frame at the given index. | ||
""" | ||
if index < 0: | ||
index += self._num_frames | ||
|
||
if not 0 <= index < self._num_frames: | ||
raise IndexError( | ||
f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})." | ||
) | ||
data, pts_seconds, duration_seconds = core.get_frame_at_index( | ||
self._decoder, frame_index=index | ||
) | ||
|
@@ -220,10 +206,6 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch: | |
Returns: | ||
FrameBatch: The frames at the given indices. | ||
""" | ||
indices = [ | ||
index if index >= 0 else index + self._num_frames for index in indices | ||
] | ||
|
||
data, pts_seconds, duration_seconds = core.get_frames_at_indices( | ||
self._decoder, frame_indices=indices | ||
) | ||
|
@@ -247,16 +229,8 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc | |
Returns: | ||
FrameBatch: The frames within the specified range. | ||
""" | ||
if not 0 <= start < self._num_frames: | ||
raise IndexError( | ||
f"Start index {start} is out of bounds; must be in the range [0, {self._num_frames})." | ||
) | ||
if stop < start: | ||
raise IndexError( | ||
f"Stop index ({stop}) must not be less than the start index ({start})." | ||
) | ||
if not step > 0: | ||
raise IndexError(f"Step ({step}) must be greater than 0.") | ||
# Adjust start / stop indices to enable indexing semantics, ex. [-10, 1000] returns the last 10 frames | ||
start, stop, step = slice(start, stop, step).indices(self._num_frames) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is how slices are handled in |
||
frames = core.get_frames_in_range( | ||
self._decoder, | ||
start=start, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -375,10 +375,10 @@ def test_device_instance(self): | |
def test_getitem_fails(self, device, seek_mode): | ||
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) | ||
|
||
with pytest.raises(IndexError, match="out of bounds"): | ||
with pytest.raises(IndexError, match="Invalid frame index"): | ||
frame = decoder[1000] # noqa | ||
|
||
with pytest.raises(IndexError, match="out of bounds"): | ||
with pytest.raises(IndexError, match="Invalid frame index"): | ||
frame = decoder[-1000] # noqa | ||
|
||
with pytest.raises(TypeError, match="Unsupported key type"): | ||
|
@@ -487,10 +487,13 @@ def test_get_frame_at_tuple_unpacking(self, device): | |
def test_get_frame_at_fails(self, device, seek_mode): | ||
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) | ||
|
||
with pytest.raises(IndexError, match="out of bounds"): | ||
with pytest.raises( | ||
IndexError, | ||
match="negative indices must have an absolute value less than the number of frames", | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous error check in I removed the single index error check in Python for consistency between the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes sense to do the error checking once, on the C++ side. But we still have the error checking on the Python side for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I'll remove those checks since |
||
frame = decoder.get_frame_at(-10000) # noqa | ||
|
||
with pytest.raises(IndexError, match="out of bounds"): | ||
with pytest.raises(IndexError, match="must be less than"): | ||
frame = decoder.get_frame_at(10000) # noqa | ||
|
||
@pytest.mark.parametrize("device", cpu_and_cuda()) | ||
|
@@ -549,13 +552,13 @@ def test_get_frames_at(self, device, seek_mode): | |
def test_get_frames_at_fails(self, device, seek_mode): | ||
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) | ||
|
||
expected_converted_index = -10000 + len(decoder) | ||
with pytest.raises( | ||
RuntimeError, match=f"Invalid frame index={expected_converted_index}" | ||
IndexError, | ||
match="negative indices must have an absolute value less than the number of frames", | ||
): | ||
decoder.get_frames_at([-10000]) | ||
|
||
with pytest.raises(RuntimeError, match="Invalid frame index=390"): | ||
with pytest.raises(IndexError, match="Invalid frame index=390"): | ||
decoder.get_frames_at([390]) | ||
|
||
with pytest.raises(RuntimeError, match="Expected a value of type"): | ||
|
@@ -772,6 +775,58 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode): | |
empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds | ||
) | ||
|
||
@pytest.mark.parametrize("device", cpu_and_cuda()) | ||
@pytest.mark.parametrize("seek_mode", ("exact", "approximate")) | ||
def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode): | ||
decoder = VideoDecoder( | ||
NASA_VIDEO.path, | ||
stream_index=3, | ||
device=device, | ||
seek_mode=seek_mode, | ||
) | ||
|
||
# high range ends get capped to num_frames | ||
frames387_389 = decoder.get_frames_in_range(start=387, stop=1000) | ||
assert frames387_389.data.shape == torch.Size( | ||
[ | ||
3, | ||
NASA_VIDEO.get_num_color_channels(stream_index=3), | ||
NASA_VIDEO.get_height(stream_index=3), | ||
NASA_VIDEO.get_width(stream_index=3), | ||
] | ||
) | ||
ref_frame387_389 = NASA_VIDEO.get_frame_data_by_range( | ||
start=387, stop=390, stream_index=3 | ||
).to(device) | ||
assert_frames_equal(frames387_389.data, ref_frame387_389) | ||
|
||
# negative indices are converted | ||
frames387_389 = decoder.get_frames_in_range(start=-3, stop=1000) | ||
assert frames387_389.data.shape == torch.Size( | ||
[ | ||
3, | ||
NASA_VIDEO.get_num_color_channels(stream_index=3), | ||
NASA_VIDEO.get_height(stream_index=3), | ||
NASA_VIDEO.get_width(stream_index=3), | ||
] | ||
) | ||
assert_frames_equal(frames387_389.data, ref_frame387_389) | ||
|
||
# "None" as stop is treated as end of the video | ||
frames387_None = decoder.get_frames_in_range(start=-3, stop=None) | ||
assert frames387_None.data.shape == torch.Size( | ||
[ | ||
3, | ||
NASA_VIDEO.get_num_color_channels(stream_index=3), | ||
NASA_VIDEO.get_height(stream_index=3), | ||
NASA_VIDEO.get_width(stream_index=3), | ||
] | ||
) | ||
reference_frame387_389 = NASA_VIDEO.get_frame_data_by_range( | ||
start=387, stop=390, stream_index=3 | ||
).to(device) | ||
assert_frames_equal(frames387_None.data, reference_frame387_389) | ||
|
||
@pytest.mark.parametrize("device", cpu_and_cuda()) | ||
@pytest.mark.parametrize("seek_mode", ("exact", "approximate")) | ||
@patch("torchcodec._core._metadata._get_stream_json_metadata") | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IfnumFrames
does not have a value, then I think we need to assert thatframeIndex
is positive, and throw astd::range_error
if it's not. Otherwise, we'll end up using a negativeframeIndex
further down.Ah, now I see that we actually do that in
validateFrameIndex()
, which we call right after. 👍