From 0cd0d5ec37a5600fdde998c6c251441610a465ea Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 2 Dec 2024 09:05:20 -0600 Subject: [PATCH 1/2] Fix gh-1913 Check array_mask_nelems and early exit if zero. --- dpctl/tensor/_copy_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index e2f1bccac0..9dd53eb383 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -924,6 +924,8 @@ def _place_impl(ary, ary_mask, vals, axis=0): else: rhs = dpt.astype(vals, ary.dtype) rhs = dpt.broadcast_to(rhs, expected_vals_shape) + if mask_nelems == 0: + return dep_ev = _manager.submitted_events hev, pl_ev = ti._place( dst=ary, From 2a643dd69b4c8b8f421d719df94674650b40314b Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 2 Dec 2024 09:06:32 -0600 Subject: [PATCH 2/2] Add tests for setitem/getitem for boolean indexing with empty mask --- dpctl/tests/test_usm_ndarray_indexing.py | 25 ++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 56289f0fc5..05b0b278fc 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -405,6 +405,31 @@ def test_boolean_indexing_validation(): x[ii[0, :]] +def test_boolean_indexing_getitem_empty_mask(): + get_queue_or_skip() + x = dpt.ones((2, 3, 4), dtype="i4") + ii = dpt.ones((0,), dtype="?") + assert x[ii].size == 0 + ii1 = dpt.ones((0, 3), dtype="?") + assert x[ii1].size == 0 + ii2 = dpt.ones((0, 3, 4), dtype="?") + assert x[ii2].size == 0 + + +def test_boolean_indexing_setitem_empty_mask(): + get_queue_or_skip() + x = dpt.ones((2, 3, 4), dtype="i4") + ii = dpt.ones((0,), dtype="?") + x[ii] = 0 + assert dpt.all(x == 1) + ii1 = dpt.ones((0, 3), dtype="?") + x[ii1] = 0 + assert dpt.all(x == 1) + ii2 = dpt.ones((0, 3, 4), dtype="?") + x[ii2] = 0 + assert dpt.all(x == 1) + + def test_integer_indexing_1d(): get_queue_or_skip() x = dpt.arange(10, dtype="i4")