Skip to content

Commit 02bf7d2

Browse files
Merge pull request #788 from IntelPython/expand_dims
Implementation of expand_dims function
2 parents a204eb2 + 6d84d64 commit 02bf7d2

File tree

3 files changed

+101
-1
lines changed

3 files changed

+101
-1
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from dpctl.tensor._ctors import asarray, empty
2626
from dpctl.tensor._device import Device
2727
from dpctl.tensor._dlpack import from_dlpack
28-
from dpctl.tensor._manipulation_functions import permute_dims
28+
from dpctl.tensor._manipulation_functions import expand_dims, permute_dims
2929
from dpctl.tensor._reshape import reshape
3030
from dpctl.tensor._usmarray import usm_ndarray
3131

@@ -38,6 +38,7 @@
3838
"empty",
3939
"reshape",
4040
"permute_dims",
41+
"expand_dims",
4142
"from_numpy",
4243
"to_numpy",
4344
"asnumpy",

dpctl/tensor/_manipulation_functions.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818
import numpy as np
19+
from numpy.core.numeric import normalize_axis_tuple
1920

2021
import dpctl.tensor as dpt
2122

@@ -61,3 +62,25 @@ def permute_dims(X, axes):
6162
strides=newstrides,
6263
offset=X.__sycl_usm_array_interface__.get("offset", 0),
6364
)
65+
66+
67+
def expand_dims(X, axes):
68+
"""
69+
expand_dims(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
70+
71+
Expands the shape of an array by inserting a new axis (dimension)
72+
of size one at the position specified by axes; returns a view, if possible,
73+
a copy otherwise with the number of dimensions increased.
74+
"""
75+
if not isinstance(X, dpt.usm_ndarray):
76+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
77+
if not isinstance(axes, (tuple, list)):
78+
axes = (axes,)
79+
80+
out_ndim = len(axes) + X.ndim
81+
axes = normalize_axis_tuple(axes, out_ndim)
82+
83+
shape_it = iter(X.shape)
84+
shape = tuple(1 if ax in axes else next(shape_it) for ax in range(out_ndim))
85+
86+
return dpt.reshape(X, shape)

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,79 @@ def test_permute_dims_2d_3d(shapes):
9090
Y = dpt.permute_dims(X, (2, 0, 1))
9191
Ynp = np.transpose(Xnp, (2, 0, 1))
9292
assert_array_equal(Ynp, dpt.asnumpy(Y))
93+
94+
95+
def test_expand_dims_incorrect_type():
96+
X_list = list([1, 2, 3, 4, 5])
97+
X_tuple = tuple(X_list)
98+
Xnp = np.array(X_list)
99+
100+
pytest.raises(TypeError, dpt.permute_dims, X_list, 1)
101+
pytest.raises(TypeError, dpt.permute_dims, X_tuple, 1)
102+
pytest.raises(TypeError, dpt.permute_dims, Xnp, 1)
103+
104+
105+
def test_expand_dims_0d():
106+
try:
107+
q = dpctl.SyclQueue()
108+
except dpctl.SyclQueueCreationError:
109+
pytest.skip("Queue could not be created")
110+
111+
Xnp = np.array(1, dtype="int64")
112+
X = dpt.asarray(Xnp, sycl_queue=q)
113+
Y = dpt.expand_dims(X, 0)
114+
Ynp = np.expand_dims(Xnp, 0)
115+
assert_array_equal(Ynp, dpt.asnumpy(Y))
116+
117+
Y = dpt.expand_dims(X, -1)
118+
Ynp = np.expand_dims(Xnp, -1)
119+
assert_array_equal(Ynp, dpt.asnumpy(Y))
120+
121+
pytest.raises(np.AxisError, dpt.expand_dims, X, 1)
122+
pytest.raises(np.AxisError, dpt.expand_dims, X, -2)
123+
124+
125+
@pytest.mark.parametrize("shapes", [(3,), (3, 3), (3, 3, 3)])
126+
def test_expand_dims_1d_3d(shapes):
127+
try:
128+
q = dpctl.SyclQueue()
129+
except dpctl.SyclQueueCreationError:
130+
pytest.skip("Queue could not be created")
131+
132+
Xnp_size = np.prod(shapes)
133+
134+
Xnp = np.random.randint(0, 2, size=Xnp_size, dtype="int64").reshape(shapes)
135+
X = dpt.asarray(Xnp, sycl_queue=q)
136+
shape_len = len(shapes)
137+
for axis in range(-shape_len - 1, shape_len):
138+
Y = dpt.expand_dims(X, axis)
139+
Ynp = np.expand_dims(Xnp, axis)
140+
assert_array_equal(Ynp, dpt.asnumpy(Y))
141+
142+
pytest.raises(np.AxisError, dpt.expand_dims, X, shape_len + 1)
143+
pytest.raises(np.AxisError, dpt.expand_dims, X, -shape_len - 2)
144+
145+
146+
@pytest.mark.parametrize(
147+
"axes", [(0, 1, 2), (0, -1, -2), (0, 3, 5), (0, -3, -5)]
148+
)
149+
def test_expand_dims_tuple(axes):
150+
try:
151+
q = dpctl.SyclQueue()
152+
except dpctl.SyclQueueCreationError:
153+
pytest.skip("Queue could not be created")
154+
155+
Xnp = np.empty((3, 3, 3))
156+
X = dpt.asarray(Xnp, sycl_queue=q)
157+
Y = dpt.expand_dims(X, axes)
158+
Ynp = np.expand_dims(Xnp, axes)
159+
assert_array_equal(Ynp, dpt.asnumpy(Y))
160+
161+
162+
def test_expand_dims_incorrect_tuple():
163+
164+
X = dpt.empty((3, 3, 3), dtype="i4")
165+
pytest.raises(np.AxisError, dpt.expand_dims, X, (0, -6))
166+
pytest.raises(np.AxisError, dpt.expand_dims, X, (0, 5))
167+
168+
pytest.raises(ValueError, dpt.expand_dims, X, (1, 1))

0 commit comments

Comments
 (0)