From e7a878cb50082b49b6e28389d5ed5ad31ae47ff3 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 9 Mar 2023 21:33:54 +0100 Subject: [PATCH 1/2] Update empty vals check in dpt.place --- dpctl/tensor/_indexing_functions.py | 2 +- dpctl/tests/test_usm_ndarray_indexing.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index 18efa49519..f51045b01f 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -293,7 +293,7 @@ def place(arr, mask, vals): raise dpctl.utils.ExecutionPlacementError if arr.shape != mask.shape or vals.ndim != 1: raise ValueError("Array sizes are not as required") - if vals.size == 0: + if vals.size == 0 and dpt.nonzero(mask)[0].size: raise ValueError("Cannot insert from an empty array!") cumsum = dpt.empty(mask.size, dtype="i8", sycl_queue=exec_q) nz_count = ti.mask_positions(mask, cumsum, sycl_queue=exec_q) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index a3a1c54f05..0367d45f7b 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1211,6 +1211,16 @@ def test_place_empty_vals_error(): dpt.place(x, sel, y) +def test_place_empty_vals_full_false_mask(): + get_queue_or_skip() + x = dpt.ones(10, dtype="f4") + y = dpt.empty((0,), dtype=x.dtype) + sel = dpt.zeros(x.size, dtype="?") + expected = np.ones(10, dtype=x.dtype) + dpt.place(x, sel, y) + assert (dpt.asnumpy(x) == expected).all() + + def test_nonzero(): get_queue_or_skip() x = dpt.concat((dpt.zeros(3), dpt.ones(4), dpt.zeros(3))) From 10c20220bb04046d1f25ef5d8ae08d929db4d602 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 10 Mar 2023 11:31:30 +0100 Subject: [PATCH 2/2] Move check vals.size after getting nz_count --- dpctl/tensor/_indexing_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index f51045b01f..ee0587138d 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -293,12 +293,12 @@ def place(arr, mask, vals): raise dpctl.utils.ExecutionPlacementError if arr.shape != mask.shape or vals.ndim != 1: raise ValueError("Array sizes are not as required") - if vals.size == 0 and dpt.nonzero(mask)[0].size: - raise ValueError("Cannot insert from an empty array!") cumsum = dpt.empty(mask.size, dtype="i8", sycl_queue=exec_q) nz_count = ti.mask_positions(mask, cumsum, sycl_queue=exec_q) if nz_count == 0: return + if vals.size == 0: + raise ValueError("Cannot insert from an empty array!") if vals.dtype == arr.dtype: rhs = vals else: