-
Notifications
You must be signed in to change notification settings - Fork 48
Support negative index in SimpleVideoDecoder #743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -328,6 +328,19 @@ def test_getitem_slice(self, device, seek_mode): | |
) | ||
assert_frames_equal(ref386_389, slice386_389) | ||
|
||
# slices with upper bound greater than len(decoder) are supported | ||
slice387_389 = decoder[-3:10000].to(device) | ||
assert slice387_389.shape == torch.Size( | ||
[ | ||
3, | ||
NASA_VIDEO.num_color_channels, | ||
NASA_VIDEO.height, | ||
NASA_VIDEO.width, | ||
] | ||
) | ||
ref387_389 = NASA_VIDEO.get_frame_data_by_range(387, 390).to(device) | ||
assert_frames_equal(ref387_389, slice387_389) | ||
|
||
# an empty range is valid! | ||
empty_frame = decoder[5:5] | ||
assert_frames_equal(empty_frame, NASA_VIDEO.empty_chw_tensor.to(device)) | ||
|
@@ -437,6 +450,11 @@ def test_get_frame_at(self, device, seek_mode): | |
expected_frame_info.duration_seconds, rel=1e-3 | ||
) | ||
|
||
# test negative frame index | ||
frame_minus1 = decoder.get_frame_at(-1) | ||
ref_frame_minus1 = NASA_VIDEO.get_frame_data_by_index(389).to(device) | ||
assert_frames_equal(ref_frame_minus1, frame_minus1.data) | ||
|
||
# test numpy.int64 | ||
frame9 = decoder.get_frame_at(numpy.int64(9)) | ||
assert_frames_equal(ref_frame9, frame9.data) | ||
|
@@ -470,7 +488,7 @@ def test_get_frame_at_fails(self, device, seek_mode): | |
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) | ||
|
||
with pytest.raises(IndexError, match="out of bounds"): | ||
frame = decoder.get_frame_at(-1) # noqa | ||
frame = decoder.get_frame_at(-10000) # noqa | ||
|
||
with pytest.raises(IndexError, match="out of bounds"): | ||
frame = decoder.get_frame_at(10000) # noqa | ||
|
@@ -480,7 +498,8 @@ def test_get_frame_at_fails(self, device, seek_mode): | |
def test_get_frames_at(self, device, seek_mode): | ||
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) | ||
|
||
frames = decoder.get_frames_at([35, 25]) | ||
# test positive and negative frame index | ||
frames = decoder.get_frames_at([35, 25, -1, -2]) | ||
|
||
assert isinstance(frames, FrameBatch) | ||
|
||
|
@@ -490,12 +509,20 @@ def test_get_frames_at(self, device, seek_mode): | |
assert_frames_equal( | ||
frames[1].data, NASA_VIDEO.get_frame_data_by_index(25).to(device) | ||
) | ||
assert_frames_equal( | ||
frames[2].data, NASA_VIDEO.get_frame_data_by_index(389).to(device) | ||
) | ||
assert_frames_equal( | ||
frames[3].data, NASA_VIDEO.get_frame_data_by_index(388).to(device) | ||
) | ||
|
||
assert frames.pts_seconds.device.type == "cpu" | ||
expected_pts_seconds = torch.tensor( | ||
[ | ||
NASA_VIDEO.get_frame_info(35).pts_seconds, | ||
NASA_VIDEO.get_frame_info(25).pts_seconds, | ||
NASA_VIDEO.get_frame_info(389).pts_seconds, | ||
NASA_VIDEO.get_frame_info(388).pts_seconds, | ||
], | ||
dtype=torch.float64, | ||
) | ||
|
@@ -508,6 +535,8 @@ def test_get_frames_at(self, device, seek_mode): | |
[ | ||
NASA_VIDEO.get_frame_info(35).duration_seconds, | ||
NASA_VIDEO.get_frame_info(25).duration_seconds, | ||
NASA_VIDEO.get_frame_info(389).duration_seconds, | ||
NASA_VIDEO.get_frame_info(388).duration_seconds, | ||
], | ||
dtype=torch.float64, | ||
) | ||
|
@@ -520,8 +549,11 @@ def test_get_frames_at(self, device, seek_mode): | |
def test_get_frames_at_fails(self, device, seek_mode): | ||
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) | ||
|
||
with pytest.raises(RuntimeError, match="Invalid frame index=-1"): | ||
decoder.get_frames_at([-1]) | ||
expected_converted_index = -10000 + len(decoder) | ||
with pytest.raises( | ||
RuntimeError, match=f"Invalid frame index={expected_converted_index}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This error is from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good observation - at the Python level, we currently don't check the input when we get a list of indices ( The question, then, is what to do in |
||
): | ||
decoder.get_frames_at([-10000]) | ||
|
||
with pytest.raises(RuntimeError, match="Invalid frame index=390"): | ||
decoder.get_frames_at([390]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not necessary to get the expected index here, but I added it to match the other tests. Alternatively, the test can simply match the text: "Invalid frame index".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's fine to do, but I usually default to just matching the text that doesn't change.