@@ -487,10 +487,13 @@ def test_get_frame_at_tuple_unpacking(self, device):
487
487
def test_get_frame_at_fails (self , device , seek_mode ):
488
488
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
489
489
490
- with pytest .raises (IndexError , match = "out of bounds" ):
490
+ with pytest .raises (
491
+ IndexError ,
492
+ match = "must be greater than or equal to 0, or be a valid negative index" ,
493
+ ):
491
494
frame = decoder .get_frame_at (- 10000 ) # noqa
492
495
493
- with pytest .raises (IndexError , match = "out of bounds " ):
496
+ with pytest .raises (IndexError , match = "must be less than " ):
494
497
frame = decoder .get_frame_at (10000 ) # noqa
495
498
496
499
@pytest .mark .parametrize ("device" , cpu_and_cuda ())
@@ -549,10 +552,13 @@ def test_get_frames_at(self, device, seek_mode):
549
552
def test_get_frames_at_fails (self , device , seek_mode ):
550
553
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
551
554
552
- with pytest .raises (IndexError , match = "Index -\\ d+ is out of bounds" ):
555
+ with pytest .raises (
556
+ IndexError ,
557
+ match = "must be greater than or equal to 0, or be a valid negative index" ,
558
+ ):
553
559
decoder .get_frames_at ([- 10000 ])
554
560
555
- with pytest .raises (IndexError , match = "Index 390 is out of bounds " ):
561
+ with pytest .raises (IndexError , match = "Invalid frame index=390 " ):
556
562
decoder .get_frames_at ([390 ])
557
563
558
564
with pytest .raises (RuntimeError , match = "Expected a value of type" ):
@@ -770,64 +776,56 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
770
776
)
771
777
772
778
@pytest .mark .parametrize ("device" , cpu_and_cuda ())
773
- @pytest .mark .parametrize ("stream_index" , [3 , None ])
774
779
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
775
- def test_get_frames_in_range_tensor_index_semantics (
776
- self , stream_index , device , seek_mode
777
- ):
780
+ def test_get_frames_in_range_slice_indices_syntax (self , device , seek_mode ):
778
781
decoder = VideoDecoder (
779
782
NASA_VIDEO .path ,
780
- stream_index = stream_index ,
783
+ stream_index = 3 ,
781
784
device = device ,
782
785
seek_mode = seek_mode ,
783
786
)
784
- # slices with upper bound greater than len(decoder) are supported
785
- ref_frames387_389 = NASA_VIDEO .get_frame_data_by_range (
786
- start = 387 , stop = 390 , stream_index = stream_index
787
- ).to (device )
787
+
788
+ # high range ends get capped to num_frames
788
789
frames387_389 = decoder .get_frames_in_range (start = 387 , stop = 1000 )
789
- print (f"{ frames387_389 .data .shape = } " )
790
790
assert frames387_389 .data .shape == torch .Size (
791
791
[
792
792
3 ,
793
- NASA_VIDEO .get_num_color_channels (stream_index = stream_index ),
794
- NASA_VIDEO .get_height (stream_index = stream_index ),
795
- NASA_VIDEO .get_width (stream_index = stream_index ),
793
+ NASA_VIDEO .get_num_color_channels (stream_index = 3 ),
794
+ NASA_VIDEO .get_height (stream_index = 3 ),
795
+ NASA_VIDEO .get_width (stream_index = 3 ),
796
796
]
797
797
)
798
- assert_frames_equal (ref_frames387_389 , frames387_389 .data )
799
-
800
- # test that negative values in the range are supported
801
- ref_frames386_389 = NASA_VIDEO .get_frame_data_by_range (
802
- start = 386 , stop = 390 , stream_index = stream_index
798
+ ref_frame387_389 = NASA_VIDEO .get_frame_data_by_range (
799
+ start = 387 , stop = 390 , stream_index = 3
803
800
).to (device )
804
- frames386_389 = decoder .get_frames_in_range (start = - 4 , stop = 1000 )
805
- assert frames386_389 .data .shape == torch .Size (
801
+ assert_frames_equal (frames387_389 .data , ref_frame387_389 )
802
+
803
+ # negative indices are converted
804
+ frames387_389 = decoder .get_frames_in_range (start = - 3 , stop = 1000 )
805
+ assert frames387_389 .data .shape == torch .Size (
806
806
[
807
- 4 ,
808
- NASA_VIDEO .get_num_color_channels (stream_index = stream_index ),
809
- NASA_VIDEO .get_height (stream_index = stream_index ),
810
- NASA_VIDEO .get_width (stream_index = stream_index ),
807
+ 3 ,
808
+ NASA_VIDEO .get_num_color_channels (stream_index = 3 ),
809
+ NASA_VIDEO .get_height (stream_index = 3 ),
810
+ NASA_VIDEO .get_width (stream_index = 3 ),
811
811
]
812
812
)
813
- assert_frames_equal (ref_frames386_389 , frames386_389 .data )
814
-
815
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
816
- @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
817
- def test_get_frames_in_range_fails (self , device , seek_mode ):
818
- decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
819
-
820
- with pytest .raises (IndexError , match = "Start index 1000 is out of bounds" ):
821
- decoder .get_frames_in_range (start = 1000 , stop = 10 )
813
+ assert_frames_equal (frames387_389 .data , ref_frame387_389 )
822
814
823
- with pytest .raises (IndexError , match = "Start index -\\ d+ is out of bounds" ):
824
- decoder .get_frames_in_range (start = - 1000 , stop = 10 )
825
-
826
- with pytest .raises (
827
- IndexError ,
828
- match = "Stop index \\ (-\\ d+\\ ) must not be less than the start index" ,
829
- ):
830
- decoder .get_frames_in_range (start = 0 , stop = - 1000 )
815
+ # "None" as stop is treated as end of the video
816
+ frames387_None = decoder .get_frames_in_range (start = - 3 , stop = None )
817
+ assert frames387_None .data .shape == torch .Size (
818
+ [
819
+ 3 ,
820
+ NASA_VIDEO .get_num_color_channels (stream_index = 3 ),
821
+ NASA_VIDEO .get_height (stream_index = 3 ),
822
+ NASA_VIDEO .get_width (stream_index = 3 ),
823
+ ]
824
+ )
825
+ reference_frame387_389 = NASA_VIDEO .get_frame_data_by_range (
826
+ start = 387 , stop = 390 , stream_index = 3
827
+ ).to (device )
828
+ assert_frames_equal (frames387_None .data , reference_frame387_389 )
831
829
832
830
@pytest .mark .parametrize ("device" , cpu_and_cuda ())
833
831
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
0 commit comments