diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 443d8184a2..5b0c00bbbe 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -563,6 +563,22 @@ cdef class usm_ndarray: @shape.setter def shape(self, new_shape): + """ + Modifies usm_ndarray instance in-place by changing its metadata + about the shape and the strides of the array, or raises + `AttributeError` exception if in-place change is not possible. + + Args: + new_shape: (tuple, int) + New shape. Only non-negative values are supported. + The new shape may not lead to the change in the + number of elements in the array. + + Whether the array can be reshape in-place depends on its + strides. Use :func:`dpctl.tensor.reshape` function which + always succeeds to reshape the array by performing a copy + if necessary. + """ cdef int new_nd = -1 cdef Py_ssize_t nelems = -1 cdef int err = 0 @@ -576,7 +592,11 @@ cdef class usm_ndarray: from ._reshape import reshaped_strides - new_nd = len(new_shape) + try: + new_nd = len(new_shape) + except TypeError: + new_nd = 1 + new_shape = (new_shape,) try: new_shape = tuple(operator.index(dim) for dim in new_shape) except TypeError: diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 6fac581830..da3c9013e2 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -16,6 +16,7 @@ import ctypes import numbers +from math import prod import numpy as np import pytest @@ -1102,7 +1103,7 @@ def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type): skip_if_dtype_not_supported(dtype, q) shape = (2, 4, 3) Xnp = ( - np.random.randint(-10, 10, size=np.prod(shape)) + np.random.randint(-10, 10, size=prod(shape)) .astype(dtype) .reshape(shape) ) @@ -1307,6 +1308,10 @@ def relaxed_strides_equal(st1, st2, sh): X = dpt.usm_ndarray(sh_s, dtype="?") X.shape = sh_f assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f) + sz = X.size + X.shape = sz + assert X.shape == (sz,) + assert relaxed_strides_equal(X.strides, (1,), (sz,)) X = dpt.usm_ndarray(sh_s, dtype="u4") with pytest.raises(TypeError): @@ -2077,11 +2082,9 @@ def test_tril(dtype): skip_if_dtype_not_supported(dtype, q) shape = (2, 3, 4, 5, 5) - X = dpt.reshape( - dpt.arange(np.prod(shape), dtype=dtype, sycl_queue=q), shape - ) + X = dpt.reshape(dpt.arange(prod(shape), dtype=dtype, sycl_queue=q), shape) Y = dpt.tril(X) - Xnp = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + Xnp = np.arange(prod(shape), dtype=dtype).reshape(shape) Ynp = np.tril(Xnp) assert Y.dtype == Ynp.dtype assert np.array_equal(Ynp, dpt.asnumpy(Y)) @@ -2093,11 +2096,9 @@ def test_triu(dtype): skip_if_dtype_not_supported(dtype, q) shape = (4, 5) - X = dpt.reshape( - dpt.arange(np.prod(shape), dtype=dtype, sycl_queue=q), shape - ) + X = dpt.reshape(dpt.arange(prod(shape), dtype=dtype, sycl_queue=q), shape) Y = dpt.triu(X, k=1) - Xnp = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + Xnp = np.arange(prod(shape), dtype=dtype).reshape(shape) Ynp = np.triu(Xnp, k=1) assert Y.dtype == Ynp.dtype assert np.array_equal(Ynp, dpt.asnumpy(Y)) @@ -2110,7 +2111,7 @@ def test_tri_usm_type(tri_fn, usm_type): dtype = dpt.uint16 shape = (2, 3, 4, 5, 5) - size = np.prod(shape) + size = prod(shape) X = dpt.reshape( dpt.arange(size, dtype=dtype, usm_type=usm_type, sycl_queue=q), shape ) @@ -2129,11 +2130,11 @@ def test_tril_slice(): q = get_queue_or_skip() shape = (6, 10) - X = dpt.reshape( - dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), shape - )[1:, ::-2] + X = dpt.reshape(dpt.arange(prod(shape), dtype="int", sycl_queue=q), shape)[ + 1:, ::-2 + ] Y = dpt.tril(X) - Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape)[1:, ::-2] + Xnp = np.arange(prod(shape), dtype="int").reshape(shape)[1:, ::-2] Ynp = np.tril(Xnp) assert Y.dtype == Ynp.dtype assert np.array_equal(Ynp, dpt.asnumpy(Y)) @@ -2144,14 +2145,12 @@ def test_triu_permute_dims(): shape = (2, 3, 4, 5) X = dpt.permute_dims( - dpt.reshape( - dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), shape - ), + dpt.reshape(dpt.arange(prod(shape), dtype="int", sycl_queue=q), shape), (3, 2, 1, 0), ) Y = dpt.triu(X) Xnp = np.transpose( - np.arange(np.prod(shape), dtype="int").reshape(shape), (3, 2, 1, 0) + np.arange(prod(shape), dtype="int").reshape(shape), (3, 2, 1, 0) ) Ynp = np.triu(Xnp) assert Y.dtype == Ynp.dtype @@ -2189,12 +2188,12 @@ def test_triu_order_k(order, k): shape = (3, 3) X = dpt.reshape( - dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), + dpt.arange(prod(shape), dtype="int", sycl_queue=q), shape, order=order, ) Y = dpt.triu(X, k=k) - Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order) + Xnp = np.arange(prod(shape), dtype="int").reshape(shape, order=order) Ynp = np.triu(Xnp, k=k) assert Y.dtype == Ynp.dtype assert X.flags == Y.flags @@ -2210,12 +2209,12 @@ def test_tril_order_k(order, k): pytest.skip("Queue could not be created") shape = (3, 3) X = dpt.reshape( - dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), + dpt.arange(prod(shape), dtype="int", sycl_queue=q), shape, order=order, ) Y = dpt.tril(X, k=k) - Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order) + Xnp = np.arange(prod(shape), dtype="int").reshape(shape, order=order) Ynp = np.tril(Xnp, k=k) assert Y.dtype == Ynp.dtype assert X.flags == Y.flags