Skip to content

Commit db43042

Browse files
committed
Added where test
- Asymmetric dtype test to improve coverage
1 parent d50d1c6 commit db43042

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

dpctl/tests/test_usm_ndarray_search_functions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,24 @@ def test_where_all_dtypes(dt):
150150
assert _dtype_all_close(dpt.asnumpy(res), res_check)
151151

152152

153+
def test_where_asymmetric_dtypes():
154+
q = get_queue_or_skip()
155+
156+
cond = dpt.asarray([0, 1, 3, 0, 10], dtype="?", sycl_queue=q)
157+
x1 = dpt.asarray(2, dtype="i4", sycl_queue=q)
158+
x2 = dpt.asarray(3, dtype="i8", sycl_queue=q)
159+
160+
res = dpt.where(cond, x1, x2)
161+
res_check = np.asarray([3, 2, 2, 3, 2], dtype=res.dtype)
162+
assert _dtype_all_close(dpt.asnumpy(res), res_check)
163+
164+
# flip order
165+
166+
res = dpt.where(cond, x2, x1)
167+
res_check = np.asarray([2, 3, 3, 2, 3], dtype=res.dtype)
168+
assert _dtype_all_close(dpt.asnumpy(res), res_check)
169+
170+
153171
def test_where_nan_inf():
154172
get_queue_or_skip()
155173

0 commit comments

Comments
 (0)