From 0a7ea0c2674362b6be70b0094d1ef4a395d2cea9 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 3 Mar 2023 17:21:59 -0800 Subject: [PATCH] dpt.take and dpt.put changes - Improved conformity to array API standard - Added docstrings --- dpctl/tensor/_indexing_functions.py | 178 +++++++++++++---------- dpctl/tests/test_usm_ndarray_indexing.py | 28 ++-- 2 files changed, 117 insertions(+), 89 deletions(-) diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index 6f19dc3bd4..c312d9e2b9 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -27,43 +27,56 @@ def take(x, indices, /, *, axis=None, mode="clip"): + """take(x, indices, axis=None, mode="clip") + + Takes elements from array along a given axis. + + Args: + x: usm_ndarray + The array that elements will be taken from. + indices: usm_ndarray + One-dimensional array of indices. + axis: + The axis over which the values will be selected. + If x is one-dimensional, this argument is optional. + mode: + How out-of-bounds indices will be handled. + "Clip" - clamps indices to (-n <= i < n), then wraps + negative indices. + "Wrap" - wraps both negative and positive indices. + + Returns: + out: usm_ndarray + Array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:] + filled with elements . + """ if not isinstance(x, dpt.usm_ndarray): raise TypeError( "Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x)) ) - if not isinstance(indices, list) and not isinstance(indices, tuple): - indices = (indices,) - - queues_ = [ - x.sycl_queue, - ] - usm_types_ = [ - x.usm_type, - ] - - for i in indices: - if not isinstance(i, dpt.usm_ndarray): - raise TypeError( - "`indices` expected `dpt.usm_ndarray`, got `{}`.".format( - type(i) - ) + if not isinstance(indices, dpt.usm_ndarray): + raise TypeError( + "`indices` expected `dpt.usm_ndarray`, got `{}`.".format( + type(indices) ) - if not np.issubdtype(i.dtype, np.integer): - raise IndexError( - "`indices` expected integer data type, got `{}`".format(i.dtype) + ) + if not np.issubdtype(indices.dtype, np.integer): + raise IndexError( + "`indices` expected integer data type, got `{}`".format( + indices.dtype ) - queues_.append(i.sycl_queue) - usm_types_.append(i.usm_type) - exec_q = dpctl.utils.get_execution_queue(queues_) - if exec_q is None: - raise dpctl.utils.ExecutionPlacementError( - "Can not automatically determine where to allocate the " - "result or performance execution. " - "Use `usm_ndarray.to_device` method to migrate data to " - "be associated with the same queue." ) - res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_) + if indices.ndim != 1: + raise ValueError( + "`indices` expected a 1D array, got `{}`".format(indices.ndim) + ) + exec_q = dpctl.utils.get_execution_queue([x.sycl_queue, indices.sycl_queue]) + if exec_q is None: + raise dpctl.utils.ExecutionPlacementError + res_usm_type = dpctl.utils.get_coerced_usm_type( + [x.usm_type, indices.usm_type] + ) modes = {"clip": 0, "wrap": 1} try: @@ -81,27 +94,47 @@ def take(x, indices, /, *, axis=None, mode="clip"): ) axis = 0 - if len(indices) > 1: - indices = dpt.broadcast_arrays(*indices) if x_ndim > 0: axis = normalize_axis_index(operator.index(axis), x_ndim) - res_shape = ( - x.shape[:axis] + indices[0].shape + x.shape[axis + len(indices) :] - ) + res_shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :] else: - res_shape = indices[0].shape + if axis != 0: + raise ValueError("`axis` must be 0 for an array of dimension 0.") + res_shape = indices.shape res = dpt.empty( res_shape, dtype=x.dtype, usm_type=res_usm_type, sycl_queue=exec_q ) - hev, _ = ti._take(x, indices, res, axis, mode, sycl_queue=exec_q) + hev, _ = ti._take(x, (indices,), res, axis, mode, sycl_queue=exec_q) hev.wait() return res def put(x, indices, vals, /, *, axis=None, mode="clip"): + """put(x, indices, vals, axis=None, mode="clip") + + Puts values of an array into another array + along a given axis. + + Args: + x: usm_ndarray + The array the values will be put into. + indices: usm_ndarray + One-dimensional array of indices. + vals: + Array of values to be put into `x`. + Must be broadcastable to the shape of `indices`. + axis: + The axis over which the values will be placed. + If x is one-dimensional, this argument is optional. + mode: + How out-of-bounds indices will be handled. + "Clip" - clamps indices to (-axis_size <= i < axis_size), + then wraps negative indices. + "Wrap" - wraps both negative and positive indices. + """ if not isinstance(x, dpt.usm_ndarray): raise TypeError( "Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x)) @@ -116,66 +149,61 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"): usm_types_ = [ x.usm_type, ] - - if not isinstance(indices, list) and not isinstance(indices, tuple): - indices = (indices,) - - for i in indices: - if not isinstance(i, dpt.usm_ndarray): - raise TypeError( - "`indices` expected `dpt.usm_ndarray`, got `{}`.".format( - type(i) - ) + if not isinstance(indices, dpt.usm_ndarray): + raise TypeError( + "`indices` expected `dpt.usm_ndarray`, got `{}`.".format( + type(indices) ) - if not np.issubdtype(i.dtype, np.integer): - raise IndexError( - "`indices` expected integer data type, got `{}`".format(i.dtype) + ) + if indices.ndim != 1: + raise ValueError( + "`indices` expected a 1D array, got `{}`".format(indices.ndim) + ) + if not np.issubdtype(indices.dtype, np.integer): + raise IndexError( + "`indices` expected integer data type, got `{}`".format( + indices.dtype ) - queues_.append(i.sycl_queue) - usm_types_.append(i.usm_type) + ) + queues_.append(indices.sycl_queue) + usm_types_.append(indices.usm_type) exec_q = dpctl.utils.get_execution_queue(queues_) if exec_q is None: - raise dpctl.utils.ExecutionPlacementError( - "Can not automatically determine where to allocate the " - "result or performance execution. " - "Use `usm_ndarray.to_device` method to migrate data to " - "be associated with the same queue." - ) - val_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_) - + raise dpctl.utils.ExecutionPlacementError + vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_) modes = {"clip": 0, "wrap": 1} try: mode = modes[mode] except KeyError: - raise ValueError("`mode` must be `wrap`, or `clip`.") + raise ValueError("`mode` must be `clip` or `wrap`.") - # when axis is none, array is treated as 1D - if axis is None: - try: - x = dpt.reshape(x, (x.size,), copy=False) - axis = 0 - except ValueError: - raise ValueError("Cannot create 1D view of input array") - if len(indices) > 1: - indices = dpt.broadcast_arrays(*indices) x_ndim = x.ndim + if axis is None: + if x_ndim > 1: + raise ValueError( + "`axis` cannot be `None` for array of dimension `{}`".format( + x_ndim + ) + ) + axis = 0 + if x_ndim > 0: axis = normalize_axis_index(operator.index(axis), x_ndim) - val_shape = ( - x.shape[:axis] + indices[0].shape + x.shape[axis + len(indices) :] - ) + val_shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :] else: - val_shape = indices[0].shape + if axis != 0: + raise ValueError("`axis` must be 0 for an array of dimension 0.") + val_shape = indices.shape if not isinstance(vals, dpt.usm_ndarray): vals = dpt.asarray( - vals, dtype=x.dtype, usm_type=val_usm_type, sycl_queue=exec_q + vals, dtype=x.dtype, usm_type=vals_usm_type, sycl_queue=exec_q ) vals = dpt.broadcast_to(vals, val_shape) - hev, _ = ti._put(x, indices, vals, axis, mode, sycl_queue=exec_q) + hev, _ = ti._put(x, (indices,), vals, axis, mode, sycl_queue=exec_q) hev.wait() diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index bcc1fdbb60..7201357c7d 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -542,11 +542,11 @@ def test_put_0d_val(data_dt): x = dpt.arange(5, dtype=data_dt, sycl_queue=q) ind = dpt.asarray([0], dtype=np.intp, sycl_queue=q) - x[ind] = 2 + val = dpt.asarray(2, dtype=x.dtype, sycl_queue=q) + x[ind] = val assert_array_equal(np.asarray(2, dtype=data_dt), dpt.asnumpy(x[0])) x = dpt.asarray(5, dtype=data_dt, sycl_queue=q) - val = 2 dpt.put(x, ind, val) assert_array_equal(np.asarray(2, dtype=data_dt), dpt.asnumpy(x)) @@ -592,13 +592,13 @@ def test_put_0d_data(data_dt): "ind_dt", _all_int_dtypes, ) -def test_take_0d_ind(ind_dt): +def test_indexing_0d_ind(ind_dt): q = get_queue_or_skip() x = dpt.arange(5, dtype="i4", sycl_queue=q) ind = dpt.asarray(3, dtype=ind_dt, sycl_queue=q) - y = dpt.take(x, ind) + y = x[ind] assert dpt.asnumpy(x[3]) == dpt.asnumpy(y) @@ -613,7 +613,7 @@ def test_put_0d_ind(ind_dt): ind = dpt.asarray(3, dtype=ind_dt, sycl_queue=q) val = dpt.asarray(5, dtype=x.dtype, sycl_queue=q) - dpt.put(x, ind, val, axis=0) + x[ind] = val assert dpt.asnumpy(x[3]) == dpt.asnumpy(val) @@ -684,10 +684,6 @@ def test_take_strided(data_dt, order): np.take(xs_np, ind_np, axis=1), dpt.asnumpy(dpt.take(xs, ind, axis=1)), ) - assert_array_equal( - xs_np[ind_np, ind_np], - dpt.asnumpy(dpt.take(xs, [ind, ind], axis=0)), - ) @pytest.mark.parametrize( @@ -751,7 +747,7 @@ def test_take_strided_indices(ind_dt, order): inds_np = ind_np[s, ::sgn] assert_array_equal( np.take(x_np, inds_np, axis=0), - dpt.asnumpy(dpt.take(x, inds, axis=0)), + dpt.asnumpy(x[inds]), ) @@ -828,7 +824,7 @@ def test_put_strided_destination(data_dt, order): x_np1[ind_np, ind_np] = val_np x1 = dpt.copy(xs) - dpt.put(x1, [ind, ind], val, axis=0) + x1[ind, ind] = val assert_array_equal(x_np1, dpt.asnumpy(x1)) @@ -887,7 +883,7 @@ def test_put_strided_indices(ind_dt, order): inds_np = ind_np[s, ::sgn] x_copy = dpt.copy(x) - dpt.put(x_copy, inds, val, axis=0) + x_copy[inds] = val x_np_copy = x_np.copy() x_np_copy[inds_np] = val_np @@ -899,7 +895,7 @@ def test_take_arg_validation(): q = get_queue_or_skip() x = dpt.arange(4, dtype="i4", sycl_queue=q) - ind0 = dpt.arange(2, dtype=np.intp, sycl_queue=q) + ind0 = dpt.arange(4, dtype=np.intp, sycl_queue=q) ind1 = dpt.arange(2.0, dtype="f", sycl_queue=q) with pytest.raises(TypeError): @@ -919,13 +915,15 @@ def test_take_arg_validation(): dpt.take(x, ind0, mode=0) with pytest.raises(ValueError): dpt.take(dpt.reshape(x, (2, 2)), ind0, axis=None) + with pytest.raises(ValueError): + dpt.take(x, dpt.reshape(ind0, (2, 2))) def test_put_arg_validation(): q = get_queue_or_skip() x = dpt.arange(4, dtype="i4", sycl_queue=q) - ind0 = dpt.arange(2, dtype=np.intp, sycl_queue=q) + ind0 = dpt.arange(4, dtype=np.intp, sycl_queue=q) ind1 = dpt.arange(2.0, dtype="f", sycl_queue=q) val = dpt.asarray(2, x.dtype, sycl_queue=q) @@ -946,6 +944,8 @@ def test_put_arg_validation(): with pytest.raises(ValueError): dpt.put(x, ind0, val, mode=0) + with pytest.raises(ValueError): + dpt.put(x, dpt.reshape(ind0, (2, 2)), val) def test_advanced_indexing_compute_follows_data():