|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 | 16 |
|
| 17 | +import ctypes |
| 18 | +import itertools |
| 19 | + |
17 | 20 | import numpy as np
|
18 | 21 | import pytest
|
19 | 22 | from helper import get_queue_or_skip, skip_if_dtype_not_supported
|
@@ -522,3 +525,54 @@ def test_where_out_arg_validation():
|
522 | 525 | dpt.where(condition, x1, x2, out=out_wrong_shape)
|
523 | 526 | with pytest.raises(ValueError):
|
524 | 527 | dpt.where(condition, x1, x2, out=out_not_writable)
|
| 528 | + |
| 529 | + |
| 530 | +@pytest.mark.parametrize("arr_dt", _all_dtypes) |
| 531 | +def test_where_python_scalar(arr_dt): |
| 532 | + q = get_queue_or_skip() |
| 533 | + skip_if_dtype_not_supported(arr_dt, q) |
| 534 | + |
| 535 | + n1, n2 = 10, 10 |
| 536 | + condition = dpt.tile( |
| 537 | + dpt.reshape( |
| 538 | + dpt.asarray([True, False], dtype="?", sycl_queue=q), (1, 2) |
| 539 | + ), |
| 540 | + (n1, n2 // 2), |
| 541 | + ) |
| 542 | + x = dpt.zeros((n1, n2), dtype=arr_dt, sycl_queue=q) |
| 543 | + py_scalars = ( |
| 544 | + bool(0), |
| 545 | + int(0), |
| 546 | + float(0), |
| 547 | + complex(0), |
| 548 | + np.float32(0), |
| 549 | + ctypes.c_int(0), |
| 550 | + ) |
| 551 | + for sc in py_scalars: |
| 552 | + r = dpt.where(condition, x, sc) |
| 553 | + assert isinstance(r, dpt.usm_ndarray) |
| 554 | + r = dpt.where(condition, sc, x) |
| 555 | + assert isinstance(r, dpt.usm_ndarray) |
| 556 | + |
| 557 | + |
| 558 | +def test_where_two_python_scalars(): |
| 559 | + get_queue_or_skip() |
| 560 | + |
| 561 | + n1, n2 = 10, 10 |
| 562 | + condition = dpt.tile( |
| 563 | + dpt.reshape(dpt.asarray([True, False], dtype="?"), (1, 2)), |
| 564 | + (n1, n2 // 2), |
| 565 | + ) |
| 566 | + |
| 567 | + py_scalars = [ |
| 568 | + bool(0), |
| 569 | + int(0), |
| 570 | + float(0), |
| 571 | + complex(0), |
| 572 | + np.float32(0), |
| 573 | + ctypes.c_int(0), |
| 574 | + ] |
| 575 | + |
| 576 | + for sc1, sc2 in itertools.product(py_scalars, repeat=2): |
| 577 | + r = dpt.where(condition, sc1, sc2) |
| 578 | + assert isinstance(r, dpt.usm_ndarray) |
0 commit comments