diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 72b9e0a021..2900c9426a 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -514,13 +514,17 @@ def _place_impl(ary, ary_mask, vals, axis=0): raise TypeError( f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}" ) - if not isinstance(vals, dpt.usm_ndarray): - raise TypeError( - f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}" - ) exec_q = dpctl.utils.get_execution_queue( - (ary.sycl_queue, ary_mask.sycl_queue, vals.sycl_queue) + ( + ary.sycl_queue, + ary_mask.sycl_queue, + ) ) + if exec_q is not None: + if not isinstance(vals, dpt.usm_ndarray): + vals = dpt.asarray(vals, dtype=ary.dtype, sycl_queue=exec_q) + else: + exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue)) if exec_q is None: raise dpctl.utils.ExecutionPlacementError( "arrays have different associated queues. " diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 7201357c7d..43aef1b404 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1206,3 +1206,15 @@ def test_nonzero(): x = dpt.concat((dpt.zeros(3), dpt.ones(4), dpt.zeros(3))) (i,) = dpt.nonzero(x) assert (dpt.asnumpy(i) == np.array([3, 4, 5, 6])).all() + + +def test_assign_scalar(): + get_queue_or_skip() + x = dpt.arange(-5, 5, dtype="i8") + cond = dpt.asarray( + [True, True, True, True, True, False, False, False, False, False] + ) + x[cond] = 0 # no error expected + x[dpt.nonzero(cond)] = -1 + expected = np.array([-1, -1, -1, -1, -1, 0, 1, 2, 3, 4], dtype=x.dtype) + assert (dpt.asnumpy(x) == expected).all()