From f0f23ac310eb291153e0e85689c4f455a09ef82c Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Sun, 13 Mar 2022 23:54:07 +0300 Subject: [PATCH 1/2] Add squeeze func --- dpctl/tensor/__init__.py | 7 ++++- dpctl/tensor/_manipulation_functions.py | 34 +++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index e57cdc1850..3e3cddda52 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -25,7 +25,11 @@ from dpctl.tensor._ctors import asarray, empty from dpctl.tensor._device import Device from dpctl.tensor._dlpack import from_dlpack -from dpctl.tensor._manipulation_functions import expand_dims, permute_dims +from dpctl.tensor._manipulation_functions import ( + expand_dims, + permute_dims, + squeeze, +) from dpctl.tensor._reshape import reshape from dpctl.tensor._usmarray import usm_ndarray @@ -39,6 +43,7 @@ "reshape", "permute_dims", "expand_dims", + "squeeze", "from_numpy", "to_numpy", "asnumpy", diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index a099ddc031..67b3bb1aca 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -68,3 +68,37 @@ def expand_dims(X, axes): shape = tuple(1 if ax in axes else next(shape_it) for ax in range(out_ndim)) return dpt.reshape(X, shape) + + +def squeeze(X, axes=None): + """ + squeeze(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray + + Removes singleton dimensions (axes) from X; returns a view, if possible, + a copy otherwise, but with all or a subset of the dimensions + of length 1 removed. + """ + if not isinstance(X, dpt.usm_ndarray): + raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") + X_shape = X.shape + if axes is not None: + if not isinstance(axes, (tuple, list)): + axes = (axes,) + axes = normalize_axis_tuple(axes, X.ndim if X.ndim != 0 else X.ndim + 1) + new_shape = [] + for i, x in enumerate(X_shape): + if i not in axes: + new_shape.append(x) + else: + if x != 1: + raise ValueError( + "Cannot select an axis to squeeze out " + "which has size not equal to one." + ) + new_shape = tuple(new_shape) + else: + new_shape = tuple(axis for axis in X_shape if axis != 1) + if new_shape == X.shape: + return X + else: + return dpt.reshape(X, new_shape) From 31a960ba6fa19aff6d65b48fb8c7d740c434d2d2 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Sun, 13 Mar 2022 23:55:58 +0300 Subject: [PATCH 2/2] Add tests for squeeze func --- dpctl/tests/test_usm_ndarray_manipulation.py | 99 ++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 37af01f9ef..cc63e3d9a4 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -166,3 +166,102 @@ def test_expand_dims_incorrect_tuple(): pytest.raises(np.AxisError, dpt.expand_dims, X, (0, 5)) pytest.raises(ValueError, dpt.expand_dims, X, (1, 1)) + + +def test_squeeze_incorrect_type(): + X_list = list([1, 2, 3, 4, 5]) + X_tuple = tuple(X_list) + Xnp = np.array(X_list) + + pytest.raises(TypeError, dpt.permute_dims, X_list, 1) + pytest.raises(TypeError, dpt.permute_dims, X_tuple, 1) + pytest.raises(TypeError, dpt.permute_dims, Xnp, 1) + + +def test_squeeze_0d(): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + Xnp = np.array(1) + X = dpt.asarray(Xnp, sycl_queue=q) + Y = dpt.squeeze(X) + Ynp = Xnp.squeeze() + assert_array_equal(Ynp, dpt.asnumpy(Y)) + + Y = dpt.squeeze(X, 0) + Ynp = Xnp.squeeze(0) + assert_array_equal(Ynp, dpt.asnumpy(Y)) + + Y = dpt.squeeze(X, (0)) + Ynp = Xnp.squeeze((0)) + assert_array_equal(Ynp, dpt.asnumpy(Y)) + + Y = dpt.squeeze(X, -1) + Ynp = Xnp.squeeze(-1) + assert_array_equal(Ynp, dpt.asnumpy(Y)) + + pytest.raises(np.AxisError, dpt.squeeze, X, 1) + pytest.raises(np.AxisError, dpt.squeeze, X, -2) + pytest.raises(np.AxisError, dpt.squeeze, X, (1)) + pytest.raises(np.AxisError, dpt.squeeze, X, (-2)) + pytest.raises(ValueError, dpt.squeeze, X, (0, 0)) + + +@pytest.mark.parametrize( + "shapes", + [ + (0), + (1), + (1, 2), + (2, 1), + (1, 1), + (2, 2), + (1, 0), + (0, 1), + (1, 2, 1), + (2, 1, 2), + (2, 2, 2), + (1, 1, 1), + (1, 0, 1), + (0, 1, 0), + ], +) +def test_squeeze_without_axes(shapes): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + Xnp = np.empty(shapes) + X = dpt.asarray(Xnp, sycl_queue=q) + Y = dpt.squeeze(X) + Ynp = Xnp.squeeze() + assert_array_equal(Ynp, dpt.asnumpy(Y)) + + +@pytest.mark.parametrize("axes", [0, 2, (0), (2), (0, 2)]) +def test_squeeze_axes_arg(axes): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + Xnp = np.array([[[1], [2], [3]]]) + X = dpt.asarray(Xnp, sycl_queue=q) + Y = dpt.squeeze(X, axes) + Ynp = Xnp.squeeze(axes) + assert_array_equal(Ynp, dpt.asnumpy(Y)) + + +@pytest.mark.parametrize("axes", [1, -2, (1), (-2), (0, 0), (1, 1)]) +def test_squeeze_axes_arg_error(axes): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + Xnp = np.array([[[1], [2], [3]]]) + X = dpt.asarray(Xnp, sycl_queue=q) + pytest.raises(ValueError, dpt.squeeze, X, axes)