From 3048f3e577ed6ae64ee82d247a5e9a72d6ecb3ae Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 1 Feb 2022 05:56:38 -0600 Subject: [PATCH 1/2] Fixes #729 The generated reshaped_strides routine is meant for non-empty arrays, and so empty ones must be handled differently. Tests added as well. --- dpctl/tensor/_reshape.py | 5 ++++- dpctl/tests/test_usm_ndarray_ctor.py | 26 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/dpctl/tensor/_reshape.py b/dpctl/tensor/_reshape.py index 8e04f1b20c..ffa60c3652 100644 --- a/dpctl/tensor/_reshape.py +++ b/dpctl/tensor/_reshape.py @@ -104,7 +104,10 @@ def reshape(X, newshape, order="C"): newshape = [v if d == -1 else d for d in newshape] if X.size != np.prod(newshape): raise ValueError("Can not reshape into {}".format(newshape)) - newsts = reshaped_strides(X.shape, X.strides, newshape, order=order) + if X.size: + newsts = reshaped_strides(X.shape, X.strides, newshape, order=order) + else: + newsts = (1,) * len(newshape) if newsts is None: # must perform a copy flat_res = dpt.usm_ndarray( diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 410c3128e8..ca2dac41f8 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -816,6 +816,32 @@ def test_reshape(): Y = dpt.reshape(X, X.shape) assert Y.flags == X.flags + A = dpt.usm_ndarray((0,), "i4") + A1 = dpt.reshape(A, (0,)) + assert A1.shape == (0,) + A2 = dpt.reshape( + A, + ( + 2, + 0, + ), + ) + assert A2.shape == ( + 2, + 0, + ) + A3 = dpt.reshape(A, (0, 2)) + assert A3.shape == ( + 0, + 2, + ) + A4 = dpt.reshape(A, (1, 0, 2)) + assert A4.shape == ( + 1, + 0, + 2, + ) + def test_transpose(): n, m = 2, 3 From 479ed600e9c5cbcbd8fc74132d70d3025767382f Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 1 Feb 2022 06:26:50 -0600 Subject: [PATCH 2/2] reshaped_strides is also called from shape setter Special case setting shape for zero-element arrays --- dpctl/tensor/_usmarray.pyx | 17 +++++++++++------ dpctl/tests/test_usm_ndarray_ctor.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 959164dd3c..3d6d16c5b5 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -428,6 +428,7 @@ cdef class usm_ndarray: cdef int contig_flag = 0 cdef Py_ssize_t *shape_ptr = NULL cdef Py_ssize_t *strides_ptr = NULL + cdef Py_ssize_t size = -1 import operator from ._reshape import reshaped_strides @@ -439,15 +440,19 @@ cdef class usm_ndarray: raise TypeError( "Target shape must be a finite iterable of integers" ) - if not np.prod(new_shape) == shape_to_elem_count(self.nd_, self.shape_): + size = shape_to_elem_count(self.nd_, self.shape_) + if not np.prod(new_shape) == size: raise TypeError( f"Can not reshape array of size {self.size} into {new_shape}" ) - new_strides = reshaped_strides( - self.shape, - self.strides, - new_shape - ) + if size > 0: + new_strides = reshaped_strides( + self.shape, + self.strides, + new_shape + ) + else: + new_strides = (1,) * len(new_shape) if new_strides is None: raise AttributeError( "Incompatible shape for in-place modification. " diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index ca2dac41f8..e8ee79a391 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -713,6 +713,21 @@ def relaxed_strides_equal(st1, st2, sh): X = dpt.usm_ndarray((4, 4), dtype="d")[::2, ::2] with pytest.raises(AttributeError): X.shape = (4,) + X = dpt.usm_ndarray((0,), dtype="i4") + X.shape = (0,) + X.shape = ( + 2, + 0, + ) + X.shape = ( + 0, + 2, + ) + X.shape = ( + 1, + 0, + 1, + ) def test_len():