Skip to content

Commit b59c43e

Browse files
committed
Implements put_along_axis
Also makes minor tweaks to `take_along_axis`
1 parent a077f42 commit b59c43e

File tree

3 files changed

+103
-15
lines changed

3 files changed

+103
-15
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
nonzero,
6666
place,
6767
put,
68+
put_along_axis,
6869
take,
6970
take_along_axis,
7071
)
@@ -384,4 +385,5 @@
384385
"diff",
385386
"count_nonzero",
386387
"take_along_axis",
388+
"put_along_axis",
387389
]

dpctl/tensor/_copy_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -938,13 +938,18 @@ def _place_impl(ary, ary_mask, vals, axis=0):
938938
return
939939

940940

941-
def _put_multi_index(ary, inds, p, vals):
941+
def _put_multi_index(ary, inds, p, vals, mode=0):
942942
if not isinstance(ary, dpt.usm_ndarray):
943943
raise TypeError(
944944
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
945945
)
946946
ary_nd = ary.ndim
947947
p = normalize_axis_index(operator.index(p), ary_nd)
948+
mode = operator.index(mode)
949+
if mode not in [0, 1]:
950+
raise ValueError(
951+
"Invalid value for mode keyword, only 0 or 1 is supported"
952+
)
948953
if isinstance(vals, dpt.usm_ndarray):
949954
queues_ = [ary.sycl_queue, vals.sycl_queue]
950955
usm_types_ = [ary.usm_type, vals.usm_type]
@@ -1018,7 +1023,7 @@ def _put_multi_index(ary, inds, p, vals):
10181023
ind=inds,
10191024
val=rhs,
10201025
axis_start=p,
1021-
mode=0,
1026+
mode=mode,
10221027
sycl_queue=exec_q,
10231028
depends=dep_ev,
10241029
)

dpctl/tensor/_indexing_functions.py

Lines changed: 94 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
import dpctl.tensor._tensor_impl as ti
2222
import dpctl.utils
2323

24-
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
24+
from ._copy_utils import (
25+
_extract_impl,
26+
_nonzero_impl,
27+
_put_multi_index,
28+
_take_multi_index,
29+
)
2530
from ._numpy_helper import normalize_axis_index
2631

2732

@@ -206,22 +211,18 @@ def put_vec_duplicates(vec, ind, vals):
206211
raise TypeError(
207212
"Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x))
208213
)
209-
if isinstance(vals, dpt.usm_ndarray):
210-
queues_ = [x.sycl_queue, vals.sycl_queue]
211-
usm_types_ = [x.usm_type, vals.usm_type]
212-
else:
213-
queues_ = [
214-
x.sycl_queue,
215-
]
216-
usm_types_ = [
217-
x.usm_type,
218-
]
219214
if not isinstance(indices, dpt.usm_ndarray):
220215
raise TypeError(
221216
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
222217
type(indices)
223218
)
224219
)
220+
if isinstance(vals, dpt.usm_ndarray):
221+
queues_ = [x.sycl_queue, indices.sycl_queue, vals.sycl_queue]
222+
usm_types_ = [x.usm_type, indices.usm_type, vals.usm_type]
223+
else:
224+
queues_ = [x.sycl_queue, indices.sycl_queue]
225+
usm_types_ = [x.usm_type, indices.usm_type]
225226
if indices.ndim != 1:
226227
raise ValueError(
227228
"`indices` expected a 1D array, got `{}`".format(indices.ndim)
@@ -232,7 +233,6 @@ def put_vec_duplicates(vec, ind, vals):
232233
indices.dtype
233234
)
234235
)
235-
queues_.append(indices.sycl_queue)
236236
usm_types_.append(indices.usm_type)
237237
exec_q = dpctl.utils.get_execution_queue(queues_)
238238
if exec_q is None:
@@ -491,8 +491,12 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
491491
"Execution placement can not be unambiguously inferred "
492492
"from input arguments. "
493493
)
494+
indexes_dt = indices.dtype
495+
if indexes_dt.kind not in "ui":
496+
raise IndexError(
497+
"`indices` expected integer data type, got `{}`".format(indexes_dt)
498+
)
494499
mode_i = _get_indexing_mode(mode)
495-
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
496500
_ind = tuple(
497501
(
498502
indices
@@ -502,3 +506,80 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
502506
for i in range(x_nd)
503507
)
504508
return _take_multi_index(x, _ind, 0, mode=mode_i)
509+
510+
511+
def put_along_axis(x, indices, vals, /, *, axis=-1, mode="wrap"):
512+
"""
513+
Returns elements from an array at the one-dimensional indices specified
514+
by ``indices`` along a provided ``axis``.
515+
516+
Args:
517+
x (usm_ndarray):
518+
input array. Must be compatible with ``indices``, except for the
519+
axis (dimension) specified by ``axis``.
520+
indices (usm_ndarray):
521+
array indices. Must have the same rank (i.e., number of dimensions)
522+
as ``x``.
523+
axis: int
524+
axis along which to select values. If ``axis`` is negative, the
525+
function determines the axis along which to select values by
526+
counting from the last dimension. Default: ``-1``.
527+
mode (str, optional):
528+
How out-of-bounds indices will be handled. Possible values
529+
are:
530+
531+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
532+
negative indices.
533+
- ``"clip"``: clips indices to (``0 <= i < n``).
534+
535+
Default: ``"wrap"``.
536+
537+
Returns:
538+
usm_ndarray:
539+
an array having the same data type as ``x``. The returned array has
540+
the same rank (i.e., number of dimensions) as ``x`` and a shape
541+
determined according to :ref:`broadcasting`, except for the axis
542+
(dimension) specified by ``axis`` whose size must equal the size
543+
of the corresponding axis (dimension) in ``indices``.
544+
545+
Note:
546+
Treatment of the out-of-bound indices in ``indices`` array is controlled
547+
by the value of ``mode`` keyword.
548+
"""
549+
if not isinstance(x, dpt.usm_ndarray):
550+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
551+
if not isinstance(indices, dpt.usm_ndarray):
552+
raise TypeError(
553+
f"Expected dpctl.tensor.usm_ndarray, got {type(indices)}"
554+
)
555+
x_nd = x.ndim
556+
if x_nd != indices.ndim:
557+
raise ValueError(
558+
"Number of dimensions in the first and the second "
559+
"argument arrays must be equal"
560+
)
561+
pp = normalize_axis_index(operator.index(axis), x_nd)
562+
if isinstance(vals, dpt.usm_ndarray):
563+
queues_ = [x.sycl_queue, indices.sycl_queue, vals.sycl_queue]
564+
usm_types_ = [x.usm_type, indices.usm_type, vals.usm_type]
565+
else:
566+
queues_ = [x.sycl_queue, indices.sycl_queue]
567+
usm_types_ = [x.usm_type, indices.usm_type]
568+
exec_q = dpctl.utils.get_execution_queue(queues_)
569+
if exec_q is None:
570+
raise dpctl.utils.ExecutionPlacementError(
571+
"Execution placement can not be unambiguously inferred "
572+
"from input arguments. "
573+
)
574+
out_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
575+
mode_i = _get_indexing_mode(mode)
576+
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
577+
_ind = tuple(
578+
(
579+
indices
580+
if i == pp
581+
else _range(x.shape[i], i, x_nd, exec_q, out_usm_type, indexes_dt)
582+
)
583+
for i in range(x_nd)
584+
)
585+
return _put_multi_index(x, _ind, 0, vals, mode=mode_i)

0 commit comments

Comments
 (0)