Skip to content

Commit 9b73305

Browse files
authored
Add implementation of dpnp.ndarray.view method (#2520)
This PR adds implementation of `dpnp.ndarray.view` method. All places in the code with connected TODO comments were updated properly.
1 parent d15d395 commit 9b73305

File tree

7 files changed

+171
-27
lines changed

7 files changed

+171
-27
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010

1111
* Added `--target-cuda[=ARCH]` option to replace the deprecated `--target=cuda`, allowing users to build for CUDA devices with optional architecture selection using [CodePlay oneAPI plug-in](https://developer.codeplay.com/products/oneapi/nvidia/home/) [#2478](https://github.com/IntelPython/dpnp/pull/2478)
1212
* Added several new `pre-commit` rules, including protection against direct commits to master/maintenance branches [#2500](https://github.com/IntelPython/dpnp/pull/2500)
13+
* Added implementation of `dpnp.ndarray.view` method [#2520](https://github.com/IntelPython/dpnp/pull/2520)
1314

1415
### Changed
1516

dpnp/dpnp_array.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# *****************************************************************************
2626

2727
import dpctl.tensor as dpt
28+
import dpctl.tensor._type_utils as dtu
2829
from dpctl.tensor._numpy_helper import AxisError
2930

3031
import dpnp
@@ -1979,5 +1980,126 @@ def var(
19791980
correction=correction,
19801981
)
19811982

1983+
def view(self, dtype=None, *, type=None):
1984+
"""
1985+
New view of array with the same data.
1986+
1987+
For full documentation refer to :obj:`numpy.ndarray.view`.
1988+
1989+
Parameters
1990+
----------
1991+
dtype : {None, str, dtype object}, optional
1992+
The desired data type of the returned view, e.g. :obj:`dpnp.float32`
1993+
or :obj:`dpnp.int16`. By default, it results in the view having the
1994+
same data type.
1995+
1996+
Default: ``None``.
1997+
1998+
Notes
1999+
-----
2000+
Passing ``None`` for `dtype` is the same as omitting the parameter,
2001+
opposite to NumPy where they have different meaning.
2002+
2003+
``view(some_dtype)`` or ``view(dtype=some_dtype)`` constructs a view of
2004+
the array's memory with a different data type. This can cause a
2005+
reinterpretation of the bytes of memory.
2006+
2007+
Only the last axis has to be contiguous.
19822008
1983-
# 'view'
2009+
Limitations
2010+
-----------
2011+
Parameter `type` is supported only with default value ``None``.
2012+
Otherwise, the function raises ``NotImplementedError`` exception.
2013+
2014+
Examples
2015+
--------
2016+
>>> import dpnp as np
2017+
>>> x = np.ones((4,), dtype=np.float32)
2018+
>>> xv = x.view(dtype=np.int32)
2019+
>>> xv[:] = 0
2020+
>>> xv
2021+
array([0, 0, 0, 0], dtype=int32)
2022+
2023+
However, views that change dtype are totally fine for arrays with a
2024+
contiguous last axis, even if the rest of the axes are not C-contiguous:
2025+
2026+
>>> x = np.arange(2 * 3 * 4, dtype=np.int8).reshape(2, 3, 4)
2027+
>>> x.transpose(1, 0, 2).view(np.int16)
2028+
array([[[ 256, 770],
2029+
[3340, 3854]],
2030+
<BLANKLINE>
2031+
[[1284, 1798],
2032+
[4368, 4882]],
2033+
<BLANKLINE>
2034+
[[2312, 2826],
2035+
[5396, 5910]]], dtype=int16)
2036+
2037+
"""
2038+
2039+
if type is not None:
2040+
raise NotImplementedError(
2041+
"Keyword argument `type` is supported only with "
2042+
f"default value ``None``, but got {type}."
2043+
)
2044+
2045+
old_sh = self.shape
2046+
old_strides = self.strides
2047+
2048+
if dtype is None:
2049+
return dpnp_array(old_sh, buffer=self, strides=old_strides)
2050+
2051+
new_dt = dpnp.dtype(dtype)
2052+
new_dt = dtu._to_device_supported_dtype(new_dt, self.sycl_device)
2053+
2054+
new_itemsz = new_dt.itemsize
2055+
old_itemsz = self.dtype.itemsize
2056+
if new_itemsz == old_itemsz:
2057+
return dpnp_array(
2058+
old_sh, dtype=new_dt, buffer=self, strides=old_strides
2059+
)
2060+
2061+
ndim = self.ndim
2062+
if ndim == 0:
2063+
raise ValueError(
2064+
"Changing the dtype of a 0d array is only supported "
2065+
"if the itemsize is unchanged"
2066+
)
2067+
2068+
# resize on last axis only
2069+
axis = ndim - 1
2070+
if old_sh[axis] != 1 and self.size != 0 and old_strides[axis] != 1:
2071+
raise ValueError(
2072+
"To change to a dtype of a different size, "
2073+
"the last axis must be contiguous"
2074+
)
2075+
2076+
# normalize strides whenever itemsize changes
2077+
if old_itemsz > new_itemsz:
2078+
new_strides = list(
2079+
el * (old_itemsz // new_itemsz) for el in old_strides
2080+
)
2081+
else:
2082+
new_strides = list(
2083+
el // (new_itemsz // old_itemsz) for el in old_strides
2084+
)
2085+
new_strides[axis] = 1
2086+
new_strides = tuple(new_strides)
2087+
2088+
new_dim = old_sh[axis] * old_itemsz
2089+
if new_dim % new_itemsz != 0:
2090+
raise ValueError(
2091+
"When changing to a larger dtype, its size must be a divisor "
2092+
"of the total size in bytes of the last axis of the array"
2093+
)
2094+
2095+
# normalize shape whenever itemsize changes
2096+
new_sh = list(old_sh)
2097+
new_sh[axis] = new_dim // new_itemsz
2098+
new_sh = tuple(new_sh)
2099+
2100+
return dpnp_array(
2101+
new_sh,
2102+
dtype=new_dt,
2103+
buffer=self,
2104+
strides=new_strides,
2105+
)

dpnp/dpnp_utils/dpnp_utils_einsum.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -945,7 +945,6 @@ def _transpose_ex(a, axeses):
945945
stride = sum(a.strides[axis] for axis in axes)
946946
strides.append(stride)
947947

948-
# TODO: replace with a.view() when it is implemented in dpnp
949948
return dpnp_array(
950949
shape,
951950
dtype=a.dtype,
@@ -1151,8 +1150,7 @@ def dpnp_einsum(
11511150
operands[idx] = operands[idx].sum(axis=sum_axes, dtype=result_dtype)
11521151

11531152
if returns_view:
1154-
# TODO: replace with a.view() when it is implemented in dpnp
1155-
operands = [a for a in operands]
1153+
operands = [a.view() for a in operands]
11561154
else:
11571155
operands = [
11581156
dpnp.astype(a, result_dtype, copy=False, casting=casting)

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,15 +1290,8 @@ def _nrm2_last_axis(x):
12901290
"""
12911291

12921292
real_dtype = _real_type(x.dtype)
1293-
# TODO: use dpnp.sum(dpnp.square(dpnp.view(x)), axis=-1, dtype=real_dtype)
1294-
# w/a since dpnp.view() in not implemented yet
1295-
# Сalculate and sum the squares of both real and imaginary parts for
1296-
# compelex array.
1297-
if dpnp.issubdtype(x.dtype, dpnp.complexfloating):
1298-
y = dpnp.abs(x) ** 2
1299-
else:
1300-
y = dpnp.square(x)
1301-
return dpnp.sum(y, axis=-1, dtype=real_dtype)
1293+
x = dpnp.ascontiguousarray(x)
1294+
return dpnp.sum(dpnp.square(x.view(real_dtype)), axis=-1)
13021295

13031296

13041297
def _real_type(dtype, device=None):

dpnp/tests/test_ndarray.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,36 @@ def test_attributes(self):
7474
assert_equal(self.two.itemsize, self.two.dtype.itemsize)
7575

7676

77+
class TestView:
78+
def test_none_dtype(self):
79+
a = numpy.ones((1, 2, 4), dtype=numpy.int32)
80+
ia = dpnp.array(a)
81+
82+
expected = a.view()
83+
result = ia.view()
84+
assert_allclose(result, expected)
85+
86+
expected = a.view() # numpy returns dtype(None) otherwise
87+
result = ia.view(None)
88+
assert_allclose(result, expected)
89+
90+
@pytest.mark.parametrize("dt", [bool, int, float, complex])
91+
def test_python_types(self, dt):
92+
a = numpy.ones((8, 4), dtype=numpy.complex64)
93+
ia = dpnp.array(a)
94+
95+
result = ia.view(dt)
96+
if not has_support_aspect64() and dt in [float, complex]:
97+
dt = result.dtype
98+
expected = a.view(dt)
99+
assert_allclose(result, expected)
100+
101+
def test_type_error(self):
102+
x = dpnp.ones(4, dtype="i4")
103+
with pytest.raises(NotImplementedError):
104+
x.view("i2", type=dpnp.ndarray)
105+
106+
77107
@pytest.mark.parametrize(
78108
"arr",
79109
[

dpnp/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def get_strides(xp, a):
2525
return a.strides
2626

2727

28-
@pytest.mark.skip("'dpnp_array' object has no attribute 'view' yet")
2928
class TestView:
3029

3130
@testing.numpy_cupy_array_equal()
@@ -98,9 +97,9 @@ def test_view_relaxed_contiguous(self, xp, dtype):
9897
)
9998
@testing.numpy_cupy_equal()
10099
def test_view_flags_smaller(self, xp, order, shape):
101-
a = xp.zeros(shape, numpy.int32, order)
100+
a = xp.zeros(shape, dtype=numpy.int32, order=order)
102101
b = a.view(numpy.int16)
103-
return b.flags.c_contiguous, b.flags.f_contiguous, b.flags.owndata
102+
return b.flags.c_contiguous, b.flags.f_contiguous # , b.flags.owndata
104103

105104
@pytest.mark.parametrize(
106105
("order", "shape"),
@@ -112,7 +111,7 @@ def test_view_flags_smaller(self, xp, order, shape):
112111
@testing.with_requires("numpy>=1.23")
113112
def test_view_flags_smaller_invalid(self, order, shape):
114113
for xp in (numpy, cupy):
115-
a = xp.zeros(shape, numpy.int32, order)
114+
a = xp.zeros(shape, dtype=numpy.int32, order=order)
116115
with pytest.raises(ValueError):
117116
a.view(numpy.int16)
118117

@@ -121,17 +120,17 @@ def test_view_flags_smaller_invalid(self, order, shape):
121120
[
122121
("C", (6,)),
123122
("C", (3, 10)),
124-
("C", (0,)),
123+
# ("C", (0,)), # dpctl-2119
125124
("C", (1, 6)),
126125
("C", (3, 2)),
127126
],
128127
ids=str,
129128
)
130129
@testing.numpy_cupy_equal()
131130
def test_view_flags_larger(self, xp, order, shape):
132-
a = xp.zeros(shape, numpy.int16, order)
131+
a = xp.zeros(shape, dtype=numpy.int16, order=order)
133132
b = a.view(numpy.int32)
134-
return b.flags.c_contiguous, b.flags.f_contiguous, b.flags.owndata
133+
return b.flags.c_contiguous, b.flags.f_contiguous # , b.flags.owndata
135134

136135
@pytest.mark.parametrize(
137136
("order", "shape"),
@@ -144,7 +143,7 @@ def test_view_flags_larger(self, xp, order, shape):
144143
@testing.with_requires("numpy>=1.23")
145144
def test_view_flags_larger_invalid(self, order, shape):
146145
for xp in (numpy, cupy):
147-
a = xp.zeros(shape, numpy.int16, order)
146+
a = xp.zeros(shape, dtype=numpy.int16, order=order)
148147
with pytest.raises(ValueError):
149148
a.view(numpy.int32)
150149

@@ -161,7 +160,7 @@ def test_view_smaller_dtype_multiple(self, xp):
161160
@testing.numpy_cupy_array_equal()
162161
def test_view_smaller_dtype_multiple2(self, xp):
163162
# x is non-contiguous, and stride[-1] != 0
164-
x = xp.ones((3, 4), xp.int32)[:, :1:2]
163+
x = xp.ones((3, 4), dtype=xp.int32)[:, :1:2]
165164
return x.view(xp.int16)
166165

167166
@testing.with_requires("numpy>=1.23")
@@ -184,7 +183,7 @@ def test_view_non_c_contiguous(self, xp):
184183

185184
@testing.numpy_cupy_array_equal()
186185
def test_view_larger_dtype_zero_sized(self, xp):
187-
x = xp.ones((3, 20), xp.int16)[:0, ::2]
186+
x = xp.ones((3, 20), dtype=xp.int16)[:0, ::2]
188187
return x.view(xp.int32)
189188

190189

@@ -387,7 +386,7 @@ def test_astype_strides_broadcast(self, xp, src_dtype, dst_dtype):
387386
dst = astype_without_warning(src, dst_dtype, order="K")
388387
return get_strides(xp, dst)
389388

390-
@pytest.mark.skip("'dpnp_array' object has no attribute 'view' yet")
389+
@pytest.mark.skip("dpctl-2121")
391390
@testing.numpy_cupy_array_equal()
392391
def test_astype_boolean_view(self, xp):
393392
# See #4354
@@ -454,7 +453,7 @@ def __array_finalize__(self, obj):
454453
self.info = getattr(obj, "info", None)
455454

456455

457-
@pytest.mark.skip("'dpnp_array' object has no attribute 'view' yet")
456+
@pytest.mark.skip("subclass array is not supported")
458457
class TestSubclassArrayView:
459458

460459
def test_view_casting(self):

dpnp/tests/third_party/cupy/testing/_array.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,14 @@ def assert_array_equal(
171171
)
172172

173173
if strides_check:
174-
if actual.strides != desired.strides:
174+
strides = tuple(el // desired.itemsize for el in desired.strides)
175+
if actual.strides != strides:
175176
msg = ["Strides are not equal:"]
176177
if err_msg:
177178
msg = [msg[0] + " " + err_msg]
178179
if verbose:
179180
msg.append(" x: {}".format(actual.strides))
180-
msg.append(" y: {}".format(desired.strides))
181+
msg.append(" y: {}".format(strides))
181182
raise AssertionError("\n".join(msg))
182183

183184

0 commit comments

Comments
 (0)