diff --git a/CHANGELOG.md b/CHANGELOG.md index f1677aa78f7c..a8d4adca875a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added `--target-cuda[=ARCH]` option to replace the deprecated `--target=cuda`, allowing users to build for CUDA devices with optional architecture selection using [CodePlay oneAPI plug-in](https://developer.codeplay.com/products/oneapi/nvidia/home/) [#2478](https://github.com/IntelPython/dpnp/pull/2478) * Added several new `pre-commit` rules, including protection against direct commits to master/maintenance branches [#2500](https://github.com/IntelPython/dpnp/pull/2500) +* Added implementation of `dpnp.ndarray.view` method [#2520](https://github.com/IntelPython/dpnp/pull/2520) ### Changed diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 48df4acf3b81..f47383619dc2 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -25,6 +25,7 @@ # ***************************************************************************** import dpctl.tensor as dpt +import dpctl.tensor._type_utils as dtu from dpctl.tensor._numpy_helper import AxisError import dpnp @@ -1979,5 +1980,126 @@ def var( correction=correction, ) + def view(self, dtype=None, *, type=None): + """ + New view of array with the same data. + + For full documentation refer to :obj:`numpy.ndarray.view`. + + Parameters + ---------- + dtype : {None, str, dtype object}, optional + The desired data type of the returned view, e.g. :obj:`dpnp.float32` + or :obj:`dpnp.int16`. By default, it results in the view having the + same data type. + + Default: ``None``. + + Notes + ----- + Passing ``None`` for `dtype` is the same as omitting the parameter, + opposite to NumPy where they have different meaning. + + ``view(some_dtype)`` or ``view(dtype=some_dtype)`` constructs a view of + the array's memory with a different data type. This can cause a + reinterpretation of the bytes of memory. + + Only the last axis has to be contiguous. -# 'view' + Limitations + ----------- + Parameter `type` is supported only with default value ``None``. + Otherwise, the function raises ``NotImplementedError`` exception. + + Examples + -------- + >>> import dpnp as np + >>> x = np.ones((4,), dtype=np.float32) + >>> xv = x.view(dtype=np.int32) + >>> xv[:] = 0 + >>> xv + array([0, 0, 0, 0], dtype=int32) + + However, views that change dtype are totally fine for arrays with a + contiguous last axis, even if the rest of the axes are not C-contiguous: + + >>> x = np.arange(2 * 3 * 4, dtype=np.int8).reshape(2, 3, 4) + >>> x.transpose(1, 0, 2).view(np.int16) + array([[[ 256, 770], + [3340, 3854]], + + [[1284, 1798], + [4368, 4882]], + + [[2312, 2826], + [5396, 5910]]], dtype=int16) + + """ + + if type is not None: + raise NotImplementedError( + "Keyword argument `type` is supported only with " + f"default value ``None``, but got {type}." + ) + + old_sh = self.shape + old_strides = self.strides + + if dtype is None: + return dpnp_array(old_sh, buffer=self, strides=old_strides) + + new_dt = dpnp.dtype(dtype) + new_dt = dtu._to_device_supported_dtype(new_dt, self.sycl_device) + + new_itemsz = new_dt.itemsize + old_itemsz = self.dtype.itemsize + if new_itemsz == old_itemsz: + return dpnp_array( + old_sh, dtype=new_dt, buffer=self, strides=old_strides + ) + + ndim = self.ndim + if ndim == 0: + raise ValueError( + "Changing the dtype of a 0d array is only supported " + "if the itemsize is unchanged" + ) + + # resize on last axis only + axis = ndim - 1 + if old_sh[axis] != 1 and self.size != 0 and old_strides[axis] != 1: + raise ValueError( + "To change to a dtype of a different size, " + "the last axis must be contiguous" + ) + + # normalize strides whenever itemsize changes + if old_itemsz > new_itemsz: + new_strides = list( + el * (old_itemsz // new_itemsz) for el in old_strides + ) + else: + new_strides = list( + el // (new_itemsz // old_itemsz) for el in old_strides + ) + new_strides[axis] = 1 + new_strides = tuple(new_strides) + + new_dim = old_sh[axis] * old_itemsz + if new_dim % new_itemsz != 0: + raise ValueError( + "When changing to a larger dtype, its size must be a divisor " + "of the total size in bytes of the last axis of the array" + ) + + # normalize shape whenever itemsize changes + new_sh = list(old_sh) + new_sh[axis] = new_dim // new_itemsz + new_sh = tuple(new_sh) + + return dpnp_array( + new_sh, + dtype=new_dt, + buffer=self, + strides=new_strides, + ) diff --git a/dpnp/dpnp_utils/dpnp_utils_einsum.py b/dpnp/dpnp_utils/dpnp_utils_einsum.py index 322d7dd2c148..12baacac3dc1 100644 --- a/dpnp/dpnp_utils/dpnp_utils_einsum.py +++ b/dpnp/dpnp_utils/dpnp_utils_einsum.py @@ -945,7 +945,6 @@ def _transpose_ex(a, axeses): stride = sum(a.strides[axis] for axis in axes) strides.append(stride) - # TODO: replace with a.view() when it is implemented in dpnp return dpnp_array( shape, dtype=a.dtype, @@ -1151,8 +1150,7 @@ def dpnp_einsum( operands[idx] = operands[idx].sum(axis=sum_axes, dtype=result_dtype) if returns_view: - # TODO: replace with a.view() when it is implemented in dpnp - operands = [a for a in operands] + operands = [a.view() for a in operands] else: operands = [ dpnp.astype(a, result_dtype, copy=False, casting=casting) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index b694b730c97c..51cebb2815bb 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -1290,15 +1290,8 @@ def _nrm2_last_axis(x): """ real_dtype = _real_type(x.dtype) - # TODO: use dpnp.sum(dpnp.square(dpnp.view(x)), axis=-1, dtype=real_dtype) - # w/a since dpnp.view() in not implemented yet - # Сalculate and sum the squares of both real and imaginary parts for - # compelex array. - if dpnp.issubdtype(x.dtype, dpnp.complexfloating): - y = dpnp.abs(x) ** 2 - else: - y = dpnp.square(x) - return dpnp.sum(y, axis=-1, dtype=real_dtype) + x = dpnp.ascontiguousarray(x) + return dpnp.sum(dpnp.square(x.view(real_dtype)), axis=-1) def _real_type(dtype, device=None): diff --git a/dpnp/tests/test_ndarray.py b/dpnp/tests/test_ndarray.py index 0a4fea422fc9..eaccf689a795 100644 --- a/dpnp/tests/test_ndarray.py +++ b/dpnp/tests/test_ndarray.py @@ -74,6 +74,36 @@ def test_attributes(self): assert_equal(self.two.itemsize, self.two.dtype.itemsize) +class TestView: + def test_none_dtype(self): + a = numpy.ones((1, 2, 4), dtype=numpy.int32) + ia = dpnp.array(a) + + expected = a.view() + result = ia.view() + assert_allclose(result, expected) + + expected = a.view() # numpy returns dtype(None) otherwise + result = ia.view(None) + assert_allclose(result, expected) + + @pytest.mark.parametrize("dt", [bool, int, float, complex]) + def test_python_types(self, dt): + a = numpy.ones((8, 4), dtype=numpy.complex64) + ia = dpnp.array(a) + + result = ia.view(dt) + if not has_support_aspect64() and dt in [float, complex]: + dt = result.dtype + expected = a.view(dt) + assert_allclose(result, expected) + + def test_type_error(self): + x = dpnp.ones(4, dtype="i4") + with pytest.raises(NotImplementedError): + x.view("i2", type=dpnp.ndarray) + + @pytest.mark.parametrize( "arr", [ diff --git a/dpnp/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py b/dpnp/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py index eaf01d1b345c..25d30b69607c 100644 --- a/dpnp/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py +++ b/dpnp/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py @@ -25,7 +25,6 @@ def get_strides(xp, a): return a.strides -@pytest.mark.skip("'dpnp_array' object has no attribute 'view' yet") class TestView: @testing.numpy_cupy_array_equal() @@ -98,9 +97,9 @@ def test_view_relaxed_contiguous(self, xp, dtype): ) @testing.numpy_cupy_equal() def test_view_flags_smaller(self, xp, order, shape): - a = xp.zeros(shape, numpy.int32, order) + a = xp.zeros(shape, dtype=numpy.int32, order=order) b = a.view(numpy.int16) - return b.flags.c_contiguous, b.flags.f_contiguous, b.flags.owndata + return b.flags.c_contiguous, b.flags.f_contiguous # , b.flags.owndata @pytest.mark.parametrize( ("order", "shape"), @@ -112,7 +111,7 @@ def test_view_flags_smaller(self, xp, order, shape): @testing.with_requires("numpy>=1.23") def test_view_flags_smaller_invalid(self, order, shape): for xp in (numpy, cupy): - a = xp.zeros(shape, numpy.int32, order) + a = xp.zeros(shape, dtype=numpy.int32, order=order) with pytest.raises(ValueError): a.view(numpy.int16) @@ -121,7 +120,7 @@ def test_view_flags_smaller_invalid(self, order, shape): [ ("C", (6,)), ("C", (3, 10)), - ("C", (0,)), + # ("C", (0,)), # dpctl-2119 ("C", (1, 6)), ("C", (3, 2)), ], @@ -129,9 +128,9 @@ def test_view_flags_smaller_invalid(self, order, shape): ) @testing.numpy_cupy_equal() def test_view_flags_larger(self, xp, order, shape): - a = xp.zeros(shape, numpy.int16, order) + a = xp.zeros(shape, dtype=numpy.int16, order=order) b = a.view(numpy.int32) - return b.flags.c_contiguous, b.flags.f_contiguous, b.flags.owndata + return b.flags.c_contiguous, b.flags.f_contiguous # , b.flags.owndata @pytest.mark.parametrize( ("order", "shape"), @@ -144,7 +143,7 @@ def test_view_flags_larger(self, xp, order, shape): @testing.with_requires("numpy>=1.23") def test_view_flags_larger_invalid(self, order, shape): for xp in (numpy, cupy): - a = xp.zeros(shape, numpy.int16, order) + a = xp.zeros(shape, dtype=numpy.int16, order=order) with pytest.raises(ValueError): a.view(numpy.int32) @@ -161,7 +160,7 @@ def test_view_smaller_dtype_multiple(self, xp): @testing.numpy_cupy_array_equal() def test_view_smaller_dtype_multiple2(self, xp): # x is non-contiguous, and stride[-1] != 0 - x = xp.ones((3, 4), xp.int32)[:, :1:2] + x = xp.ones((3, 4), dtype=xp.int32)[:, :1:2] return x.view(xp.int16) @testing.with_requires("numpy>=1.23") @@ -184,7 +183,7 @@ def test_view_non_c_contiguous(self, xp): @testing.numpy_cupy_array_equal() def test_view_larger_dtype_zero_sized(self, xp): - x = xp.ones((3, 20), xp.int16)[:0, ::2] + x = xp.ones((3, 20), dtype=xp.int16)[:0, ::2] return x.view(xp.int32) @@ -387,7 +386,7 @@ def test_astype_strides_broadcast(self, xp, src_dtype, dst_dtype): dst = astype_without_warning(src, dst_dtype, order="K") return get_strides(xp, dst) - @pytest.mark.skip("'dpnp_array' object has no attribute 'view' yet") + @pytest.mark.skip("dpctl-2121") @testing.numpy_cupy_array_equal() def test_astype_boolean_view(self, xp): # See #4354 @@ -454,7 +453,7 @@ def __array_finalize__(self, obj): self.info = getattr(obj, "info", None) -@pytest.mark.skip("'dpnp_array' object has no attribute 'view' yet") +@pytest.mark.skip("subclass array is not supported") class TestSubclassArrayView: def test_view_casting(self): diff --git a/dpnp/tests/third_party/cupy/testing/_array.py b/dpnp/tests/third_party/cupy/testing/_array.py index beecaac16e58..f2f8d455dd8e 100644 --- a/dpnp/tests/third_party/cupy/testing/_array.py +++ b/dpnp/tests/third_party/cupy/testing/_array.py @@ -171,13 +171,14 @@ def assert_array_equal( ) if strides_check: - if actual.strides != desired.strides: + strides = tuple(el // desired.itemsize for el in desired.strides) + if actual.strides != strides: msg = ["Strides are not equal:"] if err_msg: msg = [msg[0] + " " + err_msg] if verbose: msg.append(" x: {}".format(actual.strides)) - msg.append(" y: {}".format(desired.strides)) + msg.append(" y: {}".format(strides)) raise AssertionError("\n".join(msg))