From d6733b9ff8b850486d9d166cf887160856f267c8 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 6 Mar 2023 07:37:27 -0600 Subject: [PATCH 1/2] Allow x[cond] = non_usm_array This allows `x[x<0] = 0` to work. Previously, it had to be `x[x<0] = dpt.asarray(0)`. --- dpctl/tensor/_copy_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 72b9e0a021..2900c9426a 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -514,13 +514,17 @@ def _place_impl(ary, ary_mask, vals, axis=0): raise TypeError( f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}" ) - if not isinstance(vals, dpt.usm_ndarray): - raise TypeError( - f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}" - ) exec_q = dpctl.utils.get_execution_queue( - (ary.sycl_queue, ary_mask.sycl_queue, vals.sycl_queue) + ( + ary.sycl_queue, + ary_mask.sycl_queue, + ) ) + if exec_q is not None: + if not isinstance(vals, dpt.usm_ndarray): + vals = dpt.asarray(vals, dtype=ary.dtype, sycl_queue=exec_q) + else: + exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue)) if exec_q is None: raise dpctl.utils.ExecutionPlacementError( "arrays have different associated queues. " From c162c2a40376320528adaf72e7b25aa8e7a83aa5 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 6 Mar 2023 07:43:55 -0600 Subject: [PATCH 2/2] Added tests for assignment of Pythons scalar. --- dpctl/tests/test_usm_ndarray_indexing.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 7201357c7d..43aef1b404 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1206,3 +1206,15 @@ def test_nonzero(): x = dpt.concat((dpt.zeros(3), dpt.ones(4), dpt.zeros(3))) (i,) = dpt.nonzero(x) assert (dpt.asnumpy(i) == np.array([3, 4, 5, 6])).all() + + +def test_assign_scalar(): + get_queue_or_skip() + x = dpt.arange(-5, 5, dtype="i8") + cond = dpt.asarray( + [True, True, True, True, True, False, False, False, False, False] + ) + x[cond] = 0 # no error expected + x[dpt.nonzero(cond)] = -1 + expected = np.array([-1, -1, -1, -1, -1, 0, 1, 2, 3, 4], dtype=x.dtype) + assert (dpt.asnumpy(x) == expected).all()