21
21
import dpctl .tensor ._tensor_impl as ti
22
22
import dpctl .utils
23
23
24
- from ._copy_utils import _extract_impl , _nonzero_impl , _take_multi_index
24
+ from ._copy_utils import (
25
+ _extract_impl ,
26
+ _nonzero_impl ,
27
+ _put_multi_index ,
28
+ _take_multi_index ,
29
+ )
25
30
from ._numpy_helper import normalize_axis_index
26
31
27
32
@@ -206,22 +211,18 @@ def put_vec_duplicates(vec, ind, vals):
206
211
raise TypeError (
207
212
"Expected instance of `dpt.usm_ndarray`, got `{}`." .format (type (x ))
208
213
)
209
- if isinstance (vals , dpt .usm_ndarray ):
210
- queues_ = [x .sycl_queue , vals .sycl_queue ]
211
- usm_types_ = [x .usm_type , vals .usm_type ]
212
- else :
213
- queues_ = [
214
- x .sycl_queue ,
215
- ]
216
- usm_types_ = [
217
- x .usm_type ,
218
- ]
219
214
if not isinstance (indices , dpt .usm_ndarray ):
220
215
raise TypeError (
221
216
"`indices` expected `dpt.usm_ndarray`, got `{}`." .format (
222
217
type (indices )
223
218
)
224
219
)
220
+ if isinstance (vals , dpt .usm_ndarray ):
221
+ queues_ = [x .sycl_queue , indices .sycl_queue , vals .sycl_queue ]
222
+ usm_types_ = [x .usm_type , indices .usm_type , vals .usm_type ]
223
+ else :
224
+ queues_ = [x .sycl_queue , indices .sycl_queue ]
225
+ usm_types_ = [x .usm_type , indices .usm_type ]
225
226
if indices .ndim != 1 :
226
227
raise ValueError (
227
228
"`indices` expected a 1D array, got `{}`" .format (indices .ndim )
@@ -232,7 +233,6 @@ def put_vec_duplicates(vec, ind, vals):
232
233
indices .dtype
233
234
)
234
235
)
235
- queues_ .append (indices .sycl_queue )
236
236
usm_types_ .append (indices .usm_type )
237
237
exec_q = dpctl .utils .get_execution_queue (queues_ )
238
238
if exec_q is None :
@@ -491,8 +491,12 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
491
491
"Execution placement can not be unambiguously inferred "
492
492
"from input arguments. "
493
493
)
494
+ indexes_dt = indices .dtype
495
+ if indexes_dt .kind not in "ui" :
496
+ raise IndexError (
497
+ "`indices` expected integer data type, got `{}`" .format (indexes_dt )
498
+ )
494
499
mode_i = _get_indexing_mode (mode )
495
- indexes_dt = ti .default_device_index_type (exec_q .sycl_device )
496
500
_ind = tuple (
497
501
(
498
502
indices
@@ -502,3 +506,80 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
502
506
for i in range (x_nd )
503
507
)
504
508
return _take_multi_index (x , _ind , 0 , mode = mode_i )
509
+
510
+
511
+ def put_along_axis (x , indices , vals , / , * , axis = - 1 , mode = "wrap" ):
512
+ """
513
+ Returns elements from an array at the one-dimensional indices specified
514
+ by ``indices`` along a provided ``axis``.
515
+
516
+ Args:
517
+ x (usm_ndarray):
518
+ input array. Must be compatible with ``indices``, except for the
519
+ axis (dimension) specified by ``axis``.
520
+ indices (usm_ndarray):
521
+ array indices. Must have the same rank (i.e., number of dimensions)
522
+ as ``x``.
523
+ axis: int
524
+ axis along which to select values. If ``axis`` is negative, the
525
+ function determines the axis along which to select values by
526
+ counting from the last dimension. Default: ``-1``.
527
+ mode (str, optional):
528
+ How out-of-bounds indices will be handled. Possible values
529
+ are:
530
+
531
+ - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
532
+ negative indices.
533
+ - ``"clip"``: clips indices to (``0 <= i < n``).
534
+
535
+ Default: ``"wrap"``.
536
+
537
+ Returns:
538
+ usm_ndarray:
539
+ an array having the same data type as ``x``. The returned array has
540
+ the same rank (i.e., number of dimensions) as ``x`` and a shape
541
+ determined according to :ref:`broadcasting`, except for the axis
542
+ (dimension) specified by ``axis`` whose size must equal the size
543
+ of the corresponding axis (dimension) in ``indices``.
544
+
545
+ Note:
546
+ Treatment of the out-of-bound indices in ``indices`` array is controlled
547
+ by the value of ``mode`` keyword.
548
+ """
549
+ if not isinstance (x , dpt .usm_ndarray ):
550
+ raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
551
+ if not isinstance (indices , dpt .usm_ndarray ):
552
+ raise TypeError (
553
+ f"Expected dpctl.tensor.usm_ndarray, got { type (indices )} "
554
+ )
555
+ x_nd = x .ndim
556
+ if x_nd != indices .ndim :
557
+ raise ValueError (
558
+ "Number of dimensions in the first and the second "
559
+ "argument arrays must be equal"
560
+ )
561
+ pp = normalize_axis_index (operator .index (axis ), x_nd )
562
+ if isinstance (vals , dpt .usm_ndarray ):
563
+ queues_ = [x .sycl_queue , indices .sycl_queue , vals .sycl_queue ]
564
+ usm_types_ = [x .usm_type , indices .usm_type , vals .usm_type ]
565
+ else :
566
+ queues_ = [x .sycl_queue , indices .sycl_queue ]
567
+ usm_types_ = [x .usm_type , indices .usm_type ]
568
+ exec_q = dpctl .utils .get_execution_queue (queues_ )
569
+ if exec_q is None :
570
+ raise dpctl .utils .ExecutionPlacementError (
571
+ "Execution placement can not be unambiguously inferred "
572
+ "from input arguments. "
573
+ )
574
+ out_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
575
+ mode_i = _get_indexing_mode (mode )
576
+ indexes_dt = ti .default_device_index_type (exec_q .sycl_device )
577
+ _ind = tuple (
578
+ (
579
+ indices
580
+ if i == pp
581
+ else _range (x .shape [i ], i , x_nd , exec_q , out_usm_type , indexes_dt )
582
+ )
583
+ for i in range (x_nd )
584
+ )
585
+ return _put_multi_index (x , _ind , 0 , vals , mode = mode_i )
0 commit comments