@@ -1485,3 +1485,35 @@ def test_tile_arg_validation():
1485
1485
x = dpt .empty (())
1486
1486
with pytest .raises (TypeError ):
1487
1487
dpt .tile (x , dict ())
1488
+
1489
+
1490
+ def test_repeat_0_size ():
1491
+ get_queue_or_skip ()
1492
+
1493
+ x = dpt .ones ((0 , 10 , 0 ), dtype = "i4" )
1494
+ repetitions = 2
1495
+ res = dpt .repeat (x , repetitions )
1496
+ assert res .shape == (0 ,)
1497
+ res = dpt .repeat (x , repetitions , axis = 2 )
1498
+ assert res .shape == x .shape
1499
+ res = dpt .repeat (x , repetitions , axis = 1 )
1500
+ axis_sz = x .shape [1 ] * repetitions
1501
+ assert res .shape == (0 , 20 , 0 )
1502
+
1503
+ repetitions = dpt .asarray (2 , dtype = "i4" )
1504
+ res = dpt .repeat (x , repetitions )
1505
+ assert res .shape == (0 ,)
1506
+ res = dpt .repeat (x , repetitions , axis = 2 )
1507
+ assert res .shape == x .shape
1508
+ res = dpt .repeat (x , repetitions , axis = 1 )
1509
+ assert res .shape == (0 , 20 , 0 )
1510
+
1511
+ repetitions = dpt .arange (10 , dtype = "i4" )
1512
+ res = dpt .repeat (x , repetitions , axis = 1 )
1513
+ axis_sz = dpt .sum (repetitions )
1514
+ assert res .shape == (0 , axis_sz , 0 )
1515
+
1516
+ repetitions = (2 ,) * 10
1517
+ res = dpt .repeat (x , repetitions , axis = 1 )
1518
+ axis_sz = 2 * x .shape [1 ]
1519
+ assert res .shape == (0 , axis_sz , 0 )
0 commit comments