From f8e84b83a6760f6552c1d44b377369e34dd5b587 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 9 Mar 2022 17:15:10 +0300 Subject: [PATCH 1/3] Add expand_dims func --- dpctl/tensor/__init__.py | 3 ++- dpctl/tensor/_manipulation_functions.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 418abbfe27..e57cdc1850 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -25,7 +25,7 @@ from dpctl.tensor._ctors import asarray, empty from dpctl.tensor._device import Device from dpctl.tensor._dlpack import from_dlpack -from dpctl.tensor._manipulation_functions import permute_dims +from dpctl.tensor._manipulation_functions import expand_dims, permute_dims from dpctl.tensor._reshape import reshape from dpctl.tensor._usmarray import usm_ndarray @@ -38,6 +38,7 @@ "empty", "reshape", "permute_dims", + "expand_dims", "from_numpy", "to_numpy", "asnumpy", diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 91effa01ce..a4c170a992 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -16,6 +16,7 @@ import numpy as np +from numpy.core.numeric import normalize_axis_tuple import dpctl.tensor as dpt @@ -61,3 +62,25 @@ def permute_dims(X, axes): strides=newstrides, offset=X.__sycl_usm_array_interface__.get("offset", 0), ) + + +def expand_dims(X, axes): + """ + expand_dims(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray + + Expands the shape of an array by inserting a new axis (dimension) + of size one at the position specified by axes; returns a view, if possible, + a copy otherwise with the number of dimensions increased. + """ + if not isinstance(X, dpt.usm_ndarray): + raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") + if not isinstance(axes, (tuple, list)): + axes = (axes,) + + out_ndim = len(axes) + X.ndim + axes = normalize_axis_tuple(axes, out_ndim) + + shape_it = iter(X.shape) + shape = [1 if ax in axes else next(shape_it) for ax in range(out_ndim)] + + return dpt.reshape(X, shape) From 7b514579d65f9911a9206f102c5d2e7851a190c7 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 9 Mar 2022 18:40:56 +0300 Subject: [PATCH 2/3] Add tests for expand_dims func --- dpctl/tests/test_usm_ndarray_manipulation.py | 76 ++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 39be876a7b..bcbe093b23 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -90,3 +90,79 @@ def test_permute_dims_2d_3d(shapes): Y = dpt.permute_dims(X, (2, 0, 1)) Ynp = np.transpose(Xnp, (2, 0, 1)) assert_array_equal(Ynp, dpt.asnumpy(Y)) + + +def test_expand_dims_incorrect_type(): + X_list = list([1, 2, 3, 4, 5]) + X_tuple = tuple(X_list) + Xnp = np.array(X_list) + + pytest.raises(TypeError, dpt.permute_dims, X_list, 1) + pytest.raises(TypeError, dpt.permute_dims, X_tuple, 1) + pytest.raises(TypeError, dpt.permute_dims, Xnp, 1) + + +def test_expand_dims_0d(): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + Xnp = np.array(1, dtype="int64") + X = dpt.asarray(Xnp, sycl_queue=q) + Y = dpt.expand_dims(X, 0) + Ynp = np.expand_dims(Xnp, 0) + assert_array_equal(Ynp, dpt.asnumpy(Y)) + + Y = dpt.expand_dims(X, -1) + Ynp = np.expand_dims(Xnp, -1) + assert_array_equal(Ynp, dpt.asnumpy(Y)) + + pytest.raises(np.AxisError, dpt.expand_dims, X, 1) + pytest.raises(np.AxisError, dpt.expand_dims, X, -2) + + +@pytest.mark.parametrize("shapes", [(3,), (3, 3), (3, 3, 3)]) +def test_expand_dims_1d_3d(shapes): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + Xnp_size = np.prod(shapes) + + Xnp = np.random.randint(0, 2, size=Xnp_size, dtype="int64").reshape(shapes) + X = dpt.asarray(Xnp, sycl_queue=q) + shape_len = len(shapes) + for axis in range(-shape_len - 1, shape_len): + Y = dpt.expand_dims(X, axis) + Ynp = np.expand_dims(Xnp, axis) + assert_array_equal(Ynp, dpt.asnumpy(Y)) + + pytest.raises(np.AxisError, dpt.expand_dims, X, shape_len + 1) + pytest.raises(np.AxisError, dpt.expand_dims, X, -shape_len - 2) + + +@pytest.mark.parametrize( + "axes", [(0, 1, 2), (0, -1, -2), (0, 3, 5), (0, -3, -5)] +) +def test_expand_dims_tuple(axes): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + Xnp = np.empty((3, 3, 3)) + X = dpt.asarray(Xnp, sycl_queue=q) + Y = dpt.expand_dims(X, axes) + Ynp = np.expand_dims(Xnp, axes) + assert_array_equal(Ynp, dpt.asnumpy(Y)) + + +def test_expand_dims_incorrect_tuple(): + + X = dpt.empty((3, 3, 3), dtype="i4") + pytest.raises(np.AxisError, dpt.expand_dims, X, (0, -6)) + pytest.raises(np.AxisError, dpt.expand_dims, X, (0, 5)) + + pytest.raises(ValueError, dpt.expand_dims, X, (1, 1)) From 6d84d64b741f3337b356c1da9f51675c7dfa44d6 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 9 Mar 2022 11:59:52 -0600 Subject: [PATCH 3/3] Update dpctl/tensor/_manipulation_functions.py --- dpctl/tensor/_manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index a4c170a992..e62ebce236 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -81,6 +81,6 @@ def expand_dims(X, axes): axes = normalize_axis_tuple(axes, out_ndim) shape_it = iter(X.shape) - shape = [1 if ax in axes else next(shape_it) for ax in range(out_ndim)] + shape = tuple(1 if ax in axes else next(shape_it) for ax in range(out_ndim)) return dpt.reshape(X, shape)