Skip to content

Commit 7b51457

Browse files
Add tests for expand_dims func
1 parent f8e84b8 commit 7b51457

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,79 @@ def test_permute_dims_2d_3d(shapes):
9090
Y = dpt.permute_dims(X, (2, 0, 1))
9191
Ynp = np.transpose(Xnp, (2, 0, 1))
9292
assert_array_equal(Ynp, dpt.asnumpy(Y))
93+
94+
95+
def test_expand_dims_incorrect_type():
96+
X_list = list([1, 2, 3, 4, 5])
97+
X_tuple = tuple(X_list)
98+
Xnp = np.array(X_list)
99+
100+
pytest.raises(TypeError, dpt.permute_dims, X_list, 1)
101+
pytest.raises(TypeError, dpt.permute_dims, X_tuple, 1)
102+
pytest.raises(TypeError, dpt.permute_dims, Xnp, 1)
103+
104+
105+
def test_expand_dims_0d():
106+
try:
107+
q = dpctl.SyclQueue()
108+
except dpctl.SyclQueueCreationError:
109+
pytest.skip("Queue could not be created")
110+
111+
Xnp = np.array(1, dtype="int64")
112+
X = dpt.asarray(Xnp, sycl_queue=q)
113+
Y = dpt.expand_dims(X, 0)
114+
Ynp = np.expand_dims(Xnp, 0)
115+
assert_array_equal(Ynp, dpt.asnumpy(Y))
116+
117+
Y = dpt.expand_dims(X, -1)
118+
Ynp = np.expand_dims(Xnp, -1)
119+
assert_array_equal(Ynp, dpt.asnumpy(Y))
120+
121+
pytest.raises(np.AxisError, dpt.expand_dims, X, 1)
122+
pytest.raises(np.AxisError, dpt.expand_dims, X, -2)
123+
124+
125+
@pytest.mark.parametrize("shapes", [(3,), (3, 3), (3, 3, 3)])
126+
def test_expand_dims_1d_3d(shapes):
127+
try:
128+
q = dpctl.SyclQueue()
129+
except dpctl.SyclQueueCreationError:
130+
pytest.skip("Queue could not be created")
131+
132+
Xnp_size = np.prod(shapes)
133+
134+
Xnp = np.random.randint(0, 2, size=Xnp_size, dtype="int64").reshape(shapes)
135+
X = dpt.asarray(Xnp, sycl_queue=q)
136+
shape_len = len(shapes)
137+
for axis in range(-shape_len - 1, shape_len):
138+
Y = dpt.expand_dims(X, axis)
139+
Ynp = np.expand_dims(Xnp, axis)
140+
assert_array_equal(Ynp, dpt.asnumpy(Y))
141+
142+
pytest.raises(np.AxisError, dpt.expand_dims, X, shape_len + 1)
143+
pytest.raises(np.AxisError, dpt.expand_dims, X, -shape_len - 2)
144+
145+
146+
@pytest.mark.parametrize(
147+
"axes", [(0, 1, 2), (0, -1, -2), (0, 3, 5), (0, -3, -5)]
148+
)
149+
def test_expand_dims_tuple(axes):
150+
try:
151+
q = dpctl.SyclQueue()
152+
except dpctl.SyclQueueCreationError:
153+
pytest.skip("Queue could not be created")
154+
155+
Xnp = np.empty((3, 3, 3))
156+
X = dpt.asarray(Xnp, sycl_queue=q)
157+
Y = dpt.expand_dims(X, axes)
158+
Ynp = np.expand_dims(Xnp, axes)
159+
assert_array_equal(Ynp, dpt.asnumpy(Y))
160+
161+
162+
def test_expand_dims_incorrect_tuple():
163+
164+
X = dpt.empty((3, 3, 3), dtype="i4")
165+
pytest.raises(np.AxisError, dpt.expand_dims, X, (0, -6))
166+
pytest.raises(np.AxisError, dpt.expand_dims, X, (0, 5))
167+
168+
pytest.raises(ValueError, dpt.expand_dims, X, (1, 1))

0 commit comments

Comments
 (0)