Skip to content

Commit 20929be

Browse files
committed
Adds tests for where behavior with scalars
1 parent 8ab41e9 commit 20929be

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

dpctl/tests/test_usm_ndarray_search_functions.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import ctypes
18+
import itertools
19+
1720
import numpy as np
1821
import pytest
1922
from helper import get_queue_or_skip, skip_if_dtype_not_supported
@@ -522,3 +525,54 @@ def test_where_out_arg_validation():
522525
dpt.where(condition, x1, x2, out=out_wrong_shape)
523526
with pytest.raises(ValueError):
524527
dpt.where(condition, x1, x2, out=out_not_writable)
528+
529+
530+
@pytest.mark.parametrize("arr_dt", _all_dtypes)
531+
def test_where_python_scalar(arr_dt):
532+
q = get_queue_or_skip()
533+
skip_if_dtype_not_supported(arr_dt, q)
534+
535+
n1, n2 = 10, 10
536+
condition = dpt.tile(
537+
dpt.reshape(
538+
dpt.asarray([True, False], dtype="?", sycl_queue=q), (1, 2)
539+
),
540+
(n1, n2 // 2),
541+
)
542+
x = dpt.zeros((n1, n2), dtype=arr_dt, sycl_queue=q)
543+
py_scalars = (
544+
bool(0),
545+
int(0),
546+
float(0),
547+
complex(0),
548+
np.float32(0),
549+
ctypes.c_int(0),
550+
)
551+
for sc in py_scalars:
552+
r = dpt.where(condition, x, sc)
553+
assert isinstance(r, dpt.usm_ndarray)
554+
r = dpt.where(condition, sc, x)
555+
assert isinstance(r, dpt.usm_ndarray)
556+
557+
558+
def test_where_two_python_scalars():
559+
get_queue_or_skip()
560+
561+
n1, n2 = 10, 10
562+
condition = dpt.tile(
563+
dpt.reshape(dpt.asarray([True, False], dtype="?"), (1, 2)),
564+
(n1, n2 // 2),
565+
)
566+
567+
py_scalars = [
568+
bool(0),
569+
int(0),
570+
float(0),
571+
complex(0),
572+
np.float32(0),
573+
ctypes.c_int(0),
574+
]
575+
576+
for sc1, sc2 in itertools.product(py_scalars, repeat=2):
577+
r = dpt.where(condition, sc1, sc2)
578+
assert isinstance(r, dpt.usm_ndarray)

0 commit comments

Comments
 (0)