Skip to content

Commit 8b6c9a4

Browse files
Dan-FloresDaniel Flores
authored andcommitted
Update get_frames_at_fails regex, add get_frames_in_range tests
1 parent ef0ae7b commit 8b6c9a4

File tree

1 file changed

+62
-5
lines changed

1 file changed

+62
-5
lines changed

test/test_decoders.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -549,13 +549,10 @@ def test_get_frames_at(self, device, seek_mode):
549549
def test_get_frames_at_fails(self, device, seek_mode):
550550
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
551551

552-
expected_converted_index = -10000 + len(decoder)
553-
with pytest.raises(
554-
RuntimeError, match=f"Invalid frame index={expected_converted_index}"
555-
):
552+
with pytest.raises(IndexError, match="Index -\\d+ is out of bounds"):
556553
decoder.get_frames_at([-10000])
557554

558-
with pytest.raises(RuntimeError, match="Invalid frame index=390"):
555+
with pytest.raises(IndexError, match="Index 390 is out of bounds"):
559556
decoder.get_frames_at([390])
560557

561558
with pytest.raises(RuntimeError, match="Expected a value of type"):
@@ -772,6 +769,66 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
772769
empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds
773770
)
774771

772+
@pytest.mark.parametrize("device", cpu_and_cuda())
773+
@pytest.mark.parametrize("stream_index", [3, None])
774+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
775+
def test_get_frames_in_range_tensor_index_semantics(
776+
self, stream_index, device, seek_mode
777+
):
778+
decoder = VideoDecoder(
779+
NASA_VIDEO.path,
780+
stream_index=stream_index,
781+
device=device,
782+
seek_mode=seek_mode,
783+
)
784+
# slices with upper bound greater than len(decoder) are supported
785+
ref_frames387_389 = NASA_VIDEO.get_frame_data_by_range(
786+
start=387, stop=390, stream_index=stream_index
787+
).to(device)
788+
frames387_389 = decoder.get_frames_in_range(start=387, stop=1000)
789+
print(f"{frames387_389.data.shape=}")
790+
assert frames387_389.data.shape == torch.Size(
791+
[
792+
3,
793+
NASA_VIDEO.get_num_color_channels(stream_index=stream_index),
794+
NASA_VIDEO.get_height(stream_index=stream_index),
795+
NASA_VIDEO.get_width(stream_index=stream_index),
796+
]
797+
)
798+
assert_frames_equal(ref_frames387_389, frames387_389.data)
799+
800+
# test that negative values in the range are supported
801+
ref_frames386_389 = NASA_VIDEO.get_frame_data_by_range(
802+
start=386, stop=390, stream_index=stream_index
803+
).to(device)
804+
frames386_389 = decoder.get_frames_in_range(start=-4, stop=1000)
805+
assert frames386_389.data.shape == torch.Size(
806+
[
807+
4,
808+
NASA_VIDEO.get_num_color_channels(stream_index=stream_index),
809+
NASA_VIDEO.get_height(stream_index=stream_index),
810+
NASA_VIDEO.get_width(stream_index=stream_index),
811+
]
812+
)
813+
assert_frames_equal(ref_frames386_389, frames386_389.data)
814+
815+
@pytest.mark.parametrize("device", cpu_and_cuda())
816+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
817+
def test_get_frames_in_range_fails(self, device, seek_mode):
818+
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
819+
820+
with pytest.raises(IndexError, match="Start index 1000 is out of bounds"):
821+
decoder.get_frames_in_range(start=1000, stop=10)
822+
823+
with pytest.raises(IndexError, match="Start index -\\d+ is out of bounds"):
824+
decoder.get_frames_in_range(start=-1000, stop=10)
825+
826+
with pytest.raises(
827+
IndexError,
828+
match="Stop index \\(-\\d+\\) must not be less than the start index",
829+
):
830+
decoder.get_frames_in_range(start=0, stop=-1000)
831+
775832
@pytest.mark.parametrize("device", cpu_and_cuda())
776833
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
777834
@patch("torchcodec._core._metadata._get_stream_json_metadata")

0 commit comments

Comments
 (0)