Skip to content

Commit b4b4a7c

Browse files
committed
Added tests for where and type utility functions
1 parent 1a996e9 commit b4b4a7c

File tree

2 files changed

+257
-22
lines changed

2 files changed

+257
-22
lines changed

dpctl/tests/test_type_utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
import dpctl.tensor as dpt
20+
from dpctl.tensor._type_utils import (
21+
_all_data_types,
22+
_can_cast,
23+
_is_maximal_inexact_type,
24+
)
25+
26+
27+
def test_all_data_types():
28+
fp16_fp64_types = set([dpt.float16, dpt.float64, dpt.complex128])
29+
fp64_types = set([dpt.float64, dpt.complex128])
30+
31+
all_dts = _all_data_types(True, True)
32+
assert fp16_fp64_types.issubset(all_dts)
33+
34+
all_dts = _all_data_types(True, False)
35+
assert dpt.float16 in all_dts
36+
assert not fp64_types.issubset(all_dts)
37+
38+
all_dts = _all_data_types(False, True)
39+
assert dpt.float16 not in all_dts
40+
assert fp64_types.issubset(all_dts)
41+
42+
all_dts = _all_data_types(False, False)
43+
assert not fp16_fp64_types.issubset(all_dts)
44+
45+
46+
@pytest.mark.parametrize("fp16", [True, False])
47+
@pytest.mark.parametrize("fp64", [True, False])
48+
def test_maximal_inexact_types(fp16, fp64):
49+
assert not _is_maximal_inexact_type(dpt.int32, fp16, fp64)
50+
assert fp64 == _is_maximal_inexact_type(dpt.float64, fp16, fp64)
51+
assert fp64 == _is_maximal_inexact_type(dpt.complex128, fp16, fp64)
52+
assert fp64 != _is_maximal_inexact_type(dpt.float32, fp16, fp64)
53+
assert fp64 != _is_maximal_inexact_type(dpt.complex64, fp16, fp64)
54+
55+
56+
def test_can_cast_device():
57+
assert _can_cast(dpt.int64, dpt.float64, True, True)
58+
# if f8 is available, can't cast i8 to f4
59+
assert not _can_cast(dpt.int64, dpt.float32, True, True)
60+
assert not _can_cast(dpt.int64, dpt.float32, False, True)
61+
# should be able to cast to f8 when f2 unavailable
62+
assert _can_cast(dpt.int64, dpt.float64, False, True)
63+
# casting to f4 acceptable when f8 unavailable
64+
assert _can_cast(dpt.int64, dpt.float32, True, False)
65+
assert _can_cast(dpt.int64, dpt.float32, False, False)
66+
# can't safely cast inexact type to inexact type of lesser precision
67+
assert not _can_cast(dpt.float32, dpt.float16, True, False)
68+
assert not _can_cast(dpt.float64, dpt.float32, False, True)

dpctl/tests/test_usm_ndarray_search_functions.py

Lines changed: 189 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Data Parallel Control (dpctl)
22
#
3-
# Copyright 2020-2022 Intel Corporation
3+
# Copyright 2020-2023 Intel Corporation
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -20,8 +20,12 @@
2020
from numpy.testing import assert_array_equal
2121

2222
import dpctl.tensor as dpt
23+
from dpctl.tensor._search_functions import _where_result_type
24+
from dpctl.tensor._type_utils import _all_data_types
25+
from dpctl.utils import ExecutionPlacementError
2326

2427
_all_dtypes = [
28+
"?",
2529
"u1",
2630
"i1",
2731
"u2",
@@ -38,6 +42,12 @@
3842
]
3943

4044

45+
class mock_device:
46+
def __init__(self, fp16, fp64):
47+
self.has_aspect_fp16 = fp16
48+
self.has_aspect_fp64 = fp64
49+
50+
4151
def test_where_basic():
4252
get_queue_or_skip()
4353

@@ -54,7 +64,16 @@ def test_where_basic():
5464
out_expected = dpt.asarray(
5565
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0], [1, 1, 1]]
5666
)
67+
assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all()
5768

69+
out = dpt.where(cond, dpt.ones(cond.shape), dpt.zeros(cond.shape))
70+
assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all()
71+
72+
out = dpt.where(
73+
cond,
74+
dpt.ones(cond.shape[0], dtype="i4")[:, dpt.newaxis],
75+
dpt.zeros(cond.shape[0], dtype="i4")[:, dpt.newaxis],
76+
)
5877
assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all()
5978

6079

@@ -72,38 +91,98 @@ def _dtype_all_close(x1, x2):
7291

7392
@pytest.mark.parametrize("dt1", _all_dtypes)
7493
@pytest.mark.parametrize("dt2", _all_dtypes)
75-
def test_where_all_dtypes(dt1, dt2):
94+
@pytest.mark.parametrize("fp16", [True, False])
95+
@pytest.mark.parametrize("fp64", [True, False])
96+
def test_where_result_types(dt1, dt2, fp16, fp64):
97+
dev = mock_device(fp16, fp64)
98+
99+
dt1 = dpt.dtype(dt1)
100+
dt2 = dpt.dtype(dt2)
101+
res_t = _where_result_type(dt1, dt2, dev)
102+
103+
if fp16 and fp64:
104+
assert res_t == dpt.result_type(dt1, dt2)
105+
else:
106+
if res_t:
107+
assert res_t.kind == dpt.result_type(dt1, dt2).kind
108+
else:
109+
# some illegal cases are covered above, but
110+
# this guarantees that _where_result_type
111+
# produces None only when one of the dtypes
112+
# is illegal given fp aspects of device
113+
all_dts = _all_data_types(fp16, fp64)
114+
assert dt1 not in all_dts or dt2 not in all_dts
115+
116+
117+
@pytest.mark.parametrize("dt", _all_dtypes)
118+
def test_where_all_dtypes(dt):
76119
q = get_queue_or_skip()
77-
skip_if_dtype_not_supported(dt1, q)
78-
skip_if_dtype_not_supported(dt2, q)
120+
skip_if_dtype_not_supported(dt, q)
79121

80-
cond = dpt.asarray([False, False, False, True, True], sycl_queue=q)
81-
x1 = dpt.asarray(2, sycl_queue=q)
82-
x2 = dpt.asarray(3, sycl_queue=q)
122+
# mask dtype changes
123+
cond = dpt.asarray([0, 1, 3, 0, 10], dtype=dt, sycl_queue=q)
124+
x1 = dpt.asarray(0, dtype="f", sycl_queue=q)
125+
x2 = dpt.asarray(1, dtype="f", sycl_queue=q)
126+
res = dpt.where(cond, x1, x2)
127+
128+
res_check = np.asarray([1, 0, 0, 1, 0], dtype=res.dtype)
129+
assert _dtype_all_close(dpt.asnumpy(res), res_check)
83130

131+
# contiguous cases
132+
x1 = dpt.full(cond.shape, 0, dtype="f4", sycl_queue=q)
133+
x2 = dpt.full(cond.shape, 1, dtype="f4", sycl_queue=q)
84134
res = dpt.where(cond, x1, x2)
85-
res_check = np.asarray([3, 3, 3, 2, 2], dtype=res.dtype)
135+
assert _dtype_all_close(dpt.asnumpy(res), res_check)
86136

87-
dev = q.sycl_device
137+
# input array dtype changes
138+
cond = dpt.asarray([False, True, True, False, True], sycl_queue=q)
139+
x1 = dpt.asarray(0, dtype=dt, sycl_queue=q)
140+
x2 = dpt.asarray(1, dtype=dt, sycl_queue=q)
141+
res = dpt.where(cond, x1, x2)
88142

89-
if not dev.has_aspect_fp16 or not dev.has_aspect_fp64:
90-
assert res.dtype.kind == dpt.result_type(x1.dtype, x2.dtype).kind
143+
res_check = np.asarray([1, 0, 0, 1, 0], dtype=res.dtype)
144+
assert _dtype_all_close(dpt.asnumpy(res), res_check)
91145

146+
# contiguous cases
147+
x1 = dpt.full(cond.shape, 0, dtype=dt, sycl_queue=q)
148+
x2 = dpt.full(cond.shape, 1, dtype=dt, sycl_queue=q)
149+
res = dpt.where(cond, x1, x2)
92150
assert _dtype_all_close(dpt.asnumpy(res), res_check)
93151

94152

153+
def test_where_nan_inf():
154+
get_queue_or_skip()
155+
156+
cond = dpt.asarray([True, False, True, False], dtype="?")
157+
x1 = dpt.asarray([np.nan, 2.0, np.inf, 3.0], dtype="f4")
158+
x2 = dpt.asarray([2.0, np.nan, 3.0, np.inf], dtype="f4")
159+
160+
cond_np = dpt.asnumpy(cond)
161+
x1_np = dpt.asnumpy(x1)
162+
x2_np = dpt.asnumpy(x2)
163+
164+
res = dpt.where(cond, x1, x2)
165+
res_np = np.where(cond_np, x1_np, x2_np)
166+
167+
assert np.allclose(dpt.asnumpy(res), res_np, equal_nan=True)
168+
169+
res = dpt.where(x1, cond, x2)
170+
res_np = np.where(x1_np, cond_np, x2_np)
171+
assert _dtype_all_close(dpt.asnumpy(res), res_np)
172+
173+
95174
def test_where_empty():
96175
# check that numpy returns same results when
97176
# handling empty arrays
98177
get_queue_or_skip()
99178

100-
empty = dpt.empty(0)
179+
empty = dpt.empty(0, dtype="i2")
101180
m = dpt.asarray(True)
102-
x1 = dpt.asarray(1)
103-
x2 = dpt.asarray(2)
181+
x1 = dpt.asarray(1, dtype="i2")
182+
x2 = dpt.asarray(2, dtype="i2")
104183
res = dpt.where(empty, x1, x2)
105184

106-
empty_np = np.empty(0)
185+
empty_np = np.empty(0, dtype="i2")
107186
m_np = dpt.asnumpy(m)
108187
x1_np = dpt.asnumpy(x1)
109188
x2_np = dpt.asnumpy(x2)
@@ -116,12 +195,14 @@ def test_where_empty():
116195

117196
assert_array_equal(dpt.asnumpy(res), res_np)
118197

198+
# check that broadcasting is performed
199+
with pytest.raises(ValueError):
200+
dpt.where(empty, x1, dpt.empty((1, 2)))
201+
119202

120-
@pytest.mark.parametrize("dt", _all_dtypes)
121203
@pytest.mark.parametrize("order", ["C", "F"])
122-
def test_where_contiguous(dt, order):
123-
q = get_queue_or_skip()
124-
skip_if_dtype_not_supported(dt, q)
204+
def test_where_contiguous(order):
205+
get_queue_or_skip()
125206

126207
cond = dpt.asarray(
127208
[
@@ -131,14 +212,100 @@ def test_where_contiguous(dt, order):
131212
[[False, False, False], [True, False, True]],
132213
[[True, True, True], [True, False, True]],
133214
],
134-
sycl_queue=q,
135215
order=order,
136216
)
137217

138-
x1 = dpt.full(cond.shape, 2, dtype=dt, order=order, sycl_queue=q)
139-
x2 = dpt.full(cond.shape, 3, dtype=dt, order=order, sycl_queue=q)
218+
x1 = dpt.full(cond.shape, 2, dtype="i4", order=order)
219+
x2 = dpt.full(cond.shape, 3, dtype="i4", order=order)
220+
expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2))
221+
res = dpt.where(cond, x1, x2)
222+
223+
assert _dtype_all_close(dpt.asnumpy(res), expected)
224+
225+
226+
def test_where_contiguous1D():
227+
get_queue_or_skip()
140228

229+
cond = dpt.asarray([True, False, True, False, False, True])
230+
231+
x1 = dpt.full(cond.shape, 2, dtype="i4")
232+
x2 = dpt.full(cond.shape, 3, dtype="i4")
141233
expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2))
142234
res = dpt.where(cond, x1, x2)
235+
assert_array_equal(dpt.asnumpy(res), expected)
143236

237+
# test with complex dtype (branch in kernel)
238+
x1 = dpt.astype(x1, dpt.complex64)
239+
x2 = dpt.astype(x2, dpt.complex64)
240+
expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2))
241+
res = dpt.where(cond, x1, x2)
144242
assert _dtype_all_close(dpt.asnumpy(res), expected)
243+
244+
245+
def test_where_strided():
246+
get_queue_or_skip()
247+
248+
s0, s1 = 4, 9
249+
cond = dpt.reshape(
250+
dpt.asarray(
251+
[True, False, False, False, True, True, False, True, False] * s0
252+
),
253+
(s0, s1),
254+
)[:, ::3]
255+
256+
x1 = dpt.reshape(
257+
dpt.arange(cond.shape[0] * cond.shape[1] * 2, dtype="i4"),
258+
(cond.shape[0], cond.shape[1] * 2),
259+
)[:, ::2]
260+
x2 = dpt.reshape(
261+
dpt.arange(cond.shape[0] * cond.shape[1] * 3, dtype="i4"),
262+
(cond.shape[0], cond.shape[1] * 3),
263+
)[:, ::3]
264+
expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2))
265+
res = dpt.where(cond, x1, x2)
266+
267+
assert_array_equal(dpt.asnumpy(res), expected)
268+
269+
# negative strides
270+
res = dpt.where(cond, dpt.flip(x1), x2)
271+
expected = np.where(
272+
dpt.asnumpy(cond), np.flip(dpt.asnumpy(x1)), dpt.asnumpy(x2)
273+
)
274+
assert_array_equal(dpt.asnumpy(res), expected)
275+
276+
res = dpt.where(dpt.flip(cond), x1, x2)
277+
expected = np.where(
278+
np.flip(dpt.asnumpy(cond)), dpt.asnumpy(x1), dpt.asnumpy(x2)
279+
)
280+
assert_array_equal(dpt.asnumpy(res), expected)
281+
282+
283+
def test_where_arg_validation():
284+
get_queue_or_skip()
285+
286+
check = dict()
287+
x1 = dpt.empty((1,), dtype="i4")
288+
x2 = dpt.empty((1,), dtype="i4")
289+
290+
with pytest.raises(TypeError):
291+
dpt.where(check, x1, x2)
292+
with pytest.raises(TypeError):
293+
dpt.where(x1, check, x2)
294+
with pytest.raises(TypeError):
295+
dpt.where(x1, x2, check)
296+
297+
298+
def test_where_compute_follows_data():
299+
q1 = get_queue_or_skip()
300+
q2 = get_queue_or_skip()
301+
q3 = get_queue_or_skip()
302+
303+
x1 = dpt.empty((1,), dtype="i4", sycl_queue=q1)
304+
x2 = dpt.empty((1,), dtype="i4", sycl_queue=q2)
305+
306+
with pytest.raises(ExecutionPlacementError):
307+
dpt.where(dpt.empty((1,), dtype="i4", sycl_queue=q1), x1, x2)
308+
with pytest.raises(ExecutionPlacementError):
309+
dpt.where(dpt.empty((1,), dtype="i4", sycl_queue=q3), x1, x2)
310+
with pytest.raises(ExecutionPlacementError):
311+
dpt.where(x1, x1, x2)

0 commit comments

Comments
 (0)