diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index 18efa49519..ee0587138d 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -293,12 +293,12 @@ def place(arr, mask, vals): raise dpctl.utils.ExecutionPlacementError if arr.shape != mask.shape or vals.ndim != 1: raise ValueError("Array sizes are not as required") - if vals.size == 0: - raise ValueError("Cannot insert from an empty array!") cumsum = dpt.empty(mask.size, dtype="i8", sycl_queue=exec_q) nz_count = ti.mask_positions(mask, cumsum, sycl_queue=exec_q) if nz_count == 0: return + if vals.size == 0: + raise ValueError("Cannot insert from an empty array!") if vals.dtype == arr.dtype: rhs = vals else: diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index a3a1c54f05..0367d45f7b 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1211,6 +1211,16 @@ def test_place_empty_vals_error(): dpt.place(x, sel, y) +def test_place_empty_vals_full_false_mask(): + get_queue_or_skip() + x = dpt.ones(10, dtype="f4") + y = dpt.empty((0,), dtype=x.dtype) + sel = dpt.zeros(x.size, dtype="?") + expected = np.ones(10, dtype=x.dtype) + dpt.place(x, sel, y) + assert (dpt.asnumpy(x) == expected).all() + + def test_nonzero(): get_queue_or_skip() x = dpt.concat((dpt.zeros(3), dpt.ones(4), dpt.zeros(3)))