Skip to content

Commit cab0035

Browse files
Fixed tests for boolean indexing
1 parent f75723b commit cab0035

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -972,10 +972,8 @@ def test_advanced_indexing_compute_follows_data():
972972
x[ind0] = val1
973973

974974

975-
#######
976-
977-
978975
def test_extract_all_1d():
976+
get_queue_or_skip()
979977
x = dpt.arange(30, dtype="i4")
980978
sel = dpt.ones(30, dtype="?")
981979
sel[::2] = False
@@ -989,6 +987,7 @@ def test_extract_all_1d():
989987

990988

991989
def test_extract_all_2d():
990+
get_queue_or_skip()
992991
x = dpt.reshape(dpt.arange(30, dtype="i4"), (5, 6))
993992
sel = dpt.ones(30, dtype="?")
994993
sel[::2] = False
@@ -1003,6 +1002,7 @@ def test_extract_all_2d():
10031002

10041003

10051004
def test_extract_2D_axis0():
1005+
get_queue_or_skip()
10061006
x = dpt.reshape(dpt.arange(30, dtype="i4"), (5, 6))
10071007
sel = dpt.ones(x.shape[0], dtype="?")
10081008
sel[::2] = False
@@ -1013,6 +1013,7 @@ def test_extract_2D_axis0():
10131013

10141014

10151015
def test_extract_2D_axis1():
1016+
get_queue_or_skip()
10161017
x = dpt.reshape(dpt.arange(30, dtype="i4"), (5, 6))
10171018
sel = dpt.ones(x.shape[1], dtype="?")
10181019
sel[::2] = False
@@ -1023,6 +1024,7 @@ def test_extract_2D_axis1():
10231024

10241025

10251026
def test_extract_begin():
1027+
get_queue_or_skip()
10261028
x = dpt.reshape(dpt.arange(3 * 3 * 4 * 4, dtype="i2"), (3, 4, 3, 4))
10271029
y = dpt.permute_dims(x, (2, 0, 3, 1))
10281030
sel = dpt.zeros((3, 3), dtype="?")
@@ -1034,6 +1036,7 @@ def test_extract_begin():
10341036

10351037

10361038
def test_extract_end():
1039+
get_queue_or_skip()
10371040
x = dpt.reshape(dpt.arange(3 * 3 * 4 * 4, dtype="i2"), (3, 4, 3, 4))
10381041
y = dpt.permute_dims(x, (2, 0, 3, 1))
10391042
sel = dpt.zeros((4, 4), dtype="?")
@@ -1044,6 +1047,7 @@ def test_extract_end():
10441047

10451048

10461049
def test_extract_middle():
1050+
get_queue_or_skip()
10471051
x = dpt.reshape(dpt.arange(3 * 3 * 4 * 4, dtype="i2"), (3, 4, 3, 4))
10481052
y = dpt.permute_dims(x, (2, 0, 3, 1))
10491053
sel = dpt.zeros((3, 4), dtype="?")
@@ -1054,6 +1058,7 @@ def test_extract_middle():
10541058

10551059

10561060
def test_extract_empty_result():
1061+
get_queue_or_skip()
10571062
x = dpt.reshape(dpt.arange(3 * 3 * 4 * 4, dtype="i2"), (3, 4, 3, 4))
10581063
y = dpt.permute_dims(x, (2, 0, 3, 1))
10591064
sel = dpt.zeros((3, 4), dtype="?")
@@ -1066,17 +1071,19 @@ def test_extract_empty_result():
10661071

10671072

10681073
def test_place_all_1d():
1074+
get_queue_or_skip()
10691075
x = dpt.arange(10, dtype="i2")
10701076
sel = dpt.zeros(10, dtype="?")
10711077
sel[0::2] = True
10721078
val = dpt.zeros(5, dtype=x.dtype)
10731079
x[sel] = val
10741080
assert (dpt.asnumpy(x) == np.array([0, 1, 0, 3, 0, 5, 0, 7, 0, 9])).all()
1075-
dpt.place(x, sel, dpt.asarray(2))
1081+
dpt.place(x, sel, dpt.asarray([2]))
10761082
assert (dpt.asnumpy(x) == np.array([2, 1, 2, 3, 2, 5, 2, 7, 2, 9])).all()
10771083

10781084

10791085
def test_place_2d_axis0():
1086+
get_queue_or_skip()
10801087
x = dpt.reshape(dpt.arange(12, dtype="i2"), (3, 4))
10811088
sel = dpt.asarray([True, False, True])
10821089
val = dpt.zeros((2, 4), dtype=x.dtype)
@@ -1092,6 +1099,7 @@ def test_place_2d_axis0():
10921099

10931100

10941101
def test_place_2d_axis1():
1102+
get_queue_or_skip()
10951103
x = dpt.reshape(dpt.arange(12, dtype="i2"), (3, 4))
10961104
sel = dpt.asarray([True, False, True, False])
10971105
val = dpt.zeros((3, 2), dtype=x.dtype)
@@ -1103,6 +1111,7 @@ def test_place_2d_axis1():
11031111

11041112

11051113
def test_place_2d_axis1_scalar():
1114+
get_queue_or_skip()
11061115
x = dpt.reshape(dpt.arange(12, dtype="i2"), (3, 4))
11071116
sel = dpt.asarray([True, False, True, False])
11081117
val = dpt.zeros(tuple(), dtype=x.dtype)
@@ -1114,6 +1123,7 @@ def test_place_2d_axis1_scalar():
11141123

11151124

11161125
def test_place_all_slices():
1126+
get_queue_or_skip()
11171127
x = dpt.reshape(dpt.arange(12, dtype="i2"), (3, 4))
11181128
sel = dpt.asarray(
11191129
[
@@ -1128,6 +1138,7 @@ def test_place_all_slices():
11281138

11291139

11301140
def test_place_some_slices_begin():
1141+
get_queue_or_skip()
11311142
x = dpt.reshape(dpt.arange(3 * 3 * 4 * 4, dtype="i2"), (3, 4, 3, 4))
11321143
y = dpt.permute_dims(x, (2, 0, 3, 1))
11331144
sel = dpt.zeros((3, 3), dtype="?")
@@ -1139,6 +1150,7 @@ def test_place_some_slices_begin():
11391150

11401151

11411152
def test_place_some_slices_mid():
1153+
get_queue_or_skip()
11421154
x = dpt.reshape(dpt.arange(3 * 3 * 4 * 4, dtype="i2"), (3, 4, 3, 4))
11431155
y = dpt.permute_dims(x, (2, 0, 3, 1))
11441156
sel = dpt.zeros((3, 4), dtype="?")
@@ -1150,6 +1162,7 @@ def test_place_some_slices_mid():
11501162

11511163

11521164
def test_place_some_slices_end():
1165+
get_queue_or_skip()
11531166
x = dpt.reshape(dpt.arange(3 * 3 * 4 * 4, dtype="i2"), (3, 4, 3, 4))
11541167
y = dpt.permute_dims(x, (2, 0, 3, 1))
11551168
sel = dpt.zeros((4, 4), dtype="?")
@@ -1161,6 +1174,7 @@ def test_place_some_slices_end():
11611174

11621175

11631176
def test_place_cycling():
1177+
get_queue_or_skip()
11641178
x = dpt.zeros(10, dtype="f4")
11651179
y = dpt.asarray([2, 3])
11661180
sel = dpt.ones(x.size, dtype="?")
@@ -1177,16 +1191,18 @@ def test_place_cycling():
11771191

11781192

11791193
def test_place_subset():
1194+
get_queue_or_skip()
11801195
x = dpt.zeros(10, dtype="f4")
11811196
y = dpt.ones_like(x)
11821197
sel = dpt.ones(x.size, dtype="?")
11831198
sel[::2] = False
11841199
dpt.place(x, sel, y)
1185-
expected = np.array([1, 3, 5, 7, 9], dtype=x.dtype)
1200+
expected = np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], dtype=x.dtype)
11861201
assert (dpt.asnumpy(x) == expected).all()
11871202

11881203

11891204
def test_nonzero():
1205+
get_queue_or_skip()
11901206
x = dpt.concat((dpt.zeros(3), dpt.ones(4), dpt.zeros(3)))
11911207
(i,) = dpt.nonzero(x)
1192-
assert dpt.asnumpy(i) == np.array([3, 4, 5, 6]).all()
1208+
assert (dpt.asnumpy(i) == np.array([3, 4, 5, 6])).all()

0 commit comments

Comments
 (0)