Skip to content

Add out keyword to dpt.take #2010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

* Added `out` keyword to `tensor.take` [gh-2010](https://github.com/IntelPython/dpctl/pull/2010)

### Changed

### Fixed
Expand Down
52 changes: 45 additions & 7 deletions dpctl/tensor/_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def _get_indexing_mode(name):
)


def take(x, indices, /, *, axis=None, mode="wrap"):
"""take(x, indices, axis=None, mode="wrap")
def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
"""take(x, indices, axis=None, out=None, mode="wrap")

Takes elements from an array along a given axis at given indices.

Expand All @@ -54,6 +54,9 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
The axis along which the values will be selected.
If ``x`` is one-dimensional, this argument is optional.
Default: ``None``.
out (Optional[usm_ndarray]):
Output array to populate. Array must have the correct
shape and the expected data type.
mode (str, optional):
How out-of-bounds indices will be handled. Possible values
are:
Expand Down Expand Up @@ -121,18 +124,53 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
raise ValueError("`axis` must be 0 for an array of dimension 0.")
res_shape = indices.shape

res = dpt.empty(
res_shape, dtype=x.dtype, usm_type=res_usm_type, sycl_queue=exec_q
)
dt = x.dtype

orig_out = out
if out is not None:
if not isinstance(out, dpt.usm_ndarray):
raise TypeError(
f"output array must be of usm_ndarray type, got {type(out)}"
)
if not out.flags.writable:
raise ValueError("provided `out` array is read-only")

if out.shape != res_shape:
raise ValueError(
"The shape of input and output arrays are inconsistent. "
f"Expected output shape is {res_shape}, got {out.shape}"
)
if dt != out.dtype:
raise ValueError(
f"Output array of type {dt} is needed, got {out.dtype}"
)
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
raise dpctl.utils.ExecutionPlacementError(
"Input and output allocation queues are not compatible"
)
if ti._array_overlap(x, out):
out = dpt.empty_like(out)
else:
out = dpt.empty(
res_shape, dtype=dt, usm_type=res_usm_type, sycl_queue=exec_q
)

_manager = dpctl.utils.SequentialOrderManager[exec_q]
deps_ev = _manager.submitted_events
hev, take_ev = ti._take(
x, (indices,), res, axis, mode, sycl_queue=exec_q, depends=deps_ev
x, (indices,), out, axis, mode, sycl_queue=exec_q, depends=deps_ev
)
_manager.add_event_pair(hev, take_ev)

return res
if not (orig_out is None or out is orig_out):
# Copy the out data from temporary buffer to original memory
ht_e_cpy, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=out, dst=orig_out, sycl_queue=exec_q, depends=[take_ev]
)
_manager.add_event_pair(ht_e_cpy, cpy_ev)
out = orig_out

return out


def put(x, indices, vals, /, *, axis=None, mode="wrap"):
Expand Down
96 changes: 81 additions & 15 deletions dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def test_put_0d_val(data_dt):
skip_if_dtype_not_supported(data_dt, q)

x = dpt.arange(5, dtype=data_dt, sycl_queue=q)
ind = dpt.asarray([0], dtype=np.intp, sycl_queue=q)
ind = dpt.asarray([0], dtype="i8", sycl_queue=q)
val = dpt.asarray(2, dtype=x.dtype, sycl_queue=q)
x[ind] = val
assert_array_equal(np.asarray(2, dtype=data_dt), dpt.asnumpy(x[0]))
Expand All @@ -644,7 +644,7 @@ def test_take_0d_data(data_dt):
skip_if_dtype_not_supported(data_dt, q)

x = dpt.asarray(0, dtype=data_dt, sycl_queue=q)
ind = dpt.arange(5, dtype=np.intp, sycl_queue=q)
ind = dpt.arange(5, dtype="i8", sycl_queue=q)

y = dpt.take(x, ind)
assert (
Expand All @@ -662,7 +662,7 @@ def test_put_0d_data(data_dt):
skip_if_dtype_not_supported(data_dt, q)

x = dpt.asarray(0, dtype=data_dt, sycl_queue=q)
ind = dpt.arange(5, dtype=np.intp, sycl_queue=q)
ind = dpt.arange(5, dtype="i8", sycl_queue=q)
val = dpt.asarray(2, dtype=data_dt, sycl_queue=q)

dpt.put(x, ind, val, axis=0)
Expand Down Expand Up @@ -710,7 +710,7 @@ def test_take_strided_1d_source(data_dt):
skip_if_dtype_not_supported(data_dt, q)

x = dpt.arange(27, dtype=data_dt, sycl_queue=q)
ind = dpt.arange(4, 9, dtype=np.intp, sycl_queue=q)
ind = dpt.arange(4, 9, dtype="i8", sycl_queue=q)

x_np = dpt.asnumpy(x)
ind_np = dpt.asnumpy(ind)
Expand Down Expand Up @@ -748,7 +748,7 @@ def test_take_strided(data_dt, order):
skip_if_dtype_not_supported(data_dt, q)

x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order)
ind = dpt.arange(2, dtype=np.intp, sycl_queue=q)
ind = dpt.arange(2, dtype="i8", sycl_queue=q)

x_np = dpt.asnumpy(x)
ind_np = dpt.asnumpy(ind)
Expand Down Expand Up @@ -781,7 +781,7 @@ def test_take_strided_1d_indices(ind_dt):
ind = dpt.arange(12, 24, dtype=ind_dt, sycl_queue=q)

x_np = dpt.asnumpy(x)
ind_np = dpt.asnumpy(ind).astype(np.intp)
ind_np = dpt.asnumpy(ind).astype("i8")

for s in (
slice(None, None, 2),
Expand Down Expand Up @@ -820,7 +820,7 @@ def test_take_strided_indices(ind_dt, order):
)

x_np = dpt.asnumpy(x)
ind_np = dpt.asnumpy(ind).astype(np.intp)
ind_np = dpt.asnumpy(ind).astype("i8")

for s in (
slice(None, None, 2),
Expand All @@ -845,7 +845,7 @@ def test_put_strided_1d_destination(data_dt, order):
skip_if_dtype_not_supported(data_dt, q)

x = dpt.arange(27, dtype=data_dt, sycl_queue=q)
ind = dpt.arange(4, 9, dtype=np.intp, sycl_queue=q)
ind = dpt.arange(4, 9, dtype="i8", sycl_queue=q)
val = dpt.asarray(9, dtype=x.dtype, sycl_queue=q)

x_np = dpt.asnumpy(x)
Expand Down Expand Up @@ -875,7 +875,7 @@ def test_put_strided_destination(data_dt, order):
skip_if_dtype_not_supported(data_dt, q)

x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order)
ind = dpt.arange(2, dtype=np.intp, sycl_queue=q)
ind = dpt.arange(2, dtype="i8", sycl_queue=q)
val = dpt.asarray(9, dtype=x.dtype, sycl_queue=q)

x_np = dpt.asnumpy(x)
Expand Down Expand Up @@ -924,7 +924,7 @@ def test_put_strided_1d_indices(ind_dt):
val = dpt.asarray(-1, dtype=x.dtype, sycl_queue=q)

x_np = dpt.asnumpy(x)
ind_np = dpt.asnumpy(ind).astype(np.intp)
ind_np = dpt.asnumpy(ind).astype("i8")
val_np = dpt.asnumpy(val)

for s in (
Expand Down Expand Up @@ -955,7 +955,7 @@ def test_put_strided_indices(ind_dt, order):
val = dpt.asarray(-1, sycl_queue=q, dtype=x.dtype)

x_np = dpt.asnumpy(x)
ind_np = dpt.asnumpy(ind).astype(np.intp)
ind_np = dpt.asnumpy(ind).astype("i8")
val_np = dpt.asnumpy(val)

for s in (
Expand All @@ -982,15 +982,15 @@ def test_integer_indexing_modes():
x_np = dpt.asnumpy(x)

# wrapping negative indices
ind = dpt.asarray([-4, -3, 0, 2, 4], dtype=np.intp, sycl_queue=q)
ind = dpt.asarray([-4, -3, 0, 2, 4], dtype="i8", sycl_queue=q)

res = dpt.take(x, ind, mode="wrap")
expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="raise")

assert (dpt.asnumpy(res) == expected_arr).all()

# clipping to 0 (disabling negative indices)
ind = dpt.asarray([-6, -3, 0, 2, 6], dtype=np.intp, sycl_queue=q)
ind = dpt.asarray([-6, -3, 0, 2, 6], dtype="i8", sycl_queue=q)

res = dpt.take(x, ind, mode="clip")
expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="clip")
Expand All @@ -1002,7 +1002,7 @@ def test_take_arg_validation():
q = get_queue_or_skip()

x = dpt.arange(4, dtype="i4", sycl_queue=q)
ind0 = dpt.arange(4, dtype=np.intp, sycl_queue=q)
ind0 = dpt.arange(4, dtype="i8", sycl_queue=q)
ind1 = dpt.arange(2.0, dtype="f", sycl_queue=q)

with pytest.raises(TypeError):
Expand Down Expand Up @@ -1034,7 +1034,7 @@ def test_put_arg_validation():
q = get_queue_or_skip()

x = dpt.arange(4, dtype="i4", sycl_queue=q)
ind0 = dpt.arange(4, dtype=np.intp, sycl_queue=q)
ind0 = dpt.arange(4, dtype="i8", sycl_queue=q)
ind1 = dpt.arange(2.0, dtype="f", sycl_queue=q)
val = dpt.asarray(2, dtype=x.dtype, sycl_queue=q)

Expand Down Expand Up @@ -1890,3 +1890,69 @@ def test_put_along_axis_uint64_indices():
dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype), axis=1)
expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), (2, 5))
assert dpt.all(expected == x)


@pytest.mark.parametrize("data_dt", _all_dtypes)
@pytest.mark.parametrize("order", ["C", "F"])
def test_take_out(data_dt, order):
q = get_queue_or_skip()
skip_if_dtype_not_supported(data_dt, q)

axis = 0
x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order)
ind = dpt.arange(2, dtype="i8", sycl_queue=q)
out_sh = x.shape[:axis] + ind.shape + x.shape[axis + 1 :]
out = dpt.empty(out_sh, dtype=data_dt, sycl_queue=q)

expected = dpt.take(x, ind, axis=axis)

dpt.take(x, ind, axis=axis, out=out)

assert dpt.all(out == expected)


@pytest.mark.parametrize("data_dt", _all_dtypes)
@pytest.mark.parametrize("order", ["C", "F"])
def test_take_out_overlap(data_dt, order):
q = get_queue_or_skip()
skip_if_dtype_not_supported(data_dt, q)

axis = 0
x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order)
ind = dpt.arange(2, dtype="i8", sycl_queue=q)
out = x[x.shape[axis] - ind.shape[axis] : x.shape[axis], :]

expected = dpt.take(x, ind, axis=axis)

dpt.take(x, ind, axis=axis, out=out)

assert dpt.all(out == expected)
assert dpt.all(x[x.shape[0] - ind.shape[0] : x.shape[0], :] == out)


def test_take_out_errors():
q1 = get_queue_or_skip()
q2 = get_queue_or_skip()

x = dpt.arange(10, dtype="i4", sycl_queue=q1)
ind = dpt.arange(2, dtype="i4", sycl_queue=q1)

with pytest.raises(TypeError):
dpt.take(x, ind, out=dict())

out_read_only = dpt.empty(ind.shape, dtype=x.dtype, sycl_queue=q1)
out_read_only.flags["W"] = False
with pytest.raises(ValueError):
dpt.take(x, ind, out=out_read_only)

out_bad_shape = dpt.empty(0, dtype=x.dtype, sycl_queue=q1)
with pytest.raises(ValueError):
dpt.take(x, ind, out=out_bad_shape)

out_bad_dt = dpt.empty(ind.shape, dtype="i8", sycl_queue=q1)
with pytest.raises(ValueError):
dpt.take(x, ind, out=out_bad_dt)

out_bad_q = dpt.empty(ind.shape, dtype=x.dtype, sycl_queue=q2)
with pytest.raises(dpctl.utils.ExecutionPlacementError):
dpt.take(x, ind, out=out_bad_q)
Loading