@@ -549,13 +549,10 @@ def test_get_frames_at(self, device, seek_mode):
549
549
def test_get_frames_at_fails (self , device , seek_mode ):
550
550
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
551
551
552
- expected_converted_index = - 10000 + len (decoder )
553
- with pytest .raises (
554
- RuntimeError , match = f"Invalid frame index={ expected_converted_index } "
555
- ):
552
+ with pytest .raises (IndexError , match = "Index -\\ d+ is out of bounds" ):
556
553
decoder .get_frames_at ([- 10000 ])
557
554
558
- with pytest .raises (RuntimeError , match = "Invalid frame index= 390" ):
555
+ with pytest .raises (IndexError , match = "Index 390 is out of bounds " ):
559
556
decoder .get_frames_at ([390 ])
560
557
561
558
with pytest .raises (RuntimeError , match = "Expected a value of type" ):
@@ -772,6 +769,66 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
772
769
empty_frames .duration_seconds , NASA_VIDEO .empty_duration_seconds
773
770
)
774
771
772
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
773
+ @pytest .mark .parametrize ("stream_index" , [3 , None ])
774
+ @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
775
+ def test_get_frames_in_range_tensor_index_semantics (
776
+ self , stream_index , device , seek_mode
777
+ ):
778
+ decoder = VideoDecoder (
779
+ NASA_VIDEO .path ,
780
+ stream_index = stream_index ,
781
+ device = device ,
782
+ seek_mode = seek_mode ,
783
+ )
784
+ # slices with upper bound greater than len(decoder) are supported
785
+ ref_frames387_389 = NASA_VIDEO .get_frame_data_by_range (
786
+ start = 387 , stop = 390 , stream_index = stream_index
787
+ ).to (device )
788
+ frames387_389 = decoder .get_frames_in_range (start = 387 , stop = 1000 )
789
+ print (f"{ frames387_389 .data .shape = } " )
790
+ assert frames387_389 .data .shape == torch .Size (
791
+ [
792
+ 3 ,
793
+ NASA_VIDEO .get_num_color_channels (stream_index = stream_index ),
794
+ NASA_VIDEO .get_height (stream_index = stream_index ),
795
+ NASA_VIDEO .get_width (stream_index = stream_index ),
796
+ ]
797
+ )
798
+ assert_frames_equal (ref_frames387_389 , frames387_389 .data )
799
+
800
+ # test that negative values in the range are supported
801
+ ref_frames386_389 = NASA_VIDEO .get_frame_data_by_range (
802
+ start = 386 , stop = 390 , stream_index = stream_index
803
+ ).to (device )
804
+ frames386_389 = decoder .get_frames_in_range (start = - 4 , stop = 1000 )
805
+ assert frames386_389 .data .shape == torch .Size (
806
+ [
807
+ 4 ,
808
+ NASA_VIDEO .get_num_color_channels (stream_index = stream_index ),
809
+ NASA_VIDEO .get_height (stream_index = stream_index ),
810
+ NASA_VIDEO .get_width (stream_index = stream_index ),
811
+ ]
812
+ )
813
+ assert_frames_equal (ref_frames386_389 , frames386_389 .data )
814
+
815
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
816
+ @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
817
+ def test_get_frames_in_range_fails (self , device , seek_mode ):
818
+ decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
819
+
820
+ with pytest .raises (IndexError , match = "Start index 1000 is out of bounds" ):
821
+ decoder .get_frames_in_range (start = 1000 , stop = 10 )
822
+
823
+ with pytest .raises (IndexError , match = "Start index -\\ d+ is out of bounds" ):
824
+ decoder .get_frames_in_range (start = - 1000 , stop = 10 )
825
+
826
+ with pytest .raises (
827
+ IndexError ,
828
+ match = "Stop index \\ (-\\ d+\\ ) must not be less than the start index" ,
829
+ ):
830
+ decoder .get_frames_in_range (start = 0 , stop = - 1000 )
831
+
775
832
@pytest .mark .parametrize ("device" , cpu_and_cuda ())
776
833
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
777
834
@patch ("torchcodec._core._metadata._get_stream_json_metadata" )
0 commit comments