diff --git a/CHANGELOG.md b/CHANGELOG.md index 82b68cc164..1d2fdd39a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +* Added `out` keyword to `tensor.take` [gh-2010](https://github.com/IntelPython/dpctl/pull/2010) + ### Changed ### Fixed diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index 0c15e863b4..4f04a6094c 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -40,8 +40,8 @@ def _get_indexing_mode(name): ) -def take(x, indices, /, *, axis=None, mode="wrap"): - """take(x, indices, axis=None, mode="wrap") +def take(x, indices, /, *, axis=None, out=None, mode="wrap"): + """take(x, indices, axis=None, out=None, mode="wrap") Takes elements from an array along a given axis at given indices. @@ -54,6 +54,9 @@ def take(x, indices, /, *, axis=None, mode="wrap"): The axis along which the values will be selected. If ``x`` is one-dimensional, this argument is optional. Default: ``None``. + out (Optional[usm_ndarray]): + Output array to populate. Array must have the correct + shape and the expected data type. mode (str, optional): How out-of-bounds indices will be handled. Possible values are: @@ -121,18 +124,53 @@ def take(x, indices, /, *, axis=None, mode="wrap"): raise ValueError("`axis` must be 0 for an array of dimension 0.") res_shape = indices.shape - res = dpt.empty( - res_shape, dtype=x.dtype, usm_type=res_usm_type, sycl_queue=exec_q - ) + dt = x.dtype + + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + + if out.shape != res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {res_shape}, got {out.shape}" + ) + if dt != out.dtype: + raise ValueError( + f"Output array of type {dt} is needed, got {out.dtype}" + ) + if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + raise dpctl.utils.ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + if ti._array_overlap(x, out): + out = dpt.empty_like(out) + else: + out = dpt.empty( + res_shape, dtype=dt, usm_type=res_usm_type, sycl_queue=exec_q + ) _manager = dpctl.utils.SequentialOrderManager[exec_q] deps_ev = _manager.submitted_events hev, take_ev = ti._take( - x, (indices,), res, axis, mode, sycl_queue=exec_q, depends=deps_ev + x, (indices,), out, axis, mode, sycl_queue=exec_q, depends=deps_ev ) _manager.add_event_pair(hev, take_ev) - return res + if not (orig_out is None or out is orig_out): + # Copy the out data from temporary buffer to original memory + ht_e_cpy, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=exec_q, depends=[take_ev] + ) + _manager.add_event_pair(ht_e_cpy, cpy_ev) + out = orig_out + + return out def put(x, indices, vals, /, *, axis=None, mode="wrap"): diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index b2911b0b91..78501580a8 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -625,7 +625,7 @@ def test_put_0d_val(data_dt): skip_if_dtype_not_supported(data_dt, q) x = dpt.arange(5, dtype=data_dt, sycl_queue=q) - ind = dpt.asarray([0], dtype=np.intp, sycl_queue=q) + ind = dpt.asarray([0], dtype="i8", sycl_queue=q) val = dpt.asarray(2, dtype=x.dtype, sycl_queue=q) x[ind] = val assert_array_equal(np.asarray(2, dtype=data_dt), dpt.asnumpy(x[0])) @@ -644,7 +644,7 @@ def test_take_0d_data(data_dt): skip_if_dtype_not_supported(data_dt, q) x = dpt.asarray(0, dtype=data_dt, sycl_queue=q) - ind = dpt.arange(5, dtype=np.intp, sycl_queue=q) + ind = dpt.arange(5, dtype="i8", sycl_queue=q) y = dpt.take(x, ind) assert ( @@ -662,7 +662,7 @@ def test_put_0d_data(data_dt): skip_if_dtype_not_supported(data_dt, q) x = dpt.asarray(0, dtype=data_dt, sycl_queue=q) - ind = dpt.arange(5, dtype=np.intp, sycl_queue=q) + ind = dpt.arange(5, dtype="i8", sycl_queue=q) val = dpt.asarray(2, dtype=data_dt, sycl_queue=q) dpt.put(x, ind, val, axis=0) @@ -710,7 +710,7 @@ def test_take_strided_1d_source(data_dt): skip_if_dtype_not_supported(data_dt, q) x = dpt.arange(27, dtype=data_dt, sycl_queue=q) - ind = dpt.arange(4, 9, dtype=np.intp, sycl_queue=q) + ind = dpt.arange(4, 9, dtype="i8", sycl_queue=q) x_np = dpt.asnumpy(x) ind_np = dpt.asnumpy(ind) @@ -748,7 +748,7 @@ def test_take_strided(data_dt, order): skip_if_dtype_not_supported(data_dt, q) x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order) - ind = dpt.arange(2, dtype=np.intp, sycl_queue=q) + ind = dpt.arange(2, dtype="i8", sycl_queue=q) x_np = dpt.asnumpy(x) ind_np = dpt.asnumpy(ind) @@ -781,7 +781,7 @@ def test_take_strided_1d_indices(ind_dt): ind = dpt.arange(12, 24, dtype=ind_dt, sycl_queue=q) x_np = dpt.asnumpy(x) - ind_np = dpt.asnumpy(ind).astype(np.intp) + ind_np = dpt.asnumpy(ind).astype("i8") for s in ( slice(None, None, 2), @@ -820,7 +820,7 @@ def test_take_strided_indices(ind_dt, order): ) x_np = dpt.asnumpy(x) - ind_np = dpt.asnumpy(ind).astype(np.intp) + ind_np = dpt.asnumpy(ind).astype("i8") for s in ( slice(None, None, 2), @@ -845,7 +845,7 @@ def test_put_strided_1d_destination(data_dt, order): skip_if_dtype_not_supported(data_dt, q) x = dpt.arange(27, dtype=data_dt, sycl_queue=q) - ind = dpt.arange(4, 9, dtype=np.intp, sycl_queue=q) + ind = dpt.arange(4, 9, dtype="i8", sycl_queue=q) val = dpt.asarray(9, dtype=x.dtype, sycl_queue=q) x_np = dpt.asnumpy(x) @@ -875,7 +875,7 @@ def test_put_strided_destination(data_dt, order): skip_if_dtype_not_supported(data_dt, q) x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order) - ind = dpt.arange(2, dtype=np.intp, sycl_queue=q) + ind = dpt.arange(2, dtype="i8", sycl_queue=q) val = dpt.asarray(9, dtype=x.dtype, sycl_queue=q) x_np = dpt.asnumpy(x) @@ -924,7 +924,7 @@ def test_put_strided_1d_indices(ind_dt): val = dpt.asarray(-1, dtype=x.dtype, sycl_queue=q) x_np = dpt.asnumpy(x) - ind_np = dpt.asnumpy(ind).astype(np.intp) + ind_np = dpt.asnumpy(ind).astype("i8") val_np = dpt.asnumpy(val) for s in ( @@ -955,7 +955,7 @@ def test_put_strided_indices(ind_dt, order): val = dpt.asarray(-1, sycl_queue=q, dtype=x.dtype) x_np = dpt.asnumpy(x) - ind_np = dpt.asnumpy(ind).astype(np.intp) + ind_np = dpt.asnumpy(ind).astype("i8") val_np = dpt.asnumpy(val) for s in ( @@ -982,7 +982,7 @@ def test_integer_indexing_modes(): x_np = dpt.asnumpy(x) # wrapping negative indices - ind = dpt.asarray([-4, -3, 0, 2, 4], dtype=np.intp, sycl_queue=q) + ind = dpt.asarray([-4, -3, 0, 2, 4], dtype="i8", sycl_queue=q) res = dpt.take(x, ind, mode="wrap") expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="raise") @@ -990,7 +990,7 @@ def test_integer_indexing_modes(): assert (dpt.asnumpy(res) == expected_arr).all() # clipping to 0 (disabling negative indices) - ind = dpt.asarray([-6, -3, 0, 2, 6], dtype=np.intp, sycl_queue=q) + ind = dpt.asarray([-6, -3, 0, 2, 6], dtype="i8", sycl_queue=q) res = dpt.take(x, ind, mode="clip") expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="clip") @@ -1002,7 +1002,7 @@ def test_take_arg_validation(): q = get_queue_or_skip() x = dpt.arange(4, dtype="i4", sycl_queue=q) - ind0 = dpt.arange(4, dtype=np.intp, sycl_queue=q) + ind0 = dpt.arange(4, dtype="i8", sycl_queue=q) ind1 = dpt.arange(2.0, dtype="f", sycl_queue=q) with pytest.raises(TypeError): @@ -1034,7 +1034,7 @@ def test_put_arg_validation(): q = get_queue_or_skip() x = dpt.arange(4, dtype="i4", sycl_queue=q) - ind0 = dpt.arange(4, dtype=np.intp, sycl_queue=q) + ind0 = dpt.arange(4, dtype="i8", sycl_queue=q) ind1 = dpt.arange(2.0, dtype="f", sycl_queue=q) val = dpt.asarray(2, dtype=x.dtype, sycl_queue=q) @@ -1890,3 +1890,69 @@ def test_put_along_axis_uint64_indices(): dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype), axis=1) expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), (2, 5)) assert dpt.all(expected == x) + + +@pytest.mark.parametrize("data_dt", _all_dtypes) +@pytest.mark.parametrize("order", ["C", "F"]) +def test_take_out(data_dt, order): + q = get_queue_or_skip() + skip_if_dtype_not_supported(data_dt, q) + + axis = 0 + x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order) + ind = dpt.arange(2, dtype="i8", sycl_queue=q) + out_sh = x.shape[:axis] + ind.shape + x.shape[axis + 1 :] + out = dpt.empty(out_sh, dtype=data_dt, sycl_queue=q) + + expected = dpt.take(x, ind, axis=axis) + + dpt.take(x, ind, axis=axis, out=out) + + assert dpt.all(out == expected) + + +@pytest.mark.parametrize("data_dt", _all_dtypes) +@pytest.mark.parametrize("order", ["C", "F"]) +def test_take_out_overlap(data_dt, order): + q = get_queue_or_skip() + skip_if_dtype_not_supported(data_dt, q) + + axis = 0 + x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order) + ind = dpt.arange(2, dtype="i8", sycl_queue=q) + out = x[x.shape[axis] - ind.shape[axis] : x.shape[axis], :] + + expected = dpt.take(x, ind, axis=axis) + + dpt.take(x, ind, axis=axis, out=out) + + assert dpt.all(out == expected) + assert dpt.all(x[x.shape[0] - ind.shape[0] : x.shape[0], :] == out) + + +def test_take_out_errors(): + q1 = get_queue_or_skip() + q2 = get_queue_or_skip() + + x = dpt.arange(10, dtype="i4", sycl_queue=q1) + ind = dpt.arange(2, dtype="i4", sycl_queue=q1) + + with pytest.raises(TypeError): + dpt.take(x, ind, out=dict()) + + out_read_only = dpt.empty(ind.shape, dtype=x.dtype, sycl_queue=q1) + out_read_only.flags["W"] = False + with pytest.raises(ValueError): + dpt.take(x, ind, out=out_read_only) + + out_bad_shape = dpt.empty(0, dtype=x.dtype, sycl_queue=q1) + with pytest.raises(ValueError): + dpt.take(x, ind, out=out_bad_shape) + + out_bad_dt = dpt.empty(ind.shape, dtype="i8", sycl_queue=q1) + with pytest.raises(ValueError): + dpt.take(x, ind, out=out_bad_dt) + + out_bad_q = dpt.empty(ind.shape, dtype=x.dtype, sycl_queue=q2) + with pytest.raises(dpctl.utils.ExecutionPlacementError): + dpt.take(x, ind, out=out_bad_q)