From 3e515c6b606630a9fa1c6b86ae7e3c53d1381bbe Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 12 Aug 2024 14:59:36 -0700 Subject: [PATCH 1/6] Implements `put_along_axis` --- dpctl/tensor/__init__.py | 2 + dpctl/tensor/_copy_utils.py | 9 ++- dpctl/tensor/_indexing_functions.py | 102 ++++++++++++++++++++++++---- 3 files changed, 99 insertions(+), 14 deletions(-) diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 579b56d3a3..81f4c5801f 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -65,6 +65,7 @@ nonzero, place, put, + put_along_axis, take, take_along_axis, ) @@ -384,4 +385,5 @@ "diff", "count_nonzero", "take_along_axis", + "put_along_axis", ] diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index dc5e7268a4..e2f1bccac0 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -938,13 +938,18 @@ def _place_impl(ary, ary_mask, vals, axis=0): return -def _put_multi_index(ary, inds, p, vals): +def _put_multi_index(ary, inds, p, vals, mode=0): if not isinstance(ary, dpt.usm_ndarray): raise TypeError( f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}" ) ary_nd = ary.ndim p = normalize_axis_index(operator.index(p), ary_nd) + mode = operator.index(mode) + if mode not in [0, 1]: + raise ValueError( + "Invalid value for mode keyword, only 0 or 1 is supported" + ) if isinstance(vals, dpt.usm_ndarray): queues_ = [ary.sycl_queue, vals.sycl_queue] usm_types_ = [ary.usm_type, vals.usm_type] @@ -1018,7 +1023,7 @@ def _put_multi_index(ary, inds, p, vals): ind=inds, val=rhs, axis_start=p, - mode=0, + mode=mode, sycl_queue=exec_q, depends=dep_ev, ) diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index a0ac2bb6cb..511a5c57b9 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -21,7 +21,12 @@ import dpctl.tensor._tensor_impl as ti import dpctl.utils -from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index +from ._copy_utils import ( + _extract_impl, + _nonzero_impl, + _put_multi_index, + _take_multi_index, +) from ._numpy_helper import normalize_axis_index @@ -206,22 +211,18 @@ def put_vec_duplicates(vec, ind, vals): raise TypeError( "Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x)) ) - if isinstance(vals, dpt.usm_ndarray): - queues_ = [x.sycl_queue, vals.sycl_queue] - usm_types_ = [x.usm_type, vals.usm_type] - else: - queues_ = [ - x.sycl_queue, - ] - usm_types_ = [ - x.usm_type, - ] if not isinstance(indices, dpt.usm_ndarray): raise TypeError( "`indices` expected `dpt.usm_ndarray`, got `{}`.".format( type(indices) ) ) + if isinstance(vals, dpt.usm_ndarray): + queues_ = [x.sycl_queue, indices.sycl_queue, vals.sycl_queue] + usm_types_ = [x.usm_type, indices.usm_type, vals.usm_type] + else: + queues_ = [x.sycl_queue, indices.sycl_queue] + usm_types_ = [x.usm_type, indices.usm_type] if indices.ndim != 1: raise ValueError( "`indices` expected a 1D array, got `{}`".format(indices.ndim) @@ -232,7 +233,6 @@ def put_vec_duplicates(vec, ind, vals): indices.dtype ) ) - queues_.append(indices.sycl_queue) usm_types_.append(indices.usm_type) exec_q = dpctl.utils.get_execution_queue(queues_) if exec_q is None: @@ -502,3 +502,81 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"): for i in range(x_nd) ) return _take_multi_index(x, _ind, 0, mode=mode_i) + + +def put_along_axis(x, indices, vals, /, *, axis=-1, mode="wrap"): + """ + Puts elements into an array at the one-dimensional indices specified by + ``indices`` along a provided ``axis``. + + Args: + x (usm_ndarray): + input array. Must be compatible with ``indices``, except for the + axis (dimension) specified by ``axis``. + indices (usm_ndarray): + array indices. Must have the same rank (i.e., number of dimensions) + as ``x``. + vals (usm_ndarray): + Array of values to be put into ``x``. + Must be broadcastable to the shape of ``indices``. + axis: int + axis along which to select values. If ``axis`` is negative, the + function determines the axis along which to select values by + counting from the last dimension. Default: ``-1``. + mode (str, optional): + How out-of-bounds indices will be handled. Possible values + are: + + - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps + negative indices. + - ``"clip"``: clips indices to (``0 <= i < n``). + + Default: ``"wrap"``. + + .. note:: + + If input array ``indices`` contains duplicates, a race condition + occurs, and the value written into corresponding positions in ``x`` + may vary from run to run. Preserving sequential semantics in handing + the duplicates to achieve deterministic behavior requires additional + work. + + See :func:`dpctl.tensor.put` for an example. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + if not isinstance(indices, dpt.usm_ndarray): + raise TypeError( + f"Expected dpctl.tensor.usm_ndarray, got {type(indices)}" + ) + x_nd = x.ndim + if x_nd != indices.ndim: + raise ValueError( + "Number of dimensions in the first and the second " + "argument arrays must be equal" + ) + pp = normalize_axis_index(operator.index(axis), x_nd) + if isinstance(vals, dpt.usm_ndarray): + queues_ = [x.sycl_queue, indices.sycl_queue, vals.sycl_queue] + usm_types_ = [x.usm_type, indices.usm_type, vals.usm_type] + else: + queues_ = [x.sycl_queue, indices.sycl_queue] + usm_types_ = [x.usm_type, indices.usm_type] + exec_q = dpctl.utils.get_execution_queue(queues_) + if exec_q is None: + raise dpctl.utils.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments. " + ) + out_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_) + mode_i = _get_indexing_mode(mode) + indexes_dt = ti.default_device_index_type(exec_q.sycl_device) + _ind = tuple( + ( + indices + if i == pp + else _range(x.shape[i], i, x_nd, exec_q, out_usm_type, indexes_dt) + ) + for i in range(x_nd) + ) + return _put_multi_index(x, _ind, 0, vals, mode=mode_i) From 819a8650050c0934cd93d7261b894e08717db993 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 12 Aug 2024 15:10:21 -0700 Subject: [PATCH 2/6] Remove unnecessary append in `put` --- dpctl/tensor/_indexing_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index 511a5c57b9..4f5f480a96 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -233,7 +233,6 @@ def put_vec_duplicates(vec, ind, vals): indices.dtype ) ) - usm_types_.append(indices.usm_type) exec_q = dpctl.utils.get_execution_queue(queues_) if exec_q is None: raise dpctl.utils.ExecutionPlacementError From 4eb7f97ca1f3e93332d9737ad8fc9070453bd436 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 12 Aug 2024 15:51:35 -0700 Subject: [PATCH 3/6] Revise note on race conditions for `put_along_axis` Remove reference to example in `put` to avoid confusing users --- dpctl/tensor/_indexing_functions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index 4f5f480a96..e346db525d 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -539,8 +539,6 @@ def put_along_axis(x, indices, vals, /, *, axis=-1, mode="wrap"): may vary from run to run. Preserving sequential semantics in handing the duplicates to achieve deterministic behavior requires additional work. - - See :func:`dpctl.tensor.put` for an example. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") From 7adcf67b74c90a62c287afb6e101701b5f4a3fb4 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Fri, 16 Aug 2024 09:57:49 -0700 Subject: [PATCH 4/6] Add tests for put_along_axis --- dpctl/tests/test_usm_ndarray_indexing.py | 71 +++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index e11eaba9a7..76773fabf6 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1578,7 +1578,7 @@ def test_take_along_axis_validation(): def_dtypes = info_.default_dtypes(device=x_dev) ind_dt = def_dtypes["indexing"] ind = dpt.zeros(1, dtype=ind_dt) - # axis valudation + # axis validation with pytest.raises(ValueError): dpt.take_along_axis(x, ind, axis=1) # mode validation @@ -1594,6 +1594,71 @@ def test_take_along_axis_validation(): dpt.take_along_axis(x, ind2) +def test_put_along_axis(): + get_queue_or_skip() + + n0, n1, n2 = 3, 5, 7 + x = dpt.reshape(dpt.arange(n0 * n1 * n2), (n0, n1, n2)) + ind_dt = dpt.__array_namespace_info__().default_dtypes( + device=x.sycl_device + )["indexing"] + ind0 = dpt.ones((1, n1, n2), dtype=ind_dt) + ind1 = dpt.ones((n0, 1, n2), dtype=ind_dt) + ind2 = dpt.ones((n0, n1, 1), dtype=ind_dt) + + xc = dpt.copy(x) + vals = dpt.ones(ind0.shape, dtype=x.dtype) + dpt.put_along_axis(xc, ind0, vals, axis=0) + assert dpt.all(dpt.take_along_axis(xc, ind0, axis=0) == vals) + + xc = dpt.copy(x) + vals = dpt.ones(ind1.shape, dtype=x.dtype) + dpt.put_along_axis(xc, ind1, vals, axis=1) + assert dpt.all(dpt.take_along_axis(xc, ind1, axis=1) == vals) + + xc = dpt.copy(x) + vals = dpt.ones(ind2.shape, dtype=x.dtype) + dpt.put_along_axis(xc, ind2, vals, axis=2) + assert dpt.all(dpt.take_along_axis(xc, ind2, axis=2) == vals) + + xc = dpt.copy(x) + vals = dpt.ones(ind2.shape, dtype=x.dtype) + dpt.put_along_axis(xc, ind2, dpt.asnumpy(vals), axis=2) + assert dpt.all(dpt.take_along_axis(xc, ind2, axis=2) == vals) + + +def test_put_along_axis_validation(): + # type check on the first argument + with pytest.raises(TypeError): + dpt.put_along_axis(tuple(), list(), list()) + get_queue_or_skip() + n1, n2 = 2, 5 + x = dpt.ones(n1 * n2) + # type check on the second argument + with pytest.raises(TypeError): + dpt.put_along_axis(x, list(), list()) + x_dev = x.sycl_device + info_ = dpt.__array_namespace_info__() + def_dtypes = info_.default_dtypes(device=x_dev) + ind_dt = def_dtypes["indexing"] + ind = dpt.zeros(1, dtype=ind_dt) + vals = dpt.zeros(1, dtype=x.dtype) + # axis validation + with pytest.raises(ValueError): + dpt.put_along_axis(x, ind, vals, axis=1) + # mode validation + with pytest.raises(ValueError): + dpt.put_along_axis(x, ind, vals, axis=0, mode="invalid") + # same array-ranks validation + with pytest.raises(ValueError): + dpt.put_along_axis(dpt.reshape(x, (n1, n2)), ind, vals) + # check compute-follows-data + q2 = dpctl.SyclQueue(x_dev, property="enable_profiling") + ind2 = dpt.zeros(1, dtype=ind_dt, sycl_queue=q2) + with pytest.raises(ExecutionPlacementError): + dpt.put_along_axis(x, ind2, vals) + + def check__extract_impl_validation(fn): x = dpt.ones(10) ind = dpt.ones(10, dtype="?") @@ -1670,7 +1735,11 @@ def check__put_multi_index_validation(fn): with pytest.raises(ValueError): fn(x2, (ind1, ind2), 0, x2) with pytest.raises(TypeError): + # invalid index type fn(x2, (ind1, list()), 0, x2) + with pytest.raises(ValueError): + # invalid mode keyword value + fn(x, inds, 0, vals, mode=100) def test__copy_utils(): From 3b001e827c1771b9eb8c9815bd64da5d7a5fc905 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Fri, 16 Aug 2024 10:14:44 -0700 Subject: [PATCH 5/6] Add test based on application of put_along_axis Use put_along_axis to form 24 permutation matrices representing elements of S4 (group of permutations of 4 elements). Verify that every element raised to order 12 gives identity. --- dpctl/tests/test_usm_ndarray_indexing.py | 45 ++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 76773fabf6..71ba738f4a 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1659,6 +1659,51 @@ def test_put_along_axis_validation(): dpt.put_along_axis(x, ind2, vals) +def test_put_along_axis_application(): + get_queue_or_skip() + info_ = dpt.__array_namespace_info__() + def_dtypes = info_.default_dtypes(device=None) + ind_dt = def_dtypes["indexing"] + all_perms = dpt.asarray( + [ + [0, 1, 2, 3], + [0, 2, 1, 3], + [2, 0, 1, 3], + [2, 1, 0, 3], + [1, 0, 2, 3], + [1, 2, 0, 3], + [0, 1, 3, 2], + [0, 2, 3, 1], + [2, 0, 3, 1], + [2, 1, 3, 0], + [1, 0, 3, 2], + [1, 2, 3, 0], + [0, 3, 1, 2], + [0, 3, 2, 1], + [2, 3, 0, 1], + [2, 3, 1, 0], + [1, 3, 0, 2], + [1, 3, 2, 0], + [3, 0, 1, 2], + [3, 0, 2, 1], + [3, 2, 0, 1], + [3, 2, 1, 0], + [3, 1, 0, 2], + [3, 1, 2, 0], + ], + dtype=ind_dt, + ) + p_mats = dpt.zeros((24, 4, 4), dtype=dpt.int64) + vals = dpt.ones((24, 4, 1), dtype=p_mats.dtype) + # form 24 permutation matrices + dpt.put_along_axis(p_mats, all_perms[..., dpt.newaxis], vals, axis=2) + p2 = p_mats @ p_mats + p4 = p2 @ p2 + p8 = p4 @ p4 + expected = dpt.eye(4, dtype=p_mats.dtype)[dpt.newaxis, ...] + assert dpt.all(p8 @ p4 == expected) + + def check__extract_impl_validation(fn): x = dpt.ones(10) ind = dpt.ones(10, dtype="?") From f81c107be8fb2558a5dd599526d8845d10cc61be Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 16 Aug 2024 12:57:09 -0700 Subject: [PATCH 6/6] Adds `put_along_axis` to dpctl.tensor docs --- .../api_reference/dpctl/tensor.indexing_functions.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/doc_sources/api_reference/dpctl/tensor.indexing_functions.rst b/docs/doc_sources/api_reference/dpctl/tensor.indexing_functions.rst index 09287ba49f..8c752b7036 100644 --- a/docs/doc_sources/api_reference/dpctl/tensor.indexing_functions.rst +++ b/docs/doc_sources/api_reference/dpctl/tensor.indexing_functions.rst @@ -14,5 +14,6 @@ by either integral arrays of indices or boolean mask arrays. extract place put + put_along_axis take take_along_axis